Skip to content

Commit 4fffe01

Browse files
author
Andrew J Westlake
committed
Minor naming change, fixed some docs / generic impls, made async-std scope remember old value (actually acts as a scope now)
1 parent 8de1197 commit 4fffe01

File tree

7 files changed

+249
-62
lines changed

7 files changed

+249
-62
lines changed

pyo3-asyncio-macros/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ pub fn async_std_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
153153
} else {
154154
quote! {
155155
let event_loop = Python::with_gil(|py| {
156-
pyo3_asyncio::async_std::task_event_loop(py).unwrap().into()
156+
pyo3_asyncio::async_std::get_current_loop(py).unwrap().into()
157157
});
158158
Box::pin(pyo3_asyncio::async_std::re_exports::spawn_blocking(move || {
159159
#name(event_loop)
@@ -258,7 +258,7 @@ pub fn tokio_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
258258
} else {
259259
quote! {
260260
let event_loop = Python::with_gil(|py| {
261-
pyo3_asyncio::tokio::task_event_loop(py).unwrap().into()
261+
pyo3_asyncio::tokio::get_current_loop(py).unwrap().into()
262262
});
263263
Box::pin(async move {
264264
match pyo3_asyncio::tokio::get_runtime().spawn_blocking(move || #name(event_loop)).await {

pytests/test_async_std_asyncio.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ fn test_blocking_sleep() -> PyResult<()> {
9898
#[pyo3_asyncio::async_std::test]
9999
async fn test_into_future() -> PyResult<()> {
100100
common::test_into_future(Python::with_gil(|py| {
101-
pyo3_asyncio::async_std::task_event_loop(py).unwrap().into()
101+
pyo3_asyncio::async_std::get_current_loop(py)
102+
.unwrap()
103+
.into()
102104
}))
103105
.await
104106
}
@@ -111,7 +113,9 @@ async fn test_into_future_0_13() -> PyResult<()> {
111113
#[pyo3_asyncio::async_std::test]
112114
async fn test_other_awaitables() -> PyResult<()> {
113115
common::test_other_awaitables(Python::with_gil(|py| {
114-
pyo3_asyncio::async_std::task_event_loop(py).unwrap().into()
116+
pyo3_asyncio::async_std::get_current_loop(py)
117+
.unwrap()
118+
.into()
115119
}))
116120
.await
117121
}

pytests/tokio_asyncio/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ fn test_blocking_sleep() -> PyResult<()> {
9797
#[pyo3_asyncio::tokio::test]
9898
async fn test_into_future() -> PyResult<()> {
9999
common::test_into_future(Python::with_gil(|py| {
100-
pyo3_asyncio::tokio::task_event_loop(py).unwrap().into()
100+
pyo3_asyncio::tokio::get_current_loop(py).unwrap().into()
101101
}))
102102
.await
103103
}
@@ -110,7 +110,7 @@ async fn test_into_future_0_13() -> PyResult<()> {
110110
#[pyo3_asyncio::tokio::test]
111111
async fn test_other_awaitables() -> PyResult<()> {
112112
common::test_other_awaitables(Python::with_gil(|py| {
113-
pyo3_asyncio::tokio::task_event_loop(py).unwrap().into()
113+
pyo3_asyncio::tokio::get_current_loop(py).unwrap().into()
114114
}))
115115
.await
116116
}

src/async_std.rs

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
1-
use std::{any::Any, future::Future, panic::AssertUnwindSafe, pin::Pin};
1+
use std::{any::Any, cell::RefCell, future::Future, panic::AssertUnwindSafe, pin::Pin};
22

33
use async_std::task;
44
use futures::prelude::*;
5-
use once_cell::unsync::OnceCell;
6-
use pyo3::{prelude::*, PyNativeType};
5+
use pyo3::prelude::*;
76

8-
use crate::{
9-
generic::{self, JoinError, Runtime, SpawnLocalExt},
10-
into_future_with_loop,
11-
};
7+
use crate::generic::{self, JoinError, Runtime, SpawnLocalExt};
128

139
/// <span class="module-item stab portability" style="display: inline; border-radius: 3px; padding: 2px; font-size: 80%; line-height: 1.2;"><code>attributes</code></span>
1410
/// re-exports for macros
@@ -37,7 +33,7 @@ impl JoinError for AsyncStdJoinErr {
3733
}
3834

3935
async_std::task_local! {
40-
static EVENT_LOOP: OnceCell<PyObject> = OnceCell::new()
36+
static EVENT_LOOP: RefCell<Option<PyObject>> = RefCell::new(None);
4137
}
4238

4339
struct AsyncStdRuntime;
@@ -50,11 +46,19 @@ impl Runtime for AsyncStdRuntime {
5046
where
5147
F: Future<Output = R> + Send + 'static,
5248
{
53-
EVENT_LOOP.with(|c| c.set(event_loop).unwrap());
54-
Box::pin(fut)
49+
let old = EVENT_LOOP.with(|c| c.replace(Some(event_loop)));
50+
Box::pin(async move {
51+
let result = fut.await;
52+
EVENT_LOOP.with(|c| c.replace(old));
53+
result
54+
})
5555
}
5656
fn get_task_event_loop(py: Python) -> Option<&PyAny> {
57-
match EVENT_LOOP.try_with(|c| c.get().map(|event_loop| event_loop.clone().into_ref(py))) {
57+
match EVENT_LOOP.try_with(|c| {
58+
c.borrow()
59+
.as_ref()
60+
.map(|event_loop| event_loop.clone().into_ref(py))
61+
}) {
5862
Ok(event_loop) => event_loop,
5963
Err(_) => None,
6064
}
@@ -78,8 +82,12 @@ impl SpawnLocalExt for AsyncStdRuntime {
7882
where
7983
F: Future<Output = R> + 'static,
8084
{
81-
EVENT_LOOP.with(|c| c.set(event_loop).unwrap());
82-
Box::pin(fut)
85+
let old = EVENT_LOOP.with(|c| c.replace(Some(event_loop)));
86+
Box::pin(async move {
87+
let result = fut.await;
88+
EVENT_LOOP.with(|c| c.replace(old));
89+
result
90+
})
8391
}
8492

8593
fn spawn_local<F>(fut: F) -> Self::JoinHandle
@@ -110,21 +118,20 @@ where
110118
}
111119

112120
/// Get the current event loop from either Python or Rust async task local context
121+
///
122+
/// This function first checks if the runtime has a task-local reference to the Python event loop.
123+
/// If not, it calls [`get_running_loop`](`crate::get_running_loop`) to get the event loop
124+
/// associated with the current OS thread.
113125
pub fn get_current_loop(py: Python) -> PyResult<&PyAny> {
114126
generic::get_current_loop::<AsyncStdRuntime>(py)
115127
}
116128

117-
/// Get the task local event loop for the current async_std task
118-
pub fn task_event_loop(py: Python) -> Option<&PyAny> {
119-
AsyncStdRuntime::get_task_event_loop(py)
120-
}
121-
122129
/// Run the event loop until the given Future completes
123130
///
124131
/// The event loop runs until the given future is complete.
125132
///
126133
/// After this function returns, the event loop can be resumed with either [`run_until_complete`] or
127-
/// [`crate::run_forever`]
134+
/// [`run_forever`](`crate::run_forever`)
128135
///
129136
/// # Arguments
130137
/// * `py` - The current PyO3 GIL guard
@@ -227,6 +234,39 @@ where
227234
generic::into_coroutine::<AsyncStdRuntime, _>(py, fut)
228235
}
229236

237+
/// Convert a Rust Future into a Python awaitable
238+
///
239+
/// # Arguments
240+
/// * `event_loop` - The Python event loop that the awaitable should be attached to
241+
/// * `fut` - The Rust future to be converted
242+
///
243+
/// # Examples
244+
///
245+
/// ```
246+
/// use std::time::Duration;
247+
///
248+
/// use pyo3::prelude::*;
249+
///
250+
/// /// Awaitable sleep function
251+
/// #[pyfunction]
252+
/// fn sleep_for<'p>(py: Python<'p>, secs: &'p PyAny) -> PyResult<&'p PyAny> {
253+
/// let secs = secs.extract()?;
254+
/// pyo3_asyncio::async_std::future_into_py_with_loop(
255+
/// pyo3_asyncio::async_std::get_current_loop(py)?,
256+
/// async move {
257+
/// async_std::task::sleep(Duration::from_secs(secs)).await;
258+
/// Python::with_gil(|py| Ok(py.None()))
259+
/// }
260+
/// )
261+
/// }
262+
/// ```
263+
pub fn future_into_py_with_loop<F>(event_loop: &PyAny, fut: F) -> PyResult<&PyAny>
264+
where
265+
F: Future<Output = PyResult<PyObject>> + Send + 'static,
266+
{
267+
generic::future_into_py_with_loop::<AsyncStdRuntime, F>(event_loop, fut)
268+
}
269+
230270
/// Convert a Rust Future into a Python awaitable
231271
///
232272
/// # Arguments
@@ -260,7 +300,7 @@ where
260300
/// Convert a `!Send` Rust Future into a Python awaitable
261301
///
262302
/// # Arguments
263-
/// * `py` - The current PyO3 GIL guard
303+
/// * `event_loop` - The Python event loop that the awaitable should be attached to
264304
/// * `fut` - The Rust future to be converted
265305
///
266306
/// # Examples
@@ -273,7 +313,7 @@ where
273313
/// /// Awaitable non-send sleep function
274314
/// #[pyfunction]
275315
/// fn sleep_for(py: Python, secs: u64) -> PyResult<&PyAny> {
276-
/// // Rc is non-send so it cannot be passed into pyo3_asyncio::async_std::into_coroutine
316+
/// // Rc is non-send so it cannot be passed into pyo3_asyncio::async_std::future_into_py
277317
/// let secs = Rc::new(secs);
278318
/// Ok(pyo3_asyncio::async_std::local_future_into_py_with_loop(
279319
/// pyo3_asyncio::async_std::get_current_loop(py)?,
@@ -321,7 +361,7 @@ where
321361
/// /// Awaitable non-send sleep function
322362
/// #[pyfunction]
323363
/// fn sleep_for(py: Python, secs: u64) -> PyResult<&PyAny> {
324-
/// // Rc is non-send so it cannot be passed into pyo3_asyncio::async_std::into_coroutine
364+
/// // Rc is non-send so it cannot be passed into pyo3_asyncio::async_std::future_into_py
325365
/// let secs = Rc::new(secs);
326366
/// pyo3_asyncio::async_std::local_future_into_py(py, async move {
327367
/// async_std::task::sleep(Duration::from_secs(*secs)).await;
@@ -399,5 +439,5 @@ where
399439
/// }
400440
/// ```
401441
pub fn into_future(awaitable: &PyAny) -> PyResult<impl Future<Output = PyResult<PyObject>> + Send> {
402-
into_future_with_loop(get_current_loop(awaitable.py())?, awaitable)
442+
generic::into_future::<AsyncStdRuntime>(awaitable)
403443
}

0 commit comments

Comments
 (0)