From a823969dbc973b841f6d88c4f8bd5c467adfe2fc Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Thu, 8 Jan 2026 13:56:01 +0900 Subject: [PATCH] Check for awaitable instead of coroutine in stream glue --- CHANGELOG.md | 2 ++ pytests/test_async_std_asyncio.rs | 32 ++++++++++++++++++++++++++++++ pytests/tokio_asyncio/mod.rs | 33 +++++++++++++++++++++++++++++++ src/generic.rs | 4 ++-- 4 files changed, 69 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96ef345..0e923b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ To see unreleased changes, please see the CHANGELOG on the main branch. +- Fix handling of full buffer in `into_stream` functions + ## [0.27.0] - 2025-10-20 - Avoid attaching to the runtime when cloning TaskLocals by using std::sync::Arc. [#62](https://github.com/PyO3/pyo3-async-runtimes/pull/62) diff --git a/pytests/test_async_std_asyncio.rs b/pytests/test_async_std_asyncio.rs index 8319c29..39fcdb3 100644 --- a/pytests/test_async_std_asyncio.rs +++ b/pytests/test_async_std_asyncio.rs @@ -350,6 +350,38 @@ async fn test_async_gen_v2() -> PyResult<()> { Ok(()) } +#[cfg(feature = "unstable-streams")] +const ASYNC_STD_TEST_MOD_FASTGEN: &str = r#" + +async def gen(): + for i in range(1000): + yield i +"#; + +#[cfg(feature = "unstable-streams")] +#[pyo3_async_runtimes::async_std::test] +async fn test_async_gen_full_buffer() -> PyResult<()> { + let stream = Python::attach(|py| { + let test_mod = PyModule::from_code( + py, + &CString::new(ASYNC_STD_TEST_MOD_FASTGEN).unwrap(), + &CString::new("test_rust_coroutine/async_std_test_mod.py").unwrap(), + &CString::new("async_std_test_mod").unwrap(), + )?; + + pyo3_async_runtimes::async_std::into_stream_v2(test_mod.call_method0("gen")?) + })?; + + let vals = stream + .map(|item| Python::attach(|py| -> PyResult { item.bind(py).extract() })) + .try_collect::>() + .await?; + + assert_eq!((0..1000).collect::>(), vals); + + Ok(()) +} + const CONTEXTVARS_CODE: &str = r#" cx = contextvars.ContextVar("cx") diff --git a/pytests/tokio_asyncio/mod.rs b/pytests/tokio_asyncio/mod.rs index 1a3b9c8..5942fb6 100644 --- a/pytests/tokio_asyncio/mod.rs +++ b/pytests/tokio_asyncio/mod.rs @@ -362,6 +362,39 @@ async fn test_async_gen_v2() -> PyResult<()> { Ok(()) } +#[cfg(feature = "unstable-streams")] +const TOKIO_TEST_MOD_FASTGEN: &str = r#" +import asyncio + +async def gen(): + for i in range(1000): + yield i +"#; + +#[cfg(feature = "unstable-streams")] +#[pyo3_async_runtimes::tokio::test] +async fn test_async_gen_full_buffer() -> PyResult<()> { + let stream = Python::attach(|py| { + let test_mod = PyModule::from_code( + py, + &CString::new(TOKIO_TEST_MOD_FASTGEN).unwrap(), + &CString::new("test_rust_coroutine/tokio_test_mod.py").unwrap(), + &CString::new("tokio_test_mod").unwrap(), + )?; + + pyo3_async_runtimes::tokio::into_stream_v2(test_mod.call_method0("gen")?) + })?; + + let vals = stream + .map(|item| Python::attach(|py| -> PyResult { item.bind(py).extract() })) + .try_collect::>() + .await?; + + assert_eq!((0..1000).collect::>(), vals); + + Ok(()) +} + const CONTEXTVARS_CODE: &str = r#" cx = contextvars.ContextVar("cx") diff --git a/src/generic.rs b/src/generic.rs index df141fd..c393fa2 100644 --- a/src/generic.rs +++ b/src/generic.rs @@ -1572,13 +1572,13 @@ impl SenderGlue { #[cfg(feature = "unstable-streams")] const STREAM_GLUE: &str = r#" -import asyncio +import inspect async def forward(gen, sender): async for item in gen: should_continue = sender.send(item) - if asyncio.iscoroutine(should_continue): + if inspect.isawaitable(should_continue): should_continue = await should_continue if should_continue: