Skip to content

Commit 902bb40

Browse files
committed
Implement __int__ and __eq__ betweeen enum and int
1 parent 9027dc9 commit 902bb40

File tree

2 files changed

+90
-5
lines changed

2 files changed

+90
-5
lines changed

pyo3-macros-backend/src/pyclass.rs

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -498,30 +498,58 @@ fn impl_enum_class(
498498
}
499499
};
500500

501+
let repr_type = &enum_.repr;
502+
503+
let default_int = {
504+
// This implementation allows us to convert &T to #repr_type without implementing `Copy`
505+
let variants_to_int = variants.iter().map(|variant| {
506+
let variant_name = variant.ident;
507+
quote! { #cls::#variant_name => #cls::#variant_name as #repr_type, }
508+
});
509+
quote! {
510+
#[doc(hidden)]
511+
#[allow(non_snake_case)]
512+
#[pyo3(name = "__int__")]
513+
fn __pyo3__int__(&self) -> #repr_type {
514+
match self {
515+
#(#variants_to_int)*
516+
}
517+
}
518+
}
519+
};
520+
501521
let default_richcmp = {
502522
let variants_eq = variants.iter().map(|variant| {
503523
let variant_name = variant.ident;
504-
quote! {(#cls::#variant_name, #cls::#variant_name) => true.to_object(py),}
524+
quote! {(#cls::#variant_name, #cls::#variant_name) => Ok(true.to_object(py)),}
505525
});
506526
quote! {
507527
#[doc(hidden)]
508528
#[allow(non_snake_case)]
509529
#[pyo3(name = "__richcmp__")]
510-
fn __pyo3__richcmp__(&self, py: ::pyo3::Python, other: &Self, op: ::pyo3::basic::CompareOp) -> PyObject {
530+
fn __pyo3__richcmp__(&self, py: ::pyo3::Python, other: &PyAny, op: ::pyo3::basic::CompareOp) -> PyResult<PyObject> {
511531
match op {
512532
::pyo3::basic::CompareOp::Eq => {
533+
if let Ok(i) = other.extract::<#repr_type>() {
534+
let self_val = self.__pyo3__int__();
535+
return Ok((self_val == i).to_object(py));
536+
}
537+
let other = other.extract::<PyRef<Self>>()?;
538+
let other = &*other;
513539
match (self, other) {
514540
#(#variants_eq)*
515-
_ => false.to_object(py),
541+
_ => Ok(false.to_object(py)),
516542
}
517543
}
518-
_ => py.NotImplemented(),
544+
_ => Ok(py.NotImplemented()),
519545
}
520546
}
521547
}
522548
};
523549

524-
let default_impls = gen_default_slot_impls(cls, vec![default_repr_impl, default_richcmp]);
550+
let default_impls =
551+
gen_default_slot_impls(cls, vec![default_repr_impl, default_richcmp, default_int]);
552+
525553
Ok(quote! {
526554

527555
#pytypeinfo

tests/test_enum.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,60 @@ fn test_custom_discriminant() {
9090
"#);
9191
})
9292
}
93+
94+
#[test]
95+
fn test_enum_to_int() {
96+
Python::with_gil(|py| {
97+
let one = Py::new(py, CustomDiscriminant::One).unwrap();
98+
py_assert!(py, one, "int(one) == 1");
99+
let v = Py::new(py, MyEnum::Variant).unwrap();
100+
let v_value = MyEnum::Variant as isize;
101+
py_run!(py, v v_value, "int(v) == v_value");
102+
})
103+
}
104+
105+
#[test]
106+
fn test_enum_compare_int() {
107+
Python::with_gil(|py| {
108+
let one = Py::new(py, CustomDiscriminant::One).unwrap();
109+
py_run!(
110+
py,
111+
one,
112+
r#"
113+
assert one == 1
114+
assert 1 == one
115+
assert one != 2
116+
"#
117+
)
118+
})
119+
}
120+
121+
#[pyclass]
122+
#[repr(u8)]
123+
enum SmallEnum {
124+
V = 1,
125+
}
126+
127+
#[test]
128+
fn test_enum_compare_int_no_throw_when_overflow() {
129+
Python::with_gil(|py| {
130+
let v = Py::new(py, SmallEnum::V).unwrap();
131+
py_assert!(py, v, "v != 1<<30")
132+
})
133+
}
134+
135+
#[pyclass]
136+
#[repr(usize)]
137+
enum BigEnum {
138+
V = usize::MAX,
139+
}
140+
141+
#[test]
142+
fn test_big_enum_no_overflow() {
143+
Python::with_gil(|py| {
144+
let usize_max = usize::MAX;
145+
let v = Py::new(py, BigEnum::V).unwrap();
146+
py_assert!(py, usize_max v, "v == usize_max");
147+
py_assert!(py, usize_max v, "int(v) == usize_max");
148+
})
149+
}

0 commit comments

Comments
 (0)