Skip to content

Commit 0d738ce

Browse files
authored
feat: propagate cancellations for async contexts (#1228)
1 parent 6674dda commit 0d738ce

File tree

8 files changed

+116
-50
lines changed

8 files changed

+116
-50
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,4 @@ serde_path_to_error = "0.1.17"
156156
redis = { version = "0.31.0", features = ["tokio-comp", "connection-manager"] }
157157
expect-test = "1.5.0"
158158
encoding_rs = "0.8.35"
159+
tokio-util = { version = "0.7.16", features = ["rt"] }

python/cocoindex/runtime.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,13 @@ def run(self, coro: Coroutine[Any, Any, T]) -> T:
5656
RuntimeWarning,
5757
stacklevel=2,
5858
)
59-
fut = asyncio.run_coroutine_threadsafe(coro, loop)
60-
return fut.result()
6159

6260
fut = asyncio.run_coroutine_threadsafe(coro, loop)
63-
return fut.result()
61+
try:
62+
return fut.result()
63+
except KeyboardInterrupt:
64+
fut.cancel()
65+
raise
6466

6567

6668
execution_context = _ExecutionContext()

python/cocoindex/subprocess_exec.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import asyncio
2020
import os
2121
import time
22-
import atexit
2322
from .user_app_loader import load_user_app
2423
from .runtime import execution_context
2524
import logging
@@ -32,39 +31,14 @@
3231
# ---------------------------------------------
3332
_pool_lock = threading.Lock()
3433
_pool: ProcessPoolExecutor | None = None
35-
_pool_cleanup_registered = False
3634
_user_apps: list[str] = []
3735
_logger = logging.getLogger(__name__)
3836

3937

40-
def shutdown_pool_at_exit() -> None:
41-
"""Best-effort shutdown of the global ProcessPoolExecutor on interpreter exit."""
42-
global _pool, _pool_cleanup_registered # pylint: disable=global-statement
43-
with _pool_lock:
44-
if _pool is not None:
45-
try:
46-
_pool.shutdown(wait=True, cancel_futures=True)
47-
except Exception as e:
48-
_logger.error(
49-
"Error during ProcessPoolExecutor shutdown at exit: %s",
50-
e,
51-
exc_info=True,
52-
)
53-
finally:
54-
_pool = None
55-
_pool_cleanup_registered = False
56-
57-
5838
def _get_pool() -> ProcessPoolExecutor:
59-
global _pool, _pool_cleanup_registered # pylint: disable=global-statement
39+
global _pool # pylint: disable=global-statement
6040
with _pool_lock:
6141
if _pool is None:
62-
if not _pool_cleanup_registered:
63-
# Register the shutdown at exit at creation time (rather than at import time)
64-
# to make sure it's executed earlier in the shutdown sequence.
65-
atexit.register(shutdown_pool_at_exit)
66-
_pool_cleanup_registered = True
67-
6842
# Single worker process as requested
6943
_pool = ProcessPoolExecutor(
7044
max_workers=1,

src/execution/live_updater.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use futures::future::try_join_all;
1111
use indicatif::ProgressBar;
1212
use sqlx::PgPool;
1313
use tokio::{sync::watch, task::JoinSet, time::MissedTickBehavior};
14+
use tokio_util::task::AbortOnDropHandle;
1415

1516
pub struct FlowLiveUpdaterUpdates {
1617
pub active_sources: Vec<String>,
@@ -364,7 +365,10 @@ impl SourceUpdateTask {
364365
}
365366
}
366367
};
367-
(Some(tokio::spawn(report_task)), Some(pb))
368+
(
369+
Some(AbortOnDropHandle::new(tokio::spawn(report_task))),
370+
Some(pb),
371+
)
368372
} else {
369373
(None, None)
370374
};

src/ops/py_factory.rs

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::prelude::*;
1+
use crate::{prelude::*, py::future::from_py_future};
22

33
use pyo3::{
44
Bound, IntoPyObjectExt, Py, PyAny, Python, pyclass, pymethods,
@@ -93,10 +93,7 @@ impl interface::SimpleFunctionExecutor for Arc<PyFunctionExecutor> {
9393
let result_coro = self.call_py_fn(py, input)?;
9494
let task_locals =
9595
pyo3_async_runtimes::TaskLocals::new(self.py_exec_ctx.event_loop.bind(py).clone());
96-
Ok(pyo3_async_runtimes::into_future_with_locals(
97-
&task_locals,
98-
result_coro,
99-
)?)
96+
Ok(from_py_future(py, &task_locals, result_coro)?)
10097
})?;
10198
let result = result_fut.await;
10299
Python::with_gil(|py| -> Result<_> {
@@ -188,7 +185,8 @@ impl interface::SimpleFunctionFactory for PyFunctionFactory {
188185
let prepare_coro = executor
189186
.call_method(py, "prepare", (), None)
190187
.to_result_with_py_trace(py)?;
191-
let prepare_fut = pyo3_async_runtimes::into_future_with_locals(
188+
let prepare_fut = from_py_future(
189+
py,
192190
&pyo3_async_runtimes::TaskLocals::new(
193191
py_exec_ctx.event_loop.bind(py).clone(),
194192
),
@@ -294,7 +292,8 @@ impl interface::SourceExecutor for PySourceExecutor {
294292
.to_result_with_py_trace(py)?;
295293
let task_locals =
296294
pyo3_async_runtimes::TaskLocals::new(py_exec_ctx.event_loop.bind(py).clone());
297-
Ok(pyo3_async_runtimes::into_future_with_locals(
295+
Ok(from_py_future(
296+
py,
298297
&task_locals,
299298
result_coro.into_bound(py),
300299
)?)
@@ -333,10 +332,7 @@ impl PySourceExecutor {
333332
.to_result_with_py_trace(py)?;
334333
let task_locals =
335334
pyo3_async_runtimes::TaskLocals::new(py_exec_ctx.event_loop.bind(py).clone());
336-
Ok(pyo3_async_runtimes::into_future_with_locals(
337-
&task_locals,
338-
coro.into_bound(py),
339-
)?)
335+
Ok(from_py_future(py, &task_locals, coro.into_bound(py))?)
340336
})?;
341337

342338
// Await the future to get the next item
@@ -523,10 +519,7 @@ impl interface::SourceFactory for PySourceConnectorFactory {
523519
.to_result_with_py_trace(py)?;
524520
let task_locals =
525521
pyo3_async_runtimes::TaskLocals::new(py_exec_ctx.event_loop.bind(py).clone());
526-
let create_future = pyo3_async_runtimes::into_future_with_locals(
527-
&task_locals,
528-
create_coro.into_bound(py),
529-
)?;
522+
let create_future = from_py_future(py, &task_locals, create_coro.into_bound(py))?;
530523
Ok(create_future)
531524
})?;
532525

@@ -671,7 +664,8 @@ impl interface::TargetFactory for PyExportTargetFactory {
671664
let task_locals = pyo3_async_runtimes::TaskLocals::new(
672665
py_exec_ctx.event_loop.bind(py).clone(),
673666
);
674-
anyhow::Ok(pyo3_async_runtimes::into_future_with_locals(
667+
anyhow::Ok(from_py_future(
668+
py,
675669
&task_locals,
676670
prepare_coro.into_bound(py),
677671
)?)
@@ -807,7 +801,8 @@ impl interface::TargetFactory for PyExportTargetFactory {
807801
.to_result_with_py_trace(py)?;
808802
let task_locals =
809803
pyo3_async_runtimes::TaskLocals::new(py_exec_ctx.event_loop.bind(py).clone());
810-
Ok(pyo3_async_runtimes::into_future_with_locals(
804+
Ok(from_py_future(
805+
py,
811806
&task_locals,
812807
result_coro.into_bound(py),
813808
)?)
@@ -867,7 +862,8 @@ impl interface::TargetFactory for PyExportTargetFactory {
867862
.to_result_with_py_trace(py)?;
868863
let task_locals =
869864
pyo3_async_runtimes::TaskLocals::new(py_exec_ctx.event_loop.bind(py).clone());
870-
Ok(pyo3_async_runtimes::into_future_with_locals(
865+
Ok(from_py_future(
866+
py,
871867
&task_locals,
872868
result_coro.into_bound(py),
873869
)?)

src/py/future.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
use crate::prelude::*;
2+
3+
use pyo3::prelude::*;
4+
use pyo3::types::PyDict;
5+
use pyo3_async_runtimes::TaskLocals;
6+
use std::sync::atomic::{AtomicBool, Ordering};
7+
use std::{
8+
future::Future,
9+
pin::Pin,
10+
task::{Context, Poll},
11+
};
12+
13+
struct CancelOnDropPy {
14+
inner: BoxFuture<'static, pyo3::PyResult<pyo3::PyObject>>,
15+
task: Py<PyAny>,
16+
event_loop: Py<PyAny>,
17+
ctx: Py<PyAny>,
18+
done: AtomicBool,
19+
}
20+
21+
impl Future for CancelOnDropPy {
22+
type Output = pyo3::PyResult<pyo3::PyObject>;
23+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
24+
match Pin::new(&mut self.inner).poll(cx) {
25+
Poll::Ready(out) => {
26+
self.done.store(true, Ordering::SeqCst);
27+
Poll::Ready(out)
28+
}
29+
Poll::Pending => Poll::Pending,
30+
}
31+
}
32+
}
33+
34+
impl Drop for CancelOnDropPy {
35+
fn drop(&mut self) {
36+
if self.done.load(Ordering::SeqCst) {
37+
return;
38+
}
39+
Python::with_gil(|py| {
40+
let kwargs = PyDict::new(py);
41+
let result = || -> PyResult<()> {
42+
// pass context so cancellation runs under the right contextvars
43+
kwargs.set_item("context", self.ctx.bind(py))?;
44+
self.event_loop.bind(py).call_method(
45+
"call_soon_threadsafe",
46+
(self.task.bind(py).getattr("cancel")?,),
47+
Some(&kwargs),
48+
)?;
49+
// self.task.bind(py).call_method0("cancel")?;
50+
Ok(())
51+
}();
52+
if let Err(e) = result {
53+
error!("Error cancelling task: {e:?}");
54+
}
55+
});
56+
}
57+
}
58+
59+
pub fn from_py_future<'py, 'fut>(
60+
py: Python<'py>,
61+
locals: &TaskLocals,
62+
awaitable: Bound<'py, PyAny>,
63+
) -> pyo3::PyResult<impl Future<Output = pyo3::PyResult<pyo3::PyObject>> + Send + use<'fut>> {
64+
// 1) Capture loop + context from TaskLocals for thread-safe cancellation
65+
let event_loop: Bound<'py, PyAny> = locals.event_loop(py).into();
66+
let ctx: Bound<'py, PyAny> = locals.context(py);
67+
68+
// 2) Create a Task so we own a handle we can cancel later
69+
let kwarg = PyDict::new(py);
70+
kwarg.set_item("context", &ctx)?;
71+
let task: Bound<'py, PyAny> = event_loop
72+
.call_method("create_task", (awaitable,), Some(&kwarg))?
73+
.into();
74+
75+
// 3) Bridge it to a Rust Future as usual
76+
let fut = pyo3_async_runtimes::into_future_with_locals(locals, task.clone())?.boxed();
77+
78+
Ok(CancelOnDropPy {
79+
inner: fut,
80+
task: task.unbind(),
81+
event_loop: event_loop.unbind(),
82+
ctx: ctx.unbind(),
83+
done: AtomicBool::new(false),
84+
})
85+
}

src/py/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use std::sync::Arc;
2121

2222
mod convert;
2323
pub(crate) use convert::*;
24+
pub mod future;
2425

2526
pub struct PythonExecutionContext {
2627
pub event_loop: Py<PyAny>,
@@ -478,7 +479,8 @@ impl Flow {
478479
let task_locals = pyo3_async_runtimes::TaskLocals::new(
479480
py_exec_ctx.event_loop.bind(py).clone(),
480481
);
481-
Ok(pyo3_async_runtimes::into_future_with_locals(
482+
Ok(future::from_py_future(
483+
py,
482484
&task_locals,
483485
result_coro.into_bound(py),
484486
)?)

0 commit comments

Comments
 (0)