Skip to content

Commit ebbde13

Browse files
author
Andrew J Westlake
committed
Merged invalid-state-fix changes
2 parents d800b90 + 6a2bec9 commit ebbde13

File tree

4 files changed

+100
-24
lines changed

4 files changed

+100
-24
lines changed

pytests/test_async_std_asyncio.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ mod common;
33
use std::{rc::Rc, time::Duration};
44

55
use async_std::task;
6-
use pyo3::{prelude::*, wrap_pyfunction};
6+
use pyo3::{prelude::*, types::PyType, wrap_pyfunction};
77

88
#[pyfunction]
99
#[allow(deprecated)]
@@ -157,6 +157,38 @@ async fn test_local_future_into_py() -> PyResult<()> {
157157
Ok(())
158158
}
159159

160+
#[pyo3_asyncio::async_std::test]
161+
async fn test_cancel() -> PyResult<()> {
162+
let py_future = Python::with_gil(|py| -> PyResult<PyObject> {
163+
Ok(pyo3_asyncio::async_std::future_into_py(py, async {
164+
async_std::task::sleep(Duration::from_secs(1)).await;
165+
Ok(Python::with_gil(|py| py.None()))
166+
})?
167+
.into())
168+
})?;
169+
170+
if let Err(e) = Python::with_gil(|py| -> PyResult<_> {
171+
py_future.as_ref(py).call_method0("cancel")?;
172+
pyo3_asyncio::async_std::into_future(py_future.as_ref(py))
173+
})?
174+
.await
175+
{
176+
Python::with_gil(|py| -> PyResult<()> {
177+
assert!(py
178+
.import("asyncio.exceptions")?
179+
.getattr("CancelledError")?
180+
.downcast::<PyType>()
181+
.unwrap()
182+
.is_instance(e.pvalue(py))?);
183+
Ok(())
184+
})?;
185+
} else {
186+
panic!("expected CancelledError");
187+
}
188+
189+
Ok(())
190+
}
191+
160192
#[allow(deprecated)]
161193
fn main() -> pyo3::PyResult<()> {
162194
pyo3::prepare_freethreaded_python();

pytests/tokio_asyncio/mod.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::{rc::Rc, time::Duration};
22

3-
use pyo3::{prelude::*, wrap_pyfunction};
3+
use pyo3::{prelude::*, types::PyType, wrap_pyfunction};
44

55
use crate::common;
66

@@ -166,3 +166,35 @@ async fn test_panic() -> PyResult<()> {
166166
}),
167167
}
168168
}
169+
170+
#[pyo3_asyncio::tokio::test]
171+
async fn test_cancel() -> PyResult<()> {
172+
let py_future = Python::with_gil(|py| -> PyResult<PyObject> {
173+
Ok(pyo3_asyncio::tokio::future_into_py(py, async {
174+
tokio::time::sleep(Duration::from_secs(1)).await;
175+
Ok(Python::with_gil(|py| py.None()))
176+
})?
177+
.into())
178+
})?;
179+
180+
if let Err(e) = Python::with_gil(|py| -> PyResult<_> {
181+
py_future.as_ref(py).call_method0("cancel")?;
182+
pyo3_asyncio::tokio::into_future(py_future.as_ref(py))
183+
})?
184+
.await
185+
{
186+
Python::with_gil(|py| -> PyResult<()> {
187+
assert!(py
188+
.import("asyncio.exceptions")?
189+
.getattr("CancelledError")?
190+
.downcast::<PyType>()
191+
.unwrap()
192+
.is_instance(e.pvalue(py))?);
193+
Ok(())
194+
})?;
195+
} else {
196+
panic!("expected CancelledError");
197+
}
198+
199+
Ok(())
200+
}

src/generic.rs

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,10 @@ where
244244
result
245245
}
246246

247+
fn cancelled(future: &PyAny) -> PyResult<bool> {
248+
future.getattr("cancelled")?.call0()?.is_true()
249+
}
250+
247251
fn set_result(event_loop: &PyAny, future: &PyAny, result: PyResult<PyObject>) -> PyResult<()> {
248252
match result {
249253
Ok(val) => {
@@ -461,29 +465,34 @@ where
461465
let result = R::scope(event_loop2.clone(), fut).await;
462466

463467
Python::with_gil(move |py| {
464-
if set_result(event_loop2.as_ref(py), future_tx1.as_ref(py), result)
468+
if cancelled(future_tx1.as_ref(py))
465469
.map_err(dump_err(py))
466-
.is_err()
470+
.unwrap_or(false)
467471
{
468-
469-
// Cancelled
472+
return;
470473
}
474+
475+
let _ = set_result(event_loop2.as_ref(py), future_tx1.as_ref(py), result)
476+
.map_err(dump_err(py));
471477
});
472478
})
473479
.await
474480
{
475481
if e.is_panic() {
476482
Python::with_gil(move |py| {
477-
if set_result(
483+
if cancelled(future_tx2.as_ref(py))
484+
.map_err(dump_err(py))
485+
.unwrap_or(false)
486+
{
487+
return;
488+
}
489+
490+
let _ = set_result(
478491
event_loop.as_ref(py),
479492
future_tx2.as_ref(py),
480493
Err(RustPanic::new_err("rust future panicked")),
481494
)
482-
.map_err(dump_err(py))
483-
.is_err()
484-
{
485-
// Cancelled
486-
}
495+
.map_err(dump_err(py));
487496
});
488497
}
489498
}
@@ -777,28 +786,34 @@ where
777786
let result = R::scope_local(event_loop2.clone(), fut).await;
778787

779788
Python::with_gil(move |py| {
780-
if set_result(event_loop2.as_ref(py), future_tx1.as_ref(py), result)
789+
if cancelled(future_tx1.as_ref(py))
781790
.map_err(dump_err(py))
782-
.is_err()
791+
.unwrap_or(false)
783792
{
784-
// Cancelled
793+
return;
785794
}
795+
796+
let _ = set_result(event_loop2.as_ref(py), future_tx1.as_ref(py), result)
797+
.map_err(dump_err(py));
786798
});
787799
})
788800
.await
789801
{
790802
if e.is_panic() {
791803
Python::with_gil(move |py| {
792-
if set_result(
804+
if cancelled(future_tx2.as_ref(py))
805+
.map_err(dump_err(py))
806+
.unwrap_or(false)
807+
{
808+
return;
809+
}
810+
811+
let _ = set_result(
793812
event_loop.as_ref(py),
794813
future_tx2.as_ref(py),
795814
Err(RustPanic::new_err("Rust future panicked")),
796815
)
797-
.map_err(dump_err(py))
798-
.is_err()
799-
{
800-
// Cancelled
801-
}
816+
.map_err(dump_err(py));
802817
});
803818
}
804819
}

src/lib.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,14 @@
9090
9191
/// <span class="module-item stab portability" style="display: inline; border-radius: 3px; padding: 2px; font-size: 80%; line-height: 1.2;"><code>testing</code></span> Utilities for writing PyO3 Asyncio tests
9292
#[cfg(feature = "testing")]
93-
#[doc(inline)]
9493
pub mod testing;
9594

9695
/// <span class="module-item stab portability" style="display: inline; border-radius: 3px; padding: 2px; font-size: 80%; line-height: 1.2;"><code>async-std-runtime</code></span> PyO3 Asyncio functions specific to the async-std runtime
9796
#[cfg(feature = "async-std")]
98-
#[doc(inline)]
9997
pub mod async_std;
10098

10199
/// <span class="module-item stab portability" style="display: inline; border-radius: 3px; padding: 2px; font-size: 80%; line-height: 1.2;"><code>tokio-runtime</code></span> PyO3 Asyncio functions specific to the tokio runtime
102100
#[cfg(feature = "tokio-runtime")]
103-
#[doc(inline)]
104101
pub mod tokio;
105102

106103
/// Errors and exceptions related to PyO3 Asyncio

0 commit comments

Comments
 (0)