@@ -6,7 +6,11 @@ mod tokio;
6
6
7
7
use proc_macro:: TokenStream ;
8
8
use quote:: { quote, quote_spanned} ;
9
- use syn:: spanned:: Spanned ;
9
+ use syn:: {
10
+ parse:: { Parse , ParseStream , Result } ,
11
+ spanned:: Spanned ,
12
+ Attribute ,
13
+ } ;
10
14
11
15
/// Enables an async main function that uses the async-std runtime.
12
16
///
@@ -235,10 +239,53 @@ pub fn tokio_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
235
239
result. into ( )
236
240
}
237
241
242
+ enum Item {
243
+ Attribute ( Vec < Attribute > ) ,
244
+ String ( syn:: LitStr ) ,
245
+ }
246
+
247
+ impl Parse for Item {
248
+ fn parse ( input : ParseStream < ' _ > ) -> Result < Self > {
249
+ let lookahead = input. lookahead1 ( ) ;
250
+
251
+ if lookahead. peek ( syn:: Token ![ #] ) {
252
+ Attribute :: parse_outer ( input) . map ( Item :: Attribute )
253
+ } else {
254
+ input. parse ( ) . map ( Item :: String )
255
+ }
256
+ }
257
+ }
258
+
259
+ struct TokioTestMainArgs {
260
+ attrs : Vec < Attribute > ,
261
+ suite_name : String ,
262
+ }
263
+
264
+ impl Parse for TokioTestMainArgs {
265
+ fn parse ( input : ParseStream < ' _ > ) -> Result < Self > {
266
+ let mut args: syn:: punctuated:: Punctuated < Item , syn:: Token ![ , ] > =
267
+ input. parse_terminated ( Item :: parse) ?;
268
+
269
+ let suite_name = match args. pop ( ) . unwrap ( ) {
270
+ syn:: punctuated:: Pair :: Punctuated ( Item :: String ( s) , _)
271
+ | syn:: punctuated:: Pair :: End ( Item :: String ( s) ) => s. value ( ) ,
272
+ _ => panic ! ( ) ,
273
+ } ;
274
+
275
+ let attrs = match args. pop ( ) . unwrap ( ) {
276
+ syn:: punctuated:: Pair :: Punctuated ( Item :: Attribute ( attrs) , _) => attrs,
277
+ _ => panic ! ( ) ,
278
+ } ;
279
+
280
+ Ok ( Self { attrs, suite_name } )
281
+ }
282
+ }
283
+
238
284
#[ cfg( not( test) ) ]
239
285
#[ proc_macro]
240
286
pub fn tokio_test_main ( args : TokenStream ) -> TokenStream {
241
- let suite_name = syn:: parse_macro_input!( args as syn:: LitStr ) ;
287
+ let TokioTestMainArgs { attrs, suite_name } =
288
+ syn:: parse_macro_input!( args as TokioTestMainArgs ) ;
242
289
243
290
let result = quote ! {
244
291
#[ derive( Clone ) ]
@@ -259,12 +306,16 @@ pub fn tokio_test_main(args: TokenStream) -> TokenStream {
259
306
260
307
inventory:: collect!( Test ) ;
261
308
262
- fn main( ) {
263
- pyo3_asyncio:: tokio:: init_multi_thread( ) ;
264
- pyo3_asyncio:: tokio:: testing:: test_main(
265
- #suite_name,
266
- inventory:: iter:: <Test >( ) . map( |test| test. clone( ) ) . collect( )
267
- ) ;
309
+ #( #attrs) *
310
+ async fn main( ) -> pyo3:: PyResult <( ) > {
311
+ let args = pyo3_asyncio:: testing:: parse_args( #suite_name) ;
312
+
313
+ pyo3_asyncio:: testing:: test_harness(
314
+ inventory:: iter:: <Test >( ) . map( |test| test. clone( ) ) . collect( ) , args
315
+ )
316
+ . await ?;
317
+
318
+ Ok ( ( ) )
268
319
}
269
320
} ;
270
321
result. into ( )
0 commit comments