@@ -3,82 +3,110 @@ extern crate proc_macro;
33#[ macro_use]
44extern crate quote;
55
6+ use std:: convert:: TryFrom ;
67use syn:: { Data , Ident , DeriveInput , DataEnum , spanned:: Spanned } ;
7- use proc_macro2:: TokenStream ;
8- use proc_macro2:: Span ;
8+ use proc_macro2:: { TokenStream , Span } ;
99
10- /// Shorthand for a quoted `compile_error!`.
11- macro_rules! error {
12- ( $span: expr => $( $x: tt) * ) => {
13- quote_spanned!( $span => compile_error!( $( $x) * ) ; )
14- } ;
15- ( $( $x: tt) * ) => {
16- quote!( compile_error!( $( $x) * ) ; )
17- } ;
10+ struct Flag {
11+ name : Ident ,
12+ span : Span ,
13+ value : FlagValue ,
1814}
1915
20- #[ proc_macro_derive( BitFlags_internal ) ]
21- pub fn derive_enum_flags ( input : proc_macro:: TokenStream )
22- -> proc_macro:: TokenStream
23- {
16+ enum FlagValue {
17+ Literal ( u128 ) ,
18+ Deferred ,
19+ Inferred ,
20+ }
21+
22+ #[ proc_macro_attribute]
23+ pub fn bitflags_internal (
24+ _attr : proc_macro:: TokenStream ,
25+ input : proc_macro:: TokenStream ,
26+ ) -> proc_macro:: TokenStream {
2427 let ast: DeriveInput = syn:: parse ( input) . unwrap ( ) ;
2528
26- match ast. data {
29+ let output = match ast. data {
2730 Data :: Enum ( ref data) => {
2831 gen_enumflags ( & ast. ident , & ast, data)
29- . unwrap_or_else ( |err| err. to_compile_error ( ) )
30- . into ( )
3132 }
3233 Data :: Struct ( ref data) => {
33- error ! ( data. struct_token. span => "BitFlags can only be derived on enums" ) . into ( )
34+ Err ( syn :: Error :: new_spanned ( data. struct_token , "#[bitflags] requires an enum" ) )
3435 }
3536 Data :: Union ( ref data) => {
36- error ! ( data. union_token. span => "BitFlags can only be derived on enums" ) . into ( )
37+ Err ( syn :: Error :: new_spanned ( data. union_token , "#[bitflags] requires an enum" ) )
3738 }
38- }
39- }
39+ } ;
4040
41- /// Try to evaluate the expression given.
42- fn fold_expr ( expr : & syn:: Expr ) -> Result < Option < u64 > , syn:: Error > {
43- /// Recurse, but bubble-up both kinds of errors.
44- /// (I miss my monad transformers)
45- macro_rules! fold_expr {
46- ( $( $x: tt) * ) => {
47- match fold_expr( $( $x) * ) ? {
48- Some ( x) => x,
49- None => return Ok ( None ) ,
50- }
41+ output. unwrap_or_else ( |err| {
42+ let error = err. to_compile_error ( ) ;
43+ quote ! {
44+ #ast
45+ #error
5146 }
52- }
47+ } ) . into ( )
48+ }
5349
50+ /// Try to evaluate the expression given.
51+ fn fold_expr ( expr : & syn:: Expr ) -> Option < u128 > {
5452 use syn:: Expr ;
5553 match expr {
5654 Expr :: Lit ( ref expr_lit) => {
5755 match expr_lit. lit {
58- syn:: Lit :: Int ( ref lit_int) => {
59- Ok ( Some ( lit_int. base10_parse ( )
60- . map_err ( |_| syn:: Error :: new_spanned ( lit_int,
61- "Integer literal out of range" ) ) ?) )
62- }
63- _ => Ok ( None ) ,
56+ syn:: Lit :: Int ( ref lit_int) => lit_int. base10_parse ( ) . ok ( ) ,
57+ _ => None ,
6458 }
6559 } ,
6660 Expr :: Binary ( ref expr_binary) => {
67- let l = fold_expr ! ( & expr_binary. left) ;
68- let r = fold_expr ! ( & expr_binary. right) ;
61+ let l = fold_expr ( & expr_binary. left ) ? ;
62+ let r = fold_expr ( & expr_binary. right ) ? ;
6963 match & expr_binary. op {
70- syn:: BinOp :: Shl ( _) => Ok ( Some ( l << r) ) ,
71- _ => Ok ( None ) ,
64+ syn:: BinOp :: Shl ( _) => {
65+ u32:: try_from ( r) . ok ( )
66+ . and_then ( |r| l. checked_shl ( r) )
67+ }
68+ _ => None ,
7269 }
7370 }
74- _ => Ok ( None ) ,
71+ _ => None ,
7572 }
7673}
7774
75+ fn collect_flags < ' a > ( variants : impl Iterator < Item =& ' a syn:: Variant > )
76+ -> Result < Vec < Flag > , syn:: Error >
77+ {
78+ variants
79+ . map ( |variant| {
80+ // MSRV: Would this be cleaner with `matches!`?
81+ match variant. fields {
82+ syn:: Fields :: Unit => ( ) ,
83+ _ => return Err ( syn:: Error :: new_spanned ( & variant. fields ,
84+ "Bitflag variants cannot contain additional data" ) ) ,
85+ }
86+
87+ let value = if let Some ( ref expr) = variant. discriminant {
88+ if let Some ( n) = fold_expr ( & expr. 1 ) {
89+ FlagValue :: Literal ( n)
90+ } else {
91+ FlagValue :: Deferred
92+ }
93+ } else {
94+ FlagValue :: Inferred
95+ } ;
96+
97+ Ok ( Flag {
98+ name : variant. ident . clone ( ) ,
99+ span : variant. span ( ) ,
100+ value,
101+ } )
102+ } )
103+ . collect ( )
104+ }
105+
78106/// Given a list of attributes, find the `repr`, if any, and return the integer
79107/// type specified.
80108fn extract_repr ( attrs : & [ syn:: Attribute ] )
81- -> Result < Option < syn :: Ident > , syn:: Error >
109+ -> Result < Option < Ident > , syn:: Error >
82110{
83111 use syn:: { Meta , NestedMeta } ;
84112 attrs. iter ( )
@@ -104,80 +132,97 @@ fn extract_repr(attrs: &[syn::Attribute])
104132 . transpose ( )
105133}
106134
135+ /// Check the repr and return the number of bits available
136+ fn type_bits ( ty : & Ident ) -> Result < u8 , syn:: Error > {
137+ // This would be so much easier if we could just match on an Ident...
138+ if ty == "usize" {
139+ Err ( syn:: Error :: new_spanned ( ty,
140+ "#[repr(usize)] is not supported. Use u32 or u64 instead." ) )
141+ }
142+ else if ty == "i8" || ty == "i16" || ty == "i32"
143+ || ty == "i64" || ty == "i128" || ty == "isize" {
144+ Err ( syn:: Error :: new_spanned ( ty,
145+ "Signed types in a repr are not supported." ) )
146+ }
147+ else if ty == "u8" { Ok ( 8 ) }
148+ else if ty == "u16" { Ok ( 16 ) }
149+ else if ty == "u32" { Ok ( 32 ) }
150+ else if ty == "u64" { Ok ( 64 ) }
151+ else if ty == "u128" { Ok ( 128 ) }
152+ else {
153+ Err ( syn:: Error :: new_spanned ( ty,
154+ "repr must be an integer type for #[bitflags]." ) )
155+ }
156+ }
157+
107158/// Returns deferred checks
108- fn verify_flag_values < ' a > (
159+ fn check_flag (
109160 type_name : & Ident ,
110- variants : impl Iterator < Item =& ' a syn:: Variant >
111- ) -> Result < TokenStream , syn:: Error > {
112- let mut deferred_checks: Vec < TokenStream > = vec ! [ ] ;
113- for variant in variants {
114- // I'd use matches! if not for MSRV...
115- match variant. fields {
116- syn:: Fields :: Unit => ( ) ,
117- _ => return Err ( syn:: Error :: new_spanned ( & variant. fields ,
118- "Bitflag variants cannot contain additional data" ) ) ,
161+ flag : & Flag ,
162+ ) -> Result < Option < TokenStream > , syn:: Error > {
163+ use FlagValue :: * ;
164+ match flag. value {
165+ Literal ( n) => {
166+ if !n. is_power_of_two ( ) {
167+ Err ( syn:: Error :: new ( flag. span ,
168+ "Flags must have exactly one set bit" ) )
169+ } else {
170+ Ok ( None )
171+ }
119172 }
173+ Inferred => {
174+ Err ( syn:: Error :: new ( flag. span ,
175+ "Please add an explicit discriminant" ) )
176+ }
177+ Deferred => {
178+ let variant_name = & flag. name ;
179+ // MSRV: Use an unnamed constant (`const _: ...`).
180+ let assertion_name = syn:: Ident :: new (
181+ & format ! ( "__enumflags_assertion_{}_{}" ,
182+ type_name, flag. name) ,
183+ Span :: call_site ( ) ) ; // call_site because def_site is unstable
120184
121- let discr = variant. discriminant . as_ref ( )
122- . ok_or_else ( || syn:: Error :: new_spanned ( variant,
123- "Please add an explicit discriminant" ) ) ?;
124- match fold_expr ( & discr. 1 ) ? {
125- Some ( flag) => {
126- if !flag. is_power_of_two ( ) {
127- return Err ( syn:: Error :: new_spanned ( & discr. 1 ,
128- "Flags must have exactly one set bit" ) ) ;
129- }
130- }
131- None => {
132- let variant_name = & variant. ident ;
133- // TODO: Remove this madness when Debian ships a new compiler.
134- let assertion_name = syn:: Ident :: new (
135- & format ! ( "__enumflags_assertion_{}_{}" ,
136- type_name, variant_name) ,
137- Span :: call_site ( ) ) ; // call_site because def_site is unstable
138-
139- deferred_checks. push ( quote_spanned ! ( variant. span( ) =>
140- #[ doc( hidden) ]
141- const #assertion_name:
142- <<[ ( ) ; (
143- ( #type_name:: #variant_name as u64 ) . wrapping_sub( 1 ) &
144- ( #type_name:: #variant_name as u64 ) == 0 &&
145- ( #type_name:: #variant_name as u64 ) != 0
146- ) as usize ] as enumflags2:: _internal:: AssertionHelper >
147- :: Status as enumflags2:: _internal:: ExactlyOneBitSet >:: X = ( ) ;
148- ) ) ;
149- }
185+ Ok ( Some ( quote_spanned ! ( flag. span =>
186+ #[ doc( hidden) ]
187+ const #assertion_name:
188+ <<[ ( ) ; (
189+ ( #type_name:: #variant_name as u128 ) . wrapping_sub( 1 ) &
190+ ( #type_name:: #variant_name as u128 ) == 0 &&
191+ ( #type_name:: #variant_name as u128 ) != 0
192+ ) as usize ] as enumflags2:: _internal:: AssertionHelper >
193+ :: Status as enumflags2:: _internal:: ExactlyOneBitSet >:: X
194+ = ( ) ;
195+ ) ) )
150196 }
151197 }
152-
153- Ok ( quote ! (
154- #( #deferred_checks) *
155- ) )
156198}
157199
158200fn gen_enumflags ( ident : & Ident , item : & DeriveInput , data : & DataEnum )
159201 -> Result < TokenStream , syn:: Error >
160202{
161203 let span = Span :: call_site ( ) ;
162204 // for quote! interpolation
163- let variants = data . variants . iter ( ) . map ( |v| & v . ident ) ;
164- let variants_len = data. variants . len ( ) ;
165- let names = std :: iter :: repeat ( & ident) ;
166- let ty = extract_repr ( & item . attrs ) ?
167- . unwrap_or_else ( || Ident :: new ( "usize" , span ) ) ;
205+ let variant_names =
206+ data. variants . iter ( )
207+ . map ( |v| & v . ident )
208+ . collect :: < Vec < _ > > ( ) ;
209+ let repeated_name = vec ! [ & ident ; data . variants . len ( ) ] ;
168210
169- let deferred = verify_flag_values ( ident, data. variants . iter ( ) ) ?;
211+ let variants = collect_flags ( data. variants . iter ( ) ) ?;
212+ let deferred = variants. iter ( )
213+ . flat_map ( |variant| check_flag ( ident, variant) . transpose ( ) )
214+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
215+
216+ let ty = extract_repr ( & item. attrs ) ?
217+ . ok_or_else ( || syn:: Error :: new_spanned ( & ident,
218+ "repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield." ) ) ?;
219+ type_bits ( & ty) ?;
170220 let std_path = quote_spanned ! ( span => :: enumflags2:: _internal:: core) ;
171- let all = if variants_len == 0 {
172- quote ! ( 0 )
173- } else {
174- let names = names. clone ( ) ;
175- let variants = variants. clone ( ) ;
176- quote ! ( #( #names:: #variants as #ty) |* )
177- } ;
178221
179222 Ok ( quote_spanned ! {
180- span => #deferred
223+ span =>
224+ #item
225+ #( #deferred) *
181226 impl #std_path:: ops:: Not for #ident {
182227 type Output = :: enumflags2:: BitFlags <#ident>;
183228 fn not( self ) -> Self :: Output {
@@ -212,26 +257,22 @@ fn gen_enumflags(ident: &Ident, item: &DeriveInput, data: &DataEnum)
212257 impl :: enumflags2:: _internal:: RawBitFlags for #ident {
213258 type Type = #ty;
214259
215- fn all_bits( ) -> Self :: Type {
216- // make sure it's evaluated at compile time
217- const VALUE : #ty = #all;
218- VALUE
219- }
260+ const EMPTY : Self :: Type = 0 ;
220261
221- fn bits( self ) -> Self :: Type {
222- self as #ty
223- }
262+ const ALL_BITS : Self :: Type =
263+ 0 #( | ( #repeated_name:: #variant_names as #ty) ) * ;
224264
225- fn flag_list( ) -> & ' static [ Self ] {
226- const VARIANTS : [ #ident; #variants_len] = [ #( #names :: #variants) , * ] ;
227- & VARIANTS
228- }
265+ const FLAG_LIST : & ' static [ Self ] =
266+ & [ #( #repeated_name:: #variant_names) , * ] ;
267+
268+ const BITFLAGS_TYPE_NAME : & ' static str =
269+ concat!( "BitFlags<" , stringify!( #ident) , ">" ) ;
229270
230- fn bitflags_type_name ( ) -> & ' static str {
231- concat! ( "BitFlags<" , stringify! ( #ident ) , ">" )
271+ fn bits ( self ) -> Self :: Type {
272+ self as #ty
232273 }
233274 }
234275
235- impl :: enumflags2:: RawBitFlags for #ident { }
276+ impl :: enumflags2:: BitFlag for #ident { }
236277 } )
237278}
0 commit comments