Skip to content

Commit 9a70a7d

Browse files
committed
Add a proper type to manage spec versions
Signed-off-by: Christian Poveda <[email protected]>
1 parent 642a7e3 commit 9a70a7d

File tree

1 file changed

+63
-33
lines changed
  • cyclonedx-bom-macros/src

1 file changed

+63
-33
lines changed

cyclonedx-bom-macros/src/lib.rs

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::{error::Error, str::FromStr};
2+
13
use proc_macro::TokenStream;
24
use proc_macro2::Span;
35
use quote::quote;
@@ -9,35 +11,67 @@ use syn::{
911
Expr, Item,
1012
};
1113

12-
struct VersionFilter {
13-
version: String,
14+
#[derive(PartialEq, Eq)]
15+
struct Version {
16+
major: usize,
17+
minor: usize,
1418
}
1519

16-
fn extract_version(attrs: &mut Vec<syn::Attribute>) -> Option<String> {
17-
let mut version = None;
20+
impl FromStr for Version {
21+
type Err = Box<dyn Error>;
1822

19-
attrs.retain(|attr| {
20-
let path = attr.path();
23+
fn from_str(s: &str) -> Result<Self, Self::Err> {
24+
let (major_str, minor_str) = s
25+
.split_once('.')
26+
.ok_or_else(|| Self::Err::from("versions must have a `.`".to_owned()))?;
2127

22-
if path.is_ident("versioned") {
23-
version = Some(
24-
attr.parse_args::<syn::LitStr>()
25-
.expect("expected a string literal with a version number")
26-
.value(),
27-
);
28+
Ok(Self {
29+
major: major_str.parse()?,
30+
minor: minor_str.parse()?,
31+
})
32+
}
33+
}
2834

29-
false
30-
} else {
31-
true
32-
}
33-
});
35+
impl Version {
36+
fn extract_from_attrs(attrs: &mut Vec<syn::Attribute>) -> Option<Self> {
37+
let mut version = None;
38+
39+
attrs.retain(|attr| {
40+
let path = attr.path();
41+
42+
if path.is_ident("versioned") {
43+
version = Some(
44+
attr.parse_args::<syn::LitStr>()
45+
.expect("expected a string literal")
46+
.value()
47+
.parse()
48+
.expect("cannot parse version"),
49+
);
50+
51+
false
52+
} else {
53+
true
54+
}
55+
});
56+
57+
version
58+
}
3459

35-
version
60+
fn as_ident(&self) -> syn::Ident {
61+
syn::Ident::new(
62+
&format!("v{}_{}", self.major, self.minor),
63+
Span::call_site(),
64+
)
65+
}
66+
}
67+
68+
struct VersionFilter {
69+
version: Version,
3670
}
3771

3872
impl VersionFilter {
39-
fn matches(&self, found_version: &str) -> bool {
40-
self.version == found_version
73+
fn matches(&self, found_version: &Version) -> bool {
74+
&self.version == found_version
4175
}
4276

4377
fn filter_fields(
@@ -47,7 +81,7 @@ impl VersionFilter {
4781
fields
4882
.into_pairs()
4983
.filter_map(
50-
|mut pair| match extract_version(&mut pair.value_mut().attrs) {
84+
|mut pair| match Version::extract_from_attrs(&mut pair.value_mut().attrs) {
5185
Some(version) => self.matches(&version).then_some(pair),
5286
None => Some(pair),
5387
},
@@ -71,7 +105,7 @@ impl Fold for VersionFilter {
71105
match stmt {
72106
syn::Stmt::Local(syn::Local { ref mut attrs, .. })
73107
| syn::Stmt::Macro(syn::StmtMacro { ref mut attrs, .. }) => {
74-
if let Some(version) = extract_version(attrs) {
108+
if let Some(version) = Version::extract_from_attrs(attrs) {
75109
if !self.matches(&version) {
76110
stmt = parse_quote!({};);
77111
}
@@ -123,7 +157,7 @@ impl Fold for VersionFilter {
123157
| Expr::Unsafe(syn::ExprUnsafe { ref mut attrs, .. })
124158
| Expr::While(syn::ExprWhile { ref mut attrs, .. })
125159
| Expr::Yield(syn::ExprYield { ref mut attrs, .. }) => {
126-
if let Some(version) = extract_version(attrs) {
160+
if let Some(version) = Version::extract_from_attrs(attrs) {
127161
if !self.matches(&version) {
128162
expr = parse_quote!({});
129163
}
@@ -140,7 +174,7 @@ impl Fold for VersionFilter {
140174
.fields
141175
.into_pairs()
142176
.filter_map(
143-
|mut pair| match extract_version(&mut pair.value_mut().attrs) {
177+
|mut pair| match Version::extract_from_attrs(&mut pair.value_mut().attrs) {
144178
Some(version) => self.matches(&version).then_some(pair),
145179
None => Some(pair),
146180
},
@@ -152,7 +186,7 @@ impl Fold for VersionFilter {
152186

153187
fn fold_expr_match(&mut self, mut expr: syn::ExprMatch) -> syn::ExprMatch {
154188
expr.arms
155-
.retain_mut(|arm| match extract_version(&mut arm.attrs) {
189+
.retain_mut(|arm| match Version::extract_from_attrs(&mut arm.attrs) {
156190
Some(version) => self.matches(&version),
157191
None => true,
158192
});
@@ -177,7 +211,7 @@ impl Fold for VersionFilter {
177211
| Item::Type(syn::ItemType { ref mut attrs, .. })
178212
| Item::Union(syn::ItemUnion { ref mut attrs, .. })
179213
| Item::Use(syn::ItemUse { ref mut attrs, .. }) => {
180-
if let Some(version) = extract_version(attrs) {
214+
if let Some(version) = Version::extract_from_attrs(attrs) {
181215
if !self.matches(&version) {
182216
item = parse_quote!(
183217
use {};
@@ -199,21 +233,17 @@ pub fn versioned(input: TokenStream, annotated_item: TokenStream) -> TokenStream
199233

200234
// This parses the versions passed to the attribute, e.g. the `"1.3"`
201235
// and `"1.4"`in `#[versioned("1.3", "1.4")]
202-
// FIXME: we should do extra validations for the version numbers themselves.
203-
let versions: Vec<String> =
236+
let versions: Vec<Version> =
204237
parse_macro_input!(input with Punctuated::<syn::LitStr, Comma>::parse_terminated)
205238
.into_iter()
206-
.map(|s| s.value())
239+
.map(|s| s.value().parse().expect("cannot parse version"))
207240
.collect();
208241

209242
let mut tokens = proc_macro2::TokenStream::new();
210243

211244
for version in versions {
212245
let mod_vis = &module.vis;
213-
let mod_ident = syn::Ident::new(
214-
&format!("v{}", version.replace('.', "_")),
215-
Span::call_site(),
216-
);
246+
let mod_ident = version.as_ident();
217247

218248
let (_, items) = module.content.clone().unwrap();
219249

0 commit comments

Comments
 (0)