Skip to content
This repository was archived by the owner on Sep 15, 2021. It is now read-only.

Commit b952861

Browse files
authored
feat: add 5 min timeout for buckets' comm op (#5)
1 parent da8c59d commit b952861

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

bagua-core-internal/src/lib.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ use crate::telemetry::{SCHEDULED_THREAD_POOL, TELEMETRY};
1616
use cpp::cpp;
1717
use datatypes::{BaguaBucket, BaguaTensor};
1818
use events::BaguaEventChannel;
19+
use flume::RecvTimeoutError;
1920
use hashbrown::{HashMap, HashSet};
2021
use std::collections::VecDeque;
2122
use std::fmt::Debug;
2223
use std::sync::Arc;
24+
use std::time::Duration;
2325
use thiserror::Error;
2426

2527
cpp! {{
@@ -120,6 +122,7 @@ pub struct BaguaCommBackend {
120122
channels: Arc<BaguaCommOpChannels>,
121123
managed_ptrs: HashSet<u64>,
122124
comm_worker: std::thread::JoinHandle<()>,
125+
comm_monitor: std::thread::JoinHandle<()>,
123126
}
124127

125128
impl BaguaCommBackend {
@@ -168,6 +171,10 @@ impl BaguaCommBackend {
168171

169172
let channels = Arc::new(BaguaCommOpChannels::new(schedule_channel_cap));
170173
let channels_clone = channels.clone();
174+
let (monitor_op_start_channel_sender, monitor_op_start_channel_receiver) =
175+
flume::unbounded();
176+
let (monitor_op_finish_channel_sender, monitor_op_finish_channel_receiver) =
177+
flume::unbounded();
171178

172179
BaguaCommBackend {
173180
ordered_buckets: Default::default(),
@@ -190,6 +197,7 @@ impl BaguaCommBackend {
190197
"worker received scheduled communication operation {:?}",
191198
comm_op
192199
);
200+
monitor_op_start_channel_sender.send(comm_op.bucket.clone());
193201
for op in &comm_op.ops {
194202
op.execute_background_communication(
195203
comm_op.bucket.clone(),
@@ -199,6 +207,18 @@ impl BaguaCommBackend {
199207
tracing::debug!("comm op executed: {:?}", comm_op);
200208
comm_op.event_channel.finish();
201209
tracing::debug!("comm op marked finished: {:?}", comm_op);
210+
monitor_op_finish_channel_sender.send(());
211+
}
212+
}),
213+
comm_monitor: std::thread::spawn(move || loop {
214+
let op_bucket = monitor_op_start_channel_receiver
215+
.recv()
216+
.expect("monitor cannot receive next comm op bucket");
217+
match monitor_op_finish_channel_receiver.recv_timeout(Duration::from_secs(300)) {
218+
Ok(_) => {}
219+
Err(_) => {
220+
panic!("{:?} comm op has not finished for 5 min, panic", op_bucket);
221+
}
202222
}
203223
}),
204224
}

bagua-core-py/src/lib.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,14 @@ fn bagua_core(_py: Python, m: &PyModule) -> PyResult<()> {
331331
.init();
332332
color_eyre::install().unwrap();
333333

334+
// panic the whole process when thread panics
335+
let orig_hook = std::panic::take_hook();
336+
std::panic::set_hook(Box::new(move |panic_info| {
337+
// invoke the default handler and exit the process
338+
orig_hook(panic_info);
339+
std::process::exit(1);
340+
}));
341+
334342
m.add_class::<BaguaCommBackendPy>()?;
335343
m.add_class::<BaguaTensorPy>()?;
336344
m.add_class::<BaguaBucketPy>()?;

0 commit comments

Comments
 (0)