diff --git a/mypyc/ir/rtypes.py b/mypyc/ir/rtypes.py index c0871bba258c..c683268b4387 100644 --- a/mypyc/ir/rtypes.py +++ b/mypyc/ir/rtypes.py @@ -512,6 +512,11 @@ def __hash__(self) -> int: # Python range object. range_rprimitive: Final = RPrimitive("builtins.range", is_unboxed=False, is_refcounted=True) +# Python weak reference object +weakref_rprimitive: Final = RPrimitive( + "weakref.ReferenceType", is_unboxed=False, is_refcounted=True +) + def is_tagged(rtype: RType) -> bool: return rtype is int_rprimitive or rtype is short_int_rprimitive @@ -632,6 +637,10 @@ def is_sequence_rprimitive(rtype: RType) -> bool: ) +def is_weakref_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]: + return isinstance(rtype, RPrimitive) and rtype.name == "weakref.ReferenceType" + + class TupleNameVisitor(RTypeVisitor[str]): """Produce a tuple name based on the concrete representations of types.""" diff --git a/mypyc/irbuild/mapper.py b/mypyc/irbuild/mapper.py index 815688d90fb6..145a297845b8 100644 --- a/mypyc/irbuild/mapper.py +++ b/mypyc/irbuild/mapper.py @@ -46,6 +46,7 @@ str_rprimitive, tuple_rprimitive, uint8_rprimitive, + weakref_rprimitive, ) @@ -102,6 +103,8 @@ def type_to_rtype(self, typ: Type | None) -> RType: return tuple_rprimitive # Varying-length tuple elif typ.type.fullname == "builtins.range": return range_rprimitive + elif typ.type.fullname == "weakref.ReferenceType": + return weakref_rprimitive elif typ.type in self.type_to_ir: inst = RInstance(self.type_to_ir[typ.type]) # Treat protocols as Union[protocol, object], so that we can do fast diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index 3015640fb3fd..1d863706a946 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -65,6 +65,7 @@ is_int_rprimitive, is_list_rprimitive, is_uint8_rprimitive, + is_weakref_rprimitive, list_rprimitive, set_rprimitive, str_rprimitive, @@ -103,6 +104,7 @@ str_encode_utf8_strict, ) from mypyc.primitives.tuple_ops import isinstance_tuple, new_tuple_set_item_op +from mypyc.primitives.weakref_ops import weakref_deref_op # Specializers are attempted before compiling the arguments to the # function. Specializers can return None to indicate that they failed @@ -140,6 +142,8 @@ def apply_function_specialization( builder: IRBuilder, expr: CallExpr, callee: RefExpr ) -> Value | None: """Invoke the Specializer callback for a function if one has been registered""" + if is_weakref_rprimitive(builder.node_type(callee)) and len(expr.args) == 0: + return builder.call_c(weakref_deref_op, [builder.accept(expr.callee)], expr.line) return _apply_specialization(builder, expr, callee, callee.fullname) diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 1881aa97f308..56eb2fcc79f1 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -928,6 +928,7 @@ PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func, PyOb PyObject *CPy_GetAIter(PyObject *obj); PyObject *CPy_GetANext(PyObject *aiter); +PyObject *CPyWeakref_GetRef(PyObject *ref); void CPy_SetTypeAliasTypeComputeFunction(PyObject *alias, PyObject *compute_value); void CPyTrace_LogEvent(const char *location, const char *line, const char *op, const char *details); diff --git a/mypyc/lib-rt/misc_ops.c b/mypyc/lib-rt/misc_ops.c index 0c9d7812ac6c..e658aba92ee4 100644 --- a/mypyc/lib-rt/misc_ops.c +++ b/mypyc/lib-rt/misc_ops.c @@ -1113,3 +1113,18 @@ void CPy_SetImmortal(PyObject *obj) { } #endif + + +PyObject *CPyWeakref_GetRef(PyObject *ref) +{ + PyObject *obj = NULL; + int success = PyWeakref_GetRef(ref, &obj); + if (success == -1) { + return NULL; + } else if (obj == NULL) { + Py_INCREF(Py_None); + return Py_None; + } else { + return obj; + } +} diff --git a/mypyc/primitives/weakref_ops.py b/mypyc/primitives/weakref_ops.py index a7ac035b22a4..593dbca43093 100644 --- a/mypyc/primitives/weakref_ops.py +++ b/mypyc/primitives/weakref_ops.py @@ -1,13 +1,13 @@ from mypyc.ir.ops import ERR_MAGIC -from mypyc.ir.rtypes import object_rprimitive, pointer_rprimitive -from mypyc.primitives.registry import function_op +from mypyc.ir.rtypes import object_rprimitive, pointer_rprimitive, weakref_rprimitive +from mypyc.primitives.registry import custom_op, function_op # Weakref operations new_ref_op = function_op( name="weakref.ReferenceType", arg_types=[object_rprimitive], - return_type=object_rprimitive, + return_type=weakref_rprimitive, c_function_name="PyWeakref_NewRef", extra_int_constants=[(0, pointer_rprimitive)], error_kind=ERR_MAGIC, @@ -16,7 +16,15 @@ new_ref__with_callback_op = function_op( name="weakref.ReferenceType", arg_types=[object_rprimitive, object_rprimitive], - return_type=object_rprimitive, + return_type=weakref_rprimitive, c_function_name="PyWeakref_NewRef", error_kind=ERR_MAGIC, ) + +# TODO: generate specialized versions of this that return the properr rtype +weakref_deref_op = custom_op( + arg_types=[weakref_rprimitive], + return_type=object_rprimitive, + c_function_name="CPyWeakref_GetRef", + error_kind=ERR_MAGIC, +) diff --git a/mypyc/test-data/irbuild-weakref.test b/mypyc/test-data/irbuild-weakref.test index 58ac6417d297..fe380d1489a2 100644 --- a/mypyc/test-data/irbuild-weakref.test +++ b/mypyc/test-data/irbuild-weakref.test @@ -2,50 +2,70 @@ import weakref from typing import Any, Callable def f(x: object) -> object: - return weakref.ref(x) + ref = weakref.ref(x) + return ref() [out] def f(x): - x, r0 :: object + x :: object + r0, ref :: weakref.ReferenceType + r1 :: object L0: r0 = PyWeakref_NewRef(x, 0) - return r0 + ref = r0 + r1 = CPyWeakref_GetRef(ref) + return r1 [case testWeakrefRefCallback] import weakref from typing import Any, Callable def f(x: object, cb: Callable[[object], Any]) -> object: - return weakref.ref(x, cb) + ref = weakref.ref(x, cb) + return ref() [out] def f(x, cb): - x, cb, r0 :: object + x, cb :: object + r0, ref :: weakref.ReferenceType + r1 :: object L0: r0 = PyWeakref_NewRef(x, cb) - return r0 + ref = r0 + r1 = CPyWeakref_GetRef(ref) + return r1 [case testFromWeakrefRef] from typing import Any, Callable from weakref import ref def f(x: object) -> object: - return ref(x) + r = ref(x) + return r() [out] def f(x): - x, r0 :: object + x :: object + r0, r :: weakref.ReferenceType + r1 :: object L0: r0 = PyWeakref_NewRef(x, 0) - return r0 + r = r0 + r1 = CPyWeakref_GetRef(r) + return r1 [case testFromWeakrefRefCallback] from typing import Any, Callable from weakref import ref def f(x: object, cb: Callable[[object], Any]) -> object: - return ref(x, cb) + r = ref(x, cb) + return r() [out] def f(x, cb): - x, cb, r0 :: object + x, cb :: object + r0, r :: weakref.ReferenceType + r1 :: object L0: r0 = PyWeakref_NewRef(x, cb) - return r0 + r = r0 + r1 = CPyWeakref_GetRef(r) + return r1 diff --git a/test-data/unit/lib-stub/weakref.pyi b/test-data/unit/lib-stub/weakref.pyi index 34e01f4d48f1..81ee8de8cfdc 100644 --- a/test-data/unit/lib-stub/weakref.pyi +++ b/test-data/unit/lib-stub/weakref.pyi @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import Any, Generic, TypeVar +from typing import Any, Generic, Optional, TypeVar from typing_extensions import Self _T = TypeVar("_T") @@ -7,5 +7,6 @@ _T = TypeVar("_T") class ReferenceType(Generic[_T]): # "weakref" __callback__: Callable[[Self], Any] def __new__(cls, o: _T, callback: Callable[[Self], Any] | None = ..., /) -> Self: ... + def __call__(self) -> Optional[_T]: ... ref = ReferenceType