Skip to content

Commit c22d51a

Browse files
author
Andrew J Westlake
committed
Wrapped into_future awaitable arg in call to ensure_future in order to support passing asyncio.Future / asyncio.Task objects as well as coroutines
1 parent aa6fe06 commit c22d51a

File tree

1 file changed

+36
-7
lines changed

1 file changed

+36
-7
lines changed

src/lib.rs

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,17 @@ pub mod doc_test {
141141
const EXPECT_INIT: &str = "PyO3 Asyncio has not been initialized";
142142

143143
static ASYNCIO: OnceCell<PyObject> = OnceCell::new();
144+
static ENSURE_FUTURE: OnceCell<PyObject> = OnceCell::new();
144145
static EVENT_LOOP: OnceCell<PyObject> = OnceCell::new();
145146
static EXECUTOR: OnceCell<PyObject> = OnceCell::new();
146147
static CALL_SOON: OnceCell<PyObject> = OnceCell::new();
147148
static CREATE_TASK: OnceCell<PyObject> = OnceCell::new();
148149
static CREATE_FUTURE: OnceCell<PyObject> = OnceCell::new();
149150

151+
fn ensure_future(py: Python) -> &PyAny {
152+
ENSURE_FUTURE.get().expect(EXPECT_INIT).as_ref(py)
153+
}
154+
150155
#[allow(clippy::needless_doctest_main)]
151156
/// Wraps the provided function with the initialization and finalization for PyO3 Asyncio
152157
///
@@ -192,6 +197,9 @@ where
192197
/// Must be called at the start of your program
193198
fn try_init(py: Python) -> PyResult<()> {
194199
let asyncio = py.import("asyncio")?;
200+
201+
let ensure_future = asyncio.getattr("ensure_future")?;
202+
195203
let event_loop = asyncio.call_method0("get_event_loop")?;
196204
let executor = py
197205
.import("concurrent.futures.thread")?
@@ -205,6 +213,7 @@ fn try_init(py: Python) -> PyResult<()> {
205213
let create_future = event_loop.getattr("create_future")?;
206214

207215
ASYNCIO.get_or_init(|| asyncio.into());
216+
ENSURE_FUTURE.get_or_init(|| ensure_future.into());
208217
EVENT_LOOP.get_or_init(|| event_loop.into());
209218
EXECUTOR.get_or_init(|| executor.into());
210219
CALL_SOON.get_or_init(|| call_soon.into());
@@ -321,6 +330,26 @@ impl PyTaskCompleter {
321330
}
322331
}
323332

333+
#[pyclass]
334+
struct PyEnsureFuture {
335+
awaitable: PyObject,
336+
tx: Option<oneshot::Sender<PyResult<PyObject>>>,
337+
}
338+
339+
#[pymethods]
340+
impl PyEnsureFuture {
341+
#[call]
342+
pub fn __call__(&mut self) -> PyResult<()> {
343+
Python::with_gil(|py| {
344+
let task = ensure_future(py).call1((self.awaitable.as_ref(py),))?;
345+
let on_complete = PyTaskCompleter { tx: self.tx.take() };
346+
task.call_method1("add_done_callback", (on_complete,))?;
347+
348+
Ok(())
349+
})
350+
}
351+
}
352+
324353
/// Convert a Python `awaitable` into a Rust Future
325354
///
326355
/// This function converts the `awaitable` into a Python Task using `run_coroutine_threadsafe`. A
@@ -373,13 +402,13 @@ pub fn into_future(awaitable: &PyAny) -> PyResult<impl Future<Output = PyResult<
373402
let py = awaitable.py();
374403
let (tx, rx) = oneshot::channel();
375404

376-
let task = CREATE_TASK
377-
.get()
378-
.expect(EXPECT_INIT)
379-
.call1(py, (awaitable, get_event_loop(py)))?;
380-
let on_complete = PyTaskCompleter { tx: Some(tx) };
381-
382-
task.call_method1(py, "add_done_callback", (on_complete,))?;
405+
CALL_SOON.get().expect(EXPECT_INIT).call1(
406+
py,
407+
(PyEnsureFuture {
408+
awaitable: awaitable.into(),
409+
tx: Some(tx),
410+
},),
411+
)?;
383412

384413
Ok(async move {
385414
match rx.await {

0 commit comments

Comments
 (0)