@@ -4,8 +4,19 @@ extern crate proc_macro;
44extern crate quote;
55
66use syn:: { Data , Ident , DeriveInput , DataEnum , spanned:: Spanned } ;
7- use proc_macro2:: TokenStream ;
8- use proc_macro2:: Span ;
7+ use proc_macro2:: { TokenStream , Span } ;
8+
9+ struct Flag {
10+ name : Ident ,
11+ span : Span ,
12+ value : FlagValue ,
13+ }
14+
15+ enum FlagValue {
16+ Literal ( u64 ) ,
17+ Deferred ,
18+ Inferred ,
19+ }
920
1021#[ proc_macro_attribute]
1122pub fn bitflags_internal (
@@ -35,6 +46,9 @@ pub fn bitflags_internal(
3546}
3647
3748/// Try to evaluate the expression given.
49+ ///
50+ /// Returns `Err` when the expression is erroneous, and `Ok(None)` when
51+ /// information outside of the syntax, such as a `const`.
3852fn fold_expr ( expr : & syn:: Expr ) -> Result < Option < u64 > , syn:: Error > {
3953 /// Recurse, but bubble-up both kinds of errors.
4054 /// (I miss my monad transformers)
@@ -71,6 +85,37 @@ fn fold_expr(expr: &syn::Expr) -> Result<Option<u64>, syn::Error> {
7185 }
7286}
7387
88+ fn collect_flags < ' a > ( variants : impl Iterator < Item =& ' a syn:: Variant > )
89+ -> Result < Vec < Flag > , syn:: Error >
90+ {
91+ variants
92+ . map ( |variant| {
93+ // MSRV: Would this be cleaner with `matches!`?
94+ match variant. fields {
95+ syn:: Fields :: Unit => ( ) ,
96+ _ => return Err ( syn:: Error :: new_spanned ( & variant. fields ,
97+ "Bitflag variants cannot contain additional data" ) ) ,
98+ }
99+
100+ let value = if let Some ( ref expr) = variant. discriminant {
101+ if let Some ( n) = fold_expr ( & expr. 1 ) ? {
102+ FlagValue :: Literal ( n)
103+ } else {
104+ FlagValue :: Deferred
105+ }
106+ } else {
107+ FlagValue :: Inferred
108+ } ;
109+
110+ Ok ( Flag {
111+ name : variant. ident . clone ( ) ,
112+ span : variant. span ( ) ,
113+ value,
114+ } )
115+ } )
116+ . collect ( )
117+ }
118+
74119/// Given a list of attributes, find the `repr`, if any, and return the integer
75120/// type specified.
76121fn extract_repr ( attrs : & [ syn:: Attribute ] )
@@ -101,81 +146,76 @@ fn extract_repr(attrs: &[syn::Attribute])
101146}
102147
103148/// Returns deferred checks
104- fn verify_flag_values < ' a > (
149+ fn check_flag (
105150 type_name : & Ident ,
106- variants : impl Iterator < Item =& ' a syn:: Variant >
107- ) -> Result < TokenStream , syn:: Error > {
108- let mut deferred_checks: Vec < TokenStream > = vec ! [ ] ;
109- for variant in variants {
110- // I'd use matches! if not for MSRV...
111- match variant. fields {
112- syn:: Fields :: Unit => ( ) ,
113- _ => return Err ( syn:: Error :: new_spanned ( & variant. fields ,
114- "Bitflag variants cannot contain additional data" ) ) ,
151+ flag : & Flag ,
152+ ) -> Result < Option < TokenStream > , syn:: Error > {
153+ use FlagValue :: * ;
154+ match flag. value {
155+ Literal ( n) => {
156+ if !n. is_power_of_two ( ) {
157+ Err ( syn:: Error :: new ( flag. span ,
158+ "Flags must have exactly one set bit" ) )
159+ } else {
160+ Ok ( None )
161+ }
115162 }
163+ Inferred => {
164+ Err ( syn:: Error :: new ( flag. span ,
165+ "Please add an explicit discriminant" ) )
166+ }
167+ Deferred => {
168+ let variant_name = & flag. name ;
169+ // MSRV: Use an unnamed constant (`const _: ...`).
170+ let assertion_name = syn:: Ident :: new (
171+ & format ! ( "__enumflags_assertion_{}_{}" ,
172+ type_name, flag. name) ,
173+ Span :: call_site ( ) ) ; // call_site because def_site is unstable
116174
117- let discr = variant. discriminant . as_ref ( )
118- . ok_or_else ( || syn:: Error :: new_spanned ( variant,
119- "Please add an explicit discriminant" ) ) ?;
120- match fold_expr ( & discr. 1 ) ? {
121- Some ( flag) => {
122- if !flag. is_power_of_two ( ) {
123- return Err ( syn:: Error :: new_spanned ( & discr. 1 ,
124- "Flags must have exactly one set bit" ) ) ;
125- }
126- }
127- None => {
128- let variant_name = & variant. ident ;
129- // TODO: Remove this madness when Debian ships a new compiler.
130- let assertion_name = syn:: Ident :: new (
131- & format ! ( "__enumflags_assertion_{}_{}" ,
132- type_name, variant_name) ,
133- Span :: call_site ( ) ) ; // call_site because def_site is unstable
134-
135- deferred_checks. push ( quote_spanned ! ( variant. span( ) =>
136- #[ doc( hidden) ]
137- const #assertion_name:
138- <<[ ( ) ; (
139- ( #type_name:: #variant_name as u64 ) . wrapping_sub( 1 ) &
140- ( #type_name:: #variant_name as u64 ) == 0 &&
141- ( #type_name:: #variant_name as u64 ) != 0
142- ) as usize ] as enumflags2:: _internal:: AssertionHelper >
143- :: Status as enumflags2:: _internal:: ExactlyOneBitSet >:: X = ( ) ;
144- ) ) ;
145- }
175+ Ok ( Some ( quote_spanned ! ( flag. span =>
176+ #[ doc( hidden) ]
177+ const #assertion_name:
178+ <<[ ( ) ; (
179+ ( #type_name:: #variant_name as u64 ) . wrapping_sub( 1 ) &
180+ ( #type_name:: #variant_name as u64 ) == 0 &&
181+ ( #type_name:: #variant_name as u64 ) != 0
182+ ) as usize ] as enumflags2:: _internal:: AssertionHelper >
183+ :: Status as enumflags2:: _internal:: ExactlyOneBitSet >:: X
184+ = ( ) ;
185+ ) ) )
146186 }
147187 }
148-
149- Ok ( quote ! (
150- #( #deferred_checks) *
151- ) )
152188}
153189
154190fn gen_enumflags ( ident : & Ident , item : & DeriveInput , data : & DataEnum )
155191 -> Result < TokenStream , syn:: Error >
156192{
157193 let span = Span :: call_site ( ) ;
158194 // for quote! interpolation
159- let variants = data. variants . iter ( ) . map ( |v| & v. ident ) ;
160- let variants_len = data. variants . len ( ) ;
161- let names = std:: iter:: repeat ( & ident) ;
195+ let variant_names = data. variants . iter ( ) . map ( |v| & v. ident ) ;
196+ let variant_count = data. variants . len ( ) ;
197+
198+ let repeated_name = std:: iter:: repeat ( & ident) ;
162199
163- let deferred = verify_flag_values ( ident, data. variants . iter ( ) ) ?;
200+ let variants = collect_flags ( data. variants . iter ( ) ) ?;
201+ let deferred = variants. iter ( )
202+ . flat_map ( |variant| check_flag ( ident, variant) . transpose ( ) )
203+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
164204
165205 let ty = extract_repr ( & item. attrs ) ?
166206 . ok_or_else ( || syn:: Error :: new_spanned ( & ident,
167207 "repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield." ) ) ?;
168208 let std_path = quote_spanned ! ( span => :: enumflags2:: _internal:: core) ;
169- let all = if variants_len == 0 {
209+ let all = if variant_count == 0 {
170210 quote ! ( 0 )
171211 } else {
172- let names = names . clone ( ) ;
173- let variants = variants . clone ( ) ;
174- quote ! ( #( #names :: #variants as #ty) |* )
212+ let repeated_name = repeated_name . clone ( ) ;
213+ let variant_names = variant_names . clone ( ) ;
214+ quote ! ( #( #repeated_name :: #variant_names as #ty) |* )
175215 } ;
176216
177217 Ok ( quote_spanned ! {
178- span => #deferred
218+ span => #( # deferred) *
179219 impl #std_path:: ops:: Not for #ident {
180220 type Output = :: enumflags2:: BitFlags <#ident>;
181221 fn not( self ) -> Self :: Output {
@@ -221,7 +261,7 @@ fn gen_enumflags(ident: &Ident, item: &DeriveInput, data: &DataEnum)
221261 }
222262
223263 fn flag_list( ) -> & ' static [ Self ] {
224- const VARIANTS : [ #ident; #variants_len ] = [ #( #names :: #variants ) , * ] ;
264+ const VARIANTS : [ #ident; #variant_count ] = [ #( #repeated_name :: #variant_names ) , * ] ;
225265 & VARIANTS
226266 }
227267
0 commit comments