Skip to content

Commit 42d2d50

Browse files
author
Andrew J Westlake
committed
Added both versions of into_stream + tests for both - need benchmarks for both
1 parent bf903ad commit 42d2d50

File tree

6 files changed

+321
-6
lines changed

6 files changed

+321
-6
lines changed

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ harness = false
8484
required-features = ["tokio-runtime", "testing"]
8585

8686
[dependencies]
87+
async-channel = "1.6"
8788
clap = { version = "2.33", optional = true }
8889
futures = "0.3"
8990
inventory = "0.1"
@@ -100,4 +101,4 @@ optional = true
100101
[dependencies.tokio]
101102
version = "1.4"
102103
features = ["full"]
103-
optional = true
104+
optional = true

pytests/test_async_std_asyncio.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ mod common;
33
use std::time::Duration;
44

55
use async_std::task;
6+
use futures::prelude::*;
67
use pyo3::{prelude::*, wrap_pyfunction};
78

89
#[pyfunction]
@@ -74,6 +75,61 @@ fn test_init_twice() -> PyResult<()> {
7475
common::test_init_twice()
7576
}
7677

78+
const ASYNC_STD_TEST_MOD: &str = r#"
79+
import asyncio
80+
81+
async def gen():
82+
for i in range(10):
83+
await asyncio.sleep(0.1)
84+
yield i
85+
"#;
86+
87+
#[pyo3_asyncio::async_std::test]
88+
async fn test_async_gen_v1() -> PyResult<()> {
89+
let stream = Python::with_gil(|py| {
90+
let test_mod = PyModule::from_code(
91+
py,
92+
ASYNC_STD_TEST_MOD,
93+
"test_rust_coroutine/async_std_test_mod.py",
94+
"async_std_test_mod",
95+
)?;
96+
97+
pyo3_asyncio::async_std::into_stream_v1(test_mod.call_method0("gen")?)
98+
})?;
99+
100+
let vals = stream
101+
.map(|item| Python::with_gil(|py| -> PyResult<i32> { Ok(item?.as_ref(py).extract()?) }))
102+
.try_collect::<Vec<i32>>()
103+
.await?;
104+
105+
assert_eq!((0..10).collect::<Vec<i32>>(), vals);
106+
107+
Ok(())
108+
}
109+
110+
#[pyo3_asyncio::tokio::test]
111+
async fn test_async_gen_v2() -> PyResult<()> {
112+
let stream = Python::with_gil(|py| {
113+
let test_mod = PyModule::from_code(
114+
py,
115+
ASYNC_STD_TEST_MOD,
116+
"test_rust_coroutine/async_std_test_mod.py",
117+
"async_std_test_mod",
118+
)?;
119+
120+
pyo3_asyncio::async_std::into_stream_v2(test_mod.call_method0("gen")?)
121+
})?;
122+
123+
let vals = stream
124+
.map(|item| Python::with_gil(|py| -> PyResult<i32> { Ok(item.as_ref(py).extract()?) }))
125+
.try_collect::<Vec<i32>>()
126+
.await?;
127+
128+
assert_eq!((0..10).collect::<Vec<i32>>(), vals);
129+
130+
Ok(())
131+
}
132+
77133
#[pyo3_asyncio::async_std::main]
78134
async fn main() -> pyo3::PyResult<()> {
79135
pyo3_asyncio::testing::main().await

pytests/tokio_asyncio/mod.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::time::Duration;
22

3+
use futures::prelude::*;
34
use pyo3::{prelude::*, wrap_pyfunction};
45

56
use crate::common;
@@ -82,3 +83,58 @@ fn test_init_tokio_twice() -> PyResult<()> {
8283

8384
Ok(())
8485
}
86+
87+
const TOKIO_TEST_MOD: &str = r#"
88+
import asyncio
89+
90+
async def gen():
91+
for i in range(10):
92+
await asyncio.sleep(0.1)
93+
yield i
94+
"#;
95+
96+
#[pyo3_asyncio::tokio::test]
97+
async fn test_async_gen_v1() -> PyResult<()> {
98+
let stream = Python::with_gil(|py| {
99+
let test_mod = PyModule::from_code(
100+
py,
101+
TOKIO_TEST_MOD,
102+
"test_rust_coroutine/tokio_test_mod.py",
103+
"tokio_test_mod",
104+
)?;
105+
106+
pyo3_asyncio::tokio::into_stream_v1(test_mod.call_method0("gen")?)
107+
})?;
108+
109+
let vals = stream
110+
.map(|item| Python::with_gil(|py| -> PyResult<i32> { Ok(item?.as_ref(py).extract()?) }))
111+
.try_collect::<Vec<i32>>()
112+
.await?;
113+
114+
assert_eq!((0..10).collect::<Vec<i32>>(), vals);
115+
116+
Ok(())
117+
}
118+
119+
#[pyo3_asyncio::tokio::test]
120+
async fn test_async_gen_v2() -> PyResult<()> {
121+
let stream = Python::with_gil(|py| {
122+
let test_mod = PyModule::from_code(
123+
py,
124+
TOKIO_TEST_MOD,
125+
"test_rust_coroutine/tokio_test_mod.py",
126+
"tokio_test_mod",
127+
)?;
128+
129+
pyo3_asyncio::tokio::into_stream_v2(test_mod.call_method0("gen")?)
130+
})?;
131+
132+
let vals = stream
133+
.map(|item| Python::with_gil(|py| -> PyResult<i32> { Ok(item.as_ref(py).extract()?) }))
134+
.try_collect::<Vec<i32>>()
135+
.await?;
136+
137+
assert_eq!((0..10).collect::<Vec<i32>>(), vals);
138+
139+
Ok(())
140+
}

src/async_std.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::future::Future;
22

33
use async_std::task;
4+
use futures::prelude::*;
45
use pyo3::prelude::*;
56

67
use crate::generic::{self, JoinError, Runtime};
@@ -117,3 +118,15 @@ where
117118
{
118119
generic::into_coroutine::<AsyncStdRuntime, _>(py, fut)
119120
}
121+
122+
/// Convert async generator into a stream
123+
pub fn into_stream_v1<'p>(
124+
gen: &'p PyAny,
125+
) -> PyResult<impl Stream<Item = PyResult<PyObject>> + 'static> {
126+
generic::into_stream_v1::<AsyncStdRuntime>(gen)
127+
}
128+
129+
/// Convert async generator into a stream
130+
pub fn into_stream_v2<'p>(gen: &'p PyAny) -> PyResult<impl Stream<Item = PyObject> + 'static> {
131+
generic::into_stream_v2::<AsyncStdRuntime>(gen)
132+
}

src/generic.rs

Lines changed: 181 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
use std::future::Future;
1+
use std::{future::Future, marker::PhantomData};
22

3-
use pyo3::{exceptions::PyException, prelude::*};
3+
use futures::{channel::mpsc, prelude::*};
4+
use once_cell::sync::OnceCell;
5+
use pyo3::{exceptions::PyException, prelude::*, PyNativeType};
46

5-
use crate::{dump_err, get_event_loop, CALL_SOON, CREATE_FUTURE, EXPECT_INIT};
7+
use crate::{dump_err, get_event_loop, into_future, CALL_SOON, CREATE_FUTURE, EXPECT_INIT};
68

79
/// Generic utilities for a JoinError
810
pub trait JoinError {
@@ -11,7 +13,7 @@ pub trait JoinError {
1113
}
1214

1315
/// Generic Rust async/await runtime
14-
pub trait Runtime {
16+
pub trait Runtime: Send + 'static {
1517
/// The error returned by a JoinHandle after being awaited
1618
type JoinError: JoinError + Send;
1719
/// A future that completes with the result of the spawned task
@@ -236,3 +238,178 @@ where
236238

237239
Ok(future_rx)
238240
}
241+
242+
/// Convert an async generator into a Stream
243+
pub fn into_stream_v1<'p, R>(
244+
gen: &'p PyAny,
245+
) -> PyResult<impl Stream<Item = PyResult<PyObject>> + 'static>
246+
where
247+
R: Runtime,
248+
{
249+
let (tx, rx) = async_channel::bounded(1);
250+
let anext = PyObject::from(gen.getattr("__anext__")?);
251+
252+
R::spawn(async move {
253+
loop {
254+
let fut =
255+
Python::with_gil(|py| -> PyResult<_> { into_future(anext.as_ref(py).call0()?) });
256+
let item = match fut {
257+
Ok(fut) => match fut.await {
258+
Ok(item) => Ok(item),
259+
Err(e) => {
260+
if Python::with_gil(|py| {
261+
e.is_instance::<pyo3::exceptions::PyStopAsyncIteration>(py)
262+
}) {
263+
// end the iteration
264+
break;
265+
} else {
266+
Err(e)
267+
}
268+
}
269+
},
270+
Err(e) => Err(e),
271+
};
272+
273+
if let Err(_) = tx.send(item).await {
274+
// receiving side was dropped
275+
break;
276+
}
277+
}
278+
});
279+
280+
Ok(rx)
281+
}
282+
283+
fn py_true() -> PyObject {
284+
static TRUE: OnceCell<PyObject> = OnceCell::new();
285+
TRUE.get_or_init(|| Python::with_gil(|py| true.into_py(py)))
286+
.clone()
287+
}
288+
fn py_false() -> PyObject {
289+
static FALSE: OnceCell<PyObject> = OnceCell::new();
290+
FALSE
291+
.get_or_init(|| Python::with_gil(|py| false.into_py(py)))
292+
.clone()
293+
}
294+
295+
trait Sender: Send + 'static {
296+
fn send(&mut self, item: PyObject) -> PyResult<PyObject>;
297+
fn close(&mut self) -> PyResult<()>;
298+
}
299+
300+
struct GenericSender<R>
301+
where
302+
R: Runtime,
303+
{
304+
runtime: PhantomData<R>,
305+
tx: mpsc::Sender<PyObject>,
306+
}
307+
308+
impl<R> Sender for GenericSender<R>
309+
where
310+
R: Runtime,
311+
{
312+
fn send(&mut self, item: PyObject) -> PyResult<PyObject> {
313+
match self.tx.try_send(item.clone()) {
314+
Ok(_) => Ok(py_true()),
315+
Err(e) => {
316+
if e.is_full() {
317+
let mut tx = self.tx.clone();
318+
Python::with_gil(move |py| {
319+
into_coroutine::<R, _>(py, async move {
320+
if tx.flush().await.is_err() {
321+
// receiving side disconnected
322+
return Ok(py_false());
323+
}
324+
if tx.send(item).await.is_err() {
325+
// receiving side disconnected
326+
return Ok(py_false());
327+
}
328+
Ok(py_true())
329+
})
330+
})
331+
} else {
332+
Ok(py_false())
333+
}
334+
}
335+
}
336+
}
337+
fn close(&mut self) -> PyResult<()> {
338+
self.tx.close_channel();
339+
Ok(())
340+
}
341+
}
342+
343+
#[pyclass]
344+
struct SenderGlue {
345+
tx: Box<dyn Sender>,
346+
}
347+
#[pymethods]
348+
impl SenderGlue {
349+
pub fn send(&mut self, item: PyObject) -> PyResult<PyObject> {
350+
self.tx.send(item)
351+
}
352+
pub fn close(&mut self) -> PyResult<()> {
353+
self.tx.close()
354+
}
355+
}
356+
357+
const STREAM_GLUE: &str = r#"
358+
import asyncio
359+
360+
async def forward(gen, sender):
361+
async for item in gen:
362+
should_continue = sender.send(item)
363+
364+
if asyncio.iscoroutine(should_continue):
365+
should_continue = await should_continue
366+
367+
if should_continue:
368+
continue
369+
else:
370+
break
371+
372+
sender.close()
373+
"#;
374+
375+
/// Convert an async generator into a stream
376+
pub fn into_stream_v2<'p, R>(gen: &'p PyAny) -> PyResult<impl Stream<Item = PyObject> + 'static>
377+
where
378+
R: Runtime,
379+
{
380+
static GLUE_MOD: OnceCell<PyObject> = OnceCell::new();
381+
let py = gen.py();
382+
let glue = GLUE_MOD
383+
.get_or_try_init(|| -> PyResult<PyObject> {
384+
Ok(PyModule::from_code(
385+
py,
386+
STREAM_GLUE,
387+
"pyo3_asyncio/pyo3_asyncio_glue.py",
388+
"pyo3_asyncio_glue",
389+
)?
390+
.into())
391+
})?
392+
.as_ref(py);
393+
394+
let (tx, rx) = mpsc::channel(10);
395+
396+
crate::get_event_loop(py).call_method1(
397+
"call_soon_threadsafe",
398+
(
399+
crate::get_event_loop(py).getattr("create_task")?,
400+
glue.call_method1(
401+
"forward",
402+
(
403+
gen,
404+
SenderGlue {
405+
tx: Box::new(GenericSender {
406+
runtime: PhantomData::<R>,
407+
tx,
408+
}),
409+
},
410+
),
411+
)?,
412+
),
413+
)?;
414+
Ok(rx)
415+
}

0 commit comments

Comments
 (0)