Skip to content

Commit 52e6fb6

Browse files
committed
Don't panic when attribute is applied to weird stuff
1 parent 03a1d40 commit 52e6fb6

File tree

5 files changed

+62
-24
lines changed

5 files changed

+62
-24
lines changed

enumflags_derive/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@ edition = "2018"
1212
proc-macro = true
1313

1414
[dependencies]
15-
syn = "^1.0"
15+
syn = { version = "^1.0", features = ["full"] }
1616
quote = "^1.0"
1717
proc-macro2 = "^1.0"

enumflags_derive/src/lib.rs

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ extern crate proc_macro;
44
extern crate quote;
55

66
use std::convert::TryFrom;
7-
use syn::{Data, Ident, DeriveInput, DataEnum, spanned::Spanned};
7+
use syn::{Ident, Item, ItemEnum, spanned::Spanned, parse_macro_input};
88
use proc_macro2::{TokenStream, Span};
99

1010
struct Flag {
@@ -24,18 +24,11 @@ pub fn bitflags_internal(
2424
_attr: proc_macro::TokenStream,
2525
input: proc_macro::TokenStream,
2626
) -> proc_macro::TokenStream {
27-
let ast: DeriveInput = syn::parse(input).unwrap();
28-
29-
let output = match ast.data {
30-
Data::Enum(ref data) => {
31-
gen_enumflags(&ast.ident, &ast, data)
32-
}
33-
Data::Struct(ref data) => {
34-
Err(syn::Error::new_spanned(data.struct_token, "#[bitflags] requires an enum"))
35-
}
36-
Data::Union(ref data) => {
37-
Err(syn::Error::new_spanned(data.union_token, "#[bitflags] requires an enum"))
38-
}
27+
let ast = parse_macro_input!(input as Item);
28+
let output = match ast {
29+
Item::Enum(ref item_enum) => gen_enumflags(item_enum),
30+
_ => Err(syn::Error::new_spanned(&ast,
31+
"#[bitflags] requires an enum")),
3932
};
4033

4134
output.unwrap_or_else(|err| {
@@ -197,31 +190,39 @@ fn check_flag(
197190
}
198191
}
199192

200-
fn gen_enumflags(ident: &Ident, item: &DeriveInput, data: &DataEnum)
193+
fn gen_enumflags(ast: &ItemEnum)
201194
-> Result<TokenStream, syn::Error>
202195
{
196+
let ident = &ast.ident;
197+
203198
let span = Span::call_site();
204199
// for quote! interpolation
205200
let variant_names =
206-
data.variants.iter()
201+
ast.variants.iter()
207202
.map(|v| &v.ident)
208203
.collect::<Vec<_>>();
209-
let repeated_name = vec![&ident; data.variants.len()];
204+
let repeated_name = vec![&ident; ast.variants.len()];
210205

211-
let variants = collect_flags(data.variants.iter())?;
206+
let variants = collect_flags(ast.variants.iter())?;
212207
let deferred = variants.iter()
213208
.flat_map(|variant| check_flag(ident, variant).transpose())
214209
.collect::<Result<Vec<_>, _>>()?;
215210

216-
let ty = extract_repr(&item.attrs)?
211+
let ty = extract_repr(&ast.attrs)?
217212
.ok_or_else(|| syn::Error::new_spanned(&ident,
218213
"repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield."))?;
219-
type_bits(&ty)?;
214+
let bits = type_bits(&ty)?;
215+
216+
if (bits as usize) < variants.len() {
217+
return Err(syn::Error::new_spanned(&ty,
218+
format!("Not enough bits for {} flags", variants.len())));
219+
}
220+
220221
let std_path = quote_spanned!(span => ::enumflags2::_internal::core);
221222

222223
Ok(quote_spanned! {
223224
span =>
224-
#item
225+
#ast
225226
#(#deferred)*
226227
impl #std_path::ops::Not for #ident {
227228
type Output = ::enumflags2::BitFlags<#ident>;

test_suite/ui/not_enum.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,8 @@
22
#[derive(Copy, Clone)]
33
struct Foo(u16);
44

5+
#[enumflags2::bitflags]
6+
#[derive(Copy, Clone)]
7+
const WTF: u8 = 42;
8+
59
fn main() {}

test_suite/ui/not_enum.stderr

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
error: #[bitflags] requires an enum
2-
--> $DIR/not_enum.rs:3:1
2+
--> $DIR/not_enum.rs:2:1
33
|
4-
3 | struct Foo(u16);
5-
| ^^^^^^
4+
2 | / #[derive(Copy, Clone)]
5+
3 | | struct Foo(u16);
6+
| |________________^
7+
8+
error: `derive` may only be applied to structs, enums and unions
9+
--> $DIR/not_enum.rs:6:1
10+
|
11+
6 | #[derive(Copy, Clone)]
12+
| ^^^^^^^^^^^^^^^^^^^^^^
13+
14+
error: #[bitflags] requires an enum
15+
--> $DIR/not_enum.rs:6:1
16+
|
17+
6 | / #[derive(Copy, Clone)]
18+
7 | | const WTF: u8 = 42;
19+
| |___________________^
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
error: #[bitflags] requires an enum
2+
--> $DIR/not_enum.rs:2:1
3+
|
4+
2 | / #[derive(Copy, Clone)]
5+
3 | | struct Foo(u16);
6+
| |________________^
7+
8+
error[E0774]: `derive` may only be applied to structs, enums and unions
9+
--> $DIR/not_enum.rs:6:1
10+
|
11+
6 | #[derive(Copy, Clone)]
12+
| ^^^^^^^^^^^^^^^^^^^^^^
13+
14+
error: #[bitflags] requires an enum
15+
--> $DIR/not_enum.rs:6:1
16+
|
17+
6 | / #[derive(Copy, Clone)]
18+
7 | | const WTF: u8 = 42;
19+
| |___________________^

0 commit comments

Comments
 (0)