@@ -3,7 +3,7 @@ mod args;
33use args:: TestContextArgs ;
44use proc_macro:: TokenStream ;
55use quote:: { format_ident, quote} ;
6- use syn:: { Block , Ident } ;
6+ use syn:: Ident ;
77
88/// Macro to use on tests to add the setup/teardown functionality of your context.
99///
@@ -28,111 +28,82 @@ use syn::{Block, Ident};
2828#[ proc_macro_attribute]
2929pub fn test_context ( attr : TokenStream , item : TokenStream ) -> TokenStream {
3030 let args = syn:: parse_macro_input!( attr as TestContextArgs ) ;
31-
3231 let input = syn:: parse_macro_input!( item as syn:: ItemFn ) ;
33- let is_async = input. sig . asyncness . is_some ( ) ;
34-
35- let ( new_input, context_arg_name) =
36- extract_and_remove_context_arg ( input. clone ( ) , args. context_type . clone ( ) ) ;
37-
38- let wrapper_body = if is_async {
39- async_wrapper_body ( args, & context_arg_name, & input. block )
40- } else {
41- sync_wrapper_body ( args, & context_arg_name, & input. block )
42- } ;
4332
44- let mut result_input = new_input ;
45- result_input . block = Box :: new ( syn :: parse2 ( wrapper_body ) . unwrap ( ) ) ;
33+ let ( input , context_arg_name ) = remove_context_arg ( input , args . context_type . clone ( ) ) ;
34+ let input = refactor_input_body ( input , & args , context_arg_name ) ;
4635
47- quote ! { #result_input } . into ( )
36+ quote ! { #input } . into ( )
4837}
4938
50- fn async_wrapper_body (
51- args : TestContextArgs ,
52- context_arg_name : & Option < syn:: Ident > ,
53- body : & Block ,
54- ) -> proc_macro2:: TokenStream {
55- let context_type = args. context_type ;
39+ fn refactor_input_body (
40+ mut input : syn:: ItemFn ,
41+ args : & TestContextArgs ,
42+ context_arg_name : Option < Ident > ,
43+ ) -> syn:: ItemFn {
44+ let context_type = & args. context_type ;
45+ let context_arg_name = context_arg_name. unwrap_or_else ( || format_ident ! ( "test_ctx" ) ) ;
5646 let result_name = format_ident ! ( "wrapped_result" ) ;
47+ let body = & input. block ;
48+ let is_async = input. sig . asyncness . is_some ( ) ;
5749
58- let binding = format_ident ! ( "test_ctx" ) ;
59- let context_name = context_arg_name. as_ref ( ) . unwrap_or ( & binding) ;
60-
61- let body = if args. skip_teardown {
62- quote ! {
63- let #context_name = <#context_type as test_context:: AsyncTestContext >:: setup( ) . await ;
64- let #result_name = std:: panic:: AssertUnwindSafe ( async { #body } ) . catch_unwind( ) . await ;
65- }
66- } else {
67- quote ! {
68- let mut #context_name = <#context_type as test_context:: AsyncTestContext >:: setup( ) . await ;
69- let #result_name = std:: panic:: AssertUnwindSafe ( async { #body } ) . catch_unwind( ) . await ;
70- <#context_type as test_context:: AsyncTestContext >:: teardown( #context_name) . await ;
50+ let body = match ( is_async, args. skip_teardown ) {
51+ ( true , true ) => {
52+ quote ! {
53+ use test_context:: futures:: FutureExt ;
54+ let mut __context = <#context_type as test_context:: AsyncTestContext >:: setup( ) . await ;
55+ let #context_arg_name = & mut __context;
56+ let #result_name = std:: panic:: AssertUnwindSafe ( async { #body } ) . catch_unwind( ) . await ;
57+ }
7158 }
72- } ;
73-
74- let handle_wrapped_result = handle_result ( result_name) ;
75-
76- quote ! {
77- {
78- use test_context:: futures:: FutureExt ;
79- #body
80- #handle_wrapped_result
59+ ( true , false ) => {
60+ quote ! {
61+ use test_context:: futures:: FutureExt ;
62+ let mut __context = <#context_type as test_context:: AsyncTestContext >:: setup( ) . await ;
63+ let #context_arg_name = & mut __context;
64+ let #result_name = std:: panic:: AssertUnwindSafe ( async { #body } ) . catch_unwind( ) . await ;
65+ <#context_type as test_context:: AsyncTestContext >:: teardown( __context) . await ;
66+ }
8167 }
82- }
83- }
84-
85- fn sync_wrapper_body (
86- args : TestContextArgs ,
87- context_arg_name : & Option < syn:: Ident > ,
88- body : & Block ,
89- ) -> proc_macro2:: TokenStream {
90- let context_type = args. context_type ;
91- let result_name = format_ident ! ( "wrapped_result" ) ;
92-
93- let binding = format_ident ! ( "test_ctx" ) ;
94- let context_name = context_arg_name. as_ref ( ) . unwrap_or ( & binding) ;
95-
96- let body = if args. skip_teardown {
97- quote ! {
98- let mut #context_name = <#context_type as test_context:: TestContext >:: setup( ) ;
99- let #result_name = std:: panic:: catch_unwind( std:: panic:: AssertUnwindSafe ( || {
100- let #context_name = & mut #context_name;
101- #body
102- } ) ) ;
68+ ( false , true ) => {
69+ quote ! {
70+ let mut __context = <#context_type as test_context:: TestContext >:: setup( ) ;
71+ let #result_name = std:: panic:: catch_unwind( std:: panic:: AssertUnwindSafe ( || {
72+ let #context_arg_name = & mut __context;
73+ #body
74+ } ) ) ;
75+ }
10376 }
104- } else {
105- quote ! {
106- let mut #context_name = <#context_type as test_context:: TestContext >:: setup( ) ;
107- let #result_name = std:: panic:: catch_unwind( std:: panic:: AssertUnwindSafe ( || {
108- #body
109- } ) ) ;
110- <#context_type as test_context:: TestContext >:: teardown( #context_name) ;
77+ ( false , false ) => {
78+ quote ! {
79+ let mut __context = <#context_type as test_context:: TestContext >:: setup( ) ;
80+ let #result_name = std:: panic:: catch_unwind( std:: panic:: AssertUnwindSafe ( || {
81+ let #context_arg_name = & mut __context;
82+ #body
83+ } ) ) ;
84+ <#context_type as test_context:: TestContext >:: teardown( __context) ;
85+ }
11186 }
11287 } ;
11388
114- let handle_wrapped_result = handle_result ( result_name) ;
115-
116- quote ! {
89+ let body = quote ! {
11790 {
11891 #body
119- #handle_wrapped_result
120- }
121- }
122- }
123-
124- fn handle_result ( result_name : Ident ) -> proc_macro2:: TokenStream {
125- quote ! {
126- match #result_name {
127- Ok ( value) => value,
128- Err ( err) => {
129- std:: panic:: resume_unwind( err) ;
92+ match #result_name {
93+ Ok ( value) => value,
94+ Err ( err) => {
95+ std:: panic:: resume_unwind( err) ;
96+ }
13097 }
13198 }
132- }
99+ } ;
100+
101+ input. block = Box :: new ( syn:: parse2 ( body) . unwrap ( ) ) ;
102+
103+ input
133104}
134105
135- fn extract_and_remove_context_arg (
106+ fn remove_context_arg (
136107 mut input : syn:: ItemFn ,
137108 expected_context_type : syn:: Type ,
138109) -> ( syn:: ItemFn , Option < syn:: Ident > ) {
@@ -154,10 +125,12 @@ fn extract_and_remove_context_arg(
154125 }
155126 }
156127 }
128+
157129 new_args. push ( arg. clone ( ) ) ;
158130 }
159131
160132 input. sig . inputs = new_args;
133+
161134 ( input, context_arg_name)
162135}
163136
0 commit comments