Skip to content

Commit 053c785

Browse files
author
Andrew J Westlake
committed
Added support for customizable tokio runtimes
1 parent 31e1984 commit 053c785

File tree

8 files changed

+96
-46
lines changed

8 files changed

+96
-46
lines changed

Cargo.toml

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,26 @@ harness = false
2626
required-features = ["async-std-runtime", "testing"]
2727

2828
[[test]]
29-
name = "test_tokio_asyncio"
30-
path = "pytests/test_tokio_asyncio.rs"
29+
name = "test_tokio_current_thread_asyncio"
30+
path = "pytests/test_tokio_current_thread_asyncio.rs"
3131
harness = false
3232
required-features = ["tokio-runtime", "testing"]
3333

3434
[[test]]
35-
name = "test_tokio_run_forever"
36-
path = "pytests/test_tokio_run_forever.rs"
35+
name = "test_tokio_current_thread_run_forever"
36+
path = "pytests/test_tokio_current_thread_run_forever.rs"
37+
harness = false
38+
required-features = ["tokio-runtime", "testing"]
39+
40+
[[test]]
41+
name = "test_tokio_multi_thread_asyncio"
42+
path = "pytests/test_tokio_multi_thread_asyncio.rs"
43+
harness = false
44+
required-features = ["tokio-runtime", "testing"]
45+
46+
[[test]]
47+
name = "test_tokio_multi_thread_run_forever"
48+
path = "pytests/test_tokio_multi_thread_run_forever.rs"
3749
harness = false
3850
required-features = ["tokio-runtime", "testing"]
3951

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
mod common;
2+
mod tokio_asyncio;
3+
4+
fn main() {
5+
pyo3_asyncio::tokio::init_current_thread();
6+
7+
tokio_asyncio::test_main("PyO3 Asyncio Tokio Current-Thread Test Suite");
8+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
mod tokio_run_forever;
2+
3+
fn main() {
4+
pyo3_asyncio::tokio::init_current_thread();
5+
6+
tokio_run_forever::test_main();
7+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
mod common;
2+
mod tokio_asyncio;
3+
4+
fn main() {
5+
pyo3_asyncio::tokio::init_multi_thread();
6+
7+
tokio_asyncio::test_main("PyO3 Asyncio Tokio Multi-Thread Test Suite");
8+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
mod tokio_run_forever;
2+
3+
fn main() {
4+
pyo3_asyncio::tokio::init_multi_thread();
5+
6+
tokio_run_forever::test_main();
7+
}

pytests/test_tokio_asyncio.rs renamed to pytests/tokio_asyncio/mod.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
mod common;
2-
31
use std::{future::Future, time::Duration};
42

53
use pyo3::{prelude::*, wrap_pyfunction};
64

7-
use pyo3_asyncio::{
8-
testing::Test,
9-
tokio::testing::{new_sync_test, test_main},
10-
};
5+
use pyo3_asyncio::{testing::Test, tokio::testing::new_sync_test};
6+
7+
// enforce the inclusion of the common module
8+
use crate::common;
119

1210
#[pyfunction]
1311
fn sleep_for(py: Python, secs: &PyAny) -> PyResult<PyObject> {
@@ -67,9 +65,9 @@ fn test_async_sleep<'p>(
6765
})
6866
}
6967

70-
fn main() {
71-
test_main(
72-
"PyO3 Asyncio Test Suite",
68+
pub(super) fn test_main(suite_name: &str) {
69+
pyo3_asyncio::tokio::testing::test_main(
70+
suite_name,
7371
vec![
7472
Test::new_async(
7573
"test_async_sleep".into(),

pytests/test_tokio_run_forever.rs renamed to pytests/tokio_run_forever/mod.rs

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,6 @@
1-
use std::{future::pending, thread, time::Duration};
1+
use std::time::Duration;
22

3-
use lazy_static::lazy_static;
43
use pyo3::prelude::*;
5-
use tokio::runtime::{Builder, Runtime};
6-
7-
lazy_static! {
8-
static ref CURRENT_THREAD_RUNTIME: Runtime = {
9-
Builder::new_current_thread()
10-
.enable_all()
11-
.build()
12-
.expect("Couldn't build the runtime")
13-
};
14-
}
154

165
fn dump_err(py: Python<'_>) -> impl FnOnce(PyErr) + '_ {
176
move |e| {
@@ -21,14 +10,10 @@ fn dump_err(py: Python<'_>) -> impl FnOnce(PyErr) + '_ {
2110
}
2211
}
2312

24-
fn main() {
25-
thread::spawn(|| {
26-
CURRENT_THREAD_RUNTIME.block_on(pending::<()>());
27-
});
28-
13+
pub(super) fn test_main() {
2914
Python::with_gil(|py| {
3015
pyo3_asyncio::with_runtime(py, || {
31-
CURRENT_THREAD_RUNTIME.spawn(async move {
16+
pyo3_asyncio::tokio::get_runtime().spawn(async move {
3217
tokio::time::sleep(Duration::from_secs(1)).await;
3318

3419
Python::with_gil(|py| {

src/tokio.rs

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,14 @@ use ::tokio::{
77
runtime::{Builder, Runtime},
88
task,
99
};
10-
use lazy_static::lazy_static;
10+
use once_cell::sync::OnceCell;
1111
use pyo3::prelude::*;
1212

1313
use crate::generic;
1414

15-
lazy_static! {
16-
static ref CURRENT_THREAD_RUNTIME: Runtime = {
17-
Builder::new_current_thread()
18-
.enable_all()
19-
.build()
20-
.expect("Couldn't build the runtime")
21-
};
22-
}
15+
static TOKIO_RUNTIME: OnceCell<Runtime> = OnceCell::new();
16+
17+
const EXPECT_TOKIO_INIT: &str = "Tokio runtime must be initialized";
2318

2419
impl generic::JoinError for task::JoinError {
2520
fn is_panic(&self) -> bool {
@@ -37,19 +32,48 @@ impl generic::Runtime for TokioRuntime {
3732
where
3833
F: Future<Output = ()> + Send + 'static,
3934
{
40-
CURRENT_THREAD_RUNTIME.spawn(async move {
35+
get_runtime().spawn(async move {
4136
fut.await;
4237
})
4338
}
4439
}
4540

46-
/// Initialize the Tokio Runtime
47-
pub fn init() {
41+
/// Initialize the Tokio Runtime with a custom build
42+
pub fn init(runtime: Runtime) {
43+
TOKIO_RUNTIME
44+
.set(runtime)
45+
.expect("Tokio Runtime has already been initialized");
46+
}
47+
48+
/// Initialize the Tokio Runtime with current-thread scheduler
49+
pub fn init_current_thread() {
50+
init(
51+
Builder::new_current_thread()
52+
.enable_all()
53+
.build()
54+
.expect("Couldn't build the current-thread Tokio runtime"),
55+
);
56+
4857
thread::spawn(|| {
49-
CURRENT_THREAD_RUNTIME.block_on(pending::<()>());
58+
get_runtime().block_on(pending::<()>());
5059
});
5160
}
5261

62+
/// Get a reference to the current tokio runtime
63+
pub fn get_runtime<'a>() -> &'a Runtime {
64+
TOKIO_RUNTIME.get().expect(EXPECT_TOKIO_INIT)
65+
}
66+
67+
/// Initialize the Tokio Runtime with the multi-thread scheduler
68+
pub fn init_multi_thread() {
69+
init(
70+
Builder::new_multi_thread()
71+
.enable_all()
72+
.build()
73+
.expect("Couldn't build the current-thread Tokio runtime"),
74+
);
75+
}
76+
5377
/// Run the event loop until the given Future completes
5478
///
5579
/// The event loop runs until the given future is complete.
@@ -76,7 +100,7 @@ pub fn init() {
76100
/// #
77101
/// # Python::with_gil(|py| {
78102
/// # pyo3_asyncio::with_runtime(py, || {
79-
/// pyo3_asyncio::tokio::init();
103+
/// pyo3_asyncio::tokio::init_current_thread();
80104
/// pyo3_asyncio::tokio::run_until_complete(py, async move {
81105
/// tokio::time::sleep(Duration::from_secs(1)).await;
82106
/// Ok(())
@@ -234,11 +258,12 @@ pub mod testing {
234258
/// This is meant to perform the necessary initialization for most test cases. If you want
235259
/// additional control over the initialization (i.e. env_logger initialization), you can use this
236260
/// function as a template.
261+
///
262+
/// Note: The tokio runtime must be initialized before calling this function!
237263
pub fn test_main(suite_name: &str, tests: Vec<Test>) {
238264
Python::with_gil(|py| {
239265
with_runtime(py, || {
240266
let args = parse_args(suite_name);
241-
crate::tokio::init();
242267
crate::tokio::run_until_complete(py, test_harness(tests, args))?;
243268

244269
Ok(())

0 commit comments

Comments
 (0)