@@ -381,6 +381,66 @@ struct PyClassEnumVariant<'a> {
381
381
/* currently have no more options */
382
382
}
383
383
384
+ struct PyClassEnum < ' a > {
385
+ ident : & ' a syn:: Ident ,
386
+ // The underyling representation of the enum.
387
+ // It's used to implement __int__ and __richcmp__.
388
+ // This matters when the underyling representation may not fit in `isize`.
389
+ #[ allow( unused, dead_code) ]
390
+ repr : syn:: Ident ,
391
+ variants : Vec < PyClassEnumVariant < ' a > > ,
392
+ doc : PythonDoc ,
393
+ }
394
+
395
+ impl < ' a > PyClassEnum < ' a > {
396
+ fn new ( enum_ : & ' a syn:: ItemEnum ) -> syn:: Result < Self > {
397
+ fn is_numeric_type ( t : & syn:: Ident ) -> bool {
398
+ [
399
+ "u8" , "i8" , "u16" , "i16" , "u32" , "i32" , "u64" , "i64" , "u128" , "i128" , "usize" ,
400
+ "isize" ,
401
+ ]
402
+ . iter ( )
403
+ . any ( |& s| t == s)
404
+ }
405
+ struct Reprs ( syn:: punctuated:: Punctuated < syn:: Ident , Token ! [ , ] > ) ;
406
+ impl Parse for Reprs {
407
+ fn parse ( input : ParseStream ) -> Result < Self > {
408
+ let inner = Punctuated :: parse_terminated ( input) ?;
409
+ Ok ( Self ( inner) )
410
+ }
411
+ }
412
+ let ident = & enum_. ident ;
413
+ // According to the [reference](https://doc.rust-lang.org/reference/items/enumerations.html),
414
+ // "Under the default representation, the specified discriminant is interpreted as an isize
415
+ // value", so `isize` should be enough by default.
416
+ let mut repr = syn:: Ident :: new ( "isize" , proc_macro2:: Span :: call_site ( ) ) ;
417
+ for attr in & enum_. attrs {
418
+ if attr. path . is_ident ( "repr" ) {
419
+ let reprs: Reprs = attr. parse_args ( ) ?;
420
+ for r in reprs. 0 {
421
+ if is_numeric_type ( & r) {
422
+ repr = r;
423
+ break ;
424
+ }
425
+ }
426
+ }
427
+ }
428
+ let doc = utils:: get_doc ( & enum_. attrs , None ) ;
429
+
430
+ let variants = enum_
431
+ . variants
432
+ . iter ( )
433
+ . map ( extract_variant_data)
434
+ . collect :: < syn:: Result < _ > > ( ) ?;
435
+ Ok ( Self {
436
+ ident,
437
+ repr,
438
+ variants,
439
+ doc,
440
+ } )
441
+ }
442
+ }
443
+
384
444
pub fn build_py_enum (
385
445
enum_ : & syn:: ItemEnum ,
386
446
args : PyClassArgs ,
@@ -389,38 +449,32 @@ pub fn build_py_enum(
389
449
if enum_. variants . is_empty ( ) {
390
450
bail_spanned ! ( enum_. brace_token. span => "Empty enums can't be #[pyclass]." ) ;
391
451
}
392
- let variants: Vec < PyClassEnumVariant > = enum_
393
- . variants
394
- . iter ( )
395
- . map ( extract_variant_data)
396
- . collect :: < syn:: Result < _ > > ( ) ?;
397
- impl_enum ( enum_, args, variants, method_type)
452
+ let enum_ = PyClassEnum :: new ( enum_) ?;
453
+ impl_enum ( enum_, args, method_type)
398
454
}
399
455
400
456
fn impl_enum (
401
- enum_ : & syn:: ItemEnum ,
402
- attrs : PyClassArgs ,
403
- variants : Vec < PyClassEnumVariant > ,
457
+ enum_ : PyClassEnum ,
458
+ args : PyClassArgs ,
404
459
methods_type : PyClassMethodsType ,
405
460
) -> syn:: Result < TokenStream > {
406
- let enum_name = & enum_. ident ;
407
- let doc = utils:: get_doc ( & enum_. attrs , None ) ;
408
- let enum_cls = impl_enum_class ( enum_name, & attrs, variants, doc, methods_type) ?;
461
+ let enum_cls = impl_enum_class ( enum_, & args, methods_type) ?;
409
462
410
463
Ok ( quote ! {
411
464
#enum_cls
412
465
} )
413
466
}
414
467
415
468
fn impl_enum_class (
416
- cls : & syn:: Ident ,
417
- attr : & PyClassArgs ,
418
- variants : Vec < PyClassEnumVariant > ,
419
- doc : PythonDoc ,
469
+ enum_ : PyClassEnum ,
470
+ args : & PyClassArgs ,
420
471
methods_type : PyClassMethodsType ,
421
472
) -> syn:: Result < TokenStream > {
422
- let pytypeinfo = impl_pytypeinfo ( cls, attr, None ) ;
423
- let pyclass_impls = PyClassImplsBuilder :: new ( cls, attr, methods_type)
473
+ let cls = enum_. ident ;
474
+ let doc = enum_. doc ;
475
+ let variants = enum_. variants ;
476
+ let pytypeinfo = impl_pytypeinfo ( cls, args, None ) ;
477
+ let pyclass_impls = PyClassImplsBuilder :: new ( cls, args, methods_type)
424
478
. doc ( doc)
425
479
. impl_all ( ) ;
426
480
let descriptors = unit_variants_as_descriptors ( cls, variants. iter ( ) . map ( |v| v. ident ) ) ;
@@ -494,9 +548,6 @@ fn extract_variant_data(variant: &syn::Variant) -> syn::Result<PyClassEnumVarian
494
548
Fields :: Unit => & variant. ident ,
495
549
_ => bail_spanned ! ( variant. span( ) => "Currently only support unit variants." ) ,
496
550
} ;
497
- if let Some ( discriminant) = variant. discriminant . as_ref ( ) {
498
- bail_spanned ! ( discriminant. 0 . span( ) => "Currently does not support discriminats." )
499
- } ;
500
551
Ok ( PyClassEnumVariant { ident } )
501
552
}
502
553
0 commit comments