Skip to content

Commit 0b33b75

Browse files
committed
Use associated constants instead of functions where possible
1 parent 5445541 commit 0b33b75

File tree

3 files changed

+40
-51
lines changed

3 files changed

+40
-51
lines changed

enumflags_derive/src/lib.rs

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ pub fn bitflags_internal(
2626
) -> proc_macro::TokenStream {
2727
let ast: DeriveInput = syn::parse(input).unwrap();
2828

29-
let impls = match ast.data {
29+
let output = match ast.data {
3030
Data::Enum(ref data) => {
3131
gen_enumflags(&ast.ident, &ast, data)
3232
}
@@ -38,12 +38,13 @@ pub fn bitflags_internal(
3838
}
3939
};
4040

41-
let impls = impls.unwrap_or_else(|err| err.to_compile_error());
42-
let combined = quote! {
43-
#ast
44-
#impls
45-
};
46-
combined.into()
41+
output.unwrap_or_else(|err| {
42+
let error = err.to_compile_error();
43+
quote! {
44+
#ast
45+
#error
46+
}
47+
}).into()
4748
}
4849

4950
/// Try to evaluate the expression given.
@@ -178,10 +179,11 @@ fn gen_enumflags(ident: &Ident, item: &DeriveInput, data: &DataEnum)
178179
{
179180
let span = Span::call_site();
180181
// for quote! interpolation
181-
let variant_names = data.variants.iter().map(|v| &v.ident);
182-
let variant_count = data.variants.len();
183-
184-
let repeated_name = std::iter::repeat(&ident);
182+
let variant_names =
183+
data.variants.iter()
184+
.map(|v| &v.ident)
185+
.collect::<Vec<_>>();
186+
let repeated_name = vec![&ident; data.variants.len()];
185187

186188
let variants = collect_flags(data.variants.iter())?;
187189
let deferred = variants.iter()
@@ -192,16 +194,11 @@ fn gen_enumflags(ident: &Ident, item: &DeriveInput, data: &DataEnum)
192194
.ok_or_else(|| syn::Error::new_spanned(&ident,
193195
"repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield."))?;
194196
let std_path = quote_spanned!(span => ::enumflags2::_internal::core);
195-
let all = if variant_count == 0 {
196-
quote!(0)
197-
} else {
198-
let repeated_name = repeated_name.clone();
199-
let variant_names = variant_names.clone();
200-
quote!(#(#repeated_name::#variant_names as #ty)|*)
201-
};
202197

203198
Ok(quote_spanned! {
204-
span => #(#deferred)*
199+
span =>
200+
#item
201+
#(#deferred)*
205202
impl #std_path::ops::Not for #ident {
206203
type Output = ::enumflags2::BitFlags<#ident>;
207204
fn not(self) -> Self::Output {
@@ -236,23 +233,17 @@ fn gen_enumflags(ident: &Ident, item: &DeriveInput, data: &DataEnum)
236233
impl ::enumflags2::_internal::RawBitFlags for #ident {
237234
type Type = #ty;
238235

239-
fn all_bits() -> Self::Type {
240-
// make sure it's evaluated at compile time
241-
const VALUE: #ty = #all;
242-
VALUE
243-
}
236+
const ALL_BITS: Self::Type =
237+
0 #(| (#repeated_name::#variant_names as #ty))*;
244238

245-
fn bits(self) -> Self::Type {
246-
self as #ty
247-
}
239+
const FLAG_LIST: &'static [Self] =
240+
&[#(#repeated_name::#variant_names),*];
248241

249-
fn flag_list() -> &'static [Self] {
250-
const VARIANTS: [#ident; #variant_count] = [#(#repeated_name :: #variant_names),*];
251-
&VARIANTS
252-
}
242+
const BITFLAGS_TYPE_NAME : &'static str =
243+
concat!("BitFlags<", stringify!(#ident), ">");
253244

254-
fn bitflags_type_name() -> &'static str {
255-
concat!("BitFlags<", stringify!(#ident), ">")
245+
fn bits(self) -> Self::Type {
246+
self as #ty
256247
}
257248
}
258249

src/formatting.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ where
66
T: BitFlag + fmt::Debug,
77
{
88
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
9-
let name = T::bitflags_type_name();
9+
let name = T::BITFLAGS_TYPE_NAME;
1010
let bits = DebugBinaryFormatter(&self.val);
1111
let iter = if !self.is_empty() {
12-
let iter = T::flag_list().iter().filter(|&&flag| self.contains(flag));
12+
let iter = T::FLAG_LIST.iter().filter(|&&flag| self.contains(flag));
1313
Some(FlagFormatter(iter))
1414
} else {
1515
None

src/lib.rs

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -167,21 +167,19 @@ pub mod _internal {
167167
/// The underlying integer type.
168168
type Type: BitFlagNum;
169169

170-
/// Return a value with all flag bits set.
171-
fn all_bits() -> Self::Type;
170+
/// A value with all flag bits set.
171+
const ALL_BITS: Self::Type;
172172

173-
/// Return the bits as a number type.
174-
fn bits(self) -> Self::Type;
173+
/// A slice that contains each variant exactly one.
174+
const FLAG_LIST: &'static [Self];
175175

176-
/// Return a slice that contains each variant exactly one.
177-
fn flag_list() -> &'static [Self];
178-
179-
/// Return the name of the type for debug formatting purposes.
176+
/// The name of the type for debug formatting purposes.
180177
///
181178
/// This is typically `BitFlags<EnumName>`
182-
fn bitflags_type_name() -> &'static str {
183-
"BitFlags"
184-
}
179+
const BITFLAGS_TYPE_NAME: &'static str;
180+
181+
/// Return the bits as a number type.
182+
fn bits(self) -> Self::Type;
185183
}
186184

187185
use ::core::ops::{BitAnd, BitOr, BitXor, Not};
@@ -333,12 +331,12 @@ where
333331
/// assert_eq!(empty.contains(MyFlag::Three), true);
334332
/// ```
335333
pub fn all() -> Self {
336-
unsafe { BitFlags::new(T::all_bits()) }
334+
unsafe { BitFlags::new(T::ALL_BITS) }
337335
}
338336

339337
/// Returns true if all flags are set
340338
pub fn is_all(self) -> bool {
341-
self.val == T::all_bits()
339+
self.val == T::ALL_BITS
342340
}
343341

344342
/// Returns true if no flag is set
@@ -383,7 +381,7 @@ where
383381

384382
/// Truncates flags that are illegal
385383
pub fn from_bits_truncate(bits: T::Type) -> Self {
386-
unsafe { BitFlags::new(bits & T::all_bits()) }
384+
unsafe { BitFlags::new(bits & T::ALL_BITS) }
387385
}
388386

389387
/// Toggles the matching bits
@@ -403,7 +401,7 @@ where
403401

404402
/// Returns an iterator that yields each set flag
405403
pub fn iter(self) -> impl Iterator<Item = T> {
406-
T::flag_list().iter().cloned().filter(move |&flag| self.contains(flag))
404+
T::FLAG_LIST.iter().cloned().filter(move |&flag| self.contains(flag))
407405
}
408406
}
409407

@@ -485,7 +483,7 @@ where
485483
{
486484
type Output = BitFlags<T>;
487485
fn not(self) -> BitFlags<T> {
488-
unsafe { BitFlags::new(!self.bits() & T::all_bits()) }
486+
unsafe { BitFlags::new(!self.bits() & T::ALL_BITS) }
489487
}
490488
}
491489

0 commit comments

Comments
 (0)