@@ -9,21 +9,28 @@ use syn::{
99 parse:: { Parse , ParseStream } ,
1010 parse_macro_input,
1111 spanned:: Spanned ,
12- Ident , Item , ItemEnum , Token ,
12+ Expr , Ident , Item , ItemEnum , Token , Variant ,
1313} ;
1414
15- #[ derive( Debug ) ]
16- struct Flag {
15+ struct Flag < ' a > {
1716 name : Ident ,
1817 span : Span ,
19- value : FlagValue ,
18+ value : FlagValue < ' a > ,
2019}
2120
22- #[ derive( Debug ) ]
23- enum FlagValue {
21+ enum FlagValue < ' a > {
2422 Literal ( u128 ) ,
2523 Deferred ,
26- Inferred ,
24+ Inferred ( & ' a mut Variant ) ,
25+ }
26+
27+ impl FlagValue < ' _ > {
28+ fn is_inferred ( & self ) -> bool {
29+ match self {
30+ FlagValue :: Inferred ( _) => true ,
31+ _ => false ,
32+ }
33+ }
2734}
2835
2936struct Parameters {
@@ -54,9 +61,9 @@ pub fn bitflags_internal(
5461 input : proc_macro:: TokenStream ,
5562) -> proc_macro:: TokenStream {
5663 let Parameters { default } = parse_macro_input ! ( attr as Parameters ) ;
57- let ast = parse_macro_input ! ( input as Item ) ;
64+ let mut ast = parse_macro_input ! ( input as Item ) ;
5865 let output = match ast {
59- Item :: Enum ( ref item_enum) => gen_enumflags ( item_enum, default) ,
66+ Item :: Enum ( ref mut item_enum) => gen_enumflags ( item_enum, default) ,
6067 _ => Err ( syn:: Error :: new_spanned (
6168 & ast,
6269 "#[bitflags] requires an enum" ,
@@ -76,7 +83,6 @@ pub fn bitflags_internal(
7683
7784/// Try to evaluate the expression given.
7885fn fold_expr ( expr : & syn:: Expr ) -> Option < u128 > {
79- use syn:: Expr ;
8086 match expr {
8187 Expr :: Lit ( ref expr_lit) => match expr_lit. lit {
8288 syn:: Lit :: Int ( ref lit_int) => lit_int. base10_parse ( ) . ok ( ) ,
@@ -98,8 +104,8 @@ fn fold_expr(expr: &syn::Expr) -> Option<u128> {
98104}
99105
100106fn collect_flags < ' a > (
101- variants : impl Iterator < Item = & ' a syn :: Variant > ,
102- ) -> Result < Vec < Flag > , syn:: Error > {
107+ variants : impl Iterator < Item = & ' a mut Variant > ,
108+ ) -> Result < Vec < Flag < ' a > > , syn:: Error > {
103109 variants
104110 . map ( |variant| {
105111 // MSRV: Would this be cleaner with `matches!`?
@@ -113,25 +119,51 @@ fn collect_flags<'a>(
113119 }
114120 }
115121
122+ let name = variant. ident . clone ( ) ;
123+ let span = variant. span ( ) ;
116124 let value = if let Some ( ref expr) = variant. discriminant {
117125 if let Some ( n) = fold_expr ( & expr. 1 ) {
118126 FlagValue :: Literal ( n)
119127 } else {
120128 FlagValue :: Deferred
121129 }
122130 } else {
123- FlagValue :: Inferred
131+ FlagValue :: Inferred ( variant )
124132 } ;
125133
126- Ok ( Flag {
127- name : variant. ident . clone ( ) ,
128- span : variant. span ( ) ,
129- value,
130- } )
134+ Ok ( Flag { name, span, value } )
131135 } )
132136 . collect ( )
133137}
134138
139+ fn inferred_value ( type_name : & Ident , previous_variants : & [ Ident ] , repr : & Ident ) -> Expr {
140+ let tokens = if previous_variants. is_empty ( ) {
141+ quote ! ( 1 )
142+ } else {
143+ quote ! ( :: enumflags2:: _internal:: next_bit(
144+ #( #type_name:: #previous_variants as u128 ) |*
145+ ) as #repr)
146+ } ;
147+
148+ syn:: parse2 ( tokens) . expect ( "couldn't parse inferred value" )
149+ }
150+
151+ fn infer_values < ' a > ( flags : & mut [ Flag ] , type_name : & Ident , repr : & Ident ) {
152+ let mut previous_variants: Vec < Ident > = flags. iter ( )
153+ . filter ( |flag| !flag. value . is_inferred ( ) )
154+ . map ( |flag| flag. name . clone ( ) ) . collect ( ) ;
155+
156+ for flag in flags {
157+ match flag. value {
158+ FlagValue :: Inferred ( ref mut variant) => {
159+ variant. discriminant = Some ( ( <Token ! [ =] >:: default ( ) , inferred_value ( type_name, & previous_variants, repr) ) ) ;
160+ previous_variants. push ( flag. name . clone ( ) ) ;
161+ }
162+ _ => { }
163+ }
164+ }
165+ }
166+
135167/// Given a list of attributes, find the `repr`, if any, and return the integer
136168/// type specified.
137169fn extract_repr ( attrs : & [ syn:: Attribute ] ) -> Result < Option < Ident > , syn:: Error > {
@@ -210,10 +242,7 @@ fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenSt
210242 Ok ( None )
211243 }
212244 }
213- Inferred => Err ( syn:: Error :: new (
214- flag. span ,
215- "Please add an explicit discriminant" ,
216- ) ) ,
245+ Inferred ( _) => Ok ( None ) ,
217246 Deferred => {
218247 let variant_name = & flag. name ;
219248 // MSRV: Use an unnamed constant (`const _: ...`).
@@ -235,33 +264,34 @@ fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenSt
235264 }
236265}
237266
238- fn gen_enumflags ( ast : & ItemEnum , default : Vec < Ident > ) -> Result < TokenStream , syn:: Error > {
267+ fn gen_enumflags ( ast : & mut ItemEnum , default : Vec < Ident > ) -> Result < TokenStream , syn:: Error > {
239268 let ident = & ast. ident ;
240269
241270 let span = Span :: call_site ( ) ;
242- // for quote! interpolation
243- let variant_names = ast. variants . iter ( ) . map ( |v| & v. ident ) . collect :: < Vec < _ > > ( ) ;
244- let repeated_name = vec ! [ & ident; ast. variants. len( ) ] ;
245271
246- let ty = extract_repr ( & ast. attrs ) ?
272+ let repr = extract_repr ( & ast. attrs ) ?
247273 . ok_or_else ( || syn:: Error :: new_spanned ( & ident,
248274 "repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield." ) ) ?;
249- let bits = type_bits ( & ty ) ?;
275+ let bits = type_bits ( & repr ) ?;
250276
251- let variants = collect_flags ( ast. variants . iter ( ) ) ?;
277+ let mut variants = collect_flags ( ast. variants . iter_mut ( ) ) ?;
252278 let deferred = variants
253279 . iter ( )
254280 . flat_map ( |variant| check_flag ( ident, variant, bits) . transpose ( ) )
255281 . collect :: < Result < Vec < _ > , _ > > ( ) ?;
256282
283+ infer_values ( & mut variants, ident, & repr) ;
284+
257285 if ( bits as usize ) < variants. len ( ) {
258286 return Err ( syn:: Error :: new_spanned (
259- & ty ,
287+ & repr ,
260288 format ! ( "Not enough bits for {} flags" , variants. len( ) ) ,
261289 ) ) ;
262290 }
263291
264292 let std_path = quote_spanned ! ( span => :: enumflags2:: _internal:: core) ;
293+ let variant_names = ast. variants . iter ( ) . map ( |v| & v. ident ) . collect :: < Vec < _ > > ( ) ;
294+ let repeated_name = vec ! [ & ident; ast. variants. len( ) ] ;
265295
266296 Ok ( quote_spanned ! {
267297 span =>
@@ -303,15 +333,15 @@ fn gen_enumflags(ast: &ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn
303333 }
304334
305335 impl :: enumflags2:: _internal:: RawBitFlags for #ident {
306- type Numeric = #ty ;
336+ type Numeric = #repr ;
307337
308338 const EMPTY : Self :: Numeric = 0 ;
309339
310340 const DEFAULT : Self :: Numeric =
311- 0 #( | ( #repeated_name:: #default as #ty ) ) * ;
341+ 0 #( | ( #repeated_name:: #default as #repr ) ) * ;
312342
313343 const ALL_BITS : Self :: Numeric =
314- 0 #( | ( #repeated_name:: #variant_names as #ty ) ) * ;
344+ 0 #( | ( #repeated_name:: #variant_names as #repr ) ) * ;
315345
316346 const FLAG_LIST : & ' static [ Self ] =
317347 & [ #( #repeated_name:: #variant_names) , * ] ;
@@ -320,7 +350,7 @@ fn gen_enumflags(ast: &ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn
320350 concat!( "BitFlags<" , stringify!( #ident) , ">" ) ;
321351
322352 fn bits( self ) -> Self :: Numeric {
323- self as #ty
353+ self as #repr
324354 }
325355 }
326356
0 commit comments