Skip to content

Commit fb85445

Browse files
thomasywangfacebook-github-bot
authored andcommitted
Use tokio::oneshot for sim clock (#822)
Summary: Pull Request resolved: #822 There was an open TODO to remove the global mailbox for SimClock. We don't actually even need mailboxes for sim clock and a oneshot works just fine Reviewed By: pablorfb-meta Differential Revision: D80029571 fbshipit-source-id: be78bfaa2135475b4715ace6fa21cb28f2d81ace
1 parent eaa3853 commit fb85445

File tree

3 files changed

+66
-98
lines changed

3 files changed

+66
-98
lines changed

hyperactor/src/channel/sim.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ impl MessageDeliveryEvent {
145145

146146
#[async_trait]
147147
impl Event for MessageDeliveryEvent {
148-
async fn handle(&self) -> Result<(), SimNetError> {
148+
async fn handle(&mut self) -> Result<(), SimNetError> {
149149
// Send the message to the correct receiver.
150150
SENDER
151151
.send(

hyperactor/src/clock.rs

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,17 @@ use std::error::Error;
1212
use std::fmt;
1313
use std::sync::LazyLock;
1414
use std::sync::Mutex;
15-
use std::sync::OnceLock;
1615
use std::time::SystemTime;
1716

17+
use async_trait::async_trait;
1818
use futures::pin_mut;
1919
use hyperactor_telemetry::TelemetryClock;
2020
use serde::Deserialize;
2121
use serde::Serialize;
2222

23-
use crate::Mailbox;
2423
use crate::channel::ChannelAddr;
25-
use crate::data::Named;
26-
use crate::id;
27-
use crate::mailbox::DeliveryError;
28-
use crate::mailbox::MailboxSender;
29-
use crate::mailbox::MessageEnvelope;
30-
use crate::mailbox::Undeliverable;
31-
use crate::mailbox::UndeliverableMailboxSender;
32-
use crate::mailbox::monitored_return_handle;
33-
use crate::simnet::SleepEvent;
24+
use crate::simnet::Event;
25+
use crate::simnet::SimNetError;
3426
use crate::simnet::simnet_handle;
3527

3628
struct SimTime {
@@ -183,6 +175,45 @@ impl ClockKind {
183175
}
184176
}
185177

178+
#[derive(Debug)]
179+
struct SleepEvent {
180+
done_tx: Option<tokio::sync::oneshot::Sender<()>>,
181+
duration: tokio::time::Duration,
182+
}
183+
184+
impl SleepEvent {
185+
pub(crate) fn new(
186+
done_tx: tokio::sync::oneshot::Sender<()>,
187+
duration: tokio::time::Duration,
188+
) -> Box<Self> {
189+
Box::new(Self {
190+
done_tx: Some(done_tx),
191+
duration,
192+
})
193+
}
194+
}
195+
196+
#[async_trait]
197+
impl Event for SleepEvent {
198+
async fn handle(&mut self) -> Result<(), SimNetError> {
199+
self.done_tx
200+
.take()
201+
.unwrap()
202+
.send(())
203+
.map_err(|_| SimNetError::PanickedTask)?;
204+
205+
Ok(())
206+
}
207+
208+
fn duration(&self) -> tokio::time::Duration {
209+
self.duration
210+
}
211+
212+
fn summary(&self) -> String {
213+
format!("Sleeping for {} ms", self.duration.as_millis())
214+
}
215+
}
216+
186217
/// Clock to be used in simulator runs that allows the simnet to create a scheduled event for.
187218
/// When the wakeup event becomes the next earliest scheduled event, the simnet will advance it's
188219
/// time to the wakeup time and use the transmitter to wake up this green thread
@@ -192,25 +223,25 @@ pub struct SimClock;
192223
impl Clock for SimClock {
193224
/// Tell the simnet to wake up this green thread after the specified duration has pass on the simnet
194225
async fn sleep(&self, duration: tokio::time::Duration) {
195-
let mailbox = SimClock::mailbox().clone();
196-
let (tx, rx) = mailbox.open_once_port::<()>();
226+
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
197227

198228
simnet_handle()
199229
.unwrap()
200-
.send_event(SleepEvent::new(tx.bind(), mailbox, duration))
230+
.send_event(SleepEvent::new(tx, duration))
201231
.unwrap();
202-
rx.recv().await.unwrap();
232+
233+
rx.await.unwrap();
203234
}
204235

205236
async fn non_advancing_sleep(&self, duration: tokio::time::Duration) {
206-
let mailbox = SimClock::mailbox().clone();
207-
let (tx, rx) = mailbox.open_once_port::<()>();
237+
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
208238

209239
simnet_handle()
210240
.unwrap()
211-
.send_nonadvanceable_event(SleepEvent::new(tx.bind(), mailbox, duration))
241+
.send_nonadvanceable_event(SleepEvent::new(tx, duration))
212242
.unwrap();
213-
rx.recv().await.unwrap();
243+
244+
rx.await.unwrap();
214245
}
215246

216247
async fn sleep_until(&self, deadline: tokio::time::Instant) {
@@ -234,19 +265,18 @@ impl Clock for SimClock {
234265
where
235266
F: std::future::Future<Output = T>,
236267
{
237-
let mailbox = SimClock::mailbox().clone();
238-
let (tx, deadline_rx) = mailbox.open_once_port::<()>();
268+
let (tx, deadline_rx) = tokio::sync::oneshot::channel::<()>();
239269

240270
simnet_handle()
241271
.unwrap()
242-
.send_event(SleepEvent::new(tx.bind(), mailbox, duration))
272+
.send_event(SleepEvent::new(tx, duration))
243273
.unwrap();
244274

245275
let fut = f;
246276
pin_mut!(fut);
247277

248278
tokio::select! {
249-
_ = deadline_rx.recv() => {
279+
_ = deadline_rx => {
250280
Err(TimeoutError)
251281
}
252282
res = &mut fut => Ok(res)
@@ -255,28 +285,6 @@ impl Clock for SimClock {
255285
}
256286

257287
impl SimClock {
258-
// TODO (SF, 2025-07-11): Remove this global, thread through a mailbox
259-
// from upstack and handle undeliverable messages properly.
260-
fn mailbox() -> &'static Mailbox {
261-
static SIMCLOCK_MAILBOX: OnceLock<Mailbox> = OnceLock::new();
262-
SIMCLOCK_MAILBOX.get_or_init(|| {
263-
let mailbox = Mailbox::new_detached(id!(proc[0].proc).clone());
264-
let (undeliverable_messages, mut rx) =
265-
mailbox.open_port::<Undeliverable<MessageEnvelope>>();
266-
undeliverable_messages.bind_to(Undeliverable::<MessageEnvelope>::port());
267-
tokio::spawn(async move {
268-
while let Ok(Undeliverable(mut envelope)) = rx.recv().await {
269-
envelope.try_set_error(DeliveryError::BrokenLink(
270-
"message returned to undeliverable port".to_string(),
271-
));
272-
UndeliverableMailboxSender
273-
.post(envelope, /*unused */ monitored_return_handle())
274-
}
275-
});
276-
mailbox
277-
})
278-
}
279-
280288
/// Advance the sumulator's time to the specified instant
281289
pub fn advance_to(&self, time: tokio::time::Instant) {
282290
let mut guard = SIM_TIME.now.lock().unwrap();

hyperactor/src/simnet.rs

Lines changed: 12 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,12 @@ pub trait Event: Send + Sync + Debug {
8888
/// For a proc spawn, it will be creating the proc object and instantiating it.
8989
/// For any event that manipulates the network (like adding/removing nodes etc.)
9090
/// implement handle_network().
91-
async fn handle(&self) -> Result<(), SimNetError>;
91+
async fn handle(&mut self) -> Result<(), SimNetError>;
9292

9393
/// This is the method that will be called when the simulator fires the event
9494
/// Unless you need to make changes to the network, you do not have to implement this.
9595
/// Only implement handle() method for all non-simnet requirements.
96-
async fn handle_network(&self, _phantom: &SimNet) -> Result<(), SimNetError> {
96+
async fn handle_network(&mut self, _phantom: &SimNet) -> Result<(), SimNetError> {
9797
self.handle().await
9898
}
9999

@@ -117,11 +117,11 @@ struct NodeJoinEvent {
117117

118118
#[async_trait]
119119
impl Event for NodeJoinEvent {
120-
async fn handle(&self) -> Result<(), SimNetError> {
120+
async fn handle(&mut self) -> Result<(), SimNetError> {
121121
Ok(())
122122
}
123123

124-
async fn handle_network(&self, simnet: &SimNet) -> Result<(), SimNetError> {
124+
async fn handle_network(&mut self, simnet: &SimNet) -> Result<(), SimNetError> {
125125
simnet.bind(self.channel_addr.clone()).await;
126126
self.handle().await
127127
}
@@ -135,50 +135,6 @@ impl Event for NodeJoinEvent {
135135
}
136136
}
137137

138-
#[derive(Debug)]
139-
pub(crate) struct SleepEvent {
140-
done_tx: OncePortRef<()>,
141-
mailbox: Mailbox,
142-
duration: tokio::time::Duration,
143-
}
144-
145-
impl SleepEvent {
146-
pub(crate) fn new(
147-
done_tx: OncePortRef<()>,
148-
mailbox: Mailbox,
149-
duration: tokio::time::Duration,
150-
) -> Box<Self> {
151-
Box::new(Self {
152-
done_tx,
153-
mailbox,
154-
duration,
155-
})
156-
}
157-
}
158-
159-
#[async_trait]
160-
impl Event for SleepEvent {
161-
async fn handle(&self) -> Result<(), SimNetError> {
162-
Ok(())
163-
}
164-
165-
async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
166-
self.done_tx
167-
.clone()
168-
.send(&self.mailbox, ())
169-
.map_err(|_err| SimNetError::Closed("TODO".to_string()))?;
170-
Ok(())
171-
}
172-
173-
fn duration(&self) -> tokio::time::Duration {
174-
self.duration
175-
}
176-
177-
fn summary(&self) -> String {
178-
format!("Sleeping for {} ms", self.duration.as_millis())
179-
}
180-
}
181-
182138
#[derive(Debug)]
183139
/// A pytorch operation
184140
pub struct TorchOpEvent {
@@ -192,11 +148,11 @@ pub struct TorchOpEvent {
192148

193149
#[async_trait]
194150
impl Event for TorchOpEvent {
195-
async fn handle(&self) -> Result<(), SimNetError> {
151+
async fn handle(&mut self) -> Result<(), SimNetError> {
196152
Ok(())
197153
}
198154

199-
async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
155+
async fn handle_network(&mut self, _simnet: &SimNet) -> Result<(), SimNetError> {
200156
self.done_tx
201157
.clone()
202158
.send(&self.mailbox, ())
@@ -308,6 +264,10 @@ pub enum SimNetError {
308264
/// SimnetHandle being accessed without starting simnet
309265
#[error("simnet not started")]
310266
NotStarted,
267+
268+
/// A task has panicked.
269+
#[error("panicked task")]
270+
PanickedTask,
311271
}
312272

313273
struct State {
@@ -709,7 +669,7 @@ impl SimNet {
709669
training_script_waiting_time += advanced_time;
710670
}
711671
SimClock.advance_to(scheduled_time);
712-
for scheduled_event in scheduled_events {
672+
for mut scheduled_event in scheduled_events {
713673
self.pending_event_count
714674
.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
715675
if scheduled_event.event.handle_network(self).await.is_err() {
@@ -810,7 +770,7 @@ mod tests {
810770

811771
#[async_trait]
812772
impl Event for MessageDeliveryEvent {
813-
async fn handle(&self) -> Result<(), simnet::SimNetError> {
773+
async fn handle(&mut self) -> Result<(), simnet::SimNetError> {
814774
if let Some(dispatcher) = &self.dispatcher {
815775
dispatcher
816776
.send(

0 commit comments

Comments
 (0)