Skip to content

Commit 58baa05

Browse files
authored
Merge pull request #24 from awestlake87/local-tasks
Add support for spawn_local
2 parents e78aaa8 + 2ad589d commit 58baa05

File tree

5 files changed

+288
-3
lines changed

5 files changed

+288
-3
lines changed

pytests/test_async_std_asyncio.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
mod common;
22

3-
use std::time::Duration;
3+
use std::{rc::Rc, time::Duration};
44

55
use async_std::task;
66
use pyo3::{prelude::*, wrap_pyfunction};
@@ -78,3 +78,20 @@ fn test_init_twice() -> PyResult<()> {
7878
async fn main() -> pyo3::PyResult<()> {
7979
pyo3_asyncio::testing::main().await
8080
}
81+
82+
#[pyo3_asyncio::async_std::test]
83+
async fn test_local_coroutine() -> PyResult<()> {
84+
Python::with_gil(|py| {
85+
let non_send_secs = Rc::new(1);
86+
87+
let py_future = pyo3_asyncio::async_std::into_local_py_future(py, async move {
88+
async_std::task::sleep(Duration::from_secs(*non_send_secs)).await;
89+
Ok(Python::with_gil(|py| py.None()))
90+
})?;
91+
92+
pyo3_asyncio::into_future(py_future.as_ref(py))
93+
})?
94+
.await?;
95+
96+
Ok(())
97+
}

pytests/tokio_asyncio/mod.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::time::Duration;
1+
use std::{rc::Rc, time::Duration};
22

33
use pyo3::{prelude::*, wrap_pyfunction};
44

@@ -82,3 +82,22 @@ fn test_init_tokio_twice() -> PyResult<()> {
8282

8383
Ok(())
8484
}
85+
86+
#[pyo3_asyncio::tokio::test]
87+
fn test_local_set_coroutine() -> PyResult<()> {
88+
tokio::task::LocalSet::new().block_on(pyo3_asyncio::tokio::get_runtime(), async {
89+
Python::with_gil(|py| {
90+
let non_send_secs = Rc::new(1);
91+
92+
let py_future = pyo3_asyncio::tokio::into_local_py_future(py, async move {
93+
tokio::time::sleep(Duration::from_secs(*non_send_secs)).await;
94+
Ok(Python::with_gil(|py| py.None()))
95+
})?;
96+
97+
pyo3_asyncio::into_future(py_future.as_ref(py))
98+
})?
99+
.await?;
100+
101+
Ok(())
102+
})
103+
}

src/async_std.rs

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::future::Future;
33
use async_std::task;
44
use pyo3::prelude::*;
55

6-
use crate::generic::{self, JoinError, Runtime};
6+
use crate::generic::{self, JoinError, Runtime, SpawnLocalExt};
77

88
/// <span class="module-item stab portability" style="display: inline; border-radius: 3px; padding: 2px; font-size: 80%; line-height: 1.2;"><code>attributes</code></span>
99
/// re-exports for macros
@@ -48,6 +48,18 @@ impl Runtime for AsyncStdRuntime {
4848
}
4949
}
5050

51+
impl SpawnLocalExt for AsyncStdRuntime {
52+
fn spawn_local<F>(fut: F) -> Self::JoinHandle
53+
where
54+
F: Future<Output = ()> + 'static,
55+
{
56+
task::spawn_local(async move {
57+
fut.await;
58+
Ok(())
59+
})
60+
}
61+
}
62+
5163
/// Run the event loop until the given Future completes
5264
///
5365
/// The event loop runs until the given future is complete.
@@ -117,3 +129,49 @@ where
117129
{
118130
generic::into_coroutine::<AsyncStdRuntime, _>(py, fut)
119131
}
132+
133+
/// Convert a `!Send` Rust Future into a Python coroutine
134+
///
135+
/// # Arguments
136+
/// * `py` - The current PyO3 GIL guard
137+
/// * `fut` - The Rust future to be converted
138+
///
139+
/// # Examples
140+
///
141+
/// ```
142+
/// use std::{rc::Rc, time::Duration};
143+
///
144+
/// use pyo3::prelude::*;
145+
///
146+
/// /// Awaitable non-send sleep function
147+
/// #[pyfunction]
148+
/// fn sleep_for(py: Python, secs: u64) -> PyResult<PyObject> {
149+
/// // Rc is non-send so it cannot be passed into pyo3_asyncio::tokio::into_coroutine
150+
/// let secs = Rc::new(secs);
151+
///
152+
/// pyo3_asyncio::async_std::into_local_py_future(py, async move {
153+
/// async_std::task::sleep(Duration::from_secs(*secs)).await;
154+
/// Python::with_gil(|py| Ok(py.None()))
155+
/// })
156+
/// }
157+
///
158+
/// # #[cfg(all(feature = "async-std-runtime", feature = "attributes"))]
159+
/// #[pyo3_asyncio::async_std::main]
160+
/// async fn main() -> PyResult<()> {
161+
/// Python::with_gil(|py| {
162+
/// let py_future = sleep_for(py, 1)?;
163+
/// pyo3_asyncio::into_future(py_future.as_ref(py))
164+
/// })?
165+
/// .await?;
166+
///
167+
/// Ok(())
168+
/// }
169+
/// # #[cfg(not(all(feature = "async-std-runtime", feature = "attributes")))]
170+
/// # fn main() {}
171+
/// ```
172+
pub fn into_local_py_future<F>(py: Python, fut: F) -> PyResult<PyObject>
173+
where
174+
F: Future<Output = PyResult<PyObject>> + 'static,
175+
{
176+
generic::into_local_py_future::<AsyncStdRuntime, _>(py, fut)
177+
}

src/generic.rs

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ pub trait Runtime {
2323
F: Future<Output = ()> + Send + 'static;
2424
}
2525

26+
/// Extension trait for async/await runtimes that support spawning local tasks
27+
pub trait SpawnLocalExt: Runtime {
28+
/// Spawn a !Send future onto this runtime's event loop
29+
fn spawn_local<F>(fut: F) -> Self::JoinHandle
30+
where
31+
F: Future<Output = ()> + 'static;
32+
}
33+
2634
/// Run the event loop until the given Future completes
2735
///
2836
/// After this function returns, the event loop can be resumed with either [`run_until_complete`] or
@@ -236,3 +244,123 @@ where
236244

237245
Ok(future_rx)
238246
}
247+
248+
/// Convert a `!Send` Rust Future into a Python coroutine with a generic runtime
249+
///
250+
/// # Arguments
251+
/// * `py` - The current PyO3 GIL guard
252+
/// * `fut` - The Rust future to be converted
253+
///
254+
/// # Examples
255+
///
256+
/// ```no_run
257+
/// # use std::{task::{Context, Poll}, pin::Pin, future::Future};
258+
/// #
259+
/// # use pyo3_asyncio::generic::{JoinError, SpawnLocalExt, Runtime};
260+
/// #
261+
/// # struct MyCustomJoinError;
262+
/// #
263+
/// # impl JoinError for MyCustomJoinError {
264+
/// # fn is_panic(&self) -> bool {
265+
/// # unreachable!()
266+
/// # }
267+
/// # }
268+
/// #
269+
/// # struct MyCustomJoinHandle;
270+
/// #
271+
/// # impl Future for MyCustomJoinHandle {
272+
/// # type Output = Result<(), MyCustomJoinError>;
273+
/// #
274+
/// # fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
275+
/// # unreachable!()
276+
/// # }
277+
/// # }
278+
/// #
279+
/// # struct MyCustomRuntime;
280+
/// #
281+
/// # impl MyCustomRuntime {
282+
/// # async fn sleep(_: Duration) {
283+
/// # unreachable!()
284+
/// # }
285+
/// # }
286+
/// #
287+
/// # impl Runtime for MyCustomRuntime {
288+
/// # type JoinError = MyCustomJoinError;
289+
/// # type JoinHandle = MyCustomJoinHandle;
290+
/// #
291+
/// # fn spawn<F>(fut: F) -> Self::JoinHandle
292+
/// # where
293+
/// # F: Future<Output = ()> + Send + 'static
294+
/// # {
295+
/// # unreachable!()
296+
/// # }
297+
/// # }
298+
/// #
299+
/// # impl SpawnLocalExt for MyCustomRuntime {
300+
/// # fn spawn_local<F>(fut: F) -> Self::JoinHandle
301+
/// # where
302+
/// # F: Future<Output = ()> + 'static
303+
/// # {
304+
/// # unreachable!()
305+
/// # }
306+
/// # }
307+
/// #
308+
/// use std::time::Duration;
309+
///
310+
/// use pyo3::prelude::*;
311+
///
312+
/// /// Awaitable sleep function
313+
/// #[pyfunction]
314+
/// fn sleep_for(py: Python, secs: &PyAny) -> PyResult<PyObject> {
315+
/// let secs = secs.extract()?;
316+
///
317+
/// pyo3_asyncio::generic::into_local_py_future::<MyCustomRuntime, _>(py, async move {
318+
/// MyCustomRuntime::sleep(Duration::from_secs(secs)).await;
319+
/// Python::with_gil(|py| Ok(py.None()))
320+
/// })
321+
/// }
322+
/// ```
323+
pub fn into_local_py_future<R, F>(py: Python, fut: F) -> PyResult<PyObject>
324+
where
325+
R: SpawnLocalExt,
326+
F: Future<Output = PyResult<PyObject>> + 'static,
327+
{
328+
let future_rx = CREATE_FUTURE.get().expect(EXPECT_INIT).call0(py)?;
329+
let future_tx1 = future_rx.clone();
330+
let future_tx2 = future_rx.clone();
331+
332+
R::spawn_local(async move {
333+
if let Err(e) = R::spawn_local(async move {
334+
let result = fut.await;
335+
336+
Python::with_gil(move |py| {
337+
if set_result(py, future_tx1.as_ref(py), result)
338+
.map_err(dump_err(py))
339+
.is_err()
340+
{
341+
342+
// Cancelled
343+
}
344+
});
345+
})
346+
.await
347+
{
348+
if e.is_panic() {
349+
Python::with_gil(move |py| {
350+
if set_result(
351+
py,
352+
future_tx2.as_ref(py),
353+
Err(PyException::new_err("rust future panicked")),
354+
)
355+
.map_err(dump_err(py))
356+
.is_err()
357+
{
358+
// Cancelled
359+
}
360+
});
361+
}
362+
}
363+
});
364+
365+
Ok(future_rx)
366+
}

src/tokio.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ impl generic::Runtime for TokioRuntime {
5656
}
5757
}
5858

59+
impl generic::SpawnLocalExt for TokioRuntime {
60+
fn spawn_local<F>(fut: F) -> Self::JoinHandle
61+
where
62+
F: Future<Output = ()> + 'static,
63+
{
64+
tokio::task::spawn_local(fut)
65+
}
66+
}
67+
5968
/// Initialize the Tokio Runtime with a custom build
6069
pub fn init(runtime: Runtime) {
6170
TOKIO_RUNTIME
@@ -207,3 +216,57 @@ where
207216
{
208217
generic::into_coroutine::<TokioRuntime, _>(py, fut)
209218
}
219+
220+
/// Convert a `!Send` Rust Future into a Python coroutine
221+
///
222+
/// # Arguments
223+
/// * `py` - The current PyO3 GIL guard
224+
/// * `fut` - The Rust future to be converted
225+
///
226+
/// # Examples
227+
///
228+
/// ```
229+
/// use std::{rc::Rc, time::Duration};
230+
///
231+
/// use pyo3::prelude::*;
232+
///
233+
/// /// Awaitable non-send sleep function
234+
/// #[pyfunction]
235+
/// fn sleep_for(py: Python, secs: u64) -> PyResult<PyObject> {
236+
/// // Rc is non-send so it cannot be passed into pyo3_asyncio::tokio::into_coroutine
237+
/// let secs = Rc::new(secs);
238+
///
239+
/// pyo3_asyncio::tokio::into_local_py_future(py, async move {
240+
/// tokio::time::sleep(Duration::from_secs(*secs)).await;
241+
/// Python::with_gil(|py| Ok(py.None()))
242+
/// })
243+
/// }
244+
///
245+
/// # #[cfg(all(feature = "tokio-runtime", feature = "attributes"))]
246+
/// #[pyo3_asyncio::tokio::main]
247+
/// async fn main() -> PyResult<()> {
248+
/// // the main coroutine is running in a Send context, so we cannot use LocalSet here. Instead
249+
/// // we use spawn_blocking in order to use LocalSet::block_on
250+
/// tokio::task::spawn_blocking(|| {
251+
/// // LocalSet allows us to work with !Send futures within tokio. Without it, any calls to
252+
/// // pyo3_asyncio::tokio::into_local_py_future will panic.
253+
/// tokio::task::LocalSet::new().block_on(pyo3_asyncio::tokio::get_runtime(), async {
254+
/// Python::with_gil(|py| {
255+
/// let py_future = sleep_for(py, 1)?;
256+
/// pyo3_asyncio::into_future(py_future.as_ref(py))
257+
/// })?
258+
/// .await?;
259+
///
260+
/// Ok(())
261+
/// })
262+
/// }).await.unwrap()
263+
/// }
264+
/// # #[cfg(not(all(feature = "tokio-runtime", feature = "attributes")))]
265+
/// # fn main() {}
266+
/// ```
267+
pub fn into_local_py_future<F>(py: Python, fut: F) -> PyResult<PyObject>
268+
where
269+
F: Future<Output = PyResult<PyObject>> + 'static,
270+
{
271+
generic::into_local_py_future::<TokioRuntime, _>(py, fut)
272+
}

0 commit comments

Comments
 (0)