Skip to content

Commit a823969

Browse files
committed
Check for awaitable instead of coroutine in stream glue
1 parent 22e0ec1 commit a823969

File tree

4 files changed

+69
-2
lines changed

4 files changed

+69
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ To see unreleased changes, please see the CHANGELOG on the main branch.
1010

1111
<!-- towncrier release notes start -->
1212

13+
- Fix handling of full buffer in `into_stream` functions
14+
1315
## [0.27.0] - 2025-10-20
1416

1517
- Avoid attaching to the runtime when cloning TaskLocals by using std::sync::Arc. [#62](https://github.com/PyO3/pyo3-async-runtimes/pull/62)

pytests/test_async_std_asyncio.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,38 @@ async fn test_async_gen_v2() -> PyResult<()> {
350350
Ok(())
351351
}
352352

353+
#[cfg(feature = "unstable-streams")]
354+
const ASYNC_STD_TEST_MOD_FASTGEN: &str = r#"
355+
356+
async def gen():
357+
for i in range(1000):
358+
yield i
359+
"#;
360+
361+
#[cfg(feature = "unstable-streams")]
362+
#[pyo3_async_runtimes::async_std::test]
363+
async fn test_async_gen_full_buffer() -> PyResult<()> {
364+
let stream = Python::attach(|py| {
365+
let test_mod = PyModule::from_code(
366+
py,
367+
&CString::new(ASYNC_STD_TEST_MOD_FASTGEN).unwrap(),
368+
&CString::new("test_rust_coroutine/async_std_test_mod.py").unwrap(),
369+
&CString::new("async_std_test_mod").unwrap(),
370+
)?;
371+
372+
pyo3_async_runtimes::async_std::into_stream_v2(test_mod.call_method0("gen")?)
373+
})?;
374+
375+
let vals = stream
376+
.map(|item| Python::attach(|py| -> PyResult<i32> { item.bind(py).extract() }))
377+
.try_collect::<Vec<i32>>()
378+
.await?;
379+
380+
assert_eq!((0..1000).collect::<Vec<i32>>(), vals);
381+
382+
Ok(())
383+
}
384+
353385
const CONTEXTVARS_CODE: &str = r#"
354386
cx = contextvars.ContextVar("cx")
355387

pytests/tokio_asyncio/mod.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,39 @@ async fn test_async_gen_v2() -> PyResult<()> {
362362
Ok(())
363363
}
364364

365+
#[cfg(feature = "unstable-streams")]
366+
const TOKIO_TEST_MOD_FASTGEN: &str = r#"
367+
import asyncio
368+
369+
async def gen():
370+
for i in range(1000):
371+
yield i
372+
"#;
373+
374+
#[cfg(feature = "unstable-streams")]
375+
#[pyo3_async_runtimes::tokio::test]
376+
async fn test_async_gen_full_buffer() -> PyResult<()> {
377+
let stream = Python::attach(|py| {
378+
let test_mod = PyModule::from_code(
379+
py,
380+
&CString::new(TOKIO_TEST_MOD_FASTGEN).unwrap(),
381+
&CString::new("test_rust_coroutine/tokio_test_mod.py").unwrap(),
382+
&CString::new("tokio_test_mod").unwrap(),
383+
)?;
384+
385+
pyo3_async_runtimes::tokio::into_stream_v2(test_mod.call_method0("gen")?)
386+
})?;
387+
388+
let vals = stream
389+
.map(|item| Python::attach(|py| -> PyResult<i32> { item.bind(py).extract() }))
390+
.try_collect::<Vec<i32>>()
391+
.await?;
392+
393+
assert_eq!((0..1000).collect::<Vec<i32>>(), vals);
394+
395+
Ok(())
396+
}
397+
365398
const CONTEXTVARS_CODE: &str = r#"
366399
cx = contextvars.ContextVar("cx")
367400

src/generic.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,13 +1572,13 @@ impl SenderGlue {
15721572

15731573
#[cfg(feature = "unstable-streams")]
15741574
const STREAM_GLUE: &str = r#"
1575-
import asyncio
1575+
import inspect
15761576
15771577
async def forward(gen, sender):
15781578
async for item in gen:
15791579
should_continue = sender.send(item)
15801580
1581-
if asyncio.iscoroutine(should_continue):
1581+
if inspect.isawaitable(should_continue):
15821582
should_continue = await should_continue
15831583
15841584
if should_continue:

0 commit comments

Comments
 (0)