Skip to content

Commit bf91cf8

Browse files
committed
refactor wait_for_future to avoid spawn
1 parent 268a855 commit bf91cf8

File tree

1 file changed

+21
-25
lines changed

1 file changed

+21
-25
lines changed

src/utils.rs

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ use std::future::Future;
2828
use std::sync::OnceLock;
2929
use std::time::Duration;
3030
use tokio::runtime::Runtime;
31-
use tokio::time::timeout;
3231

3332
/// Utility to get the Tokio Runtime from Python
3433
#[inline]
@@ -68,35 +67,32 @@ where
6867
let (runtime, _enter_guard) = get_and_enter_tokio_runtime();
6968
// Define the interval for checking Python signals
7069
const SIGNAL_CHECK_INTERVAL_MS: u64 = 1000;
71-
// Spawn the task so we can poll it with timeouts
72-
let mut handle = runtime.spawn(f);
7370

74-
// Release the GIL and poll the future with periodic signal checks
71+
// Release the GIL and directly block on the future with periodic signal checks
7572
py.allow_threads(|| {
76-
loop {
77-
// Poll the future with a timeout to allow periodic signal checking
78-
match runtime.block_on(timeout(
79-
Duration::from_millis(SIGNAL_CHECK_INTERVAL_MS),
80-
&mut handle,
81-
)) {
82-
Ok(join_result) => {
83-
// The inner task has completed before timeout
84-
return join_result.map_err(|e| {
85-
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
86-
"Task failed: {}",
87-
e
88-
))
89-
});
90-
}
91-
Err(_elapsed) => {
92-
// SIGNAL_CHECK_INTERVAL_MS elapsed without task completion → check Python signals
93-
if let Err(py_exc) = Python::with_gil(|py| py.check_signals()) {
94-
return Err(py_exc);
73+
runtime.block_on(async {
74+
let mut interval =
75+
tokio::time::interval(Duration::from_millis(SIGNAL_CHECK_INTERVAL_MS));
76+
77+
// Pin the future to the stack
78+
tokio::pin!(f);
79+
80+
loop {
81+
tokio::select! {
82+
result = &mut f => {
83+
// Future completed
84+
return Ok(result);
85+
}
86+
_ = interval.tick() => {
87+
// Time to check Python signals
88+
if let Err(py_exc) = Python::with_gil(|py| py.check_signals()) {
89+
return Err(py_exc);
90+
}
91+
// Continue waiting for the future
9592
}
96-
// Loop again, reintroducing another timeout slice
9793
}
9894
}
99-
}
95+
})
10096
})
10197
}
10298

0 commit comments

Comments
 (0)