Skip to content

Commit f596a03

Browse files
author
Andrew J Westlake
committed
Added support for async-std test attribute using the inventory crate
1 parent 62eef13 commit f596a03

File tree

6 files changed

+157
-11
lines changed

6 files changed

+157
-11
lines changed

Cargo.toml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ name = "tokio_multi_thread"
4242
path = "examples/tokio_multi_thread.rs"
4343
required-features = ["attributes", "tokio-runtime"]
4444

45+
46+
[[test]]
47+
name = "test_async_std_test_macro"
48+
path = "pytests/test_async_std_test_macro.rs"
49+
harness = false
50+
required-features = ["attributes", "async-std-runtime", "testing"]
51+
4552
[[test]]
4653
name = "test_async_std_asyncio"
4754
path = "pytests/test_async_std_asyncio.rs"
@@ -94,4 +101,7 @@ optional = true
94101
[dependencies.tokio]
95102
version = "1.0"
96103
features = ["full"]
97-
optional = true
104+
optional = true
105+
106+
[dev-dependencies]
107+
inventory = "0.1"

pyo3-asyncio-macros/src/lib.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,93 @@ pub fn async_std_main(_attr: TokenStream, item: TokenStream) -> TokenStream {
102102
pub fn tokio_main(args: TokenStream, item: TokenStream) -> TokenStream {
103103
tokio::main(args, item, true)
104104
}
105+
106+
/// Enables an async main function that uses the async-std runtime.
107+
///
108+
/// # Examples
109+
///
110+
/// ```ignore
111+
/// #[pyo3_asyncio::async_std::main]
112+
/// async fn main() -> PyResult<()> {
113+
/// Ok(())
114+
/// }
115+
/// ```
116+
#[cfg(not(test))] // NOTE: exporting main breaks tests, we should file an issue.
117+
#[proc_macro_attribute]
118+
pub fn async_std_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
119+
let input = syn::parse_macro_input!(item as syn::ItemFn);
120+
121+
let sig = &input.sig;
122+
// let inputs = &input.sig.inputs;
123+
let name = &input.sig.ident;
124+
let body = &input.block;
125+
// let attrs = &input.attrs;
126+
let vis = &input.vis;
127+
128+
let fn_impl = if input.sig.asyncness.is_none() {
129+
quote! {
130+
#vis fn #name() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
131+
#sig {
132+
#body
133+
}
134+
135+
Box::pin(async_std::task::spawn_blocking(move || {
136+
#name();
137+
Ok(())
138+
}))
139+
}
140+
}
141+
} else {
142+
quote! {
143+
#vis fn #name() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
144+
#sig {
145+
#body
146+
}
147+
148+
Box::pin(#name())
149+
}
150+
}
151+
};
152+
153+
let result = quote! {
154+
#fn_impl
155+
156+
inventory::submit!(crate::Test { name: stringify!(#name), test_fn: &#name });
157+
};
158+
159+
result.into()
160+
}
161+
162+
#[cfg(not(test))]
163+
#[proc_macro]
164+
pub fn async_std_test_main(args: TokenStream) -> TokenStream {
165+
let suite_name = syn::parse_macro_input!(args as syn::LitStr);
166+
167+
let result = quote! {
168+
#[derive(Clone)]
169+
pub(crate) struct Test {
170+
pub name: &'static str,
171+
pub test_fn: &'static (dyn Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> + Send + Sync),
172+
}
173+
174+
impl pyo3_asyncio::testing::TestTrait for Test {
175+
fn name(&self) -> &str {
176+
self.name
177+
}
178+
179+
fn task(self) -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
180+
(self.test_fn)()
181+
}
182+
}
183+
184+
inventory::collect!(Test);
185+
186+
fn main() {
187+
pyo3_asyncio::async_std::testing::test_main(
188+
#suite_name,
189+
inventory::iter::<Test>().map(|test| test.clone()).collect()
190+
);
191+
}
192+
};
193+
result.into()
194+
}

pytests/test_async_std_test_macro.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
use std::{time::Duration, thread};
2+
3+
use pyo3::prelude::*;
4+
5+
#[pyo3_asyncio::async_std::test]
6+
fn test() {
7+
thread::sleep(Duration::from_secs(1));
8+
}
9+
10+
#[pyo3_asyncio::async_std::test]
11+
async fn test_sleep() -> PyResult<()> {
12+
async_std::task::sleep(Duration::from_secs(1)).await;
13+
14+
Ok(())
15+
}
16+
17+
pyo3_asyncio::async_std::test_main!("PyO3 test async-std test macros");

src/async_std.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ use crate::generic::{self, JoinError, Runtime};
88
/// <span class="module-item stab portability" style="display: inline; border-radius: 3px; padding: 2px; font-size: 80%; line-height: 1.2;"><code>attributes</code></span> Sets up the async-std runtime and runs an async fn as main
99
#[cfg(feature = "attributes")]
1010
pub use pyo3_asyncio_macros::async_std_main as main;
11+
#[cfg(feature = "attributes")]
12+
pub use pyo3_asyncio_macros::async_std_test as test;
13+
14+
#[cfg(feature = "attributes")]
15+
pub use pyo3_asyncio_macros::async_std_test_main as test_main;
1116

1217
struct AsyncStdJoinError;
1318

@@ -157,7 +162,10 @@ pub mod testing {
157162
//!
158163
//! ```no_run
159164
//! fn main() {
160-
//! pyo3_asyncio::async_std::testing::test_main("Example Test Suite", vec![]);
165+
//! pyo3_asyncio::async_std::testing::test_main(
166+
//! "Example Test Suite",
167+
//! Vec::<pyo3_asyncio::testing::Test>::new()
168+
//! );
161169
//! }
162170
//! ```
163171
//!
@@ -194,7 +202,11 @@ pub mod testing {
194202
use async_std::task;
195203
use pyo3::prelude::*;
196204

197-
use crate::{async_std::AsyncStdRuntime, generic, testing::Test};
205+
use crate::{
206+
async_std::AsyncStdRuntime,
207+
generic,
208+
testing::{Test, TestTrait},
209+
};
198210

199211
/// Construct a test from a blocking function (like the traditional `#[test]` attribute)
200212
pub fn new_sync_test<F>(name: String, func: F) -> Test
@@ -208,7 +220,7 @@ pub mod testing {
208220
///
209221
/// This is meant to perform the necessary initialization for most test cases. If you want
210222
/// additional control over the initialization, you can use this function as a template.
211-
pub fn test_main(suite_name: &str, tests: Vec<Test>) {
212-
generic::testing::test_main::<AsyncStdRuntime>(suite_name, tests)
223+
pub fn test_main(suite_name: &str, tests: Vec<impl TestTrait + 'static>) {
224+
generic::testing::test_main::<AsyncStdRuntime, _>(suite_name, tests)
213225
}
214226
}

src/generic.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ pub mod testing {
244244
use crate::{
245245
dump_err,
246246
generic::{run_until_complete, Runtime},
247-
testing::{parse_args, test_harness, Test},
247+
testing::{parse_args, test_harness, TestTrait},
248248
with_runtime,
249249
};
250250

@@ -253,9 +253,10 @@ pub mod testing {
253253
/// This is meant to perform the necessary initialization for most test cases. If you want
254254
/// additional control over the initialization, you can use this
255255
/// function as a template.
256-
pub fn test_main<R>(suite_name: &str, tests: Vec<Test>)
256+
pub fn test_main<R, T>(suite_name: &str, tests: Vec<T>)
257257
where
258258
R: Runtime,
259+
T: TestTrait + 'static,
259260
{
260261
Python::with_gil(|py| {
261262
with_runtime(py, || {

src/testing.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,20 @@ pub struct Test {
8282
task: Pin<Box<dyn Future<Output = PyResult<()>> + Send>>,
8383
}
8484

85+
pub trait TestTrait: Send {
86+
fn name(&self) -> &str;
87+
fn task(self) -> Pin<Box<dyn Future<Output = PyResult<()>> + Send>>;
88+
}
89+
90+
impl TestTrait for Test {
91+
fn name(&self) -> &str {
92+
self.name.as_str()
93+
}
94+
fn task(self) -> Pin<Box<dyn Future<Output = PyResult<()>> + Send>> {
95+
self.task
96+
}
97+
}
98+
8599
impl Test {
86100
/// Construct a test from a future
87101
pub fn new_async(
@@ -96,21 +110,23 @@ impl Test {
96110
}
97111

98112
/// Run a sequence of tests while applying any necessary filtering from the `Args`
99-
pub async fn test_harness(tests: Vec<Test>, args: Args) -> PyResult<()> {
113+
pub async fn test_harness(tests: Vec<impl TestTrait + 'static>, args: Args) -> PyResult<()> {
100114
stream::iter(tests)
101115
.for_each_concurrent(Some(4), |test| {
102116
let mut ignore = false;
103117

104118
if let Some(filter) = args.filter.as_ref() {
105-
if !test.name.contains(filter) {
119+
if !test.name().contains(filter) {
106120
ignore = true;
107121
}
108122
}
109123

110124
async move {
111125
if !ignore {
112-
test.task.await.unwrap();
113-
println!("test {} ... ok", test.name);
126+
let name = test.name().to_string();
127+
test.task().await.unwrap();
128+
129+
println!("test {} ... ok", name);
114130
}
115131
}
116132
})

0 commit comments

Comments
 (0)