Skip to content

Commit d194a24

Browse files
authored
Introduce #[cgp_impl] to simplify provider trait implementation (#174)
* Draft implement new #[cgp_impl] macro * Finish draft #[cgp_impl] macro * Basic test is working * Improve impl macro * Revert use of Refl for composite type * Remove Refl * Test use #[cgp_impl] in cgp-anyhow-error * Use provider trait name directly instead * Only replace self var when self receiver is used * Add __ to context variable to avoid name clash * Allow component name to be specified * Fix clippy
1 parent ba9bf32 commit d194a24

File tree

14 files changed

+311
-40
lines changed

14 files changed

+311
-40
lines changed

crates/cgp-core/src/prelude.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pub use cgp_field::types::{
1818
};
1919
pub use cgp_macro::{
2020
BuildField, CgpData, CgpRecord, CgpVariant, ExtractField, FromVariant, HasField, HasFields,
21-
Product, Sum, Symbol, cgp_auto_getter, cgp_component, cgp_context, cgp_getter,
21+
Product, Sum, Symbol, cgp_auto_getter, cgp_component, cgp_context, cgp_getter, cgp_impl,
2222
cgp_new_provider, cgp_preset, cgp_provider, cgp_type, check_components,
2323
delegate_and_check_components, delegate_components, product, re_export_imports, replace_with,
2424
};

crates/cgp-error-anyhow/src/impls/raise_anyhow_error.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ use cgp_core::prelude::*;
77

88
pub struct RaiseAnyhowError;
99

10-
#[cgp_provider]
11-
impl<Context, E> ErrorRaiser<Context, E> for RaiseAnyhowError
10+
#[cgp_impl(RaiseAnyhowError)]
11+
impl<Context, E> ErrorRaiser<E> for Context
1212
where
1313
Context: HasErrorType<Error = Error>,
1414
E: StdError + Send + Sync + 'static,
@@ -18,8 +18,8 @@ where
1818
}
1919
}
2020

21-
#[cgp_provider]
22-
impl<Context, Detail> ErrorWrapper<Context, Detail> for RaiseAnyhowError
21+
#[cgp_impl(RaiseAnyhowError)]
22+
impl<Context, Detail> ErrorWrapper<Detail> for Context
2323
where
2424
Context: HasErrorType<Error = Error>,
2525
Detail: Display + Send + Sync + 'static,

crates/cgp-macro-lib/src/derive_component/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@ mod use_context_impl;
1313
mod use_delegate_impl;
1414

1515
pub use derive::*;
16+
pub use replace_self_receiver::*;
1617
pub use replace_self_type::*;
1718
pub use snake_case::*;

crates/cgp-macro-lib/src/derive_component/provider_trait.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
use alloc::vec::Vec;
22

3-
use quote::quote;
3+
use quote::{ToTokens, quote};
44
use syn::punctuated::Punctuated;
55
use syn::token::Comma;
66
use syn::{Ident, ItemTrait, TraitItem, TypeParamBound, parse2};
77

8-
use crate::derive_component::replace_self_receiver::replace_self_receiver;
8+
use crate::derive_component::replace_self_receiver::replace_self_receiver_in_signature;
99
use crate::derive_component::replace_self_type::{
1010
iter_parse_and_replace_self_type, parse_and_replace_self_type,
1111
};
12+
use crate::derive_component::to_snake_case_ident;
1213
use crate::parse::parse_is_provider_params;
1314

1415
pub fn derive_provider_trait(
@@ -89,7 +90,11 @@ pub fn derive_provider_trait(
8990
parse_and_replace_self_type(item, context_type, &local_assoc_types)?;
9091

9192
if let TraitItem::Fn(func) = &mut replaced_item {
92-
replace_self_receiver(func, context_type);
93+
replace_self_receiver_in_signature(
94+
&mut func.sig,
95+
&to_snake_case_ident(context_type),
96+
context_type.to_token_stream(),
97+
);
9398
}
9499

95100
*item = replaced_item;
Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,41 @@
1-
use proc_macro2::Ident;
2-
use syn::{FnArg, TraitItemFn, parse_quote};
1+
use proc_macro2::{Ident, TokenStream};
2+
use syn::{FnArg, Receiver, Signature, parse_quote};
33

4-
use crate::derive_component::snake_case::to_snake_case_ident;
5-
6-
pub fn replace_self_receiver(func: &mut TraitItemFn, replaced_type: &Ident) {
7-
if let Some(arg) = func.sig.inputs.first_mut()
4+
pub fn replace_self_receiver_in_signature(
5+
sig: &mut Signature,
6+
replaced_var: &Ident,
7+
replaced_type: TokenStream,
8+
) {
9+
if let Some(arg) = sig.inputs.first_mut()
810
&& let FnArg::Receiver(receiver) = arg
911
{
10-
let replaced_var = to_snake_case_ident(replaced_type);
12+
*arg = replace_self_receiver(receiver, replaced_var, replaced_type);
13+
}
14+
}
1115

12-
match (&receiver.reference, &receiver.mutability) {
13-
(None, None) => {
14-
*arg = parse_quote!(#replaced_var : #replaced_type);
15-
}
16-
(Some((_and, None)), None) => {
17-
*arg = parse_quote!(#replaced_var : & #replaced_type);
18-
}
19-
(Some((_and, Some(life))), None) => {
20-
*arg = parse_quote!(#replaced_var : & #life #replaced_type);
21-
}
22-
(Some((_and, None)), Some(_mut)) => {
23-
*arg = parse_quote!(#replaced_var : &mut #replaced_type);
24-
}
25-
(Some((_and, Some(life))), Some(_mut)) => {
26-
*arg = parse_quote!(#replaced_var : & #life mut #replaced_type);
27-
}
28-
_ => {}
16+
pub fn replace_self_receiver(
17+
receiver: &mut Receiver,
18+
replaced_var: &Ident,
19+
replaced_type: TokenStream,
20+
) -> FnArg {
21+
match (&receiver.reference, &receiver.mutability) {
22+
(None, None) => {
23+
parse_quote!(#replaced_var : #replaced_type)
24+
}
25+
(Some((_and, None)), None) => {
26+
parse_quote!(#replaced_var : & #replaced_type)
27+
}
28+
(Some((_and, Some(life))), None) => {
29+
parse_quote!(#replaced_var : & #life #replaced_type)
30+
}
31+
(Some((_and, None)), Some(_mut)) => {
32+
parse_quote!(#replaced_var : &mut #replaced_type)
33+
}
34+
(Some((_and, Some(life))), Some(_mut)) => {
35+
parse_quote!(#replaced_var : & #life mut #replaced_type)
36+
}
37+
(None, Some(_mut)) => {
38+
parse_quote!(#replaced_var : mut #replaced_type)
2939
}
3040
}
3141
}

crates/cgp-macro-lib/src/derive_component/replace_self_type.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,17 @@ pub fn parse_and_replace_self_type<T>(
2727
where
2828
T: ToTokens + Parse,
2929
{
30-
let stream = replace_self_type(val.to_token_stream(), replaced_ident, local_assoc_types);
30+
let stream = replace_self_type(
31+
val.to_token_stream(),
32+
replaced_ident.to_token_stream(),
33+
local_assoc_types,
34+
);
3135
syn::parse2(stream)
3236
}
3337

3438
pub fn replace_self_type(
3539
stream: TokenStream,
36-
replaced_ident: &Ident,
40+
replaced_ident: TokenStream,
3741
local_assoc_types: &Vec<Ident>,
3842
) -> TokenStream {
3943
let self_type = format_ident!("Self");
@@ -57,7 +61,7 @@ pub fn replace_self_type(
5761
Some(TokenTree::Ident(assoc_type))
5862
if local_assoc_types.contains(assoc_type) =>
5963
{
60-
ident
64+
ident.to_token_stream()
6165
}
6266
_ => replaced_ident,
6367
}
@@ -68,14 +72,14 @@ pub fn replace_self_type(
6872
_ => replaced_ident,
6973
};
7074

71-
result_stream.push(TokenTree::Ident(replaced));
75+
result_stream.extend(replaced);
7276
} else {
7377
result_stream.push(TokenTree::Ident(ident));
7478
}
7579
}
7680
TokenTree::Group(group) => {
7781
let replaced_stream =
78-
replace_self_type(group.stream(), replaced_ident, local_assoc_types);
82+
replace_self_type(group.stream(), replaced_ident.clone(), local_assoc_types);
7983
let replaced_group = Group::new(group.delimiter(), replaced_stream);
8084

8185
result_stream.push(TokenTree::Group(replaced_group));

crates/cgp-macro-lib/src/derive_component/snake_case.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,8 @@ pub fn to_snake_case_str(val: &str) -> String {
1919
}
2020

2121
pub fn to_snake_case_ident(val: &Ident) -> Ident {
22-
Ident::new(&to_snake_case_str(&val.to_string()), Span::call_site())
22+
Ident::new(
23+
&format!("__{}__", to_snake_case_str(&val.to_string())),
24+
Span::call_site(),
25+
)
2326
}

crates/cgp-macro-lib/src/derive_getter/parse.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ fn parse_receiver(context_ident: &Ident, arg: &FnArg) -> syn::Result<(ReceiverMo
174174
Type::Reference(ty) => {
175175
let receiver = parse2(replace_self_type(
176176
ty.elem.to_token_stream(),
177-
context_ident,
177+
context_ident.to_token_stream(),
178178
&Vec::new(),
179179
))?;
180180
Ok((ReceiverMode::Type(receiver), ty.mutability))
@@ -191,7 +191,7 @@ fn parse_return_type(context_type: &Ident, return_type: &ReturnType) -> syn::Res
191191
match return_type {
192192
ReturnType::Type(_, ty) => parse2(replace_self_type(
193193
ty.to_token_stream(),
194-
context_type,
194+
context_type.to_token_stream(),
195195
&Vec::new(),
196196
)),
197197
_ => Err(Error::new(
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
use proc_macro2::{Group, Span, TokenStream, TokenTree};
2+
use quote::{ToTokens, format_ident, quote};
3+
use syn::parse::discouraged::Speculative;
4+
use syn::parse::{Parse, ParseStream};
5+
use syn::spanned::Spanned;
6+
use syn::token::{Colon, For};
7+
use syn::{Error, FnArg, Ident, ImplItem, ItemImpl, Type, parse2};
8+
9+
use crate::derive_component::{replace_self_receiver, replace_self_type, to_snake_case_ident};
10+
use crate::derive_provider::{
11+
derive_component_name_from_provider_impl, derive_is_provider_for, derive_provider_struct,
12+
};
13+
use crate::parse::SimpleType;
14+
15+
pub fn cgp_impl(attr: TokenStream, body: TokenStream) -> syn::Result<TokenStream> {
16+
let spec: ImplProviderSpec = parse2(attr)?;
17+
let item_impl: ItemImpl = parse2(body)?;
18+
19+
let consumer_trait_path = &item_impl
20+
.trait_
21+
.as_ref()
22+
.ok_or_else(|| Error::new(item_impl.span(), "expect impl trait to contain path"))?
23+
.1;
24+
25+
let consumer_trait_path: SimpleType = parse2(consumer_trait_path.to_token_stream())?;
26+
27+
let provider_impl =
28+
transform_impl_trait(&item_impl, &consumer_trait_path, &spec.provider_type)?;
29+
30+
let component_type = match &spec.component_type {
31+
Some(component_type) => component_type.clone(),
32+
None => derive_component_name_from_provider_impl(&provider_impl)?,
33+
};
34+
35+
let is_provider_for_impl: ItemImpl = derive_is_provider_for(&component_type, &provider_impl)?;
36+
37+
let provider_struct = if spec.new_struct {
38+
Some(derive_provider_struct(&provider_impl)?)
39+
} else {
40+
None
41+
};
42+
43+
Ok(quote! {
44+
#provider_struct
45+
46+
#provider_impl
47+
48+
#is_provider_for_impl
49+
})
50+
}
51+
52+
pub struct ImplProviderSpec {
53+
pub new_struct: bool,
54+
pub provider_type: Type,
55+
pub component_type: Option<Type>,
56+
}
57+
58+
impl Parse for ImplProviderSpec {
59+
fn parse(input: ParseStream) -> syn::Result<Self> {
60+
let new_struct = {
61+
let fork = input.fork();
62+
let new_ident: Option<Ident> = fork.parse().ok();
63+
match new_ident {
64+
Some(new_ident) if new_ident == "new" => {
65+
input.advance_to(&fork);
66+
true
67+
}
68+
_ => false,
69+
}
70+
};
71+
72+
let provider_type = input.parse()?;
73+
74+
let component_type = if let Some(_colon) = input.parse::<Option<Colon>>()? {
75+
let component_type: Type = input.parse()?;
76+
Some(component_type)
77+
} else {
78+
None
79+
};
80+
81+
Ok(ImplProviderSpec {
82+
new_struct,
83+
provider_type,
84+
component_type,
85+
})
86+
}
87+
}
88+
89+
pub fn transform_impl_trait(
90+
item_impl: &ItemImpl,
91+
consumer_trait_path: &SimpleType,
92+
provider_type: &Type,
93+
) -> syn::Result<ItemImpl> {
94+
let context_type = item_impl.self_ty.as_ref();
95+
96+
let context_var = if let Ok(ident) = parse2::<Ident>(context_type.to_token_stream()) {
97+
to_snake_case_ident(&ident)
98+
} else {
99+
Ident::new("__context__", Span::call_site())
100+
};
101+
102+
let local_assoc_types: Vec<Ident> = item_impl
103+
.items
104+
.iter()
105+
.filter_map(|item| {
106+
if let ImplItem::Type(assoc_type) = item {
107+
Some(assoc_type.ident.clone())
108+
} else {
109+
None
110+
}
111+
})
112+
.collect();
113+
114+
let raw_out_impl = replace_self_type(
115+
item_impl.to_token_stream(),
116+
context_type.to_token_stream(),
117+
&local_assoc_types,
118+
);
119+
120+
let mut out_impl: ItemImpl = parse2(raw_out_impl)?;
121+
out_impl.self_ty = Box::new(provider_type.clone());
122+
123+
let mut provider_trait_path: SimpleType = consumer_trait_path.clone();
124+
125+
match &mut provider_trait_path.generics {
126+
Some(generics) => {
127+
generics
128+
.args
129+
.insert(0, parse2(context_type.to_token_stream())?);
130+
}
131+
None => {
132+
provider_trait_path.generics = Some(parse2(quote! { < #context_type > })?);
133+
}
134+
}
135+
136+
out_impl.trait_ = Some((
137+
None,
138+
parse2(provider_trait_path.to_token_stream())?,
139+
For(Span::call_site()),
140+
));
141+
142+
for item in out_impl.items.iter_mut() {
143+
if let ImplItem::Fn(item_fn) = item
144+
&& let Some(arg) = item_fn.sig.inputs.first_mut()
145+
&& let FnArg::Receiver(receiver) = arg
146+
{
147+
*arg = replace_self_receiver(receiver, &context_var, context_type.to_token_stream());
148+
149+
let replaced_block = replace_self_var(item_fn.block.to_token_stream(), &context_var);
150+
item_fn.block = parse2(replaced_block)?;
151+
}
152+
}
153+
154+
Ok(out_impl)
155+
}
156+
157+
fn replace_self_var(stream: TokenStream, replaced_ident: &Ident) -> TokenStream {
158+
let self_ident = format_ident!("self");
159+
160+
let mut result_stream: Vec<TokenTree> = Vec::new();
161+
162+
let token_iter = stream.into_iter();
163+
164+
for tree in token_iter {
165+
match tree {
166+
TokenTree::Ident(ident) => {
167+
if ident == self_ident {
168+
result_stream.push(TokenTree::Ident(replaced_ident.clone()));
169+
} else {
170+
result_stream.push(TokenTree::Ident(ident));
171+
}
172+
}
173+
TokenTree::Group(group) => {
174+
let replaced_stream = replace_self_var(group.stream(), replaced_ident);
175+
let replaced_group = Group::new(group.delimiter(), replaced_stream);
176+
177+
result_stream.push(TokenTree::Group(replaced_group));
178+
}
179+
TokenTree::Punct(punct) => {
180+
result_stream.push(TokenTree::Punct(punct));
181+
}
182+
TokenTree::Literal(lit) => result_stream.push(TokenTree::Literal(lit)),
183+
}
184+
}
185+
186+
result_stream.into_iter().collect()
187+
}

0 commit comments

Comments
 (0)