Skip to content

Commit af7a820

Browse files
authored
use singleton in simple enum IntoPyObject (#5665)
1 parent 016f795 commit af7a820

File tree

3 files changed

+72
-19
lines changed

3 files changed

+72
-19
lines changed

newsfragments/5665.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`IntoPyObject` for simple enums now uses a sigleton value, allowing identity (python `is`) comparisons

pyo3-macros-backend/src/pyclass.rs

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -950,6 +950,7 @@ fn impl_simple_enum(
950950
methods_type: PyClassMethodsType,
951951
ctx: &Ctx,
952952
) -> Result<TokenStream> {
953+
let Ctx { pyo3_path, .. } = ctx;
953954
let cls = simple_enum.ident;
954955
let ty: syn::Type = syn::parse_quote!(#cls);
955956
let variants = simple_enum.variants;
@@ -1021,7 +1022,7 @@ fn impl_simple_enum(
10211022
default_slots.extend(default_hash_slot);
10221023
default_slots.extend(default_str_slot);
10231024

1024-
let pyclass_impls = PyClassImplsBuilder::new(
1025+
let impl_builder = PyClassImplsBuilder::new(
10251026
cls,
10261027
args,
10271028
methods_type,
@@ -1034,8 +1035,58 @@ fn impl_simple_enum(
10341035
),
10351036
default_slots,
10361037
)
1037-
.doc(doc)
1038-
.impl_all(ctx)?;
1038+
.doc(doc);
1039+
1040+
let enum_into_pyobject_impl = {
1041+
let output_type = if cfg!(feature = "experimental-inspect") {
1042+
quote!(const OUTPUT_TYPE: #pyo3_path::inspect::TypeHint = <#cls as #pyo3_path::PyTypeInfo>::TYPE_HINT;)
1043+
} else {
1044+
TokenStream::new()
1045+
};
1046+
1047+
let num = variants.len();
1048+
let i = (0..num).map(proc_macro2::Literal::usize_unsuffixed);
1049+
let variant_idents = variants.iter().map(|v| v.ident);
1050+
let cfgs = variants.iter().map(|v| &v.cfg_attrs);
1051+
quote! {
1052+
impl<'py> #pyo3_path::conversion::IntoPyObject<'py> for #cls {
1053+
type Target = Self;
1054+
type Output = #pyo3_path::Bound<'py, <Self as #pyo3_path::conversion::IntoPyObject<'py>>::Target>;
1055+
type Error = #pyo3_path::PyErr;
1056+
#output_type
1057+
1058+
fn into_pyobject(self, py: #pyo3_path::Python<'py>) -> ::std::result::Result<
1059+
<Self as #pyo3_path::conversion::IntoPyObject<'py>>::Output,
1060+
<Self as #pyo3_path::conversion::IntoPyObject<'py>>::Error,
1061+
> {
1062+
// TODO(icxolu): switch this to lookup the variants on the type object, once that is immutable
1063+
const LOCK: #pyo3_path::sync::PyOnceLock<#pyo3_path::Py<#cls>> = #pyo3_path::sync::PyOnceLock::<#pyo3_path::Py<#cls>>::new();
1064+
static SINGLETON: [#pyo3_path::sync::PyOnceLock<#pyo3_path::Py<#cls>>; #num] = [LOCK; #num];
1065+
let idx: usize = match self {
1066+
#(
1067+
#(#cfgs)*
1068+
Self::#variant_idents => #i,
1069+
)*
1070+
};
1071+
#[allow(unreachable_code)]
1072+
SINGLETON[idx].get_or_try_init(py, || {
1073+
#pyo3_path::Py::new(py, self)
1074+
}).map(|obj| ::std::clone::Clone::clone(obj.bind(py)))
1075+
}
1076+
}
1077+
}
1078+
};
1079+
1080+
let pyclass_impls: TokenStream = [
1081+
impl_builder.impl_pyclass(ctx),
1082+
enum_into_pyobject_impl,
1083+
impl_builder.impl_pyclassimpl(ctx)?,
1084+
impl_builder.impl_add_to_module(ctx),
1085+
impl_builder.impl_freelist(ctx),
1086+
impl_builder.impl_introspection(ctx),
1087+
]
1088+
.into_iter()
1089+
.collect();
10391090

10401091
Ok(quote! {
10411092
#variant_cfg_check
@@ -1120,11 +1171,7 @@ fn impl_complex_enum(
11201171
}
11211172
}
11221173
});
1123-
let output_type = if cfg!(feature = "experimental-inspect") {
1124-
quote!(const OUTPUT_TYPE: #pyo3_path::inspect::TypeHint = <#cls as #pyo3_path::PyTypeInfo>::TYPE_HINT;)
1125-
} else {
1126-
TokenStream::new()
1127-
};
1174+
let output_type = get_conversion_type_hint(ctx, &format_ident!("OUTPUT_TYPE"), cls);
11281175
quote! {
11291176
impl<'py> #pyo3_path::conversion::IntoPyObject<'py> for #cls {
11301177
type Target = Self;
@@ -2345,11 +2392,7 @@ impl<'a> PyClassImplsBuilder<'a> {
23452392
let attr = self.attr;
23462393
// If #cls is not extended type, we allow Self->PyObject conversion
23472394
if attr.options.extends.is_none() {
2348-
let output_type = if cfg!(feature = "experimental-inspect") {
2349-
quote!(const OUTPUT_TYPE: #pyo3_path::inspect::TypeHint = <#cls as #pyo3_path::PyTypeInfo>::TYPE_HINT;)
2350-
} else {
2351-
TokenStream::new()
2352-
};
2395+
let output_type = get_conversion_type_hint(ctx, &format_ident!("OUTPUT_TYPE"), cls);
23532396
quote! {
23542397
impl<'py> #pyo3_path::conversion::IntoPyObject<'py> for #cls {
23552398
type Target = Self;
@@ -2532,11 +2575,7 @@ impl<'a> PyClassImplsBuilder<'a> {
25322575
let extract_pyclass_with_clone = if let Some(from_py_object) =
25332576
self.attr.options.from_py_object
25342577
{
2535-
let input_type = if cfg!(feature = "experimental-inspect") {
2536-
quote!(const INPUT_TYPE: #pyo3_path::inspect::TypeHint = <#cls as #pyo3_path::PyTypeInfo>::TYPE_HINT;)
2537-
} else {
2538-
TokenStream::new()
2539-
};
2578+
let input_type = get_conversion_type_hint(ctx, &format_ident!("INPUT_TYPE"), cls);
25402579
quote_spanned! { from_py_object.span() =>
25412580
impl<'a, 'py> #pyo3_path::FromPyObject<'a, 'py> for #cls
25422581
where
@@ -2770,6 +2809,18 @@ fn generate_cfg_check(variants: &[PyClassEnumUnitVariant<'_>], cls: &syn::Ident)
27702809
}
27712810
}
27722811

2812+
fn get_conversion_type_hint(
2813+
Ctx { pyo3_path, .. }: &Ctx,
2814+
konst: &Ident,
2815+
cls: &Ident,
2816+
) -> TokenStream {
2817+
if cfg!(feature = "experimental-inspect") {
2818+
quote!(const #konst: #pyo3_path::inspect::TypeHint = <#cls as #pyo3_path::PyTypeInfo>::TYPE_HINT;)
2819+
} else {
2820+
TokenStream::new()
2821+
}
2822+
}
2823+
27732824
const UNIQUE_GET: &str = "`get` may only be specified once";
27742825
const UNIQUE_SET: &str = "`set` may only be specified once";
27752826
const UNIQUE_NAME: &str = "`name` may only be specified once";

tests/test_enum.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ fn test_return_enum() {
5454
let f = wrap_pyfunction!(return_enum)(py).unwrap();
5555
let mynum = py.get_type::<MyEnum>();
5656

57-
py_run!(py, f mynum, "assert f() == mynum.Variant")
57+
py_run!(py, f mynum, "assert f() == mynum.Variant");
58+
py_run!(py, f mynum, "assert f() is mynum.Variant");
5859
});
5960
}
6061

0 commit comments

Comments
 (0)