diff --git a/mypyc/codegen/emitclass.py b/mypyc/codegen/emitclass.py index 122f62a0d582..b14716a0c8f2 100644 --- a/mypyc/codegen/emitclass.py +++ b/mypyc/codegen/emitclass.py @@ -318,6 +318,19 @@ def emit_line() -> None: if emitter.capi_version < (3, 12): fields["tp_dictoffset"] = base_size fields["tp_weaklistoffset"] = weak_offset + elif cl.supports_weakref: + # __weakref__ lives right after the struct + # TODO: It should get a member in the struct instead of doing this nonsense. + emitter.emit_lines( + f"PyMemberDef {members_name}[] = {{", + f'{{"__weakref__", T_OBJECT_EX, {base_size}, 0, NULL}},', + "{0}", + "};", + ) + if emitter.capi_version < (3, 12): + # versions >= 3.12 set Py_TPFLAGS_MANAGED_WEAKREF flag instead + # https://docs.python.org/3.12/extending/newtypes.html#weak-reference-support + fields["tp_weaklistoffset"] = base_size else: fields["tp_basicsize"] = base_size @@ -376,6 +389,9 @@ def emit_line() -> None: fields["tp_call"] = "PyVectorcall_Call" if has_managed_dict(cl, emitter): flags.append("Py_TPFLAGS_MANAGED_DICT") + if cl.supports_weakref and emitter.capi_version >= (3, 12): + flags.append("Py_TPFLAGS_MANAGED_WEAKREF") + fields["tp_flags"] = " | ".join(flags) fields["tp_doc"] = f"PyDoc_STR({native_class_doc_initializer(cl)})" @@ -892,6 +908,13 @@ def generate_dealloc_for_class( emitter.emit_line("static void") emitter.emit_line(f"{dealloc_func_name}({cl.struct_name(emitter.names)} *self)") emitter.emit_line("{") + if cl.supports_weakref: + if emitter.capi_version < (3, 12): + emitter.emit_line("if (self->weakreflist != NULL) {") + emitter.emit_line("PyObject_ClearWeakRefs((PyObject *) self);") + emitter.emit_line("}") + else: + emitter.emit_line("PyObject_ClearWeakRefs((PyObject *) self);") if has_tp_finalize: emitter.emit_line("PyObject *type, *value, *traceback;") emitter.emit_line("PyErr_Fetch(&type, &value, &traceback);") diff --git a/mypyc/ir/class_ir.py b/mypyc/ir/class_ir.py index 0a56aaf5d101..de667b70ce11 100644 --- a/mypyc/ir/class_ir.py +++ b/mypyc/ir/class_ir.py @@ -109,6 +109,8 @@ def __init__( self.inherits_python = False # Do instances of this class have __dict__? self.has_dict = False + # Do instances of this class have __weakref__? + self.supports_weakref = False # Do we allow interpreted subclasses? Derived from a mypyc_attr. self.allow_interpreted_subclasses = False # Does this class need getseters to be generated for its attributes? (getseters are also @@ -381,6 +383,7 @@ def serialize(self) -> JsonDict: "is_final_class": self.is_final_class, "inherits_python": self.inherits_python, "has_dict": self.has_dict, + "supports_weakref": self.supports_weakref, "allow_interpreted_subclasses": self.allow_interpreted_subclasses, "needs_getseters": self.needs_getseters, "_serializable": self._serializable, @@ -440,6 +443,7 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR: ir.is_final_class = data["is_final_class"] ir.inherits_python = data["inherits_python"] ir.has_dict = data["has_dict"] + ir.supports_weakref = data["supports_weakref"] ir.allow_interpreted_subclasses = data["allow_interpreted_subclasses"] ir.needs_getseters = data["needs_getseters"] ir._serializable = data["_serializable"] diff --git a/mypyc/irbuild/prepare.py b/mypyc/irbuild/prepare.py index 20f2aeef8e6e..9a1557031659 100644 --- a/mypyc/irbuild/prepare.py +++ b/mypyc/irbuild/prepare.py @@ -363,6 +363,9 @@ def prepare_class_def( if attrs.get("serializable") is True: # Supports copy.copy and pickle (including subclasses) ir._serializable = True + if attrs.get("supports_weakref") is True: + # Has a tp_weakrefoffset slot allowing the creation of weak references (including subclasses) + ir.supports_weakref = True free_list_len = attrs.get("free_list_len") if free_list_len is not None: diff --git a/mypyc/irbuild/vtable.py b/mypyc/irbuild/vtable.py index 2d4f7261e4ca..766b4086c594 100644 --- a/mypyc/irbuild/vtable.py +++ b/mypyc/irbuild/vtable.py @@ -15,6 +15,8 @@ def compute_vtable(cls: ClassIR) -> None: if not cls.is_generated: cls.has_dict = any(x.inherits_python for x in cls.mro) + # TODO: define more weakref triggers + cls.supports_weakref = cls.supports_weakref or cls.has_dict for t in cls.mro[1:]: # Make sure all ancestors are processed first diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index a98b3a7d3dcf..9e3829fa1fe0 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -1449,6 +1449,48 @@ class TestOverload: def __mypyc_generator_helper__(self, x: Any) -> Any: return x +[case testMypycAttrSupportsWeakref] +import weakref +from mypy_extensions import mypyc_attr + +@mypyc_attr(supports_weakref=True) +class WeakrefClass: + pass + +obj = WeakrefClass() +ref = weakref.ref(obj) +assert ref() is obj + +[case testMypycAttrSupportsWeakrefInheritance] +import weakref +from mypy_extensions import mypyc_attr + +@mypyc_attr(supports_weakref=True) +class WeakrefClass: + pass + +class WeakrefInheritor(WeakrefClass): + pass + +obj = WeakrefInheritor() +ref = weakref.ref(obj) +assert ref() is obj + +[case testMypycAttrSupportsWeakrefSubclass] +import weakref +from mypy_extensions import mypyc_attr + +class NativeClass: + pass + +@mypyc_attr(supports_weakref=True) +class WeakrefSubclass(NativeClass): + pass + +obj = WeakrefSubclass() +ref = weakref.ref(obj) +assert ref() is obj + [case testNativeBufferFastPath] from typing import Final from mypy_extensions import u8