Skip to content

Commit 4569450

Browse files
authored
Merge pull request #14 from awestlake87/init-changes
Init changes for #13
2 parents 3e4edf2 + 1be4008 commit 4569450

File tree

6 files changed

+108
-41
lines changed

6 files changed

+108
-41
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,6 @@ features = ["unstable"]
9898
optional = true
9999

100100
[dependencies.tokio]
101-
version = "1.0"
101+
version = "1.4"
102102
features = ["full"]
103103
optional = true

pytests/common/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,11 @@ pub(super) async fn test_other_awaitables() -> PyResult<()> {
5151

5252
Ok(())
5353
}
54+
55+
pub(super) fn test_init_twice() -> PyResult<()> {
56+
// try_init has already been called in test main - ensure a second call doesn't mess the other
57+
// tests up
58+
Python::with_gil(|py| pyo3_asyncio::try_init(py))?;
59+
60+
Ok(())
61+
}

pytests/test_async_std_asyncio.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ async fn test_other_awaitables() -> PyResult<()> {
6969
common::test_other_awaitables().await
7070
}
7171

72+
#[pyo3_asyncio::async_std::test]
73+
fn test_init_twice() -> PyResult<()> {
74+
common::test_init_twice()
75+
}
76+
7277
#[pyo3_asyncio::async_std::main]
7378
async fn main() -> pyo3::PyResult<()> {
7479
pyo3_asyncio::testing::main().await

pytests/tokio_asyncio/mod.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,18 @@ async fn test_into_future() -> PyResult<()> {
6767
async fn test_other_awaitables() -> PyResult<()> {
6868
common::test_other_awaitables().await
6969
}
70+
71+
#[pyo3_asyncio::tokio::test]
72+
fn test_init_twice() -> PyResult<()> {
73+
common::test_init_twice()
74+
}
75+
76+
#[pyo3_asyncio::tokio::test]
77+
fn test_init_tokio_twice() -> PyResult<()> {
78+
// tokio has already been initialized in test main. call these functions to
79+
// make sure they don't cause problems with the other tests.
80+
pyo3_asyncio::tokio::init_multi_thread_once();
81+
pyo3_asyncio::tokio::init_current_thread_once();
82+
83+
Ok(())
84+
}

src/lib.rs

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -193,29 +193,29 @@ where
193193

194194
/// Attempt to initialize the Python and Rust event loops
195195
///
196-
/// Must be called at the start of your program
197-
fn try_init(py: Python) -> PyResult<()> {
198-
let asyncio = py.import("asyncio")?;
199-
200-
let ensure_future = asyncio.getattr("ensure_future")?;
201-
202-
let event_loop = asyncio.call_method0("get_event_loop")?;
203-
let executor = py
204-
.import("concurrent.futures.thread")?
205-
.getattr("ThreadPoolExecutor")?
206-
.call0()?;
207-
208-
event_loop.call_method1("set_default_executor", (executor,))?;
209-
210-
let call_soon = event_loop.getattr("call_soon_threadsafe")?;
211-
let create_future = event_loop.getattr("create_future")?;
212-
213-
ASYNCIO.get_or_init(|| asyncio.into());
214-
ENSURE_FUTURE.get_or_init(|| ensure_future.into());
215-
EVENT_LOOP.get_or_init(|| event_loop.into());
216-
EXECUTOR.get_or_init(|| executor.into());
217-
CALL_SOON.get_or_init(|| call_soon.into());
218-
CREATE_FUTURE.get_or_init(|| create_future.into());
196+
/// - Must be called before any other pyo3-asyncio functions.
197+
/// - Calling `try_init` a second time returns `Ok(())` and does nothing.
198+
/// > In future versions this may return an `Err`.
199+
pub fn try_init(py: Python) -> PyResult<()> {
200+
EVENT_LOOP.get_or_try_init(|| -> PyResult<PyObject> {
201+
let asyncio = py.import("asyncio")?;
202+
let ensure_future = asyncio.getattr("ensure_future")?;
203+
let event_loop = asyncio.call_method0("get_event_loop")?;
204+
let executor = py
205+
.import("concurrent.futures.thread")?
206+
.getattr("ThreadPoolExecutor")?
207+
.call0()?;
208+
event_loop.call_method1("set_default_executor", (executor,))?;
209+
let call_soon = event_loop.getattr("call_soon_threadsafe")?;
210+
let create_future = event_loop.getattr("create_future")?;
211+
212+
ASYNCIO.get_or_init(|| asyncio.into());
213+
ENSURE_FUTURE.get_or_init(|| ensure_future.into());
214+
EXECUTOR.get_or_init(|| executor.into());
215+
CALL_SOON.get_or_init(|| call_soon.into());
216+
CREATE_FUTURE.get_or_init(|| create_future.into());
217+
Ok(event_loop.into())
218+
})?;
219219

220220
Ok(())
221221
}
@@ -284,7 +284,7 @@ pub fn run_forever(py: Python) -> PyResult<()> {
284284
}
285285

286286
/// Shutdown the event loops and perform any necessary cleanup
287-
fn try_close(py: Python) -> PyResult<()> {
287+
pub fn try_close(py: Python) -> PyResult<()> {
288288
// Shutdown the executor and wait until all threads are cleaned up
289289
EXECUTOR
290290
.get()

src/tokio.rs

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -63,33 +63,72 @@ pub fn init(runtime: Runtime) {
6363
.expect("Tokio Runtime has already been initialized");
6464
}
6565

66+
fn current_thread() -> Runtime {
67+
Builder::new_current_thread()
68+
.enable_all()
69+
.build()
70+
.expect("Couldn't build the current-thread Tokio runtime")
71+
}
72+
73+
fn start_current_thread() {
74+
thread::spawn(move || {
75+
TOKIO_RUNTIME.get().unwrap().block_on(pending::<()>());
76+
});
77+
}
78+
6679
/// Initialize the Tokio Runtime with current-thread scheduler
80+
///
81+
/// # Panics
82+
/// This function will panic if called a second time. See [`init_current_thread_once`] if you want
83+
/// to avoid this panic.
6784
pub fn init_current_thread() {
68-
init(
69-
Builder::new_current_thread()
70-
.enable_all()
71-
.build()
72-
.expect("Couldn't build the current-thread Tokio runtime"),
73-
);
74-
75-
thread::spawn(|| {
76-
get_runtime().block_on(pending::<()>());
77-
});
85+
init(current_thread());
86+
start_current_thread();
7887
}
7988

8089
/// Get a reference to the current tokio runtime
8190
pub fn get_runtime<'a>() -> &'a Runtime {
8291
TOKIO_RUNTIME.get().expect(EXPECT_TOKIO_INIT)
8392
}
8493

94+
fn multi_thread() -> Runtime {
95+
Builder::new_multi_thread()
96+
.enable_all()
97+
.build()
98+
.expect("Couldn't build the multi-thread Tokio runtime")
99+
}
100+
85101
/// Initialize the Tokio Runtime with the multi-thread scheduler
102+
///
103+
/// # Panics
104+
/// This function will panic if called a second time. See [`init_multi_thread_once`] if you want to
105+
/// avoid this panic.
86106
pub fn init_multi_thread() {
87-
init(
88-
Builder::new_multi_thread()
89-
.enable_all()
90-
.build()
91-
.expect("Couldn't build the multi-thread Tokio runtime"),
92-
);
107+
init(multi_thread());
108+
}
109+
110+
/// Ensure that the Tokio Runtime is initialized
111+
///
112+
/// If the runtime has not been initialized already, the multi-thread scheduler
113+
/// is used. Calling this function a second time is a no-op.
114+
pub fn init_multi_thread_once() {
115+
TOKIO_RUNTIME.get_or_init(|| multi_thread());
116+
}
117+
118+
/// Ensure that the Tokio Runtime is initialized
119+
///
120+
/// If the runtime has not been initialized already, the current-thread
121+
/// scheduler is used. Calling this function a second time is a no-op.
122+
pub fn init_current_thread_once() {
123+
let mut initialized = false;
124+
TOKIO_RUNTIME.get_or_init(|| {
125+
initialized = true;
126+
current_thread()
127+
});
128+
129+
if initialized {
130+
start_current_thread();
131+
}
93132
}
94133

95134
/// Run the event loop until the given Future completes

0 commit comments

Comments
 (0)