Skip to content

Commit 28079be

Browse files
committed
Parse #[repr(..)] for #[pyclass] enums.
1 parent 8a03778 commit 28079be

File tree

2 files changed

+93
-21
lines changed

2 files changed

+93
-21
lines changed

pyo3-macros-backend/src/pyclass.rs

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,66 @@ struct PyClassEnumVariant<'a> {
381381
/* currently have no more options */
382382
}
383383

384+
struct PyClassEnum<'a> {
385+
ident: &'a syn::Ident,
386+
// The underyling representation of the enum.
387+
// It's used to implement __int__ and __richcmp__.
388+
// This matters when the underyling representation may not fit in `isize`.
389+
#[allow(unused, dead_code)]
390+
repr: syn::Ident,
391+
variants: Vec<PyClassEnumVariant<'a>>,
392+
doc: PythonDoc,
393+
}
394+
395+
impl<'a> PyClassEnum<'a> {
396+
fn new(enum_: &'a syn::ItemEnum) -> syn::Result<Self> {
397+
fn is_numeric_type(t: &syn::Ident) -> bool {
398+
[
399+
"u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128", "usize",
400+
"isize",
401+
]
402+
.iter()
403+
.any(|&s| t == s)
404+
}
405+
struct Reprs(syn::punctuated::Punctuated<syn::Ident, Token![,]>);
406+
impl Parse for Reprs {
407+
fn parse(input: ParseStream) -> Result<Self> {
408+
let inner = Punctuated::parse_terminated(input)?;
409+
Ok(Self(inner))
410+
}
411+
}
412+
let ident = &enum_.ident;
413+
// According to the [reference](https://doc.rust-lang.org/reference/items/enumerations.html),
414+
// "Under the default representation, the specified discriminant is interpreted as an isize
415+
// value", so `isize` should be enough by default.
416+
let mut repr = syn::Ident::new("isize", proc_macro2::Span::call_site());
417+
for attr in &enum_.attrs {
418+
if attr.path.is_ident("repr") {
419+
let reprs: Reprs = attr.parse_args()?;
420+
for r in reprs.0 {
421+
if is_numeric_type(&r) {
422+
repr = r;
423+
break;
424+
}
425+
}
426+
}
427+
}
428+
let doc = utils::get_doc(&enum_.attrs, None);
429+
430+
let variants = enum_
431+
.variants
432+
.iter()
433+
.map(extract_variant_data)
434+
.collect::<syn::Result<_>>()?;
435+
Ok(Self {
436+
ident,
437+
repr,
438+
variants,
439+
doc,
440+
})
441+
}
442+
}
443+
384444
pub fn build_py_enum(
385445
enum_: &syn::ItemEnum,
386446
args: PyClassArgs,
@@ -389,38 +449,32 @@ pub fn build_py_enum(
389449
if enum_.variants.is_empty() {
390450
bail_spanned!(enum_.brace_token.span => "Empty enums can't be #[pyclass].");
391451
}
392-
let variants: Vec<PyClassEnumVariant> = enum_
393-
.variants
394-
.iter()
395-
.map(extract_variant_data)
396-
.collect::<syn::Result<_>>()?;
397-
impl_enum(enum_, args, variants, method_type)
452+
let enum_ = PyClassEnum::new(enum_)?;
453+
impl_enum(enum_, args, method_type)
398454
}
399455

400456
fn impl_enum(
401-
enum_: &syn::ItemEnum,
402-
attrs: PyClassArgs,
403-
variants: Vec<PyClassEnumVariant>,
457+
enum_: PyClassEnum,
458+
args: PyClassArgs,
404459
methods_type: PyClassMethodsType,
405460
) -> syn::Result<TokenStream> {
406-
let enum_name = &enum_.ident;
407-
let doc = utils::get_doc(&enum_.attrs, None);
408-
let enum_cls = impl_enum_class(enum_name, &attrs, variants, doc, methods_type)?;
461+
let enum_cls = impl_enum_class(enum_, &args, methods_type)?;
409462

410463
Ok(quote! {
411464
#enum_cls
412465
})
413466
}
414467

415468
fn impl_enum_class(
416-
cls: &syn::Ident,
417-
attr: &PyClassArgs,
418-
variants: Vec<PyClassEnumVariant>,
419-
doc: PythonDoc,
469+
enum_: PyClassEnum,
470+
args: &PyClassArgs,
420471
methods_type: PyClassMethodsType,
421472
) -> syn::Result<TokenStream> {
422-
let pytypeinfo = impl_pytypeinfo(cls, attr, None);
423-
let pyclass_impls = PyClassImplsBuilder::new(cls, attr, methods_type)
473+
let cls = enum_.ident;
474+
let doc = enum_.doc;
475+
let variants = enum_.variants;
476+
let pytypeinfo = impl_pytypeinfo(cls, args, None);
477+
let pyclass_impls = PyClassImplsBuilder::new(cls, args, methods_type)
424478
.doc(doc)
425479
.impl_all();
426480
let descriptors = unit_variants_as_descriptors(cls, variants.iter().map(|v| v.ident));
@@ -494,9 +548,6 @@ fn extract_variant_data(variant: &syn::Variant) -> syn::Result<PyClassEnumVarian
494548
Fields::Unit => &variant.ident,
495549
_ => bail_spanned!(variant.span() => "Currently only support unit variants."),
496550
};
497-
if let Some(discriminant) = variant.discriminant.as_ref() {
498-
bail_spanned!(discriminant.0.span() => "Currently does not support discriminats.")
499-
};
500551
Ok(PyClassEnumVariant { ident })
501552
}
502553

tests/test_enum.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,24 @@ fn test_default_repr_correct() {
6161
py_assert!(py, var2, "repr(var2) == 'MyEnum.OtherVariant'");
6262
})
6363
}
64+
65+
#[pyclass]
66+
enum CustomDiscriminant {
67+
One = 1,
68+
Two = 2,
69+
}
70+
71+
#[test]
72+
fn test_custom_discriminant() {
73+
Python::with_gil(|py| {
74+
#[allow(non_snake_case)]
75+
let CustomDiscriminant = py.get_type::<CustomDiscriminant>();
76+
let one = Py::new(py, CustomDiscriminant::One).unwrap();
77+
let two = Py::new(py, CustomDiscriminant::Two).unwrap();
78+
py_run!(py, CustomDiscriminant one two, r#"
79+
assert CustomDiscriminant.One == one
80+
assert CustomDiscriminant.Two == two
81+
assert one != two
82+
"#);
83+
})
84+
}

0 commit comments

Comments
 (0)