Skip to content

Commit af2caf7

Browse files
committed
support generics
Signed-off-by: Teo Koon Peng <[email protected]>
1 parent b612a05 commit af2caf7

File tree

3 files changed

+75
-47
lines changed

3 files changed

+75
-47
lines changed

macros/src/buffer.rs

Lines changed: 59 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,35 @@
11
use proc_macro::TokenStream;
22
use quote::{format_ident, quote};
3-
use syn::{DeriveInput, Ident, Type};
3+
use syn::{parse_quote, Ident, ItemStruct, Type};
44

55
use crate::Result;
66

7-
pub(crate) fn impl_joined_value(ast: DeriveInput) -> Result<TokenStream> {
8-
let struct_ident = ast.ident;
9-
let (field_ident, field_type): (Vec<Ident>, Vec<Type>) = match ast.data {
10-
syn::Data::Struct(data) => get_fields_map(data.fields)?.into_iter().unzip(),
11-
_ => return Err("expected struct".to_string()),
12-
};
13-
let map_key: Vec<String> = field_ident.iter().map(|v| v.to_string()).collect();
7+
pub(crate) fn impl_joined_value(input_struct: &ItemStruct) -> Result<TokenStream> {
8+
let struct_ident = &input_struct.ident;
9+
let (impl_generics, ty_generics, where_clause) = input_struct.generics.split_for_impl();
10+
let (field_ident, field_type): (Vec<_>, Vec<_>) =
11+
get_fields_map(&input_struct.fields)?.into_iter().unzip();
1412
let struct_buffer_ident = format_ident!("__bevy_impulse_{}_Buffers", struct_ident);
1513

16-
let impl_buffer_map_layout =
17-
buffer_map_layout(&struct_buffer_ident, &field_ident, &field_type, &map_key);
18-
let impl_joined = joined(&struct_buffer_ident, &struct_ident, &field_ident);
19-
20-
let gen = quote! {
21-
impl ::bevy_impulse::JoinedValue for #struct_ident {
22-
type Buffers = #struct_buffer_ident;
23-
}
24-
14+
let buffer_struct: ItemStruct = parse_quote! {
2515
#[derive(Clone)]
2616
#[allow(non_camel_case_types)]
27-
struct #struct_buffer_ident {
17+
struct #struct_buffer_ident #impl_generics #where_clause {
2818
#(
2919
#field_ident: ::bevy_impulse::Buffer<#field_type>,
3020
)*
3121
}
22+
};
23+
24+
let impl_buffer_map_layout = impl_buffer_map_layout(&buffer_struct, &input_struct)?;
25+
let impl_joined = impl_joined(&buffer_struct, &input_struct)?;
26+
27+
let gen = quote! {
28+
impl #impl_generics ::bevy_impulse::JoinedValue for #struct_ident #ty_generics #where_clause {
29+
type Buffers = #struct_buffer_ident #ty_generics;
30+
}
31+
32+
#buffer_struct
3233

3334
#impl_buffer_map_layout
3435

@@ -38,28 +39,38 @@ pub(crate) fn impl_joined_value(ast: DeriveInput) -> Result<TokenStream> {
3839
Ok(gen.into())
3940
}
4041

41-
fn get_fields_map(fields: syn::Fields) -> Result<Vec<(Ident, Type)>> {
42+
fn get_fields_map(fields: &syn::Fields) -> Result<Vec<(&Ident, &Type)>> {
4243
match fields {
4344
syn::Fields::Named(data) => {
4445
let mut idents_types = Vec::with_capacity(data.named.len());
45-
for field in data.named {
46-
let ident = field.ident.ok_or("expected named fields".to_string())?;
47-
idents_types.push((ident, field.ty));
46+
for field in &data.named {
47+
let ident = field
48+
.ident
49+
.as_ref()
50+
.ok_or("expected named fields".to_string())?;
51+
idents_types.push((ident, &field.ty));
4852
}
4953
Ok(idents_types)
5054
}
5155
_ => return Err("expected named fields".to_string()),
5256
}
5357
}
5458

55-
fn buffer_map_layout(
56-
struct_ident: &Ident,
57-
field_ident: &Vec<Ident>,
58-
field_type: &Vec<Type>,
59-
map_key: &Vec<String>,
60-
) -> proc_macro2::TokenStream {
61-
quote! {
62-
impl ::bevy_impulse::BufferMapLayout for #struct_ident {
59+
/// Params:
60+
/// buffer_struct: The struct to implement `BufferMapLayout`.
61+
/// item_struct: The struct which `buffer_struct` is derived from.
62+
fn impl_buffer_map_layout(
63+
buffer_struct: &ItemStruct,
64+
item_struct: &ItemStruct,
65+
) -> Result<proc_macro2::TokenStream> {
66+
let struct_ident = &buffer_struct.ident;
67+
let (impl_generics, ty_generics, where_clause) = buffer_struct.generics.split_for_impl();
68+
let (field_ident, field_type): (Vec<_>, Vec<_>) =
69+
get_fields_map(&item_struct.fields)?.into_iter().unzip();
70+
let map_key: Vec<String> = field_ident.iter().map(|v| v.to_string()).collect();
71+
72+
Ok(quote! {
73+
impl #impl_generics ::bevy_impulse::BufferMapLayout for #struct_ident #ty_generics #where_clause {
6374
fn buffer_list(&self) -> ::smallvec::SmallVec<[AnyBuffer; 8]> {
6475
use smallvec::smallvec;
6576
smallvec![#(
@@ -83,16 +94,25 @@ fn buffer_map_layout(
8394
}
8495
}
8596
}
97+
.into())
8698
}
8799

88-
fn joined(
89-
struct_ident: &Ident,
90-
item_struct_ident: &Ident,
91-
field_ident: &Vec<Ident>,
92-
) -> proc_macro2::TokenStream {
93-
quote! {
94-
impl ::bevy_impulse::Joined for #struct_ident {
95-
type Item = #item_struct_ident;
100+
/// Params:
101+
/// joined_struct: The struct to implement `Joined`.
102+
/// item_struct: The associated `Item` type to use for the `Joined` implementation.
103+
fn impl_joined(
104+
joined_struct: &ItemStruct,
105+
item_struct: &ItemStruct,
106+
) -> Result<proc_macro2::TokenStream> {
107+
let struct_ident = &joined_struct.ident;
108+
let item_struct_ident = &item_struct.ident;
109+
let (impl_generics, ty_generics, where_clause) = item_struct.generics.split_for_impl();
110+
let (field_ident, _): (Vec<_>, Vec<_>) =
111+
get_fields_map(&item_struct.fields)?.into_iter().unzip();
112+
113+
Ok(quote! {
114+
impl #impl_generics ::bevy_impulse::Joined for #struct_ident #ty_generics #where_clause {
115+
type Item = #item_struct_ident #ty_generics;
96116

97117
fn pull(&self, session: ::bevy_ecs::prelude::Entity, world: &mut ::bevy_ecs::prelude::World) -> Result<Self::Item, ::bevy_impulse::OperationError> {
98118
#(
@@ -104,5 +124,5 @@ fn joined(
104124
)*})
105125
}
106126
}
107-
}
127+
}.into())
108128
}

macros/src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use buffer::impl_joined_value;
2020

2121
use proc_macro::TokenStream;
2222
use quote::quote;
23-
use syn::{parse_macro_input, DeriveInput};
23+
use syn::{parse_macro_input, DeriveInput, ItemStruct};
2424

2525
#[proc_macro_derive(Stream)]
2626
pub fn simple_stream_macro(item: TokenStream) -> TokenStream {
@@ -67,8 +67,8 @@ type Result<T> = std::result::Result<T, String>;
6767

6868
#[proc_macro_derive(JoinedValue)]
6969
pub fn derive_joined_value(input: TokenStream) -> TokenStream {
70-
let input = parse_macro_input!(input as DeriveInput);
71-
match impl_joined_value(input) {
70+
let input = parse_macro_input!(input as ItemStruct);
71+
match impl_joined_value(&input) {
7272
Ok(tokens) => tokens,
7373
Err(msg) => quote! {
7474
compile_error!(#msg);

src/buffer/buffer_map.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,11 @@ mod tests {
316316
use crate::{prelude::*, testing::*, BufferMap};
317317

318318
#[derive(Clone, JoinedValue)]
319-
struct TestJoinedValue {
319+
struct TestJoinedValue<T: Send + Sync + 'static + Clone> {
320320
integer: i64,
321321
float: f64,
322322
string: String,
323+
generic: T,
323324
}
324325

325326
#[test]
@@ -330,32 +331,36 @@ mod tests {
330331
let buffer_i64 = builder.create_buffer(BufferSettings::default());
331332
let buffer_f64 = builder.create_buffer(BufferSettings::default());
332333
let buffer_string = builder.create_buffer(BufferSettings::default());
334+
let buffer_generic = builder.create_buffer(BufferSettings::default());
333335

334336
let mut buffers = BufferMap::default();
335337
buffers.insert("integer", buffer_i64);
336338
buffers.insert("float", buffer_f64);
337339
buffers.insert("string", buffer_string);
340+
buffers.insert("generic", buffer_generic);
338341

339342
scope.input.chain(builder).fork_unzip((
340343
|chain: Chain<_>| chain.connect(buffer_i64.input_slot()),
341344
|chain: Chain<_>| chain.connect(buffer_f64.input_slot()),
342345
|chain: Chain<_>| chain.connect(buffer_string.input_slot()),
346+
|chain: Chain<_>| chain.connect(buffer_generic.input_slot()),
343347
));
344348

345349
builder.try_join(&buffers).unwrap().connect(scope.terminate);
346350
});
347351

348352
let mut promise = context.command(|commands| {
349353
commands
350-
.request((5_i64, 3.14_f64, "hello".to_string()), workflow)
354+
.request((5_i64, 3.14_f64, "hello".to_string(), "world"), workflow)
351355
.take_response()
352356
});
353357

354358
context.run_with_conditions(&mut promise, Duration::from_secs(2));
355-
let value: TestJoinedValue = promise.take().available().unwrap();
359+
let value: TestJoinedValue<&'static str> = promise.take().available().unwrap();
356360
assert_eq!(value.integer, 5);
357361
assert_eq!(value.float, 3.14);
358362
assert_eq!(value.string, "hello");
363+
assert_eq!(value.generic, "world");
359364
assert!(context.no_unhandled_errors());
360365
}
361366

@@ -368,28 +373,31 @@ mod tests {
368373
integer: builder.create_buffer(BufferSettings::default()),
369374
float: builder.create_buffer(BufferSettings::default()),
370375
string: builder.create_buffer(BufferSettings::default()),
376+
generic: builder.create_buffer(BufferSettings::default()),
371377
};
372378

373379
scope.input.chain(builder).fork_unzip((
374380
|chain: Chain<_>| chain.connect(buffers.integer.input_slot()),
375381
|chain: Chain<_>| chain.connect(buffers.float.input_slot()),
376382
|chain: Chain<_>| chain.connect(buffers.string.input_slot()),
383+
|chain: Chain<_>| chain.connect(buffers.generic.input_slot()),
377384
));
378385

379386
builder.join(buffers).connect(scope.terminate);
380387
});
381388

382389
let mut promise = context.command(|commands| {
383390
commands
384-
.request((5_i64, 3.14_f64, "hello".to_string()), workflow)
391+
.request((5_i64, 3.14_f64, "hello".to_string(), "world"), workflow)
385392
.take_response()
386393
});
387394

388395
context.run_with_conditions(&mut promise, Duration::from_secs(2));
389-
let value: TestJoinedValue = promise.take().available().unwrap();
396+
let value: TestJoinedValue<&'static str> = promise.take().available().unwrap();
390397
assert_eq!(value.integer, 5);
391398
assert_eq!(value.float, 3.14);
392399
assert_eq!(value.string, "hello");
400+
assert_eq!(value.generic, "world");
393401
assert!(context.no_unhandled_errors());
394402
}
395403
}

0 commit comments

Comments
 (0)