Skip to content

Commit 5d6f6a7

Browse files
author
Andrew J Westlake
committed
Made tokio test_main fully customizable
1 parent 81e556f commit 5d6f6a7

File tree

3 files changed

+67
-10
lines changed

3 files changed

+67
-10
lines changed

pyo3-asyncio-macros/src/lib.rs

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ mod tokio;
66

77
use proc_macro::TokenStream;
88
use quote::{quote, quote_spanned};
9-
use syn::spanned::Spanned;
9+
use syn::{
10+
parse::{Parse, ParseStream, Result},
11+
spanned::Spanned,
12+
Attribute,
13+
};
1014

1115
/// Enables an async main function that uses the async-std runtime.
1216
///
@@ -235,10 +239,53 @@ pub fn tokio_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
235239
result.into()
236240
}
237241

242+
enum Item {
243+
Attribute(Vec<Attribute>),
244+
String(syn::LitStr),
245+
}
246+
247+
impl Parse for Item {
248+
fn parse(input: ParseStream<'_>) -> Result<Self> {
249+
let lookahead = input.lookahead1();
250+
251+
if lookahead.peek(syn::Token![#]) {
252+
Attribute::parse_outer(input).map(Item::Attribute)
253+
} else {
254+
input.parse().map(Item::String)
255+
}
256+
}
257+
}
258+
259+
struct TokioTestMainArgs {
260+
attrs: Vec<Attribute>,
261+
suite_name: String,
262+
}
263+
264+
impl Parse for TokioTestMainArgs {
265+
fn parse(input: ParseStream<'_>) -> Result<Self> {
266+
let mut args: syn::punctuated::Punctuated<Item, syn::Token![,]> =
267+
input.parse_terminated(Item::parse)?;
268+
269+
let suite_name = match args.pop().unwrap() {
270+
syn::punctuated::Pair::Punctuated(Item::String(s), _)
271+
| syn::punctuated::Pair::End(Item::String(s)) => s.value(),
272+
_ => panic!(),
273+
};
274+
275+
let attrs = match args.pop().unwrap() {
276+
syn::punctuated::Pair::Punctuated(Item::Attribute(attrs), _) => attrs,
277+
_ => panic!(),
278+
};
279+
280+
Ok(Self { attrs, suite_name })
281+
}
282+
}
283+
238284
#[cfg(not(test))]
239285
#[proc_macro]
240286
pub fn tokio_test_main(args: TokenStream) -> TokenStream {
241-
let suite_name = syn::parse_macro_input!(args as syn::LitStr);
287+
let TokioTestMainArgs { attrs, suite_name } =
288+
syn::parse_macro_input!(args as TokioTestMainArgs);
242289

243290
let result = quote! {
244291
#[derive(Clone)]
@@ -259,12 +306,16 @@ pub fn tokio_test_main(args: TokenStream) -> TokenStream {
259306

260307
inventory::collect!(Test);
261308

262-
fn main() {
263-
pyo3_asyncio::tokio::init_multi_thread();
264-
pyo3_asyncio::tokio::testing::test_main(
265-
#suite_name,
266-
inventory::iter::<Test>().map(|test| test.clone()).collect()
267-
);
309+
#(#attrs)*
310+
async fn main() -> pyo3::PyResult<()> {
311+
let args = pyo3_asyncio::testing::parse_args(#suite_name);
312+
313+
pyo3_asyncio::testing::test_harness(
314+
inventory::iter::<Test>().map(|test| test.clone()).collect(), args
315+
)
316+
.await?;
317+
318+
Ok(())
268319
}
269320
};
270321
result.into()

pytests/test_tokio_current_thread_asyncio.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,7 @@ mod common;
22
mod tokio_asyncio;
33

44
// TODO: Fix current thread init
5-
pyo3_asyncio::tokio::test_main!("PyO3 Asyncio Test Suite for Tokio Current-Thread Runtime");
5+
pyo3_asyncio::tokio::test_main!(
6+
#[pyo3_asyncio::tokio::main(flavor = "current_thread")],
7+
"PyO3 Asyncio Tokio Current Thread Test Suite"
8+
);
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
mod common;
22
mod tokio_asyncio;
33

4-
pyo3_asyncio::tokio::test_main!("PyO3 Asyncio Test Suite for Tokio Multi-Thread Runtime");
4+
pyo3_asyncio::tokio::test_main!(
5+
#[pyo3_asyncio::tokio::main],
6+
"PyO3 Asyncio Tokio Multi Thread Test Suite"
7+
);

0 commit comments

Comments
 (0)