Skip to content

Commit cfb5c5f

Browse files
thomasywangfacebook-github-bot
authored andcommitted
Use tokio::oneshot for sim clock (#822)
Summary: There was an open TODO to remove the global mailbox for SimClock. We don't actually even need mailboxes for sim clock and a oneshot works just fine Differential Revision: D80029571
1 parent 3db8984 commit cfb5c5f

File tree

3 files changed

+56
-106
lines changed

3 files changed

+56
-106
lines changed

hyperactor/src/channel/sim.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ impl MessageDeliveryEvent {
146146

147147
#[async_trait]
148148
impl Event for MessageDeliveryEvent {
149-
async fn handle(&self) -> Result<(), SimNetError> {
149+
async fn handle(&mut self) -> Result<(), SimNetError> {
150150
// Send the message to the correct receiver.
151151
SENDER
152152
.send(

hyperactor/src/clock.rs

Lines changed: 47 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,17 @@ use std::error::Error;
1212
use std::fmt;
1313
use std::sync::LazyLock;
1414
use std::sync::Mutex;
15-
use std::sync::OnceLock;
1615
use std::time::SystemTime;
1716

17+
use async_trait::async_trait;
1818
use futures::pin_mut;
1919
use hyperactor_telemetry::TelemetryClock;
2020
use serde::Deserialize;
2121
use serde::Serialize;
2222

23-
use crate::Mailbox;
2423
use crate::channel::ChannelAddr;
25-
use crate::data::Named;
26-
use crate::id;
27-
use crate::mailbox::DeliveryError;
28-
use crate::mailbox::MailboxSender;
29-
use crate::mailbox::MessageEnvelope;
30-
use crate::mailbox::Undeliverable;
31-
use crate::mailbox::UndeliverableMailboxSender;
32-
use crate::mailbox::monitored_return_handle;
33-
use crate::simnet::SleepEvent;
24+
use crate::simnet::Event;
25+
use crate::simnet::SimNetError;
3426
use crate::simnet::simnet_handle;
3527

3628
struct SimTime {
@@ -183,6 +175,39 @@ impl ClockKind {
183175
}
184176
}
185177

178+
#[derive(Debug)]
179+
struct SleepEvent {
180+
done_tx: Option<tokio::sync::oneshot::Sender<()>>,
181+
duration_ms: u64,
182+
}
183+
184+
impl SleepEvent {
185+
pub(crate) fn new(done_tx: tokio::sync::oneshot::Sender<()>, duration_ms: u64) -> Box<Self> {
186+
Box::new(Self {
187+
done_tx: Some(done_tx),
188+
duration_ms,
189+
})
190+
}
191+
}
192+
193+
#[async_trait]
194+
impl Event for SleepEvent {
195+
async fn handle(&mut self) -> Result<(), SimNetError> {
196+
if self.done_tx.take().unwrap().send(()).is_err() {
197+
tracing::error!("Failed to send wakeup event");
198+
}
199+
Ok(())
200+
}
201+
202+
fn duration_ms(&self) -> u64 {
203+
self.duration_ms
204+
}
205+
206+
fn summary(&self) -> String {
207+
format!("Sleeping for {} ms", self.duration_ms)
208+
}
209+
}
210+
186211
/// Clock to be used in simulator runs that allows the simnet to create a scheduled event for.
187212
/// When the wakeup event becomes the next earliest scheduled event, the simnet will advance it's
188213
/// time to the wakeup time and use the transmitter to wake up this green thread
@@ -192,33 +217,25 @@ pub struct SimClock;
192217
impl Clock for SimClock {
193218
/// Tell the simnet to wake up this green thread after the specified duration has pass on the simnet
194219
async fn sleep(&self, duration: tokio::time::Duration) {
195-
let mailbox = SimClock::mailbox().clone();
196-
let (tx, rx) = mailbox.open_once_port::<()>();
220+
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
197221

198222
simnet_handle()
199223
.unwrap()
200-
.send_event(SleepEvent::new(
201-
tx.bind(),
202-
mailbox,
203-
duration.as_millis() as u64,
204-
))
224+
.send_event(SleepEvent::new(tx, duration.as_millis() as u64))
205225
.unwrap();
206-
rx.recv().await.unwrap();
226+
227+
rx.await.unwrap();
207228
}
208229

209230
async fn non_advancing_sleep(&self, duration: tokio::time::Duration) {
210-
let mailbox = SimClock::mailbox().clone();
211-
let (tx, rx) = mailbox.open_once_port::<()>();
231+
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
212232

213233
simnet_handle()
214234
.unwrap()
215-
.send_nonadvanceable_event(SleepEvent::new(
216-
tx.bind(),
217-
mailbox,
218-
duration.as_millis() as u64,
219-
))
235+
.send_nonadvanceable_event(SleepEvent::new(tx, duration.as_millis() as u64))
220236
.unwrap();
221-
rx.recv().await.unwrap();
237+
238+
rx.await.unwrap();
222239
}
223240

224241
async fn sleep_until(&self, deadline: tokio::time::Instant) {
@@ -242,23 +259,18 @@ impl Clock for SimClock {
242259
where
243260
F: std::future::Future<Output = T>,
244261
{
245-
let mailbox = SimClock::mailbox().clone();
246-
let (tx, deadline_rx) = mailbox.open_once_port::<()>();
262+
let (tx, deadline_rx) = tokio::sync::oneshot::channel::<()>();
247263

248264
simnet_handle()
249265
.unwrap()
250-
.send_event(SleepEvent::new(
251-
tx.bind(),
252-
mailbox,
253-
duration.as_millis() as u64,
254-
))
266+
.send_event(SleepEvent::new(tx, duration.as_millis() as u64))
255267
.unwrap();
256268

257269
let fut = f;
258270
pin_mut!(fut);
259271

260272
tokio::select! {
261-
_ = deadline_rx.recv() => {
273+
_ = deadline_rx => {
262274
Err(TimeoutError)
263275
}
264276
res = &mut fut => Ok(res)
@@ -267,28 +279,6 @@ impl Clock for SimClock {
267279
}
268280

269281
impl SimClock {
270-
// TODO (SF, 2025-07-11): Remove this global, thread through a mailbox
271-
// from upstack and handle undeliverable messages properly.
272-
fn mailbox() -> &'static Mailbox {
273-
static SIMCLOCK_MAILBOX: OnceLock<Mailbox> = OnceLock::new();
274-
SIMCLOCK_MAILBOX.get_or_init(|| {
275-
let mailbox = Mailbox::new_detached(id!(proc[0].proc).clone());
276-
let (undeliverable_messages, mut rx) =
277-
mailbox.open_port::<Undeliverable<MessageEnvelope>>();
278-
undeliverable_messages.bind_to(Undeliverable::<MessageEnvelope>::port());
279-
tokio::spawn(async move {
280-
while let Ok(Undeliverable(mut envelope)) = rx.recv().await {
281-
envelope.try_set_error(DeliveryError::BrokenLink(
282-
"message returned to undeliverable port".to_string(),
283-
));
284-
UndeliverableMailboxSender
285-
.post(envelope, /*unused */ monitored_return_handle())
286-
}
287-
});
288-
mailbox
289-
})
290-
}
291-
292282
/// Advance the sumulator's time to the specified instant
293283
pub fn advance_to(&self, millis: u64) {
294284
let mut guard = SIM_TIME.now.lock().unwrap();

hyperactor/src/simnet.rs

Lines changed: 8 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,12 @@ pub trait Event: Send + Sync + Debug {
8888
/// For a proc spawn, it will be creating the proc object and instantiating it.
8989
/// For any event that manipulates the network (like adding/removing nodes etc.)
9090
/// implement handle_network().
91-
async fn handle(&self) -> Result<(), SimNetError>;
91+
async fn handle(&mut self) -> Result<(), SimNetError>;
9292

9393
/// This is the method that will be called when the simulator fires the event
9494
/// Unless you need to make changes to the network, you do not have to implement this.
9595
/// Only implement handle() method for all non-simnet requirements.
96-
async fn handle_network(&self, _phantom: &SimNet) -> Result<(), SimNetError> {
96+
async fn handle_network(&mut self, _phantom: &SimNet) -> Result<(), SimNetError> {
9797
self.handle().await
9898
}
9999

@@ -117,11 +117,11 @@ struct NodeJoinEvent {
117117

118118
#[async_trait]
119119
impl Event for NodeJoinEvent {
120-
async fn handle(&self) -> Result<(), SimNetError> {
120+
async fn handle(&mut self) -> Result<(), SimNetError> {
121121
Ok(())
122122
}
123123

124-
async fn handle_network(&self, simnet: &SimNet) -> Result<(), SimNetError> {
124+
async fn handle_network(&mut self, simnet: &SimNet) -> Result<(), SimNetError> {
125125
simnet.bind(self.channel_addr.clone()).await;
126126
self.handle().await
127127
}
@@ -135,46 +135,6 @@ impl Event for NodeJoinEvent {
135135
}
136136
}
137137

138-
#[derive(Debug)]
139-
pub(crate) struct SleepEvent {
140-
done_tx: OncePortRef<()>,
141-
mailbox: Mailbox,
142-
duration_ms: u64,
143-
}
144-
145-
impl SleepEvent {
146-
pub(crate) fn new(done_tx: OncePortRef<()>, mailbox: Mailbox, duration_ms: u64) -> Box<Self> {
147-
Box::new(Self {
148-
done_tx,
149-
mailbox,
150-
duration_ms,
151-
})
152-
}
153-
}
154-
155-
#[async_trait]
156-
impl Event for SleepEvent {
157-
async fn handle(&self) -> Result<(), SimNetError> {
158-
Ok(())
159-
}
160-
161-
async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
162-
self.done_tx
163-
.clone()
164-
.send(&self.mailbox, ())
165-
.map_err(|_err| SimNetError::Closed("TODO".to_string()))?;
166-
Ok(())
167-
}
168-
169-
fn duration_ms(&self) -> u64 {
170-
self.duration_ms
171-
}
172-
173-
fn summary(&self) -> String {
174-
format!("Sleeping for {} ms", self.duration_ms)
175-
}
176-
}
177-
178138
#[derive(Debug)]
179139
/// A pytorch operation
180140
pub struct TorchOpEvent {
@@ -188,11 +148,11 @@ pub struct TorchOpEvent {
188148

189149
#[async_trait]
190150
impl Event for TorchOpEvent {
191-
async fn handle(&self) -> Result<(), SimNetError> {
151+
async fn handle(&mut self) -> Result<(), SimNetError> {
192152
Ok(())
193153
}
194154

195-
async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
155+
async fn handle_network(&mut self, _simnet: &SimNet) -> Result<(), SimNetError> {
196156
self.done_tx
197157
.clone()
198158
.send(&self.mailbox, ())
@@ -710,7 +670,7 @@ impl SimNet {
710670
training_script_waiting_time += advanced_time;
711671
}
712672
SimClock.advance_to(scheduled_time);
713-
for scheduled_event in scheduled_events {
673+
for mut scheduled_event in scheduled_events {
714674
self.pending_event_count
715675
.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
716676
if scheduled_event.event.handle_network(self).await.is_err() {
@@ -811,7 +771,7 @@ mod tests {
811771

812772
#[async_trait]
813773
impl Event for MessageDeliveryEvent {
814-
async fn handle(&self) -> Result<(), simnet::SimNetError> {
774+
async fn handle(&mut self) -> Result<(), simnet::SimNetError> {
815775
if let Some(dispatcher) = &self.dispatcher {
816776
dispatcher
817777
.send(

0 commit comments

Comments
 (0)