Skip to content

Commit b935c3c

Browse files
curryyzhengLuthaf
authored andcommitted
feat: allow to add attribute to specific soa type
1 parent 5c76a54 commit b935c3c

File tree

7 files changed

+107
-15
lines changed

7 files changed

+107
-15
lines changed

soa-derive-internal/src/input.rs

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
use std::convert::TryInto;
2+
13
use proc_macro2::{Span, TokenStream};
2-
use syn::{Data, DeriveInput, Ident, Field, Visibility, Meta, MetaNameValue, Lit};
34
use quote::quote;
5+
use syn::{
6+
Data, DeriveInput, Field, Ident, Lit, Meta, MetaList, MetaNameValue, NestedMeta, Visibility,
7+
};
48

59
/// Representing the struct we are deriving
610
pub struct Input {
@@ -11,29 +15,73 @@ pub struct Input {
1115
/// The list of fields in the struct
1216
pub fields: Vec<Field>,
1317
/// The struct overall visibility
14-
pub visibility: Visibility
18+
pub visibility: Visibility,
19+
20+
pub vec_attrs: Vec<Meta>,
21+
pub slice_attrs: Vec<Meta>,
22+
pub ref_attrs: Vec<Meta>,
23+
pub ptr_attrs: Vec<Meta>,
1524
}
1625

1726
impl Input {
1827
pub fn new(input: DeriveInput) -> Input {
1928
let fields = match input.data {
20-
Data::Struct(s) => {
21-
s.fields.iter().cloned().collect::<Vec<_>>()
22-
}
29+
Data::Struct(s) => s.fields.iter().cloned().collect::<Vec<_>>(),
2330
_ => panic!("#[derive(StructOfArray)] only supports structs."),
2431
};
2532

2633
let mut derives: Vec<Ident> = vec![];
34+
let mut vec_attrs = Vec::new();
35+
let mut slice_attrs = Vec::new();
36+
let mut ref_attrs = Vec::new();
37+
let mut ptr_attrs = Vec::new();
2738
for attr in input.attrs {
2839
if let Ok(meta) = attr.parse_meta() {
2940
if meta.path().is_ident("soa_derive") {
3041
match meta {
31-
Meta::NameValue(MetaNameValue{lit: Lit::Str(string), ..}) => {
42+
Meta::NameValue(MetaNameValue {
43+
lit: Lit::Str(string),
44+
..
45+
}) => {
3246
for value in string.value().split(',') {
3347
derives.push(Ident::new(value.trim(), Span::call_site()));
3448
}
3549
}
36-
_ => panic!("expected #[soa_derive = \"Traits, To, Derive\"], got #[{}]", quote!(#meta))
50+
_ => panic!(
51+
"expected #[soa_derive = \"Traits, To, Derive\"], got #[{}]",
52+
quote!(#meta)
53+
),
54+
}
55+
} else if meta.path().is_ident("soa_attr") {
56+
match meta.clone() {
57+
Meta::List(MetaList { nested, .. }) => {
58+
let [soa_type, attr]: [NestedMeta; 2] =
59+
nested.into_iter().collect::<Vec<_>>().try_into().unwrap_or_else(|_| panic!("expected #[soa_derive(\"Types, To, Add, Attribute\", \"Attributes\", got #[{}])]", quote!(#meta)));
60+
let attr = match attr {
61+
NestedMeta::Meta(meta) => meta,
62+
NestedMeta::Lit(_) => {
63+
panic!("expected a attribute, got {}", quote!(attr))
64+
}
65+
};
66+
match soa_type {
67+
NestedMeta::Meta(Meta::Path(path)) => match path.get_ident() {
68+
Some(ident) => match ident.to_string().as_str() {
69+
"Vec" => vec_attrs.push(attr),
70+
"Slice" => slice_attrs.push(attr),
71+
"Ref" => ref_attrs.push(attr),
72+
"ptr" => ptr_attrs.push(attr),
73+
_ => panic!("expected a soa type, got {}", ident),
74+
},
75+
None => {
76+
panic!("expected a soa type, got {}", quote!(#path))
77+
}
78+
},
79+
_ => {
80+
panic!("expected a soa type, got {}", quote!(#soa_type))
81+
}
82+
}
83+
}
84+
_ => panic!("expected #[soa_attr(...), got #[{}]]", quote!(#meta)),
3785
}
3886
}
3987
}
@@ -43,7 +91,11 @@ impl Input {
4391
name: input.ident,
4492
derives: derives,
4593
fields: fields,
46-
visibility: input.vis
94+
visibility: input.vis,
95+
vec_attrs,
96+
slice_attrs,
97+
ref_attrs,
98+
ptr_attrs,
4799
}
48100
}
49101

@@ -64,12 +116,14 @@ impl Input {
64116
if self.derives.is_empty() {
65117
TokenStream::new()
66118
} else {
67-
let derives = &self.derives.iter()
68-
.cloned()
69-
.filter(|name| name != "Clone")
70-
.filter(|name| name != "Deserialize")
71-
.filter(|name| name != "Serialize")
72-
.collect::<Vec<_>>();
119+
let derives = &self
120+
.derives
121+
.iter()
122+
.cloned()
123+
.filter(|name| name != "Clone")
124+
.filter(|name| name != "Deserialize")
125+
.filter(|name| name != "Serialize")
126+
.collect::<Vec<_>>();
73127
quote!(
74128
#[derive(
75129
#(#derives,)*

soa-derive-internal/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ mod refs;
1717
mod slice;
1818
mod vec;
1919

20-
#[proc_macro_derive(StructOfArray, attributes(soa_derive))]
20+
#[proc_macro_derive(StructOfArray, attributes(soa_derive, soa_attr))]
2121
pub fn soa_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
2222
let ast = syn::parse(input).unwrap();
2323
let input = input::Input::new(ast);

soa-derive-internal/src/ptr.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub fn derive(input: &Input) -> TokenStream {
77
let name = &input.name;
88
let visibility = &input.visibility;
99
let other_derive = &input.derive_with_exceptions();
10+
let attrs = &input.ptr_attrs;
1011
let vec_name = &input.vec_name();
1112
let ptr_name = &input.ptr_name();
1213
let ptr_mut_name = &input.ptr_mut_name();
@@ -42,6 +43,7 @@ pub fn derive(input: &Input) -> TokenStream {
4243
#[doc = #doc_url]
4344
/// with struct of array layout.
4445
#other_derive
46+
#(#[#attrs])*
4547
#[derive(Copy, Clone)]
4648
#visibility struct #ptr_name {
4749
#(

soa-derive-internal/src/refs.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub fn derive(input: &Input) -> TokenStream {
77
let name = &input.name;
88
let visibility = &input.visibility;
99
let other_derive = &input.derive_with_exceptions();
10+
let attrs = &input.ref_attrs;
1011
let vec_name = &input.vec_name();
1112
let ref_name = &input.ref_name();
1213
let ref_mut_name = &input.ref_mut_name();
@@ -38,6 +39,7 @@ pub fn derive(input: &Input) -> TokenStream {
3839
#[doc = #doc_url]
3940
/// with struct of array layout.
4041
#other_derive
42+
#(#[#attrs])*
4143
#[derive(Copy, Clone)]
4244
#visibility struct #ref_name<'a> {
4345
#(
@@ -50,6 +52,7 @@ pub fn derive(input: &Input) -> TokenStream {
5052
#[doc = #doc_url]
5153
/// with struct of array layout.
5254
#other_derive
55+
#(#[#attrs])*
5356
#visibility struct #ref_mut_name<'a> {
5457
#(
5558
#[doc = #fields_mut_doc]

soa-derive-internal/src/slice.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pub fn derive(input: &Input) -> TokenStream {
99
let other_derive = &input.derive_with_exceptions();
1010
let visibility = &input.visibility;
1111
let slice_name = &input.slice_name();
12+
let attrs = &input.slice_attrs;
1213
let vec_name = &input.vec_name();
1314
let ref_name = &input.ref_name();
1415
let ptr_name = &input.ptr_name();
@@ -48,6 +49,7 @@ pub fn derive(input: &Input) -> TokenStream {
4849
#[allow(dead_code)]
4950
#[derive(Copy, Clone)]
5051
#other_derive
52+
#(#[#attrs])*
5153
#visibility struct #slice_name<'a> {
5254
#(
5355
#[doc = #fields_doc]
@@ -239,6 +241,7 @@ pub fn derive_mut(input: &Input) -> TokenStream {
239241
let slice_name = &input.slice_name();
240242
let slice_mut_name = &input.slice_mut_name();
241243
let vec_name = &input.vec_name();
244+
let attrs = &input.slice_attrs;
242245
let ref_mut_name = &input.ref_mut_name();
243246
let ptr_name = &input.ptr_name();
244247
let ptr_mut_name = &input.ptr_mut_name();
@@ -280,6 +283,7 @@ pub fn derive_mut(input: &Input) -> TokenStream {
280283
/// .
281284
#[allow(dead_code)]
282285
#other_derive
286+
#(#[#attrs])*
283287
#visibility struct #slice_mut_name<'a> {
284288
#(
285289
#[doc = #fields_doc]

soa-derive-internal/src/vec.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pub fn derive(input: &Input) -> TokenStream {
99
let name = &input.name;
1010
let vec_name_str = format!("Vec<{}>", name);
1111
let other_derive = &input.derive();
12+
let attrs = &input.vec_attrs;
1213
let visibility = &input.visibility;
1314
let vec_name = &input.vec_name();
1415
let slice_name = &input.slice_name();
@@ -42,6 +43,7 @@ pub fn derive(input: &Input) -> TokenStream {
4243
/// ` with Struct of Array (SoA) layout
4344
#[allow(dead_code)]
4445
#other_derive
46+
#(#[#attrs])*
4547
#visibility struct #vec_name {
4648
#(
4749
#[doc = #fields_doc]

tests/soa_attr.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
use soa_derive::StructOfArray;
2+
3+
#[derive(Debug, Clone, PartialEq, StructOfArray)]
4+
#[soa_attr(Vec, cfg_attr(test, derive(PartialEq, Debug)))]
5+
pub struct Particle {
6+
pub name: String,
7+
pub mass: f64,
8+
}
9+
10+
impl Particle {
11+
pub fn new(name: String, mass: f64) -> Self {
12+
Particle { name, mass }
13+
}
14+
}
15+
16+
#[test]
17+
fn eq_test() {
18+
let particles0 = ParticleVec {
19+
name: vec![String::from("foo"), String::from("bar")],
20+
mass: vec![1.0, 2.0],
21+
};
22+
let particles1 = ParticleVec {
23+
name: vec![String::from("foo"), String::from("bar")],
24+
mass: vec![1.0, 2.0],
25+
};
26+
assert_eq!(particles0, particles1);
27+
}

0 commit comments

Comments
 (0)