Skip to content

Commit 03a1d40

Browse files
committed
Merge branch 'attribute-macro' into master
2 parents 3f6e5ef + 2805cc3 commit 03a1d40

35 files changed

+413
-261
lines changed

enumflags_derive/src/lib.rs

Lines changed: 157 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -3,82 +3,110 @@ extern crate proc_macro;
33
#[macro_use]
44
extern crate quote;
55

6+
use std::convert::TryFrom;
67
use syn::{Data, Ident, DeriveInput, DataEnum, spanned::Spanned};
7-
use proc_macro2::TokenStream;
8-
use proc_macro2::Span;
8+
use proc_macro2::{TokenStream, Span};
99

10-
/// Shorthand for a quoted `compile_error!`.
11-
macro_rules! error {
12-
($span:expr => $($x:tt)*) => {
13-
quote_spanned!($span => compile_error!($($x)*);)
14-
};
15-
($($x:tt)*) => {
16-
quote!(compile_error!($($x)*);)
17-
};
10+
struct Flag {
11+
name: Ident,
12+
span: Span,
13+
value: FlagValue,
1814
}
1915

20-
#[proc_macro_derive(BitFlags_internal)]
21-
pub fn derive_enum_flags(input: proc_macro::TokenStream)
22-
-> proc_macro::TokenStream
23-
{
16+
enum FlagValue {
17+
Literal(u128),
18+
Deferred,
19+
Inferred,
20+
}
21+
22+
#[proc_macro_attribute]
23+
pub fn bitflags_internal(
24+
_attr: proc_macro::TokenStream,
25+
input: proc_macro::TokenStream,
26+
) -> proc_macro::TokenStream {
2427
let ast: DeriveInput = syn::parse(input).unwrap();
2528

26-
match ast.data {
29+
let output = match ast.data {
2730
Data::Enum(ref data) => {
2831
gen_enumflags(&ast.ident, &ast, data)
29-
.unwrap_or_else(|err| err.to_compile_error())
30-
.into()
3132
}
3233
Data::Struct(ref data) => {
33-
error!(data.struct_token.span => "BitFlags can only be derived on enums").into()
34+
Err(syn::Error::new_spanned(data.struct_token, "#[bitflags] requires an enum"))
3435
}
3536
Data::Union(ref data) => {
36-
error!(data.union_token.span => "BitFlags can only be derived on enums").into()
37+
Err(syn::Error::new_spanned(data.union_token, "#[bitflags] requires an enum"))
3738
}
38-
}
39-
}
39+
};
4040

41-
/// Try to evaluate the expression given.
42-
fn fold_expr(expr: &syn::Expr) -> Result<Option<u64>, syn::Error> {
43-
/// Recurse, but bubble-up both kinds of errors.
44-
/// (I miss my monad transformers)
45-
macro_rules! fold_expr {
46-
($($x:tt)*) => {
47-
match fold_expr($($x)*)? {
48-
Some(x) => x,
49-
None => return Ok(None),
50-
}
41+
output.unwrap_or_else(|err| {
42+
let error = err.to_compile_error();
43+
quote! {
44+
#ast
45+
#error
5146
}
52-
}
47+
}).into()
48+
}
5349

50+
/// Try to evaluate the expression given.
51+
fn fold_expr(expr: &syn::Expr) -> Option<u128> {
5452
use syn::Expr;
5553
match expr {
5654
Expr::Lit(ref expr_lit) => {
5755
match expr_lit.lit {
58-
syn::Lit::Int(ref lit_int) => {
59-
Ok(Some(lit_int.base10_parse()
60-
.map_err(|_| syn::Error::new_spanned(lit_int,
61-
"Integer literal out of range"))?))
62-
}
63-
_ => Ok(None),
56+
syn::Lit::Int(ref lit_int) => lit_int.base10_parse().ok(),
57+
_ => None,
6458
}
6559
},
6660
Expr::Binary(ref expr_binary) => {
67-
let l = fold_expr!(&expr_binary.left);
68-
let r = fold_expr!(&expr_binary.right);
61+
let l = fold_expr(&expr_binary.left)?;
62+
let r = fold_expr(&expr_binary.right)?;
6963
match &expr_binary.op {
70-
syn::BinOp::Shl(_) => Ok(Some(l << r)),
71-
_ => Ok(None),
64+
syn::BinOp::Shl(_) => {
65+
u32::try_from(r).ok()
66+
.and_then(|r| l.checked_shl(r))
67+
}
68+
_ => None,
7269
}
7370
}
74-
_ => Ok(None),
71+
_ => None,
7572
}
7673
}
7774

75+
fn collect_flags<'a>(variants: impl Iterator<Item=&'a syn::Variant>)
76+
-> Result<Vec<Flag>, syn::Error>
77+
{
78+
variants
79+
.map(|variant| {
80+
// MSRV: Would this be cleaner with `matches!`?
81+
match variant.fields {
82+
syn::Fields::Unit => (),
83+
_ => return Err(syn::Error::new_spanned(&variant.fields,
84+
"Bitflag variants cannot contain additional data")),
85+
}
86+
87+
let value = if let Some(ref expr) = variant.discriminant {
88+
if let Some(n) = fold_expr(&expr.1) {
89+
FlagValue::Literal(n)
90+
} else {
91+
FlagValue::Deferred
92+
}
93+
} else {
94+
FlagValue::Inferred
95+
};
96+
97+
Ok(Flag {
98+
name: variant.ident.clone(),
99+
span: variant.span(),
100+
value,
101+
})
102+
})
103+
.collect()
104+
}
105+
78106
/// Given a list of attributes, find the `repr`, if any, and return the integer
79107
/// type specified.
80108
fn extract_repr(attrs: &[syn::Attribute])
81-
-> Result<Option<syn::Ident>, syn::Error>
109+
-> Result<Option<Ident>, syn::Error>
82110
{
83111
use syn::{Meta, NestedMeta};
84112
attrs.iter()
@@ -104,80 +132,97 @@ fn extract_repr(attrs: &[syn::Attribute])
104132
.transpose()
105133
}
106134

135+
/// Check the repr and return the number of bits available
136+
fn type_bits(ty: &Ident) -> Result<u8, syn::Error> {
137+
// This would be so much easier if we could just match on an Ident...
138+
if ty == "usize" {
139+
Err(syn::Error::new_spanned(ty,
140+
"#[repr(usize)] is not supported. Use u32 or u64 instead."))
141+
}
142+
else if ty == "i8" || ty == "i16" || ty == "i32"
143+
|| ty == "i64" || ty == "i128" || ty == "isize" {
144+
Err(syn::Error::new_spanned(ty,
145+
"Signed types in a repr are not supported."))
146+
}
147+
else if ty == "u8" { Ok(8) }
148+
else if ty == "u16" { Ok(16) }
149+
else if ty == "u32" { Ok(32) }
150+
else if ty == "u64" { Ok(64) }
151+
else if ty == "u128" { Ok(128) }
152+
else {
153+
Err(syn::Error::new_spanned(ty,
154+
"repr must be an integer type for #[bitflags]."))
155+
}
156+
}
157+
107158
/// Returns deferred checks
108-
fn verify_flag_values<'a>(
159+
fn check_flag(
109160
type_name: &Ident,
110-
variants: impl Iterator<Item=&'a syn::Variant>
111-
) -> Result<TokenStream, syn::Error> {
112-
let mut deferred_checks: Vec<TokenStream> = vec![];
113-
for variant in variants {
114-
// I'd use matches! if not for MSRV...
115-
match variant.fields {
116-
syn::Fields::Unit => (),
117-
_ => return Err(syn::Error::new_spanned(&variant.fields,
118-
"Bitflag variants cannot contain additional data")),
161+
flag: &Flag,
162+
) -> Result<Option<TokenStream>, syn::Error> {
163+
use FlagValue::*;
164+
match flag.value {
165+
Literal(n) => {
166+
if !n.is_power_of_two() {
167+
Err(syn::Error::new(flag.span,
168+
"Flags must have exactly one set bit"))
169+
} else {
170+
Ok(None)
171+
}
119172
}
173+
Inferred => {
174+
Err(syn::Error::new(flag.span,
175+
"Please add an explicit discriminant"))
176+
}
177+
Deferred => {
178+
let variant_name = &flag.name;
179+
// MSRV: Use an unnamed constant (`const _: ...`).
180+
let assertion_name = syn::Ident::new(
181+
&format!("__enumflags_assertion_{}_{}",
182+
type_name, flag.name),
183+
Span::call_site()); // call_site because def_site is unstable
120184

121-
let discr = variant.discriminant.as_ref()
122-
.ok_or_else(|| syn::Error::new_spanned(variant,
123-
"Please add an explicit discriminant"))?;
124-
match fold_expr(&discr.1)? {
125-
Some(flag) => {
126-
if !flag.is_power_of_two() {
127-
return Err(syn::Error::new_spanned(&discr.1,
128-
"Flags must have exactly one set bit"));
129-
}
130-
}
131-
None => {
132-
let variant_name = &variant.ident;
133-
// TODO: Remove this madness when Debian ships a new compiler.
134-
let assertion_name = syn::Ident::new(
135-
&format!("__enumflags_assertion_{}_{}",
136-
type_name, variant_name),
137-
Span::call_site()); // call_site because def_site is unstable
138-
139-
deferred_checks.push(quote_spanned!(variant.span() =>
140-
#[doc(hidden)]
141-
const #assertion_name:
142-
<<[(); (
143-
(#type_name::#variant_name as u64).wrapping_sub(1) &
144-
(#type_name::#variant_name as u64) == 0 &&
145-
(#type_name::#variant_name as u64) != 0
146-
) as usize] as enumflags2::_internal::AssertionHelper>
147-
::Status as enumflags2::_internal::ExactlyOneBitSet>::X = ();
148-
));
149-
}
185+
Ok(Some(quote_spanned!(flag.span =>
186+
#[doc(hidden)]
187+
const #assertion_name:
188+
<<[(); (
189+
(#type_name::#variant_name as u128).wrapping_sub(1) &
190+
(#type_name::#variant_name as u128) == 0 &&
191+
(#type_name::#variant_name as u128) != 0
192+
) as usize] as enumflags2::_internal::AssertionHelper>
193+
::Status as enumflags2::_internal::ExactlyOneBitSet>::X
194+
= ();
195+
)))
150196
}
151197
}
152-
153-
Ok(quote!(
154-
#(#deferred_checks)*
155-
))
156198
}
157199

158200
fn gen_enumflags(ident: &Ident, item: &DeriveInput, data: &DataEnum)
159201
-> Result<TokenStream, syn::Error>
160202
{
161203
let span = Span::call_site();
162204
// for quote! interpolation
163-
let variants = data.variants.iter().map(|v| &v.ident);
164-
let variants_len = data.variants.len();
165-
let names = std::iter::repeat(&ident);
166-
let ty = extract_repr(&item.attrs)?
167-
.unwrap_or_else(|| Ident::new("usize", span));
205+
let variant_names =
206+
data.variants.iter()
207+
.map(|v| &v.ident)
208+
.collect::<Vec<_>>();
209+
let repeated_name = vec![&ident; data.variants.len()];
168210

169-
let deferred = verify_flag_values(ident, data.variants.iter())?;
211+
let variants = collect_flags(data.variants.iter())?;
212+
let deferred = variants.iter()
213+
.flat_map(|variant| check_flag(ident, variant).transpose())
214+
.collect::<Result<Vec<_>, _>>()?;
215+
216+
let ty = extract_repr(&item.attrs)?
217+
.ok_or_else(|| syn::Error::new_spanned(&ident,
218+
"repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield."))?;
219+
type_bits(&ty)?;
170220
let std_path = quote_spanned!(span => ::enumflags2::_internal::core);
171-
let all = if variants_len == 0 {
172-
quote!(0)
173-
} else {
174-
let names = names.clone();
175-
let variants = variants.clone();
176-
quote!(#(#names::#variants as #ty)|*)
177-
};
178221

179222
Ok(quote_spanned! {
180-
span => #deferred
223+
span =>
224+
#item
225+
#(#deferred)*
181226
impl #std_path::ops::Not for #ident {
182227
type Output = ::enumflags2::BitFlags<#ident>;
183228
fn not(self) -> Self::Output {
@@ -212,26 +257,22 @@ fn gen_enumflags(ident: &Ident, item: &DeriveInput, data: &DataEnum)
212257
impl ::enumflags2::_internal::RawBitFlags for #ident {
213258
type Type = #ty;
214259

215-
fn all_bits() -> Self::Type {
216-
// make sure it's evaluated at compile time
217-
const VALUE: #ty = #all;
218-
VALUE
219-
}
260+
const EMPTY: Self::Type = 0;
220261

221-
fn bits(self) -> Self::Type {
222-
self as #ty
223-
}
262+
const ALL_BITS: Self::Type =
263+
0 #(| (#repeated_name::#variant_names as #ty))*;
224264

225-
fn flag_list() -> &'static [Self] {
226-
const VARIANTS: [#ident; #variants_len] = [#(#names :: #variants),*];
227-
&VARIANTS
228-
}
265+
const FLAG_LIST: &'static [Self] =
266+
&[#(#repeated_name::#variant_names),*];
267+
268+
const BITFLAGS_TYPE_NAME : &'static str =
269+
concat!("BitFlags<", stringify!(#ident), ">");
229270

230-
fn bitflags_type_name() -> &'static str {
231-
concat!("BitFlags<", stringify!(#ident), ">")
271+
fn bits(self) -> Self::Type {
272+
self as #ty
232273
}
233274
}
234275

235-
impl ::enumflags2::RawBitFlags for #ident {}
276+
impl ::enumflags2::BitFlag for #ident {}
236277
})
237278
}

0 commit comments

Comments
 (0)