Skip to content

Commit 640ad44

Browse files
authored
Merge pull request #11 from awestlake87/bugfix-support-all-awaitables
Bugfix - Missing support for asyncio.Future and asyncio.Task in into_future fn
2 parents aa6fe06 + 67ca64d commit 640ad44

File tree

5 files changed

+69
-11
lines changed

5 files changed

+69
-11
lines changed

pytests/common/mod.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,25 @@ pub(super) fn test_blocking_sleep() -> PyResult<()> {
2929
thread::sleep(Duration::from_secs(1));
3030
Ok(())
3131
}
32+
33+
pub(super) async fn test_other_awaitables() -> PyResult<()> {
34+
let fut = Python::with_gil(|py| {
35+
let functools = py.import("functools")?;
36+
let time = py.import("time")?;
37+
38+
// spawn a blocking sleep in the threadpool executor - returns a task, not a coroutine
39+
let task = pyo3_asyncio::get_event_loop(py).call_method1(
40+
"run_in_executor",
41+
(
42+
py.None(),
43+
functools.call_method1("partial", (time.getattr("sleep")?, 1))?,
44+
),
45+
)?;
46+
47+
pyo3_asyncio::into_future(task)
48+
})?;
49+
50+
fut.await?;
51+
52+
Ok(())
53+
}

pytests/test_async_std_asyncio.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ async fn test_into_future() -> PyResult<()> {
6464
common::test_into_future().await
6565
}
6666

67+
#[pyo3_asyncio::async_std::test]
68+
async fn test_other_awaitables() -> PyResult<()> {
69+
common::test_other_awaitables().await
70+
}
71+
6772
#[pyo3_asyncio::async_std::main]
6873
async fn main() -> pyo3::PyResult<()> {
6974
pyo3_asyncio::testing::main().await

pytests/tokio_asyncio/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,8 @@ fn test_blocking_sleep() -> PyResult<()> {
6262
async fn test_into_future() -> PyResult<()> {
6363
common::test_into_future().await
6464
}
65+
66+
#[pyo3_asyncio::tokio::test]
67+
async fn test_other_awaitables() -> PyResult<()> {
68+
common::test_other_awaitables().await
69+
}

src/lib.rs

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,16 @@ pub mod doc_test {
141141
const EXPECT_INIT: &str = "PyO3 Asyncio has not been initialized";
142142

143143
static ASYNCIO: OnceCell<PyObject> = OnceCell::new();
144+
static ENSURE_FUTURE: OnceCell<PyObject> = OnceCell::new();
144145
static EVENT_LOOP: OnceCell<PyObject> = OnceCell::new();
145146
static EXECUTOR: OnceCell<PyObject> = OnceCell::new();
146147
static CALL_SOON: OnceCell<PyObject> = OnceCell::new();
147-
static CREATE_TASK: OnceCell<PyObject> = OnceCell::new();
148148
static CREATE_FUTURE: OnceCell<PyObject> = OnceCell::new();
149149

150+
fn ensure_future(py: Python) -> &PyAny {
151+
ENSURE_FUTURE.get().expect(EXPECT_INIT).as_ref(py)
152+
}
153+
150154
#[allow(clippy::needless_doctest_main)]
151155
/// Wraps the provided function with the initialization and finalization for PyO3 Asyncio
152156
///
@@ -192,6 +196,9 @@ where
192196
/// Must be called at the start of your program
193197
fn try_init(py: Python) -> PyResult<()> {
194198
let asyncio = py.import("asyncio")?;
199+
200+
let ensure_future = asyncio.getattr("ensure_future")?;
201+
195202
let event_loop = asyncio.call_method0("get_event_loop")?;
196203
let executor = py
197204
.import("concurrent.futures.thread")?
@@ -201,14 +208,13 @@ fn try_init(py: Python) -> PyResult<()> {
201208
event_loop.call_method1("set_default_executor", (executor,))?;
202209

203210
let call_soon = event_loop.getattr("call_soon_threadsafe")?;
204-
let create_task = asyncio.getattr("run_coroutine_threadsafe")?;
205211
let create_future = event_loop.getattr("create_future")?;
206212

207213
ASYNCIO.get_or_init(|| asyncio.into());
214+
ENSURE_FUTURE.get_or_init(|| ensure_future.into());
208215
EVENT_LOOP.get_or_init(|| event_loop.into());
209216
EXECUTOR.get_or_init(|| executor.into());
210217
CALL_SOON.get_or_init(|| call_soon.into());
211-
CREATE_TASK.get_or_init(|| create_task.into());
212218
CREATE_FUTURE.get_or_init(|| create_future.into());
213219

214220
Ok(())
@@ -321,6 +327,26 @@ impl PyTaskCompleter {
321327
}
322328
}
323329

330+
#[pyclass]
331+
struct PyEnsureFuture {
332+
awaitable: PyObject,
333+
tx: Option<oneshot::Sender<PyResult<PyObject>>>,
334+
}
335+
336+
#[pymethods]
337+
impl PyEnsureFuture {
338+
#[call]
339+
pub fn __call__(&mut self) -> PyResult<()> {
340+
Python::with_gil(|py| {
341+
let task = ensure_future(py).call1((self.awaitable.as_ref(py),))?;
342+
let on_complete = PyTaskCompleter { tx: self.tx.take() };
343+
task.call_method1("add_done_callback", (on_complete,))?;
344+
345+
Ok(())
346+
})
347+
}
348+
}
349+
324350
/// Convert a Python `awaitable` into a Rust Future
325351
///
326352
/// This function converts the `awaitable` into a Python Task using `run_coroutine_threadsafe`. A
@@ -373,13 +399,13 @@ pub fn into_future(awaitable: &PyAny) -> PyResult<impl Future<Output = PyResult<
373399
let py = awaitable.py();
374400
let (tx, rx) = oneshot::channel();
375401

376-
let task = CREATE_TASK
377-
.get()
378-
.expect(EXPECT_INIT)
379-
.call1(py, (awaitable, get_event_loop(py)))?;
380-
let on_complete = PyTaskCompleter { tx: Some(tx) };
381-
382-
task.call_method1(py, "add_done_callback", (on_complete,))?;
402+
CALL_SOON.get().expect(EXPECT_INIT).call1(
403+
py,
404+
(PyEnsureFuture {
405+
awaitable: awaitable.into(),
406+
tx: Some(tx),
407+
},),
408+
)?;
383409

384410
Ok(async move {
385411
match rx.await {

src/tokio.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ pub fn init_multi_thread() {
8888
Builder::new_multi_thread()
8989
.enable_all()
9090
.build()
91-
.expect("Couldn't build the current-thread Tokio runtime"),
91+
.expect("Couldn't build the multi-thread Tokio runtime"),
9292
);
9393
}
9494

0 commit comments

Comments
 (0)