Skip to content

Commit 1e00af5

Browse files
ahlincdavidhewittmejrsIcxolu
authored
Implement #[init] method attribute in #[pymethods] (#4951)
* Implement #[init] method attribute in #[pymethods] This allows to control objects initialization flow in the Rust code in case of inheritance from native Python types. * Apply suggestions from code review Co-authored-by: Bruno Kolenbrander <[email protected]> * review feedback * expose `PySuper` on PyPy, GraalPy and ABI3 * fix graalpy issue --------- Co-authored-by: David Hewitt <[email protected]> Co-authored-by: Bruno Kolenbrander <[email protected]> Co-authored-by: Icxolu <[email protected]>
1 parent 9ef1b50 commit 1e00af5

File tree

15 files changed

+293
-13
lines changed

15 files changed

+293
-13
lines changed

guide/src/class.md

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ There is a [detailed discussion on thread-safety](./class/thread-safety.md) late
145145

146146
By default, it is not possible to create an instance of a custom class from Python code.
147147
To declare a constructor, you need to define a method and annotate it with the `#[new]` attribute.
148-
Only Python's `__new__` method can be specified, `__init__` is not available.
148+
A constructor is accessible as Python's `__new__` method.
149149

150150
```rust
151151
# #![allow(dead_code)]
@@ -192,6 +192,76 @@ If no method marked with `#[new]` is declared, object instances can only be crea
192192

193193
For arguments, see the [`Method arguments`](#method-arguments) section below.
194194

195+
## Initializer
196+
197+
An initializer implements Python's `__init__` method.
198+
199+
It may be required when it's needed to control an object initalization flow on the Rust code.
200+
If possible handling this in `__new__` should be preferred, but in some cases, like subclassing native types, overwriting `__init__` might be necessary.
201+
For example, you define a class that extends `PyDict` and don't want that the original `__init__` method of `PyDict` been called.
202+
In this case by defining an own `__init__` method it's possible to stop initialization flow.
203+
204+
If you declare an `__init__` method you may need to call a super class' `__init__` method explicitly like in Python code.
205+
206+
To declare an initializer, you need to define the `__init__` method.
207+
Like in Python `__init__` must have the `self` receiver as the first argument, followed by the same arguments as the constructor.
208+
It can either return `()` or `PyResult<()>`.
209+
210+
```rust
211+
# #![allow(dead_code)]
212+
# use pyo3::prelude::*;
213+
# #[cfg(not(any(Py_LIMITED_API, GraalPy)))]
214+
use pyo3::types::{PyDict, PyTuple, PySuper};
215+
# #[cfg(not(any(Py_LIMITED_API, GraalPy)))]
216+
use crate::pyo3::PyTypeInfo;
217+
218+
# #[cfg(not(any(Py_LIMITED_API, GraalPy)))]
219+
#[pyclass(extends = PyDict)]
220+
struct MyDict;
221+
222+
# #[cfg(not(any(Py_LIMITED_API, GraalPy)))]
223+
#[pymethods]
224+
impl MyDict {
225+
# #[allow(unused_variables)]
226+
#[new]
227+
#[pyo3(signature = (*args, **kwargs))]
228+
fn __new__(
229+
args: &Bound<'_, PyTuple>,
230+
kwargs: Option<&Bound<'_, PyDict>>,
231+
) -> PyResult<Self> {
232+
Ok(Self)
233+
}
234+
235+
#[pyo3(signature = (*args, **kwargs))]
236+
fn __init__(
237+
slf: &Bound<'_, Self>,
238+
args: &Bound<'_, PyTuple>,
239+
kwargs: Option<&Bound<'_, PyDict>>,
240+
) -> PyResult<()> {
241+
// call the super types __init__
242+
PySuper::new(&PyDict::type_object(slf.py()), slf)?
243+
.call_method("__init__", args.to_owned(), kwargs)?;
244+
// Note: if `MyDict` allows further subclassing, and this is called from such a subclass,
245+
// then this will not that any overrides into account that such a subclass may have defined.
246+
// In such a case it may be preferred to just call `slf.set_item` and let Python figure it out.
247+
slf.as_super().set_item("my_key", "always insert this key")?;
248+
Ok(())
249+
}
250+
}
251+
252+
# #[cfg(not(any(Py_LIMITED_API, GraalPy)))]
253+
# fn main() {
254+
# Python::attach(|py| {
255+
# let typeobj = py.get_type::<MyDict>();
256+
# let obj = typeobj.call((), None).unwrap().cast_into::<MyDict>().unwrap();
257+
# // check __init__ was called
258+
# assert_eq!(obj.get_item("my_key").unwrap().extract::<&str>().unwrap(), "always insert this key");
259+
# });
260+
# }
261+
# #[cfg(any(Py_LIMITED_API, GraalPy))]
262+
# fn main() {}
263+
```
264+
195265
## Adding the class to a module
196266

197267
The next step is to create the Python module and add our class to it:

guide/src/class/protocols.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Because of the double-underscores surrounding their name, these are also known a
66

77
PyO3 makes it possible for every magic method to be implemented in `#[pymethods]` just as they would be done in a regular Python class, with a few notable differences:
88

9-
- `__new__` and `__init__` are replaced by the [`#[new]` attribute](../class.md#constructor).
9+
- `__new__` is replaced by the [`#[new]` attribute](../class.md#constructor).
1010
- `__del__` is not yet supported, but may be in the future.
1111
- `__buffer__` and `__release_buffer__` are currently not supported and instead PyO3 supports [`__getbuffer__` and `__releasebuffer__`](#buffer-objects) methods (these predate [PEP 688](https://peps.python.org/pep-0688/#python-level-buffer-protocol)), again this may change in the future.
1212
- PyO3 adds [`__traverse__` and `__clear__`](#garbage-collector-integration) methods for controlling garbage collection.

newsfragments/4951.added.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Added `__init__` support in `#[pymethods]`.
2+
expose `PySuper` on PyPy, GraalPy and ABI3

pyo3-macros-backend/src/pymethod.rs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ impl PyMethodKind {
8181
match name {
8282
// Protocol implemented through slots
8383
"__new__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__NEW__)),
84+
"__init__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__INIT__)),
8485
"__str__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__STR__)),
8586
"__repr__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__REPR__)),
8687
"__hash__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__HASH__)),
@@ -312,11 +313,11 @@ fn ensure_no_forbidden_protocol_attributes(
312313
method_name: &str,
313314
) -> syn::Result<()> {
314315
if let Some(signature) = &spec.signature.attribute {
315-
// __new__ and __call__ are allowed to have a signature, but nothing else is.
316+
// __new__, __init__ and __call__ are allowed to have a signature, but nothing else is.
316317
if !matches!(
317318
proto_kind,
318319
PyMethodProtoKind::Slot(SlotDef {
319-
calling_convention: SlotCallingConvention::TpNew,
320+
calling_convention: SlotCallingConvention::TpNew | SlotCallingConvention::TpInit,
320321
..
321322
})
322323
) && !matches!(proto_kind, PyMethodProtoKind::Call)
@@ -367,7 +368,6 @@ pub fn impl_py_method_def(
367368

368369
fn impl_call_slot(cls: &syn::Type, spec: &FnSpec<'_>, ctx: &Ctx) -> Result<MethodAndSlotDef> {
369370
let Ctx { pyo3_path, .. } = ctx;
370-
371371
let wrapper_ident = syn::Ident::new("__pymethod___call____", Span::call_site());
372372
let associated_method =
373373
spec.get_wrapper_function(&wrapper_ident, Some(cls), CallingConvention::Varargs, ctx)?;
@@ -904,6 +904,7 @@ impl PropertyType<'_> {
904904
}
905905

906906
pub const __NEW__: SlotDef = SlotDef::new("Py_tp_new", "newfunc");
907+
pub const __INIT__: SlotDef = SlotDef::new("Py_tp_init", "initproc");
907908
pub const __STR__: SlotDef = SlotDef::new("Py_tp_str", "reprfunc");
908909
pub const __REPR__: SlotDef = SlotDef::new("Py_tp_repr", "reprfunc");
909910
pub const __HASH__: SlotDef =
@@ -1164,13 +1165,15 @@ enum SlotCallingConvention {
11641165
FixedArguments(&'static [Ty]),
11651166
/// Arbitrary arguments for `__new__` from the signature (extracted from args / kwargs)
11661167
TpNew,
1168+
TpInit,
11671169
}
11681170

11691171
impl SlotDef {
11701172
const fn new(slot: &'static str, func_ty: &'static str) -> Self {
11711173
// The FFI function pointer type determines the arguments and return type
11721174
let (calling_convention, ret_ty) = match func_ty.as_bytes() {
11731175
b"newfunc" => (SlotCallingConvention::TpNew, Ty::Object),
1176+
b"initproc" => (SlotCallingConvention::TpInit, Ty::Int),
11741177
b"reprfunc" => (SlotCallingConvention::FixedArguments(&[]), Ty::Object),
11751178
b"hashfunc" => (SlotCallingConvention::FixedArguments(&[]), Ty::PyHashT),
11761179
b"richcmpfunc" => (
@@ -1384,6 +1387,35 @@ fn generate_method_body(
13841387
};
13851388
(arg_idents, arg_types, body)
13861389
}
1390+
SlotCallingConvention::TpInit => {
1391+
let arg_idents = vec![
1392+
format_ident!("_slf"),
1393+
format_ident!("_args"),
1394+
format_ident!("_kwargs"),
1395+
];
1396+
let arg_types = vec![
1397+
quote! { *mut #pyo3_path::ffi::PyObject },
1398+
quote! { *mut #pyo3_path::ffi::PyObject },
1399+
quote! { *mut #pyo3_path::ffi::PyObject },
1400+
];
1401+
let (arg_convert, args) = impl_arg_params(spec, Some(cls), false, holders, ctx);
1402+
let call = quote! {{
1403+
let r = #cls::#rust_name(#self_arg #(#args),*);
1404+
#pyo3_path::impl_::wrap::converter(&r)
1405+
.wrap(r)
1406+
.map_err(::core::convert::Into::<#pyo3_path::PyErr>::into)?
1407+
}};
1408+
let output = quote_spanned! { *output_span => result.convert(py) };
1409+
1410+
let body = quote! {
1411+
use #pyo3_path::impl_::callback::IntoPyCallbackOutput;
1412+
#warnings
1413+
#arg_convert
1414+
let result = #call;
1415+
#output
1416+
};
1417+
(arg_idents, arg_types, body)
1418+
}
13871419
SlotCallingConvention::FixedArguments(arguments) => {
13881420
let arg_idents: Vec<_> = std::iter::once(format_ident!("_slf"))
13891421
.chain((0..arguments.len()).map(|i| format_ident!("arg{}", i)))

pytests/src/pyclasses.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use std::{thread, time};
33
use pyo3::exceptions::{PyStopIteration, PyValueError};
44
use pyo3::prelude::*;
55
use pyo3::types::PyType;
6+
#[cfg(not(any(Py_LIMITED_API, GraalPy)))]
7+
use pyo3::types::{PyDict, PyTuple};
68

79
#[pyclass(from_py_object)]
810
#[derive(Clone, Default)]
@@ -104,6 +106,34 @@ impl ClassWithDict {
104106
}
105107
}
106108

109+
#[cfg(not(any(Py_LIMITED_API, GraalPy)))] // Can't subclass native types on abi3 yet
110+
#[pyclass(extends = PyDict)]
111+
struct SubClassWithInit;
112+
113+
#[cfg(not(any(Py_LIMITED_API, GraalPy)))]
114+
#[pymethods]
115+
impl SubClassWithInit {
116+
#[new]
117+
#[pyo3(signature = (*args, **kwargs))]
118+
#[allow(unused_variables)]
119+
fn __new__(args: &Bound<'_, PyTuple>, kwargs: Option<&Bound<'_, PyDict>>) -> Self {
120+
Self
121+
}
122+
123+
#[pyo3(signature = (*args, **kwargs))]
124+
fn __init__(
125+
self_: &Bound<'_, Self>,
126+
args: &Bound<'_, PyTuple>,
127+
kwargs: Option<&Bound<'_, PyDict>>,
128+
) -> PyResult<()> {
129+
self_
130+
.py_super()?
131+
.call_method("__init__", args.to_owned(), kwargs)?;
132+
self_.as_super().set_item("__init__", true)?;
133+
Ok(())
134+
}
135+
}
136+
107137
#[pyclass(skip_from_py_object)]
108138
#[derive(Clone)]
109139
struct ClassWithDecorators {
@@ -173,6 +203,9 @@ pub mod pyclasses {
173203
#[cfg(any(Py_3_10, not(Py_LIMITED_API)))]
174204
#[pymodule_export]
175205
use super::ClassWithDict;
206+
#[cfg(not(any(Py_LIMITED_API, GraalPy)))]
207+
#[pymodule_export]
208+
use super::SubClassWithInit;
176209
#[pymodule_export]
177210
use super::{
178211
map_a_class, AssertingBaseClass, ClassWithDecorators, ClassWithoutConstructor, EmptyClass,

pytests/stubs/pyclasses.pyi

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from _typeshed import Incomplete
2-
from typing import final
2+
from typing import Any, final
33

44
class AssertingBaseClass:
55
def __new__(cls, /, expected_type: type) -> AssertingBaseClass: ...
@@ -53,6 +53,11 @@ class PyClassThreadIter:
5353
def __new__(cls, /) -> PyClassThreadIter: ...
5454
def __next__(self, /) -> int: ...
5555

56+
@final
57+
class SubClassWithInit(dict):
58+
def __init__(self, /, *args, **kwargs) -> Any: ...
59+
def __new__(cls, /, *args, **kwargs) -> SubClassWithInit: ...
60+
5661
def map_a_class(
5762
cls: EmptyClass | tuple[EmptyClass, EmptyClass] | Incomplete,
5863
) -> EmptyClass | tuple[EmptyClass, EmptyClass] | Incomplete: ...

pytests/tests/test_pyclasses.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,16 @@ def test_class_method(benchmark):
164164
def test_static_method(benchmark):
165165
cls = pyclasses.ClassWithDecorators
166166
benchmark(lambda: cls.static_method())
167+
168+
169+
def test_class_init_method():
170+
try:
171+
SubClassWithInit = pyclasses.SubClassWithInit
172+
except AttributeError:
173+
pytest.skip("not defined using abi3")
174+
175+
d = SubClassWithInit()
176+
assert d == {"__init__": True}
177+
178+
d = SubClassWithInit({"a": 1}, b=2)
179+
assert d == {"__init__": True, "a": 1, "b": 2}

src/impl_/trampoline.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,12 @@ trampolines!(
162162
kwargs: *mut ffi::PyObject,
163163
) -> *mut ffi::PyObject;
164164

165+
pub fn initproc(
166+
slf: *mut ffi::PyObject,
167+
args: *mut ffi::PyObject,
168+
kwargs: *mut ffi::PyObject,
169+
) -> c_int;
170+
165171
pub fn objobjproc(slf: *mut ffi::PyObject, arg1: *mut ffi::PyObject) -> c_int;
166172

167173
pub fn reprfunc(slf: *mut ffi::PyObject) -> *mut ffi::PyObject;

src/types/any.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ use crate::instance::Bound;
88
use crate::internal::get_slot::TP_DESCR_GET;
99
use crate::py_result_ext::PyResultExt;
1010
use crate::type_object::{PyTypeCheck, PyTypeInfo};
11-
#[cfg(not(any(PyPy, GraalPy)))]
1211
use crate::types::PySuper;
1312
use crate::types::{PyDict, PyIterator, PyList, PyString, PyType};
1413
use crate::{err, ffi, Borrowed, BoundObject, IntoPyObjectExt, Py, Python};
@@ -939,7 +938,6 @@ pub trait PyAnyMethods<'py>: crate::sealed::Sealed {
939938
/// Return a proxy object that delegates method calls to a parent or sibling class of type.
940939
///
941940
/// This is equivalent to the Python expression `super()`
942-
#[cfg(not(any(PyPy, GraalPy)))]
943941
fn py_super(&self) -> PyResult<Bound<'py, PySuper>>;
944942
}
945943

@@ -1611,7 +1609,6 @@ impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> {
16111609
)
16121610
}
16131611

1614-
#[cfg(not(any(PyPy, GraalPy)))]
16151612
fn py_super(&self) -> PyResult<Bound<'py, PySuper>> {
16161613
PySuper::new(&self.get_type(), self)
16171614
}

src/types/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ pub use self::mutex::{PyMutex, PyMutexGuard};
3939
pub use self::none::PyNone;
4040
pub use self::notimplemented::PyNotImplemented;
4141
pub use self::num::PyInt;
42-
#[cfg(not(any(PyPy, GraalPy)))]
4342
pub use self::pysuper::PySuper;
4443
pub use self::range::{PyRange, PyRangeMethods};
4544
pub use self::sequence::{PySequence, PySequenceMethods};
@@ -278,7 +277,6 @@ mod mutex;
278277
mod none;
279278
mod notimplemented;
280279
mod num;
281-
#[cfg(not(any(PyPy, GraalPy)))]
282280
mod pysuper;
283281
pub(crate) mod range;
284282
pub(crate) mod sequence;

0 commit comments

Comments
 (0)