Skip to content

Commit 81e556f

Browse files
author
Andrew J Westlake
committed
Added tokio test attributes
1 parent 04b6c4b commit 81e556f

File tree

11 files changed

+216
-243
lines changed

11 files changed

+216
-243
lines changed

Cargo.toml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,6 @@ path = "examples/tokio_multi_thread.rs"
4343
required-features = ["attributes", "tokio-runtime"]
4444

4545

46-
[[test]]
47-
name = "test_async_std_test_macro"
48-
path = "pytests/test_async_std_test_macro.rs"
49-
harness = false
50-
required-features = ["attributes", "async-std-runtime", "testing"]
51-
5246
[[test]]
5347
name = "test_async_std_asyncio"
5448
path = "pytests/test_async_std_asyncio.rs"

pyo3-asyncio-macros/src/lib.rs

Lines changed: 94 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -103,26 +103,14 @@ pub fn tokio_main(args: TokenStream, item: TokenStream) -> TokenStream {
103103
tokio::main(args, item, true)
104104
}
105105

106-
/// Enables an async main function that uses the async-std runtime.
107-
///
108-
/// # Examples
109-
///
110-
/// ```ignore
111-
/// #[pyo3_asyncio::async_std::main]
112-
/// async fn main() -> PyResult<()> {
113-
/// Ok(())
114-
/// }
115-
/// ```
116106
#[cfg(not(test))] // NOTE: exporting main breaks tests, we should file an issue.
117107
#[proc_macro_attribute]
118108
pub fn async_std_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
119109
let input = syn::parse_macro_input!(item as syn::ItemFn);
120110

121111
let sig = &input.sig;
122-
// let inputs = &input.sig.inputs;
123112
let name = &input.sig.ident;
124113
let body = &input.block;
125-
// let attrs = &input.attrs;
126114
let vis = &input.vis;
127115

128116
let fn_impl = if input.sig.asyncness.is_none() {
@@ -133,8 +121,7 @@ pub fn async_std_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
133121
}
134122

135123
Box::pin(async_std::task::spawn_blocking(move || {
136-
#name();
137-
Ok(())
124+
#name()
138125
}))
139126
}
140127
}
@@ -153,7 +140,10 @@ pub fn async_std_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
153140
let result = quote! {
154141
#fn_impl
155142

156-
inventory::submit!(crate::Test { name: stringify!(#name), test_fn: &#name });
143+
inventory::submit!(crate::Test {
144+
name: format!("{}::{}", std::module_path!(), stringify!(#name)),
145+
test_fn: &#name
146+
});
157147
};
158148

159149
result.into()
@@ -167,13 +157,13 @@ pub fn async_std_test_main(args: TokenStream) -> TokenStream {
167157
let result = quote! {
168158
#[derive(Clone)]
169159
pub(crate) struct Test {
170-
pub name: &'static str,
160+
pub name: String,
171161
pub test_fn: &'static (dyn Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> + Send + Sync),
172162
}
173163

174164
impl pyo3_asyncio::testing::TestTrait for Test {
175165
fn name(&self) -> &str {
176-
self.name
166+
self.name.as_str()
177167
}
178168

179169
fn task(self) -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
@@ -192,3 +182,90 @@ pub fn async_std_test_main(args: TokenStream) -> TokenStream {
192182
};
193183
result.into()
194184
}
185+
186+
#[cfg(not(test))] // NOTE: exporting main breaks tests, we should file an issue.
187+
#[proc_macro_attribute]
188+
pub fn tokio_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
189+
let input = syn::parse_macro_input!(item as syn::ItemFn);
190+
191+
let sig = &input.sig;
192+
let name = &input.sig.ident;
193+
let body = &input.block;
194+
let vis = &input.vis;
195+
196+
let fn_impl = if input.sig.asyncness.is_none() {
197+
quote! {
198+
#vis fn #name() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
199+
#sig {
200+
#body
201+
}
202+
203+
Box::pin(async {
204+
match pyo3_asyncio::tokio::get_runtime().spawn_blocking(&#name).await {
205+
Ok(result) => result,
206+
Err(e) => {
207+
assert!(e.is_panic());
208+
Err(pyo3::exceptions::PyException::new_err("rust future panicked"))
209+
}
210+
}
211+
})
212+
}
213+
}
214+
} else {
215+
quote! {
216+
#vis fn #name() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
217+
#sig {
218+
#body
219+
}
220+
221+
Box::pin(#name())
222+
}
223+
}
224+
};
225+
226+
let result = quote! {
227+
#fn_impl
228+
229+
inventory::submit!(crate::Test {
230+
name: format!("{}::{}", std::module_path!(), stringify!(#name)),
231+
test_fn: &#name
232+
});
233+
};
234+
235+
result.into()
236+
}
237+
238+
#[cfg(not(test))]
239+
#[proc_macro]
240+
pub fn tokio_test_main(args: TokenStream) -> TokenStream {
241+
let suite_name = syn::parse_macro_input!(args as syn::LitStr);
242+
243+
let result = quote! {
244+
#[derive(Clone)]
245+
pub(crate) struct Test {
246+
pub name: String,
247+
pub test_fn: &'static (dyn Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> + Send + Sync),
248+
}
249+
250+
impl pyo3_asyncio::testing::TestTrait for Test {
251+
fn name(&self) -> &str {
252+
self.name.as_str()
253+
}
254+
255+
fn task(self) -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
256+
(self.test_fn)()
257+
}
258+
}
259+
260+
inventory::collect!(Test);
261+
262+
fn main() {
263+
pyo3_asyncio::tokio::init_multi_thread();
264+
pyo3_asyncio::tokio::testing::test_main(
265+
#suite_name,
266+
inventory::iter::<Test>().map(|test| test.clone()).collect()
267+
);
268+
}
269+
};
270+
result.into()
271+
}

pytests/common/mod.rs

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{future::Future, thread, time::Duration};
1+
use std::{thread, time::Duration};
22

33
use pyo3::prelude::*;
44

@@ -12,25 +12,20 @@ async def sleep_for_1s(sleep_for):
1212
await sleep_for(1)
1313
"#;
1414

15-
pub(super) fn test_into_future(
16-
py: Python,
17-
) -> PyResult<impl Future<Output = PyResult<()>> + Send + 'static> {
18-
let test_mod: PyObject =
19-
PyModule::from_code(py, TEST_MOD, "test_rust_coroutine/test_mod.py", "test_mod")?.into();
20-
21-
Ok(async move {
22-
Python::with_gil(|py| {
23-
pyo3_asyncio::into_future(
24-
test_mod
25-
.call_method1(py, "py_sleep", (1.into_py(py),))?
26-
.as_ref(py),
27-
)
28-
})?
29-
.await?;
30-
Ok(())
31-
})
15+
pub(super) async fn test_into_future() -> PyResult<()> {
16+
let fut = Python::with_gil(|py| {
17+
let test_mod =
18+
PyModule::from_code(py, TEST_MOD, "test_rust_coroutine/test_mod.py", "test_mod")?;
19+
20+
pyo3_asyncio::into_future(test_mod.call_method1("py_sleep", (1.into_py(py),))?)
21+
})?;
22+
23+
fut.await?;
24+
25+
Ok(())
3226
}
3327

34-
pub(super) fn test_blocking_sleep() {
28+
pub(super) fn test_blocking_sleep() -> PyResult<()> {
3529
thread::sleep(Duration::from_secs(1));
30+
Ok(())
3631
}

pytests/test_async_std_asyncio.rs

Lines changed: 43 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
mod common;
22

3-
use std::{future::Future, time::Duration};
3+
use std::time::Duration;
44

55
use async_std::task;
66
use pyo3::{prelude::*, wrap_pyfunction};
77

8-
use pyo3_asyncio::{
9-
async_std::testing::{new_sync_test, test_main},
10-
testing::Test,
11-
};
12-
138
#[pyfunction]
149
fn sleep_for(py: Python, secs: &PyAny) -> PyResult<PyObject> {
1510
let secs = secs.extract()?;
@@ -20,91 +15,53 @@ fn sleep_for(py: Python, secs: &PyAny) -> PyResult<PyObject> {
2015
})
2116
}
2217

23-
fn test_into_coroutine(
24-
py: Python,
25-
) -> PyResult<impl Future<Output = PyResult<()>> + Send + 'static> {
26-
let sleeper_mod: Py<PyModule> = PyModule::new(py, "rust_sleeper")?.into();
27-
28-
sleeper_mod
29-
.as_ref(py)
30-
.add_wrapped(wrap_pyfunction!(sleep_for))?;
31-
32-
let test_mod: PyObject = PyModule::from_code(
33-
py,
34-
common::TEST_MOD,
35-
"test_rust_coroutine/test_mod.py",
36-
"test_mod",
37-
)?
38-
.into();
39-
40-
Ok(async move {
41-
Python::with_gil(|py| {
42-
pyo3_asyncio::into_future(
43-
test_mod
44-
.call_method1(py, "sleep_for_1s", (sleeper_mod.getattr(py, "sleep_for")?,))?
45-
.as_ref(py),
46-
)
47-
})?
48-
.await?;
49-
Ok(())
50-
})
18+
#[pyo3_asyncio::async_std::test]
19+
async fn test_into_coroutine() -> PyResult<()> {
20+
let fut = Python::with_gil(|py| {
21+
let sleeper_mod = PyModule::new(py, "rust_sleeper")?;
22+
23+
sleeper_mod.add_wrapped(wrap_pyfunction!(sleep_for))?;
24+
25+
let test_mod = PyModule::from_code(
26+
py,
27+
common::TEST_MOD,
28+
"test_rust_coroutine/test_mod.py",
29+
"test_mod",
30+
)?;
31+
32+
pyo3_asyncio::into_future(
33+
test_mod.call_method1("sleep_for_1s", (sleeper_mod.getattr("sleep_for")?,))?,
34+
)
35+
})?;
36+
37+
fut.await?;
38+
39+
Ok(())
5140
}
5241

53-
fn test_async_sleep<'p>(
54-
py: Python<'p>,
55-
) -> PyResult<impl Future<Output = PyResult<()>> + Send + 'static> {
56-
let asyncio = PyObject::from(py.import("asyncio")?);
42+
#[pyo3_asyncio::async_std::test]
43+
async fn test_async_sleep() -> PyResult<()> {
44+
let asyncio =
45+
Python::with_gil(|py| py.import("asyncio").map(|asyncio| PyObject::from(asyncio)))?;
5746

58-
Ok(async move {
59-
task::sleep(Duration::from_secs(1)).await;
47+
task::sleep(Duration::from_secs(1)).await;
6048

61-
Python::with_gil(|py| {
62-
pyo3_asyncio::into_future(asyncio.as_ref(py).call_method1("sleep", (1.0,))?)
63-
})?
64-
.await?;
49+
Python::with_gil(|py| {
50+
pyo3_asyncio::into_future(asyncio.as_ref(py).call_method1("sleep", (1.0,))?)
51+
})?
52+
.await?;
6553

66-
Ok(())
67-
})
54+
Ok(())
6855
}
6956

70-
fn main() {
71-
test_main(
72-
"PyO3 Asyncio Test Suite",
73-
vec![
74-
Test::new_async(
75-
"test_async_sleep".into(),
76-
Python::with_gil(|py| {
77-
test_async_sleep(py)
78-
.map_err(|e| {
79-
e.print_and_set_sys_last_vars(py);
80-
})
81-
.unwrap()
82-
}),
83-
),
84-
new_sync_test("test_blocking_sleep".into(), || {
85-
common::test_blocking_sleep();
86-
Ok(())
87-
}),
88-
Test::new_async(
89-
"test_into_coroutine".into(),
90-
Python::with_gil(|py| {
91-
test_into_coroutine(py)
92-
.map_err(|e| {
93-
e.print_and_set_sys_last_vars(py);
94-
})
95-
.unwrap()
96-
}),
97-
),
98-
Test::new_async(
99-
"test_into_future".into(),
100-
Python::with_gil(|py| {
101-
common::test_into_future(py)
102-
.map_err(|e| {
103-
e.print_and_set_sys_last_vars(py);
104-
})
105-
.unwrap()
106-
}),
107-
),
108-
],
109-
)
57+
#[pyo3_asyncio::async_std::test]
58+
fn test_blocking_sleep() -> PyResult<()> {
59+
common::test_blocking_sleep()
11060
}
61+
62+
#[pyo3_asyncio::async_std::test]
63+
async fn test_into_future() -> PyResult<()> {
64+
common::test_into_future().await
65+
}
66+
67+
pyo3_asyncio::async_std::test_main!("PyO3 Asyncio Test Suite for Async-Std Runtime");

pytests/test_async_std_test_macro.rs

Lines changed: 0 additions & 17 deletions
This file was deleted.
Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
mod common;
22
mod tokio_asyncio;
33

4-
fn main() {
5-
pyo3_asyncio::tokio::init_current_thread();
6-
7-
tokio_asyncio::test_main("PyO3 Asyncio Tokio Current-Thread Test Suite");
8-
}
4+
// TODO: Fix current thread init
5+
pyo3_asyncio::tokio::test_main!("PyO3 Asyncio Test Suite for Tokio Current-Thread Runtime");

0 commit comments

Comments
 (0)