Skip to content

Commit 3975635

Browse files
authored
feat: make it so immutable references & full ownership can be taken depending on context (#58)
This pull requests introduces changes that make the use of the `test_context` macro more flexible. Now, if the teardown is not skipped (default behavior), either an `immutable` or a `mutable` reference can be used for the context. If the teardown is skipped with the `skip_teardown` option, an `immutable`, a `mutable` reference or full ownership can be taken. So now the following is possible: ```rust #[test_context(TeardownPanicContext, skip_teardown)] #[tokio::test] async fn test_async_skip_teardown(_ctx: &mut TeardownPanicContext) {} #[test_context(TeardownPanicContext, skip_teardown)] #[tokio::test] async fn test_async_skip_teardown_with_immutable_ref(_ctx: &TeardownPanicContext) {} #[test_context(TeardownPanicContext, skip_teardown)] #[tokio::test] async fn test_async_skip_teardown_with_full_ownership(_ctx: TeardownPanicContext) {} #[test_context(TeardownPanicContext, skip_teardown)] #[test] fn test_sync_skip_teardown(_ctx: &mut TeardownPanicContext) {} #[test_context(TeardownPanicContext, skip_teardown)] #[test] fn test_sync_skip_teardown_with_immutable_ref(_ctx: &TeardownPanicContext) {} #[test_context(TeardownPanicContext, skip_teardown)] #[test] fn test_sync_skip_teardown_with_full_ownership(_ctx: TeardownPanicContext) {} ```
1 parent 5e407cd commit 3975635

File tree

6 files changed

+237
-94
lines changed

6 files changed

+237
-94
lines changed

README.md

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,8 @@ tests annotated with `#[tokio::test]` continue to work as usual without the feat
127127

128128
## Skipping the teardown execution
129129

130-
Also, if you don't care about the
131-
teardown execution for a specific test, you can use the `skip_teardown` keyword on the macro
132-
like this:
130+
Also, if you don't care about the teardown execution for a specific test,
131+
you can use the `skip_teardown` keyword on the macro like this:
133132

134133
```rust
135134
use test_context::{test_context, TestContext};
@@ -144,11 +143,50 @@ like this:
144143

145144
#[test_context(MyContext, skip_teardown)]
146145
#[test]
147-
fn test_without_teardown(ctx: &mut MyContext) {
146+
fn test_without_teardown(ctx: &MyContext) {}
147+
```
148+
149+
## Taking ownership of the context vs taking a reference
150+
151+
If the teardown is ON (default behavior), you can only take a reference to the context, either mutable or immutable, as follows:
152+
153+
```rust
154+
#[test_context(MyContext)]
155+
#[test]
156+
fn test_with_teardown_using_immutable_ref(ctx: &MyContext) {}
157+
158+
#[test_context(MyContext)]
159+
#[test]
160+
fn test_with_teardown_using_mutable_ref(ctx: &mut MyContext) {}
161+
```
162+
163+
❌The following is invalid:
164+
165+
```rust
166+
#[test_context(MyContext)]
167+
#[test]
168+
fn test_with_teardown_taking_ownership(ctx: MyContext) {}
169+
```
170+
171+
If the teardown is skipped (as specified in the section above), you can take an immutable ref, mutable ref or full ownership of the context:
172+
173+
```rust
174+
#[test_context(MyContext, skip_teardown)]
175+
#[test]
176+
fn test_without_teardown(ctx: MyContext) {
148177
// Perform any operations that require full ownership of your context
149178
}
179+
180+
#[test_context(MyContext, skip_teardown)]
181+
#[test]
182+
fn test_without_teardown_taking_a_ref(ctx: &MyContext) {}
183+
184+
#[test_context(MyContext, skip_teardown)]
185+
#[test]
186+
fn test_without_teardown_taking_a_mut_ref(ctx: &mut MyContext) {}
150187
```
151188

189+
152190
## ⚠️ Ensure that the context type specified in the macro matches the test function argument type exactly
153191

154192
The error occurs when a context type with an absolute path is mixed with an it's alias.
@@ -161,8 +199,8 @@ mod database {
161199
162200
pub struct Connection;
163201
164-
impl TestContext for :Connection {
165-
fn setup() -> Self {Connection}
202+
impl TestContext for Connection {
203+
fn setup() -> Self { Connection }
166204
fn teardown(self) {...}
167205
}
168206
}

release-plz.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ pr_branch_prefix = "release-"
33
pr_labels = ["release"]
44
git_tag_enable = true
55
git_tag_name = "v{{ version }}"
6+
git_release_name = "v{{ version }}"
67
pr_draft = true

test-context-macros/src/lib.rs

Lines changed: 95 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
mod args;
1+
mod macro_args;
2+
mod test_args;
23

3-
use args::TestContextArgs;
4+
use crate::test_args::{ContextArg, ContextArgMode, TestArg};
5+
use macro_args::TestContextArgs;
46
use proc_macro::TokenStream;
57
use quote::{format_ident, quote};
6-
use syn::Ident;
8+
use syn::ItemFn;
79

810
/// Macro to use on tests to add the setup/teardown functionality of your context.
911
///
@@ -30,59 +32,108 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
3032
let args = syn::parse_macro_input!(attr as TestContextArgs);
3133
let input = syn::parse_macro_input!(item as syn::ItemFn);
3234

33-
let (input, context_arg_name) = remove_context_arg(input, args.context_type.clone());
34-
let input = refactor_input_body(input, &args, context_arg_name);
35+
let (input, context_args) = remove_context_args(input, args.context_type.clone());
36+
37+
if context_args.len() != 1 {
38+
panic!("Exactly one Context argument must be defined");
39+
}
40+
41+
let context_arg = context_args.into_iter().next().unwrap();
42+
43+
if !args.skip_teardown && context_arg.mode == ContextArgMode::Owned {
44+
panic!(
45+
"It is not possible to take ownership of the context if the teardown has to be ran."
46+
);
47+
}
48+
49+
let input = refactor_input_body(input, &args, context_arg);
3550

3651
quote! { #input }.into()
3752
}
3853

39-
fn refactor_input_body(
54+
fn remove_context_args(
4055
mut input: syn::ItemFn,
56+
expected_context_type: syn::Type,
57+
) -> (syn::ItemFn, Vec<ContextArg>) {
58+
let test_args: Vec<TestArg> = input
59+
.sig
60+
.inputs
61+
.into_iter()
62+
.map(|arg| TestArg::parse_arg_with_expected_context(arg, &expected_context_type))
63+
.collect();
64+
65+
let context_args: Vec<ContextArg> = test_args
66+
.iter()
67+
.cloned()
68+
.filter_map(|arg| match arg {
69+
TestArg::Any(_) => None,
70+
TestArg::Context(context_arg_info) => Some(context_arg_info),
71+
})
72+
.collect();
73+
74+
let new_args: syn::punctuated::Punctuated<_, _> = test_args
75+
.into_iter()
76+
.filter_map(|arg| match arg {
77+
TestArg::Any(fn_arg) => Some(fn_arg),
78+
TestArg::Context(_) => None,
79+
})
80+
.collect();
81+
82+
input.sig.inputs = new_args;
83+
84+
(input, context_args)
85+
}
86+
87+
fn refactor_input_body(
88+
input: syn::ItemFn,
4189
args: &TestContextArgs,
42-
context_arg_name: Option<Ident>,
90+
context_arg: ContextArg,
4391
) -> syn::ItemFn {
4492
let context_type = &args.context_type;
45-
let context_arg_name = context_arg_name.unwrap_or_else(|| format_ident!("test_ctx"));
4693
let result_name = format_ident!("wrapped_result");
4794
let body = &input.block;
4895
let is_async = input.sig.asyncness.is_some();
96+
let context_arg_name = context_arg.name;
4997

50-
let body = match (is_async, args.skip_teardown) {
51-
(true, true) => {
52-
quote! {
53-
use test_context::futures::FutureExt;
54-
let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
55-
let #context_arg_name = &mut __context;
56-
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
57-
}
98+
let context_binding = match context_arg.mode {
99+
ContextArgMode::Owned => quote! { let #context_arg_name = __context; },
100+
ContextArgMode::Reference => quote! { let #context_arg_name = &__context; },
101+
ContextArgMode::MutableReference => quote! { let #context_arg_name = &mut __context; },
102+
};
103+
104+
let body = if args.skip_teardown && is_async {
105+
quote! {
106+
use test_context::futures::FutureExt;
107+
let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
108+
#context_binding
109+
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
58110
}
59-
(true, false) => {
60-
quote! {
61-
use test_context::futures::FutureExt;
62-
let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
63-
let #context_arg_name = &mut __context;
64-
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
65-
<#context_type as test_context::AsyncTestContext>::teardown(__context).await;
66-
}
111+
} else if args.skip_teardown && !is_async {
112+
quote! {
113+
let mut __context = <#context_type as test_context::TestContext>::setup();
114+
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
115+
#context_binding
116+
#body
117+
}));
67118
}
68-
(false, true) => {
69-
quote! {
70-
let mut __context = <#context_type as test_context::TestContext>::setup();
71-
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
72-
let #context_arg_name = &mut __context;
73-
#body
74-
}));
75-
}
119+
} else if !args.skip_teardown && is_async {
120+
quote! {
121+
use test_context::futures::FutureExt;
122+
let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
123+
#context_binding
124+
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
125+
<#context_type as test_context::AsyncTestContext>::teardown(__context).await;
76126
}
77-
(false, false) => {
78-
quote! {
79-
let mut __context = <#context_type as test_context::TestContext>::setup();
80-
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
81-
let #context_arg_name = &mut __context;
82-
#body
83-
}));
84-
<#context_type as test_context::TestContext>::teardown(__context);
85-
}
127+
}
128+
// !args.skip_teardown && !is_async
129+
else {
130+
quote! {
131+
let mut __context = <#context_type as test_context::TestContext>::setup();
132+
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
133+
#context_binding
134+
#body
135+
}));
136+
<#context_type as test_context::TestContext>::teardown(__context);
86137
}
87138
};
88139

@@ -98,46 +149,8 @@ fn refactor_input_body(
98149
}
99150
};
100151

101-
input.block = Box::new(syn::parse2(body).unwrap());
102-
103-
input
104-
}
105-
106-
fn remove_context_arg(
107-
mut input: syn::ItemFn,
108-
expected_context_type: syn::Type,
109-
) -> (syn::ItemFn, Option<syn::Ident>) {
110-
let mut context_arg_name = None;
111-
let mut new_args = syn::punctuated::Punctuated::new();
112-
113-
for arg in &input.sig.inputs {
114-
// Extract function arg:
115-
if let syn::FnArg::Typed(pat_type) = arg {
116-
// Extract arg identifier:
117-
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
118-
// Check that context arg is only ref or mutable ref:
119-
if let syn::Type::Reference(type_ref) = &*pat_type.ty {
120-
// Check that context has expected type:
121-
if types_equal(&type_ref.elem, &expected_context_type) {
122-
context_arg_name = Some(pat_ident.ident.clone());
123-
continue;
124-
}
125-
}
126-
}
127-
}
128-
129-
new_args.push(arg.clone());
130-
}
131-
132-
input.sig.inputs = new_args;
133-
134-
(input, context_arg_name)
135-
}
136-
137-
fn types_equal(a: &syn::Type, b: &syn::Type) -> bool {
138-
if let (syn::Type::Path(a_path), syn::Type::Path(b_path)) = (a, b) {
139-
return a_path.path.segments.last().unwrap().ident
140-
== b_path.path.segments.last().unwrap().ident;
152+
ItemFn {
153+
block: Box::new(syn::parse2(body).unwrap()),
154+
..input
141155
}
142-
quote!(#a).to_string() == quote!(#b).to_string()
143156
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
use quote::quote;
2+
use syn::FnArg;
3+
4+
#[derive(Clone)]
5+
pub struct ContextArg {
6+
/// The identifier name used for the context argument.
7+
pub name: syn::Ident,
8+
/// The mode in which the context was passed to the test function.
9+
pub mode: ContextArgMode,
10+
}
11+
12+
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
13+
pub enum ContextArgMode {
14+
/// The argument was passed as an owned value (`ContextType`). Only valid with `skip_teardown`.
15+
Owned,
16+
/// The argument was passed as an immutable reference (`&ContextType`).
17+
Reference,
18+
/// The argument was passed as a mutable reference (`&mut ContextType`).
19+
MutableReference,
20+
}
21+
22+
#[derive(Clone)]
23+
pub enum TestArg {
24+
Any(FnArg),
25+
Context(ContextArg),
26+
}
27+
28+
impl TestArg {
29+
pub fn parse_arg_with_expected_context(arg: FnArg, expected_context_type: &syn::Type) -> Self {
30+
// Check if the argument is the context argument
31+
if let syn::FnArg::Typed(pat_type) = &arg
32+
&& let syn::Pat::Ident(pat_ident) = &*pat_type.pat
33+
{
34+
let arg_type = &*pat_type.ty;
35+
// Check for mutable/immutable reference
36+
if let syn::Type::Reference(type_ref) = arg_type
37+
&& types_equal(&type_ref.elem, expected_context_type)
38+
{
39+
let mode = if type_ref.mutability.is_some() {
40+
ContextArgMode::MutableReference
41+
} else {
42+
ContextArgMode::Reference
43+
};
44+
45+
TestArg::Context(ContextArg {
46+
name: pat_ident.ident.clone(),
47+
mode,
48+
})
49+
} else if types_equal(arg_type, expected_context_type) {
50+
TestArg::Context(ContextArg {
51+
name: pat_ident.ident.clone(),
52+
mode: ContextArgMode::Owned,
53+
})
54+
} else {
55+
TestArg::Any(arg)
56+
}
57+
} else {
58+
TestArg::Any(arg)
59+
}
60+
}
61+
}
62+
63+
fn types_equal(a: &syn::Type, b: &syn::Type) -> bool {
64+
if let (syn::Type::Path(a_path), syn::Type::Path(b_path)) = (a, b) {
65+
return a_path.path.segments.last().unwrap().ident
66+
== b_path.path.segments.last().unwrap().ident;
67+
}
68+
quote!(#a).to_string() == quote!(#b).to_string()
69+
}

0 commit comments

Comments
 (0)