Skip to content

Commit 6eb7b5a

Browse files
committed
fix: use custom_op
1 parent 56a6ef5 commit 6eb7b5a

File tree

6 files changed

+74
-22
lines changed

6 files changed

+74
-22
lines changed

mypyc/codegen/emit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
is_str_rprimitive,
4949
is_tuple_rprimitive,
5050
is_uint8_rprimitive,
51+
is_weakref_rprimitive,
5152
object_rprimitive,
5253
optional_value_type,
5354
)
@@ -704,7 +705,7 @@ def emit_cast(
704705
self.emit_lines(f" {dest} = {src};", "else {")
705706
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
706707
self.emit_line("}")
707-
elif is_object_rprimitive(typ):
708+
elif is_object_rprimitive(typ) or is_weakref_rprimitive(typ):
708709
if declare_dest:
709710
self.emit_line(f"PyObject *{dest};")
710711
self.emit_arg_check(src, dest, typ, "", optional)

mypyc/ir/rtypes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,10 @@ def is_sequence_rprimitive(rtype: RType) -> bool:
637637
)
638638

639639

640+
def is_weakref_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
641+
return isinstance(rtype, RPrimitive) and rtype.name == "weakref.ReferenceType"
642+
643+
640644
class TupleNameVisitor(RTypeVisitor[str]):
641645
"""Produce a tuple name based on the concrete representations of types."""
642646

mypyc/irbuild/specialize.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
RefExpr,
3232
StrExpr,
3333
TupleExpr,
34+
Var,
3435
)
3536
from mypy.types import AnyType, TypeOfAny
3637
from mypyc.ir.ops import (
@@ -103,6 +104,7 @@
103104
str_encode_utf8_strict,
104105
)
105106
from mypyc.primitives.tuple_ops import isinstance_tuple, new_tuple_set_item_op
107+
from mypyc.primitives.weakref_ops import weakref_deref_op
106108

107109
# Specializers are attempted before compiling the arguments to the
108110
# function. Specializers can return None to indicate that they failed
@@ -140,6 +142,15 @@ def apply_function_specialization(
140142
builder: IRBuilder, expr: CallExpr, callee: RefExpr
141143
) -> Value | None:
142144
"""Invoke the Specializer callback for a function if one has been registered"""
145+
if (
146+
isinstance(callee, NameExpr)
147+
and isinstance(callee.node, Var)
148+
# NOTE: why is this not a weakref rprimitive?
149+
# TODO: fix to weakref rprimitive so _apply_specialization can use the custom_op
150+
and str(callee.node.type).startswith("weakref.ReferenceType")
151+
and len(expr.args) == 0
152+
):
153+
return builder.call_c(weakref_deref_op, [builder.accept(expr.callee)], expr.line)
143154
return _apply_specialization(builder, expr, callee, callee.fullname)
144155

145156

mypyc/lib-rt/misc_ops.c

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,3 +1098,18 @@ void CPy_SetImmortal(PyObject *obj) {
10981098
}
10991099

11001100
#endif
1101+
1102+
1103+
PyObject *CPyWeakref_GetRef(PyObject *ref)
1104+
{
1105+
PyObject *obj = NULL;
1106+
int success = PyWeakref_GetRef(ref, &obj);
1107+
if (success == -1) {
1108+
return NULL;
1109+
} else if (obj == NULL) {
1110+
Py_INCREF(Py_None);
1111+
return Py_None;
1112+
} else {
1113+
return obj;
1114+
}
1115+
}

mypyc/primitives/weakref_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from mypyc.ir.ops import ERR_MAGIC
22
from mypyc.ir.rtypes import object_rprimitive, pointer_rprimitive, weakref_rprimitive
3-
from mypyc.primitives.registry import ERR_NEG_INT, function_op, method_op
3+
from mypyc.primitives.registry import custom_op, function_op
44

55
# Weakref operations
66

@@ -21,10 +21,10 @@
2121
error_kind=ERR_MAGIC,
2222
)
2323

24-
deref_op = method_op(
25-
name="__call__",
24+
# TODO: generate specialized versions of this that return the properr rtype
25+
weakref_deref_op = custom_op(
2626
arg_types=[weakref_rprimitive],
2727
return_type=object_rprimitive,
28-
c_function_name="PyWeakref_GetRef",
29-
error_kind=ERR_NEG_INT,
28+
c_function_name="CPyWeakref_GetRef",
29+
error_kind=ERR_MAGIC,
3030
)

mypyc/test-data/irbuild-weakref.test

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,62 +2,83 @@
22
import weakref
33
from typing import Any, Callable
44
def f(x: object) -> object:
5-
return weakref.ref(x)()
5+
ref = weakref.ref(x)
6+
return ref()
67

78
[out]
89
def f(x):
910
x :: object
1011
r0 :: weakref.ReferenceType
11-
r1 :: object
12+
ref :: object
13+
r1 :: weakref.ReferenceType
14+
r2 :: object
1215
L0:
1316
r0 = PyWeakref_NewRef(x, 0)
14-
r1 = PyWeakref_GetRef(r0)
15-
return r1
17+
ref = r0
18+
r1 = cast(weakref.ReferenceType, ref)
19+
r2 = CPyWeakref_GetRef(r1)
20+
return r2
1621

1722
[case testWeakrefRefCallback]
1823
import weakref
1924
from typing import Any, Callable
2025
def f(x: object, cb: Callable[[object], Any]) -> object:
21-
return weakref.ref(x, cb)()
26+
ref = weakref.ref(x, cb)
27+
return ref()
2228

2329
[out]
2430
def f(x, cb):
2531
x, cb :: object
2632
r0 :: weakref.ReferenceType
27-
r1 :: object
33+
ref :: object
34+
r1 :: weakref.ReferenceType
35+
r2 :: object
2836
L0:
2937
r0 = PyWeakref_NewRef(x, cb)
30-
r1 = PyWeakref_GetRef(r0)
31-
return r1
38+
ref = r0
39+
r1 = cast(weakref.ReferenceType, ref)
40+
r2 = CPyWeakref_GetRef(r1)
41+
return r2
3242

3343
[case testFromWeakrefRef]
3444
from typing import Any, Callable
3545
from weakref import ref
3646
def f(x: object) -> object:
37-
return ref(x)()
47+
r = ref(x)
48+
return r()
3849

3950
[out]
4051
def f(x):
4152
x :: object
4253
r0 :: weakref.ReferenceType
43-
r1 :: object
54+
r :: object
55+
r1 :: weakref.ReferenceType
56+
r2 :: object
4457
L0:
4558
r0 = PyWeakref_NewRef(x, 0)
46-
r1 = PyWeakref_GetRef(r0)
47-
return r1
59+
r = r0
60+
r1 = cast(weakref.ReferenceType, r)
61+
r2 = CPyWeakref_GetRef(r1)
62+
return r2
4863

4964
[case testFromWeakrefRefCallback]
5065
from typing import Any, Callable
5166
from weakref import ref
5267
def f(x: object, cb: Callable[[object], Any]) -> object:
53-
return ref(x, cb)()
68+
r = ref(x, cb)
69+
return r()
5470

5571
[out]
5672
def f(x, cb):
5773
x, cb :: object
5874
r0 :: weakref.ReferenceType
59-
r1 :: object
75+
r :: object
76+
r1 :: weakref.ReferenceType
77+
r2 :: object
6078
L0:
6179
r0 = PyWeakref_NewRef(x, cb)
62-
r1 = PyWeakref_GetRef(r0)
63-
return r1
80+
r = r0
81+
r1 = cast(weakref.ReferenceType, r)
82+
r2 = CPyWeakref_GetRef(r1)
83+
return r2
84+

0 commit comments

Comments
 (0)