Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions mypyc/codegen/emitclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,19 @@ def emit_line() -> None:
if emitter.capi_version < (3, 12):
fields["tp_dictoffset"] = base_size
fields["tp_weaklistoffset"] = weak_offset
elif cl.supports_weakref:
# __weakref__ lives right after the struct
# TODO: It should get a member in the struct instead of doing this nonsense.
emitter.emit_lines(
f"PyMemberDef {members_name}[] = {{",
f'{{"__weakref__", T_OBJECT_EX, {base_size}, 0, NULL}},',
"{0}",
"};",
)
if emitter.capi_version < (3, 12):
# versions >= 3.12 set Py_TPFLAGS_MANAGED_WEAKREF flag instead
# https://docs.python.org/3.12/extending/newtypes.html#weak-reference-support
fields["tp_weaklistoffset"] = base_size
else:
fields["tp_basicsize"] = base_size

Expand Down Expand Up @@ -376,6 +389,9 @@ def emit_line() -> None:
fields["tp_call"] = "PyVectorcall_Call"
if has_managed_dict(cl, emitter):
flags.append("Py_TPFLAGS_MANAGED_DICT")
if cl.supports_weakref and emitter.capi_version >= (3, 12):
flags.append("Py_TPFLAGS_MANAGED_WEAKREF")

fields["tp_flags"] = " | ".join(flags)

fields["tp_doc"] = f"PyDoc_STR({native_class_doc_initializer(cl)})"
Expand Down Expand Up @@ -892,6 +908,13 @@ def generate_dealloc_for_class(
emitter.emit_line("static void")
emitter.emit_line(f"{dealloc_func_name}({cl.struct_name(emitter.names)} *self)")
emitter.emit_line("{")
if cl.supports_weakref:
if emitter.capi_version < (3, 12):
emitter.emit_line("if (self->weakreflist != NULL) {")
emitter.emit_line("PyObject_ClearWeakRefs((PyObject *) self);")
emitter.emit_line("}")
else:
emitter.emit_line("PyObject_ClearWeakRefs((PyObject *) self);")
if has_tp_finalize:
emitter.emit_line("PyObject *type, *value, *traceback;")
emitter.emit_line("PyErr_Fetch(&type, &value, &traceback);")
Expand Down
4 changes: 4 additions & 0 deletions mypyc/ir/class_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def __init__(
self.inherits_python = False
# Do instances of this class have __dict__?
self.has_dict = False
# Do instances of this class have __weakref__?
self.supports_weakref = False
# Do we allow interpreted subclasses? Derived from a mypyc_attr.
self.allow_interpreted_subclasses = False
# Does this class need getseters to be generated for its attributes? (getseters are also
Expand Down Expand Up @@ -381,6 +383,7 @@ def serialize(self) -> JsonDict:
"is_final_class": self.is_final_class,
"inherits_python": self.inherits_python,
"has_dict": self.has_dict,
"supports_weakref": self.supports_weakref,
"allow_interpreted_subclasses": self.allow_interpreted_subclasses,
"needs_getseters": self.needs_getseters,
"_serializable": self._serializable,
Expand Down Expand Up @@ -440,6 +443,7 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR:
ir.is_final_class = data["is_final_class"]
ir.inherits_python = data["inherits_python"]
ir.has_dict = data["has_dict"]
ir.supports_weakref = data["supports_weakref"]
ir.allow_interpreted_subclasses = data["allow_interpreted_subclasses"]
ir.needs_getseters = data["needs_getseters"]
ir._serializable = data["_serializable"]
Expand Down
3 changes: 3 additions & 0 deletions mypyc/irbuild/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,9 @@ def prepare_class_def(
if attrs.get("serializable") is True:
# Supports copy.copy and pickle (including subclasses)
ir._serializable = True
if attrs.get("supports_weakref") is True:
# Has a tp_weakrefoffset slot allowing the creation of weak references (including subclasses)
ir.supports_weakref = True

free_list_len = attrs.get("free_list_len")
if free_list_len is not None:
Expand Down
2 changes: 2 additions & 0 deletions mypyc/irbuild/vtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def compute_vtable(cls: ClassIR) -> None:

if not cls.is_generated:
cls.has_dict = any(x.inherits_python for x in cls.mro)
# TODO: define more weakref triggers
cls.supports_weakref = cls.supports_weakref or cls.has_dict

for t in cls.mro[1:]:
# Make sure all ancestors are processed first
Expand Down
42 changes: 42 additions & 0 deletions mypyc/test-data/irbuild-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,48 @@ class TestOverload:
def __mypyc_generator_helper__(self, x: Any) -> Any:
return x

[case testMypycAttrSupportsWeakref]
import weakref
from mypy_extensions import mypyc_attr

@mypyc_attr(supports_weakref=True)
class WeakrefClass:
pass

obj = WeakrefClass()
ref = weakref.ref(obj)
assert ref() is obj

[case testMypycAttrSupportsWeakrefInheritance]
import weakref
from mypy_extensions import mypyc_attr

@mypyc_attr(supports_weakref=True)
class WeakrefClass:
pass

class WeakrefInheritor(WeakrefClass):
pass

obj = WeakrefInheritor()
ref = weakref.ref(obj)
assert ref() is obj

[case testMypycAttrSupportsWeakrefSubclass]
import weakref
from mypy_extensions import mypyc_attr

class NativeClass:
pass

@mypyc_attr(supports_weakref=True)
class WeakrefSubclass(NativeClass):
pass

obj = WeakrefSubclass()
ref = weakref.ref(obj)
assert ref() is obj

[case testNativeBufferFastPath]
from typing import Final
from mypy_extensions import u8
Expand Down
Loading