Skip to content

Commit eaea68c

Browse files
committed
init mpsc
1 parent 66654a3 commit eaea68c

File tree

6 files changed

+306
-17
lines changed

6 files changed

+306
-17
lines changed

pulsebeam-runtime/src/sync/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
pub mod mpsc;
12
pub mod spmc;
23

34
#[cfg(not(feature = "loom"))]

pulsebeam-runtime/src/sync/mpsc.rs

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
use event_listener::{Event, EventListener};
2+
use futures_lite::Stream;
3+
use parking_lot::Mutex;
4+
use std::pin::Pin;
5+
use std::sync::Arc;
6+
use std::sync::atomic::{AtomicU64, Ordering};
7+
use std::task::{Context, Poll, ready};
8+
9+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10+
pub enum RecvError {
11+
Lagged(u64),
12+
Closed,
13+
}
14+
15+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16+
pub enum StreamRecvError {
17+
Lagged(u64),
18+
}
19+
20+
#[derive(Debug)]
21+
struct Slot<T> {
22+
seq: u64,
23+
val: Option<T>,
24+
}
25+
26+
#[derive(Debug)]
27+
struct Ring<T> {
28+
slots: Vec<Mutex<Slot<T>>>, // per-slot mutex
29+
mask: usize,
30+
head: AtomicU64, // next free sequence number
31+
event: Event,
32+
closed: AtomicU64,
33+
}
34+
35+
impl<T> Ring<T> {
36+
fn new(mut capacity: usize) -> Arc<Self> {
37+
if capacity == 0 {
38+
capacity = 1;
39+
} else if !capacity.is_power_of_two() {
40+
capacity = capacity.next_power_of_two();
41+
}
42+
43+
let mut slots = Vec::with_capacity(capacity);
44+
for _ in 0..capacity {
45+
slots.push(Mutex::new(Slot { seq: 0, val: None }));
46+
}
47+
48+
Arc::new(Self {
49+
slots,
50+
mask: capacity - 1,
51+
head: AtomicU64::new(0),
52+
event: Event::new(),
53+
closed: AtomicU64::new(0),
54+
})
55+
}
56+
}
57+
58+
#[derive(Debug)]
59+
pub enum TrySendError<T> {
60+
Closed(T),
61+
}
62+
63+
#[derive(Debug, Clone)]
64+
pub struct Sender<T> {
65+
ring: Arc<Ring<T>>,
66+
}
67+
68+
impl<T> Sender<T> {
69+
pub fn try_send(&self, val: T) -> Result<(), TrySendError<T>> {
70+
if self.ring.closed.load(Ordering::Relaxed) == 1 {
71+
return Err(TrySendError::Closed(val));
72+
}
73+
74+
// Atomically claim slot
75+
let seq = self.ring.head.fetch_add(1, Ordering::AcqRel);
76+
let idx = (seq as usize) & self.ring.mask;
77+
78+
let mut slot = self.ring.slots[idx].lock();
79+
slot.val = Some(val);
80+
slot.seq = seq;
81+
82+
// there's only 1 consumer
83+
self.ring.event.notify(1);
84+
Ok(())
85+
}
86+
}
87+
88+
impl<T> Drop for Sender<T> {
89+
fn drop(&mut self) {
90+
// TODO: there's no notification for receiver when no senders left.
91+
// self.ring.closed.store(1, Ordering::Release);
92+
// self.ring.event.notify(usize::MAX);
93+
}
94+
}
95+
96+
#[derive(Debug)]
97+
pub struct Receiver<T> {
98+
ring: Arc<Ring<T>>,
99+
next_seq: u64,
100+
local_head: u64,
101+
listener: Option<EventListener>,
102+
}
103+
104+
impl<T> Drop for Receiver<T> {
105+
fn drop(&mut self) {
106+
self.ring.closed.store(1, Ordering::Release);
107+
}
108+
}
109+
110+
impl<T: Clone> Receiver<T> {
111+
pub async fn recv(&mut self) -> Result<T, RecvError> {
112+
std::future::poll_fn(|cx| self.poll_recv(cx)).await
113+
}
114+
115+
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
116+
loop {
117+
let coop = std::task::ready!(tokio::task::coop::poll_proceed(cx));
118+
119+
// Refresh local snapshot of head
120+
if self.next_seq == self.local_head {
121+
self.local_head = self.ring.head.load(Ordering::Acquire);
122+
}
123+
124+
let capacity = (self.ring.mask as u64) + 1;
125+
let earliest = self.local_head.saturating_sub(capacity);
126+
127+
// Closed + nothing left
128+
if self.ring.closed.load(Ordering::Acquire) == 1 && self.next_seq >= self.local_head {
129+
return Poll::Ready(Err(RecvError::Closed));
130+
}
131+
132+
// No new items
133+
if self.next_seq >= self.local_head {
134+
match &mut self.listener {
135+
Some(l) => {
136+
if Pin::new(l).poll(cx).is_pending() {
137+
return Poll::Pending;
138+
}
139+
self.listener = None;
140+
continue;
141+
}
142+
None => {
143+
self.listener = Some(self.ring.event.listen());
144+
continue;
145+
}
146+
}
147+
}
148+
149+
let idx = (self.next_seq as usize) & self.ring.mask;
150+
let slot = self.ring.slots[idx].lock();
151+
let slot_seq = slot.seq;
152+
153+
if slot_seq < earliest {
154+
self.next_seq = self.local_head;
155+
return Poll::Ready(Err(RecvError::Lagged(self.local_head)));
156+
}
157+
158+
if slot_seq != self.next_seq {
159+
self.next_seq = self.local_head;
160+
return Poll::Ready(Err(RecvError::Lagged(self.local_head)));
161+
}
162+
163+
if let Some(ref v) = slot.val {
164+
coop.made_progress();
165+
let out = v.clone();
166+
self.next_seq += 1;
167+
return Poll::Ready(Ok(out));
168+
}
169+
170+
self.next_seq = self.local_head;
171+
return Poll::Ready(Err(RecvError::Lagged(self.local_head)));
172+
}
173+
}
174+
}
175+
176+
impl<T: Clone> Stream for Receiver<T> {
177+
type Item = Result<T, StreamRecvError>;
178+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
179+
let this = self.get_mut();
180+
let res = match ready!(this.poll_recv(cx)) {
181+
Ok(item) => Some(Ok(item)),
182+
Err(RecvError::Lagged(n)) => Some(Err(StreamRecvError::Lagged(n))),
183+
Err(RecvError::Closed) => None,
184+
};
185+
Poll::Ready(res)
186+
}
187+
}
188+
189+
impl<T: Clone> Clone for Receiver<T> {
190+
fn clone(&self) -> Self {
191+
Self {
192+
ring: self.ring.clone(),
193+
next_seq: self.next_seq,
194+
local_head: self.local_head,
195+
listener: None,
196+
}
197+
}
198+
}
199+
200+
pub fn channel<T: Send + Sync + Clone + 'static>(capacity: usize) -> (Sender<T>, Receiver<T>) {
201+
let ring = Ring::new(capacity);
202+
(
203+
Sender { ring: ring.clone() },
204+
Receiver {
205+
ring,
206+
next_seq: 0,
207+
local_head: 0,
208+
listener: None,
209+
},
210+
)
211+
}
212+
213+
#[cfg(test)]
214+
mod tests {
215+
use super::*;
216+
use tokio::time::Duration;
217+
218+
#[tokio::test]
219+
async fn basic_send_recv() {
220+
let (tx, mut rx) = channel::<u64>(8);
221+
222+
tx.try_send(42).unwrap();
223+
assert_eq!(rx.recv().await, Ok(42));
224+
225+
tx.try_send(123).unwrap();
226+
assert_eq!(rx.recv().await, Ok(123));
227+
}
228+
229+
#[tokio::test]
230+
async fn lagging_consumer_jumps_to_head() {
231+
let (tx, mut rx) = channel::<u64>(4);
232+
233+
for i in 0..6 {
234+
tx.try_send(i).unwrap();
235+
}
236+
237+
match rx.recv().await {
238+
Err(RecvError::Lagged(seq)) => assert_eq!(seq, 6),
239+
_ => panic!("Expected lag error"),
240+
}
241+
242+
tx.try_send(6).unwrap();
243+
assert_eq!(rx.recv().await, Ok(6));
244+
}
245+
246+
#[tokio::test]
247+
async fn close_signal_drains_then_stops() {
248+
let (tx, mut rx) = channel::<u64>(4);
249+
250+
tx.try_send(1).unwrap();
251+
tx.try_send(2).unwrap();
252+
253+
drop(tx);
254+
255+
assert_eq!(rx.recv().await, Ok(1));
256+
assert_eq!(rx.recv().await, Ok(2));
257+
assert_eq!(rx.recv().await, Err(RecvError::Closed));
258+
assert_eq!(rx.recv().await, Err(RecvError::Closed));
259+
}
260+
261+
#[tokio::test]
262+
async fn async_waker_notification() {
263+
let (tx, mut rx) = channel::<u64>(4);
264+
265+
let handle = tokio::spawn(async move { rx.recv().await });
266+
267+
tokio::time::sleep(Duration::from_millis(10)).await;
268+
tx.try_send(99).unwrap();
269+
270+
let result = handle.await.unwrap();
271+
assert_eq!(result, Ok(99));
272+
}
273+
274+
#[tokio::test]
275+
async fn buffer_wrap_around_behavior() {
276+
let (tx, mut rx) = channel::<u64>(4);
277+
278+
for i in 0..4 {
279+
tx.try_send(i).unwrap();
280+
}
281+
282+
assert_eq!(rx.recv().await, Ok(0));
283+
assert_eq!(rx.recv().await, Ok(1));
284+
285+
tx.try_send(4).unwrap();
286+
tx.try_send(5).unwrap();
287+
288+
assert_eq!(rx.recv().await, Ok(2));
289+
assert_eq!(rx.recv().await, Ok(3));
290+
assert_eq!(rx.recv().await, Ok(4));
291+
assert_eq!(rx.recv().await, Ok(5));
292+
}
293+
}

pulsebeam-runtime/src/sync/spmc.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,30 +168,26 @@ impl<T: Clone> Receiver<T> {
168168

169169
// Slot overwritten before we got here
170170
if slot_seq < earliest {
171-
drop(slot);
172171
self.next_seq = self.local_head;
173172
return Poll::Ready(Err(RecvError::Lagged(self.local_head)));
174173
}
175174

176175
// Seq mismatch — producer overwrote after head snapshot
177176
if slot_seq != self.next_seq {
178-
drop(slot);
179177
self.next_seq = self.local_head;
180178
return Poll::Ready(Err(RecvError::Lagged(self.local_head)));
181179
}
182180

183181
// Valid message
184182
if let Some(v) = &slot.val {
185183
let out = v.clone();
186-
drop(slot);
187184
coop.made_progress();
188185
self.next_seq += 1;
189186
return Poll::Ready(Ok(out));
190187
}
191188

192189
// This shouldn't never happen, but just in case..
193190
// Seq was correct but value missing — treat as lag
194-
drop(slot);
195191
self.next_seq = self.local_head;
196192
return Poll::Ready(Err(RecvError::Lagged(self.local_head)));
197193
}

pulsebeam/src/gateway/actor.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
use crate::gateway::demux::Demuxer;
22
use pulsebeam_runtime::actor::{ActorKind, ActorStatus, RunnerConfig};
33
use pulsebeam_runtime::prelude::*;
4-
use pulsebeam_runtime::{actor, mailbox, net};
4+
use pulsebeam_runtime::{actor, net};
55
use std::{io, sync::Arc};
66
use tokio::task::JoinSet;
77

88
#[derive(Clone)]
99
pub enum GatewayControlMessage {
10-
AddParticipant(String, mailbox::Sender<net::RecvPacketBatch>),
10+
AddParticipant(
11+
String,
12+
pulsebeam_runtime::sync::mpsc::Sender<net::RecvPacketBatch>,
13+
),
1114
RemoveParticipant(String),
1215
}
1316

@@ -171,7 +174,7 @@ impl GatewayWorkerActor {
171174
};
172175

173176
let src = batch.src;
174-
if !self.demuxer.demux(&mut self.socket, batch).await {
177+
if !self.demuxer.demux(&mut self.socket, batch) {
175178
// In case there's a malicious actor, close immediately as there's no
176179
// associated participant.
177180
self.socket.close_peer(&src);

pulsebeam/src/gateway/demux.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
use pulsebeam_runtime::net;
12
use pulsebeam_runtime::net::UnifiedSocketReader;
2-
use pulsebeam_runtime::{mailbox, net};
33

44
use crate::gateway::ice;
55
use std::collections::HashMap;
66
use std::net::SocketAddr;
77

8-
pub type ParticipantHandle = mailbox::Sender<net::RecvPacketBatch>;
8+
pub type ParticipantHandle = pulsebeam_runtime::sync::mpsc::Sender<net::RecvPacketBatch>;
99

1010
/// A UDP demuxer that maps packets to participants based on source address and STUN ufrag.
1111
///
@@ -57,11 +57,7 @@ impl Demuxer {
5757

5858
/// Routes a packet to the correct participant.
5959
/// Returns `true` if sent, `false` if dropped
60-
pub async fn demux(
61-
&mut self,
62-
socket: &mut UnifiedSocketReader,
63-
batch: net::RecvPacketBatch,
64-
) -> bool {
60+
pub fn demux(&mut self, socket: &mut UnifiedSocketReader, batch: net::RecvPacketBatch) -> bool {
6561
let src = batch.src;
6662

6763
let handle = if let Some(h) = self.addr_map.get_mut(&src) {
@@ -83,7 +79,7 @@ impl Demuxer {
8379
return false;
8480
};
8581

86-
if let Err(_) = handle.send(batch).await {
82+
if let Err(_) = handle.try_send(batch) {
8783
// Handle is closed! Clean up everything related to this participant.
8884
if let Some(ufrag) = self.addr_to_ufrag.get(&src).cloned() {
8985
tracing::info!("Participant handle closed, cleaning up ufrag: {:?}", ufrag);

0 commit comments

Comments
 (0)