Skip to content

Commit a7b8632

Browse files
committed
Refactor: separate parsing from validation
1 parent 377b2d7 commit a7b8632

File tree

3 files changed

+99
-59
lines changed

3 files changed

+99
-59
lines changed

enumflags_derive/src/lib.rs

Lines changed: 95 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,19 @@ extern crate proc_macro;
44
extern crate quote;
55

66
use syn::{Data, Ident, DeriveInput, DataEnum, spanned::Spanned};
7-
use proc_macro2::TokenStream;
8-
use proc_macro2::Span;
7+
use proc_macro2::{TokenStream, Span};
8+
9+
struct Flag {
10+
name: Ident,
11+
span: Span,
12+
value: FlagValue,
13+
}
14+
15+
enum FlagValue {
16+
Literal(u64),
17+
Deferred,
18+
Inferred,
19+
}
920

1021
#[proc_macro_attribute]
1122
pub fn bitflags_internal(
@@ -35,6 +46,9 @@ pub fn bitflags_internal(
3546
}
3647

3748
/// Try to evaluate the expression given.
49+
///
50+
/// Returns `Err` when the expression is erroneous, and `Ok(None)` when
51+
/// information outside of the syntax, such as a `const`.
3852
fn fold_expr(expr: &syn::Expr) -> Result<Option<u64>, syn::Error> {
3953
/// Recurse, but bubble-up both kinds of errors.
4054
/// (I miss my monad transformers)
@@ -71,6 +85,37 @@ fn fold_expr(expr: &syn::Expr) -> Result<Option<u64>, syn::Error> {
7185
}
7286
}
7387

88+
fn collect_flags<'a>(variants: impl Iterator<Item=&'a syn::Variant>)
89+
-> Result<Vec<Flag>, syn::Error>
90+
{
91+
variants
92+
.map(|variant| {
93+
// MSRV: Would this be cleaner with `matches!`?
94+
match variant.fields {
95+
syn::Fields::Unit => (),
96+
_ => return Err(syn::Error::new_spanned(&variant.fields,
97+
"Bitflag variants cannot contain additional data")),
98+
}
99+
100+
let value = if let Some(ref expr) = variant.discriminant {
101+
if let Some(n) = fold_expr(&expr.1)? {
102+
FlagValue::Literal(n)
103+
} else {
104+
FlagValue::Deferred
105+
}
106+
} else {
107+
FlagValue::Inferred
108+
};
109+
110+
Ok(Flag {
111+
name: variant.ident.clone(),
112+
span: variant.span(),
113+
value,
114+
})
115+
})
116+
.collect()
117+
}
118+
74119
/// Given a list of attributes, find the `repr`, if any, and return the integer
75120
/// type specified.
76121
fn extract_repr(attrs: &[syn::Attribute])
@@ -101,81 +146,76 @@ fn extract_repr(attrs: &[syn::Attribute])
101146
}
102147

103148
/// Returns deferred checks
104-
fn verify_flag_values<'a>(
149+
fn check_flag(
105150
type_name: &Ident,
106-
variants: impl Iterator<Item=&'a syn::Variant>
107-
) -> Result<TokenStream, syn::Error> {
108-
let mut deferred_checks: Vec<TokenStream> = vec![];
109-
for variant in variants {
110-
// I'd use matches! if not for MSRV...
111-
match variant.fields {
112-
syn::Fields::Unit => (),
113-
_ => return Err(syn::Error::new_spanned(&variant.fields,
114-
"Bitflag variants cannot contain additional data")),
151+
flag: &Flag,
152+
) -> Result<Option<TokenStream>, syn::Error> {
153+
use FlagValue::*;
154+
match flag.value {
155+
Literal(n) => {
156+
if !n.is_power_of_two() {
157+
Err(syn::Error::new(flag.span,
158+
"Flags must have exactly one set bit"))
159+
} else {
160+
Ok(None)
161+
}
115162
}
163+
Inferred => {
164+
Err(syn::Error::new(flag.span,
165+
"Please add an explicit discriminant"))
166+
}
167+
Deferred => {
168+
let variant_name = &flag.name;
169+
// MSRV: Use an unnamed constant (`const _: ...`).
170+
let assertion_name = syn::Ident::new(
171+
&format!("__enumflags_assertion_{}_{}",
172+
type_name, flag.name),
173+
Span::call_site()); // call_site because def_site is unstable
116174

117-
let discr = variant.discriminant.as_ref()
118-
.ok_or_else(|| syn::Error::new_spanned(variant,
119-
"Please add an explicit discriminant"))?;
120-
match fold_expr(&discr.1)? {
121-
Some(flag) => {
122-
if !flag.is_power_of_two() {
123-
return Err(syn::Error::new_spanned(&discr.1,
124-
"Flags must have exactly one set bit"));
125-
}
126-
}
127-
None => {
128-
let variant_name = &variant.ident;
129-
// TODO: Remove this madness when Debian ships a new compiler.
130-
let assertion_name = syn::Ident::new(
131-
&format!("__enumflags_assertion_{}_{}",
132-
type_name, variant_name),
133-
Span::call_site()); // call_site because def_site is unstable
134-
135-
deferred_checks.push(quote_spanned!(variant.span() =>
136-
#[doc(hidden)]
137-
const #assertion_name:
138-
<<[(); (
139-
(#type_name::#variant_name as u64).wrapping_sub(1) &
140-
(#type_name::#variant_name as u64) == 0 &&
141-
(#type_name::#variant_name as u64) != 0
142-
) as usize] as enumflags2::_internal::AssertionHelper>
143-
::Status as enumflags2::_internal::ExactlyOneBitSet>::X = ();
144-
));
145-
}
175+
Ok(Some(quote_spanned!(flag.span =>
176+
#[doc(hidden)]
177+
const #assertion_name:
178+
<<[(); (
179+
(#type_name::#variant_name as u64).wrapping_sub(1) &
180+
(#type_name::#variant_name as u64) == 0 &&
181+
(#type_name::#variant_name as u64) != 0
182+
) as usize] as enumflags2::_internal::AssertionHelper>
183+
::Status as enumflags2::_internal::ExactlyOneBitSet>::X
184+
= ();
185+
)))
146186
}
147187
}
148-
149-
Ok(quote!(
150-
#(#deferred_checks)*
151-
))
152188
}
153189

154190
fn gen_enumflags(ident: &Ident, item: &DeriveInput, data: &DataEnum)
155191
-> Result<TokenStream, syn::Error>
156192
{
157193
let span = Span::call_site();
158194
// for quote! interpolation
159-
let variants = data.variants.iter().map(|v| &v.ident);
160-
let variants_len = data.variants.len();
161-
let names = std::iter::repeat(&ident);
195+
let variant_names = data.variants.iter().map(|v| &v.ident);
196+
let variant_count = data.variants.len();
197+
198+
let repeated_name = std::iter::repeat(&ident);
162199

163-
let deferred = verify_flag_values(ident, data.variants.iter())?;
200+
let variants = collect_flags(data.variants.iter())?;
201+
let deferred = variants.iter()
202+
.flat_map(|variant| check_flag(ident, variant).transpose())
203+
.collect::<Result<Vec<_>, _>>()?;
164204

165205
let ty = extract_repr(&item.attrs)?
166206
.ok_or_else(|| syn::Error::new_spanned(&ident,
167207
"repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield."))?;
168208
let std_path = quote_spanned!(span => ::enumflags2::_internal::core);
169-
let all = if variants_len == 0 {
209+
let all = if variant_count == 0 {
170210
quote!(0)
171211
} else {
172-
let names = names.clone();
173-
let variants = variants.clone();
174-
quote!(#(#names::#variants as #ty)|*)
212+
let repeated_name = repeated_name.clone();
213+
let variant_names = variant_names.clone();
214+
quote!(#(#repeated_name::#variant_names as #ty)|*)
175215
};
176216

177217
Ok(quote_spanned! {
178-
span => #deferred
218+
span => #(#deferred)*
179219
impl #std_path::ops::Not for #ident {
180220
type Output = ::enumflags2::BitFlags<#ident>;
181221
fn not(self) -> Self::Output {
@@ -221,7 +261,7 @@ fn gen_enumflags(ident: &Ident, item: &DeriveInput, data: &DataEnum)
221261
}
222262

223263
fn flag_list() -> &'static [Self] {
224-
const VARIANTS: [#ident; #variants_len] = [#(#names :: #variants),*];
264+
const VARIANTS: [#ident; #variant_count] = [#(#repeated_name :: #variant_names),*];
225265
&VARIANTS
226266
}
227267

test_suite/ui/multiple_bits.stderr

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
error: Flags must have exactly one set bit
2-
--> $DIR/multiple_bits.rs:5:20
2+
--> $DIR/multiple_bits.rs:5:5
33
|
44
5 | MultipleBits = 6,
5-
| ^
5+
| ^^^^^^^^^^^^
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
error: Flags must have exactly one set bit
2-
--> $DIR/zero_disciminant.rs:4:12
2+
--> $DIR/zero_disciminant.rs:4:5
33
|
44
4 | Zero = 0,
5-
| ^
5+
| ^^^^

0 commit comments

Comments
 (0)