Skip to content

Commit 23d48a1

Browse files
committed
drop clients on drop packets
1 parent 365dde2 commit 23d48a1

File tree

6 files changed

+144
-78
lines changed

6 files changed

+144
-78
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pulsebeam-runtime/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ event-listener = "5.4.1"
3838
parking_lot = "0.12.5"
3939
dashmap = "6.1.0"
4040
async-channel = "2.5.0"
41+
tokio-util = "0.7.17"
4142

4243
[dev-dependencies]
4344
criterion = { version = "0.8.1", features = ["async", "async_tokio"] }

pulsebeam-runtime/src/net/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@ pub enum UnifiedSocketReader {
130130
}
131131

132132
impl UnifiedSocketReader {
133+
pub fn close_peer(&mut self, peer_addr: &SocketAddr) {
134+
if let Self::Tcp(inner) = self {
135+
inner.close_peer(peer_addr);
136+
}
137+
}
138+
133139
pub fn local_addr(&self) -> SocketAddr {
134140
match self {
135141
Self::Udp(inner) => inner.local_addr(),

pulsebeam-runtime/src/net/tcp.rs

Lines changed: 131 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,82 @@
1-
use super::{BATCH_SIZE, CHUNK_SIZE, RecvPacketBatch, SendPacketBatch};
1+
use super::{BATCH_SIZE, RecvPacketBatch, SendPacketBatch};
22
use crate::net::Transport;
33
use bytes::{Buf, BufMut, Bytes, BytesMut};
44
use dashmap::DashMap;
55
use std::sync::Arc;
6-
use std::{io, net::SocketAddr};
6+
use std::{
7+
io,
8+
net::{IpAddr, SocketAddr},
9+
time::Duration,
10+
};
711
use tokio::io::{AsyncReadExt, AsyncWriteExt};
812
use tokio::net::{TcpListener, TcpStream};
13+
use tokio_util::sync::CancellationToken;
14+
15+
const MAX_FRAME_SIZE: usize = 1500;
16+
const MAX_CONNECTIONS: usize = 10_000;
17+
const MAX_CONNS_PER_IP: usize = 20;
18+
const READ_TIMEOUT: Duration = Duration::from_secs(30);
19+
20+
/// Shared state for connection management
21+
type ConnMap = Arc<DashMap<SocketAddr, (async_channel::Sender<Bytes>, CancellationToken)>>;
922

10-
/// Creates a TCP Reader and Writer pair.
1123
pub async fn bind(
1224
addr: SocketAddr,
1325
external_addr: Option<SocketAddr>,
1426
) -> io::Result<(TcpTransportReader, TcpTransportWriter)> {
1527
let listener = TcpListener::bind(addr).await?;
1628
let local_addr = external_addr.unwrap_or(listener.local_addr()?);
1729

18-
let (packet_tx, packet_rx) = async_channel::bounded(BATCH_SIZE * 64);
19-
let conns = Arc::new(DashMap::new());
30+
let (packet_tx, packet_rx) = async_channel::bounded(BATCH_SIZE * 128);
31+
let conns: ConnMap = Arc::new(DashMap::new());
32+
let ip_counts = Arc::new(DashMap::<IpAddr, usize>::new());
33+
2034
let readable_notifier = Arc::new(tokio::sync::Notify::new());
2135
let writable_notifier = Arc::new(tokio::sync::Notify::new());
36+
let conn_semaphore = Arc::new(tokio::sync::Semaphore::new(MAX_CONNECTIONS));
2237

2338
let conns_clone = conns.clone();
2439
let r_notify = readable_notifier.clone();
2540
let w_notify = writable_notifier.clone();
41+
let ip_counts_clone = ip_counts.clone();
42+
let semaphore_clone = conn_semaphore.clone();
2643

27-
// Passive Listener Task (RFC 6544)
44+
// Passive Listener Task
2845
tokio::spawn(async move {
2946
while let Ok((stream, peer_addr)) = listener.accept().await {
47+
let ip = peer_addr.ip();
48+
49+
// 1. Global Connection Limit
50+
let permit = match semaphore_clone.clone().try_acquire_owned() {
51+
Ok(p) => p,
52+
Err(_) => {
53+
tracing::warn!(%peer_addr, "Rejecting: Global limit reached");
54+
continue;
55+
}
56+
};
57+
58+
// 2. Per-IP Limit
59+
{
60+
let mut count = ip_counts_clone.entry(ip).or_insert(0);
61+
if *count >= MAX_CONNS_PER_IP {
62+
tracing::warn!(%peer_addr, "Rejecting: Per-IP limit reached");
63+
continue;
64+
}
65+
*count += 1;
66+
}
67+
3068
let _ = stream.set_nodelay(true);
31-
let _ = stream.set_linger(None);
3269

3370
handle_new_connection(
3471
stream,
3572
peer_addr,
3673
local_addr,
3774
packet_tx.clone(),
3875
conns_clone.clone(),
76+
ip_counts_clone.clone(),
3977
r_notify.clone(),
4078
w_notify.clone(),
79+
permit,
4180
);
4281
w_notify.notify_waiters();
4382
}
@@ -47,6 +86,7 @@ pub async fn bind(
4786
local_addr,
4887
packet_rx,
4988
readable_notifier,
89+
conns: conns.clone(), // Reader now has access to conns
5090
};
5191

5292
let writer = TcpTransportWriter {
@@ -62,14 +102,25 @@ pub struct TcpTransportReader {
62102
packet_rx: async_channel::Receiver<RecvPacketBatch>,
63103
readable_notifier: Arc<tokio::sync::Notify>,
64104
local_addr: SocketAddr,
105+
/// Shared connection map to allow forced disconnects
106+
conns: ConnMap,
65107
}
66108

67109
impl TcpTransportReader {
110+
/// Closes the connection for a specific peer.
111+
/// Useful when high-level demux or auth fails.
112+
pub fn close_peer(&self, peer_addr: &SocketAddr) {
113+
if let Some((_, (tx, cancel))) = self.conns.remove(peer_addr) {
114+
cancel.cancel(); // Stops the background reader task
115+
tx.close(); // Stops the background writer task
116+
tracing::info!(%peer_addr, "Reader forced connection close");
117+
}
118+
}
119+
68120
pub fn local_addr(&self) -> SocketAddr {
69121
self.local_addr
70122
}
71123

72-
/// Waits until the socket is readable.
73124
pub async fn readable(&self) -> io::Result<()> {
74125
loop {
75126
if !self.packet_rx.is_empty() {
@@ -83,8 +134,6 @@ impl TcpTransportReader {
83134
}
84135
}
85136

86-
/// Pulls packets from the internal channel into the provided vector.
87-
/// Takes `&mut self` as requested for the Reader API.
88137
#[inline]
89138
pub fn try_recv_batch(&mut self, out: &mut Vec<RecvPacketBatch>) -> io::Result<()> {
90139
let mut count = 0;
@@ -105,7 +154,7 @@ impl TcpTransportReader {
105154
#[derive(Clone)]
106155
pub struct TcpTransportWriter {
107156
local_addr: SocketAddr,
108-
conns: Arc<DashMap<SocketAddr, async_channel::Sender<Bytes>>>,
157+
conns: ConnMap,
109158
writable_notifier: Arc<tokio::sync::Notify>,
110159
}
111160

@@ -126,7 +175,7 @@ impl TcpTransportWriter {
126175
}
127176
let mut any_available = false;
128177
for c in self.conns.iter() {
129-
if !c.is_full() {
178+
if !c.value().0.is_full() {
130179
any_available = true;
131180
break;
132181
}
@@ -135,99 +184,112 @@ impl TcpTransportWriter {
135184
return Ok(());
136185
}
137186
let wait = self.writable_notifier.notified();
138-
139-
// Re-check logic to prevent race conditions
140-
let mut any_available = false;
141-
for c in self.conns.iter() {
142-
if !c.is_full() {
143-
any_available = true;
144-
break;
145-
}
146-
}
147-
if any_available {
148-
return Ok(());
149-
}
150187
wait.await;
151188
}
152189
}
153190

154191
#[inline]
155192
pub fn try_send_batch(&self, batch: &SendPacketBatch) -> io::Result<bool> {
156-
let Some(peer_tx) = self.conns.get(&batch.dst) else {
157-
// If the peer is gone, we drop the packet (consistent with UDP behavior)
193+
let Some(peer_entry) = self.conns.get(&batch.dst) else {
158194
return Ok(true);
159195
};
196+
let (peer_tx, _) = peer_entry.value();
160197

161-
let required_slots = (batch.buf.len() + batch.segment_size - 1) / batch.segment_size;
198+
let required_slots = batch.buf.len().div_ceil(batch.segment_size);
162199
if peer_tx.capacity().unwrap() - peer_tx.len() < required_slots {
163200
return Ok(false);
164201
}
165202

166203
let mut offset = 0;
167-
let total_len = batch.buf.len();
168-
while offset < total_len {
169-
let end = std::cmp::min(offset + batch.segment_size, total_len);
170-
let segment = &batch.buf[offset..end];
171-
let _ = peer_tx.try_send(Bytes::copy_from_slice(segment));
204+
while offset < batch.buf.len() {
205+
let end = std::cmp::min(offset + batch.segment_size, batch.buf.len());
206+
let _ = peer_tx.try_send(Bytes::copy_from_slice(&batch.buf[offset..end]));
172207
offset = end;
173208
}
174209
Ok(true)
175210
}
176211
}
177212

178-
/// Background connection handler
179213
fn handle_new_connection(
180214
stream: TcpStream,
181215
peer_addr: SocketAddr,
182216
local_addr: SocketAddr,
183217
packet_tx: async_channel::Sender<RecvPacketBatch>,
184-
conns: Arc<DashMap<SocketAddr, async_channel::Sender<Bytes>>>,
218+
conns: ConnMap,
219+
ip_counts: Arc<DashMap<IpAddr, usize>>,
185220
r_notify: Arc<tokio::sync::Notify>,
186221
w_notify: Arc<tokio::sync::Notify>,
222+
permit: tokio::sync::OwnedSemaphorePermit,
187223
) {
188-
let (send_tx, send_rx) = async_channel::bounded::<Bytes>(8192);
189-
conns.insert(peer_addr, send_tx);
224+
let (send_tx, send_rx) = async_channel::bounded::<Bytes>(1024);
225+
let cancel_token = CancellationToken::new();
190226

191-
let (mut reader, writer) = stream.into_split();
227+
conns.insert(peer_addr, (send_tx, cancel_token.clone()));
228+
229+
let (mut tcp_reader, tcp_writer) = stream.into_split();
230+
let peer_ip = peer_addr.ip();
192231

193232
// Task: Receiver (RFC 4571 Un-framing)
233+
let r_cancel = cancel_token.clone();
234+
let r_conns = conns.clone();
235+
let r_ip_counts = ip_counts.clone();
194236
tokio::spawn(async move {
195-
let mut recv_buf = BytesMut::with_capacity(CHUNK_SIZE * 4);
196-
while let Ok(n) = reader.read_buf(&mut recv_buf).await {
197-
if n == 0 {
198-
break;
199-
}
200-
let mut added = false;
201-
while recv_buf.len() >= 2 {
202-
let len = u16::from_be_bytes([recv_buf[0], recv_buf[1]]) as usize;
203-
if recv_buf.len() < 2 + len {
204-
break;
237+
// Guard to release semaphore and cleanup DashMap on task exit
238+
let _guard = (permit, r_cancel);
239+
let mut recv_buf = BytesMut::with_capacity(MAX_FRAME_SIZE + 2);
240+
241+
loop {
242+
tokio::select! {
243+
_ = _guard.1.cancelled() => break,
244+
res = tokio::time::timeout(READ_TIMEOUT, tcp_reader.read_buf(&mut recv_buf)) => {
245+
let n = match res {
246+
Ok(Ok(n)) if n > 0 => n,
247+
_ => break, // Timeout, Error, or EOF
248+
};
249+
250+
while recv_buf.len() >= 2 {
251+
let len = u16::from_be_bytes([recv_buf[0], recv_buf[1]]) as usize;
252+
253+
if len > MAX_FRAME_SIZE || len == 0 {
254+
tracing::warn!(%peer_addr, len, "Invalid TCP frame size, dropping connection");
255+
return;
256+
}
257+
258+
if recv_buf.len() < 2 + len { break; }
259+
260+
recv_buf.advance(2);
261+
let data = recv_buf.split_to(len).freeze();
262+
263+
// Use try_send to prevent reader task from blocking if SFU logic lags
264+
if let Err(_) = packet_tx.try_send(RecvPacketBatch {
265+
src: peer_addr,
266+
dst: local_addr,
267+
buf: data,
268+
stride: len,
269+
len,
270+
transport: Transport::Tcp,
271+
}) {
272+
tracing::debug!("TCP packet dropped: Global queue full");
273+
} else {
274+
r_notify.notify_waiters();
275+
}
276+
}
205277
}
206-
recv_buf.advance(2);
207-
let data = recv_buf.split_to(len).freeze();
208-
let _ = packet_tx
209-
.send(RecvPacketBatch {
210-
src: peer_addr,
211-
dst: local_addr,
212-
buf: data,
213-
stride: len,
214-
len,
215-
transport: Transport::Tcp,
216-
})
217-
.await;
218-
added = true;
219-
}
220-
if added {
221-
r_notify.notify_waiters();
222278
}
223279
}
224-
conns.remove(&peer_addr);
280+
281+
// Final cleanup
282+
r_conns.remove(&peer_addr);
283+
if let Some(mut count) = r_ip_counts.get_mut(&peer_ip) {
284+
*count = count.saturating_sub(1);
285+
}
225286
});
226287

227-
// Task: Sender (RFC 4571 Framing + Syscall Batching)
288+
// Task: Sender
228289
tokio::spawn(async move {
229-
let mut write_buf = Vec::with_capacity(CHUNK_SIZE * 2);
230-
let mut writer = writer;
290+
let mut write_buf = Vec::with_capacity(MAX_FRAME_SIZE + 2);
291+
let mut writer = tcp_writer;
292+
231293
while let Ok(first) = send_rx.recv().await {
232294
write_buf.clear();
233295
write_buf.put_u16(first.len() as u16);
@@ -236,7 +298,7 @@ fn handle_new_connection(
236298
while let Ok(next) = send_rx.try_recv() {
237299
write_buf.put_u16(next.len() as u16);
238300
write_buf.put_slice(&next);
239-
if write_buf.len() > 65535 {
301+
if write_buf.len() > 16384 {
240302
break;
241303
}
242304
}
@@ -248,6 +310,7 @@ fn handle_new_connection(
248310
}
249311
});
250312
}
313+
251314
#[cfg(test)]
252315
mod tests {
253316
use super::*;

pulsebeam/src/gateway/actor.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,10 @@ impl GatewayWorkerActor {
175175
1
176176
};
177177

178-
self.demuxer.demux(batch).await;
178+
let src = batch.src;
179+
if !self.demuxer.demux(batch).await {
180+
self.socket.close_peer(&src);
181+
}
179182

180183
spent_budget += cost;
181184

0 commit comments

Comments
 (0)