@@ -141,12 +141,16 @@ pub mod doc_test {
141
141
const EXPECT_INIT : & str = "PyO3 Asyncio has not been initialized" ;
142
142
143
143
static ASYNCIO : OnceCell < PyObject > = OnceCell :: new ( ) ;
144
+ static ENSURE_FUTURE : OnceCell < PyObject > = OnceCell :: new ( ) ;
144
145
static EVENT_LOOP : OnceCell < PyObject > = OnceCell :: new ( ) ;
145
146
static EXECUTOR : OnceCell < PyObject > = OnceCell :: new ( ) ;
146
147
static CALL_SOON : OnceCell < PyObject > = OnceCell :: new ( ) ;
147
- static CREATE_TASK : OnceCell < PyObject > = OnceCell :: new ( ) ;
148
148
static CREATE_FUTURE : OnceCell < PyObject > = OnceCell :: new ( ) ;
149
149
150
+ fn ensure_future ( py : Python ) -> & PyAny {
151
+ ENSURE_FUTURE . get ( ) . expect ( EXPECT_INIT ) . as_ref ( py)
152
+ }
153
+
150
154
#[ allow( clippy:: needless_doctest_main) ]
151
155
/// Wraps the provided function with the initialization and finalization for PyO3 Asyncio
152
156
///
@@ -192,6 +196,9 @@ where
192
196
/// Must be called at the start of your program
193
197
fn try_init ( py : Python ) -> PyResult < ( ) > {
194
198
let asyncio = py. import ( "asyncio" ) ?;
199
+
200
+ let ensure_future = asyncio. getattr ( "ensure_future" ) ?;
201
+
195
202
let event_loop = asyncio. call_method0 ( "get_event_loop" ) ?;
196
203
let executor = py
197
204
. import ( "concurrent.futures.thread" ) ?
@@ -201,14 +208,13 @@ fn try_init(py: Python) -> PyResult<()> {
201
208
event_loop. call_method1 ( "set_default_executor" , ( executor, ) ) ?;
202
209
203
210
let call_soon = event_loop. getattr ( "call_soon_threadsafe" ) ?;
204
- let create_task = asyncio. getattr ( "run_coroutine_threadsafe" ) ?;
205
211
let create_future = event_loop. getattr ( "create_future" ) ?;
206
212
207
213
ASYNCIO . get_or_init ( || asyncio. into ( ) ) ;
214
+ ENSURE_FUTURE . get_or_init ( || ensure_future. into ( ) ) ;
208
215
EVENT_LOOP . get_or_init ( || event_loop. into ( ) ) ;
209
216
EXECUTOR . get_or_init ( || executor. into ( ) ) ;
210
217
CALL_SOON . get_or_init ( || call_soon. into ( ) ) ;
211
- CREATE_TASK . get_or_init ( || create_task. into ( ) ) ;
212
218
CREATE_FUTURE . get_or_init ( || create_future. into ( ) ) ;
213
219
214
220
Ok ( ( ) )
@@ -321,6 +327,26 @@ impl PyTaskCompleter {
321
327
}
322
328
}
323
329
330
+ #[ pyclass]
331
+ struct PyEnsureFuture {
332
+ awaitable : PyObject ,
333
+ tx : Option < oneshot:: Sender < PyResult < PyObject > > > ,
334
+ }
335
+
336
+ #[ pymethods]
337
+ impl PyEnsureFuture {
338
+ #[ call]
339
+ pub fn __call__ ( & mut self ) -> PyResult < ( ) > {
340
+ Python :: with_gil ( |py| {
341
+ let task = ensure_future ( py) . call1 ( ( self . awaitable . as_ref ( py) , ) ) ?;
342
+ let on_complete = PyTaskCompleter { tx : self . tx . take ( ) } ;
343
+ task. call_method1 ( "add_done_callback" , ( on_complete, ) ) ?;
344
+
345
+ Ok ( ( ) )
346
+ } )
347
+ }
348
+ }
349
+
324
350
/// Convert a Python `awaitable` into a Rust Future
325
351
///
326
352
/// This function converts the `awaitable` into a Python Task using `run_coroutine_threadsafe`. A
@@ -373,13 +399,13 @@ pub fn into_future(awaitable: &PyAny) -> PyResult<impl Future<Output = PyResult<
373
399
let py = awaitable. py ( ) ;
374
400
let ( tx, rx) = oneshot:: channel ( ) ;
375
401
376
- let task = CREATE_TASK
377
- . get ( )
378
- . expect ( EXPECT_INIT )
379
- . call1 ( py , ( awaitable, get_event_loop ( py ) ) ) ? ;
380
- let on_complete = PyTaskCompleter { tx : Some ( tx) } ;
381
-
382
- task . call_method1 ( py , "add_done_callback" , ( on_complete , ) ) ?;
402
+ CALL_SOON . get ( ) . expect ( EXPECT_INIT ) . call1 (
403
+ py ,
404
+ ( PyEnsureFuture {
405
+ awaitable : awaitable . into ( ) ,
406
+ tx : Some ( tx) ,
407
+ } , ) ,
408
+ ) ?;
383
409
384
410
Ok ( async move {
385
411
match rx. await {
0 commit comments