Skip to content

Commit ef4bdb2

Browse files
authored
fix: Regression in version 0.5.0 (#55)
An attempt to fix the issue defined in #53 & some refactors
1 parent 90f0f92 commit ef4bdb2

File tree

5 files changed

+65
-92
lines changed

5 files changed

+65
-92
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ resolver = "2"
33
members = ["test-context", "test-context-macros"]
44

55
[workspace.package]
6-
edition = "2021"
6+
edition = "2024"
77
version = "0.5.0"
8-
rust-version = "1.75.0"
8+
rust-version = "1.91.0"
99
homepage = "https://github.com/JasterV/test-context"
1010
repository = "https://github.com/JasterV/test-context"
1111
authors = [

rust-toolchain.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[toolchain]
2-
channel = "1.90"
2+
channel = "1.91"
33
components = ["rustfmt", "clippy", "rust-src", "rust-analyzer"]

test-context-macros/src/args.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use syn::{parse::Parse, Token, Type};
1+
use syn::{Token, Type, parse::Parse};
22

33
pub(crate) struct TestContextArgs {
44
pub(crate) context_type: Type,

test-context-macros/src/lib.rs

Lines changed: 60 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ mod args;
33
use args::TestContextArgs;
44
use proc_macro::TokenStream;
55
use quote::{format_ident, quote};
6-
use syn::{Block, Ident};
6+
use syn::Ident;
77

88
/// Macro to use on tests to add the setup/teardown functionality of your context.
99
///
@@ -28,111 +28,82 @@ use syn::{Block, Ident};
2828
#[proc_macro_attribute]
2929
pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
3030
let args = syn::parse_macro_input!(attr as TestContextArgs);
31-
3231
let input = syn::parse_macro_input!(item as syn::ItemFn);
33-
let is_async = input.sig.asyncness.is_some();
34-
35-
let (new_input, context_arg_name) =
36-
extract_and_remove_context_arg(input.clone(), args.context_type.clone());
37-
38-
let wrapper_body = if is_async {
39-
async_wrapper_body(args, &context_arg_name, &input.block)
40-
} else {
41-
sync_wrapper_body(args, &context_arg_name, &input.block)
42-
};
4332

44-
let mut result_input = new_input;
45-
result_input.block = Box::new(syn::parse2(wrapper_body).unwrap());
33+
let (input, context_arg_name) = remove_context_arg(input, args.context_type.clone());
34+
let input = refactor_input_body(input, &args, context_arg_name);
4635

47-
quote! { #result_input }.into()
36+
quote! { #input }.into()
4837
}
4938

50-
fn async_wrapper_body(
51-
args: TestContextArgs,
52-
context_arg_name: &Option<syn::Ident>,
53-
body: &Block,
54-
) -> proc_macro2::TokenStream {
55-
let context_type = args.context_type;
39+
fn refactor_input_body(
40+
mut input: syn::ItemFn,
41+
args: &TestContextArgs,
42+
context_arg_name: Option<Ident>,
43+
) -> syn::ItemFn {
44+
let context_type = &args.context_type;
45+
let context_arg_name = context_arg_name.unwrap_or_else(|| format_ident!("test_ctx"));
5646
let result_name = format_ident!("wrapped_result");
47+
let body = &input.block;
48+
let is_async = input.sig.asyncness.is_some();
5749

58-
let binding = format_ident!("test_ctx");
59-
let context_name = context_arg_name.as_ref().unwrap_or(&binding);
60-
61-
let body = if args.skip_teardown {
62-
quote! {
63-
let #context_name = <#context_type as test_context::AsyncTestContext>::setup().await;
64-
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
65-
}
66-
} else {
67-
quote! {
68-
let mut #context_name = <#context_type as test_context::AsyncTestContext>::setup().await;
69-
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
70-
<#context_type as test_context::AsyncTestContext>::teardown(#context_name).await;
50+
let body = match (is_async, args.skip_teardown) {
51+
(true, true) => {
52+
quote! {
53+
use test_context::futures::FutureExt;
54+
let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
55+
let #context_arg_name = &mut __context;
56+
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
57+
}
7158
}
72-
};
73-
74-
let handle_wrapped_result = handle_result(result_name);
75-
76-
quote! {
77-
{
78-
use test_context::futures::FutureExt;
79-
#body
80-
#handle_wrapped_result
59+
(true, false) => {
60+
quote! {
61+
use test_context::futures::FutureExt;
62+
let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
63+
let #context_arg_name = &mut __context;
64+
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
65+
<#context_type as test_context::AsyncTestContext>::teardown(__context).await;
66+
}
8167
}
82-
}
83-
}
84-
85-
fn sync_wrapper_body(
86-
args: TestContextArgs,
87-
context_arg_name: &Option<syn::Ident>,
88-
body: &Block,
89-
) -> proc_macro2::TokenStream {
90-
let context_type = args.context_type;
91-
let result_name = format_ident!("wrapped_result");
92-
93-
let binding = format_ident!("test_ctx");
94-
let context_name = context_arg_name.as_ref().unwrap_or(&binding);
95-
96-
let body = if args.skip_teardown {
97-
quote! {
98-
let mut #context_name = <#context_type as test_context::TestContext>::setup();
99-
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
100-
let #context_name = &mut #context_name;
101-
#body
102-
}));
68+
(false, true) => {
69+
quote! {
70+
let mut __context = <#context_type as test_context::TestContext>::setup();
71+
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
72+
let #context_arg_name = &mut __context;
73+
#body
74+
}));
75+
}
10376
}
104-
} else {
105-
quote! {
106-
let mut #context_name = <#context_type as test_context::TestContext>::setup();
107-
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
108-
#body
109-
}));
110-
<#context_type as test_context::TestContext>::teardown(#context_name);
77+
(false, false) => {
78+
quote! {
79+
let mut __context = <#context_type as test_context::TestContext>::setup();
80+
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
81+
let #context_arg_name = &mut __context;
82+
#body
83+
}));
84+
<#context_type as test_context::TestContext>::teardown(__context);
85+
}
11186
}
11287
};
11388

114-
let handle_wrapped_result = handle_result(result_name);
115-
116-
quote! {
89+
let body = quote! {
11790
{
11891
#body
119-
#handle_wrapped_result
120-
}
121-
}
122-
}
123-
124-
fn handle_result(result_name: Ident) -> proc_macro2::TokenStream {
125-
quote! {
126-
match #result_name {
127-
Ok(value) => value,
128-
Err(err) => {
129-
std::panic::resume_unwind(err);
92+
match #result_name {
93+
Ok(value) => value,
94+
Err(err) => {
95+
std::panic::resume_unwind(err);
96+
}
13097
}
13198
}
132-
}
99+
};
100+
101+
input.block = Box::new(syn::parse2(body).unwrap());
102+
103+
input
133104
}
134105

135-
fn extract_and_remove_context_arg(
106+
fn remove_context_arg(
136107
mut input: syn::ItemFn,
137108
expected_context_type: syn::Type,
138109
) -> (syn::ItemFn, Option<syn::Ident>) {
@@ -154,10 +125,12 @@ fn extract_and_remove_context_arg(
154125
}
155126
}
156127
}
128+
157129
new_args.push(arg.clone());
158130
}
159131

160132
input.sig.inputs = new_args;
133+
161134
(input, context_arg_name)
162135
}
163136

test-context/tests/test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::marker::PhantomData;
22

33
use rstest::rstest;
4-
use test_context::{test_context, AsyncTestContext, TestContext};
4+
use test_context::{AsyncTestContext, TestContext, test_context};
55

66
struct Context {
77
n: u32,

0 commit comments

Comments
 (0)