Skip to content

Commit 54ff364

Browse files
authored
[mypyc] Refactor: use new-style primitives for unary and method ops (#18230)
Also fix an issue with redundant coercions for some primitive ops. Add a few tests.
1 parent 725145e commit 54ff364

File tree

9 files changed

+39
-45
lines changed

9 files changed

+39
-45
lines changed

mypyc/irbuild/builder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def add_to_non_ext_dict(
404404
) -> None:
405405
# Add an attribute entry into the class dict of a non-extension class.
406406
key_unicode = self.load_str(key)
407-
self.call_c(dict_set_item_op, [non_ext.dict, key_unicode, val], line)
407+
self.primitive_op(dict_set_item_op, [non_ext.dict, key_unicode, val], line)
408408

409409
def gen_import(self, id: str, line: int) -> None:
410410
self.imports[id] = None
@@ -435,7 +435,7 @@ def get_module(self, module: str, line: int) -> Value:
435435
# Python 3.7 has a nice 'PyImport_GetModule' function that we can't use :(
436436
mod_dict = self.call_c(get_module_dict_op, [], line)
437437
# Get module object from modules dict.
438-
return self.call_c(dict_get_item_op, [mod_dict, self.load_str(module)], line)
438+
return self.primitive_op(dict_get_item_op, [mod_dict, self.load_str(module)], line)
439439

440440
def get_module_attr(self, module: str, attr: str, line: int) -> Value:
441441
"""Look up an attribute of a module without storing it in the local namespace.
@@ -817,7 +817,7 @@ def process_iterator_tuple_assignment(
817817
self.activate_block(ok_block)
818818

819819
for litem in reversed(post_star_vals):
820-
ritem = self.call_c(list_pop_last, [iter_list], line)
820+
ritem = self.primitive_op(list_pop_last, [iter_list], line)
821821
self.assign(litem, ritem, line)
822822

823823
# Assign the starred value
@@ -1302,7 +1302,7 @@ def load_global(self, expr: NameExpr) -> Value:
13021302
def load_global_str(self, name: str, line: int) -> Value:
13031303
_globals = self.load_globals_dict()
13041304
reg = self.load_str(name)
1305-
return self.call_c(dict_get_item_op, [_globals, reg], line)
1305+
return self.primitive_op(dict_get_item_op, [_globals, reg], line)
13061306

13071307
def load_globals_dict(self) -> Value:
13081308
return self.add(LoadStatic(dict_rprimitive, "globals", self.module_name))

mypyc/irbuild/classdef.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def finalize(self, ir: ClassIR) -> None:
256256
)
257257

258258
# Add the non-extension class to the dict
259-
self.builder.call_c(
259+
self.builder.primitive_op(
260260
dict_set_item_op,
261261
[
262262
self.builder.load_globals_dict(),
@@ -466,7 +466,7 @@ def allocate_class(builder: IRBuilder, cdef: ClassDef) -> Value:
466466
builder.add(InitStatic(tp, cdef.name, builder.module_name, NAMESPACE_TYPE))
467467

468468
# Add it to the dict
469-
builder.call_c(
469+
builder.primitive_op(
470470
dict_set_item_op, [builder.load_globals_dict(), builder.load_str(cdef.name), tp], cdef.line
471471
)
472472

@@ -493,7 +493,7 @@ def make_generic_base_class(
493493
else:
494494
arg = builder.new_tuple(args, line)
495495

496-
base = builder.call_c(py_get_item_op, [gent, arg], line)
496+
base = builder.primitive_op(py_get_item_op, [gent, arg], line)
497497
return base
498498

499499

@@ -661,7 +661,7 @@ def add_non_ext_class_attr_ann(
661661
typ = builder.add(LoadAddress(type_object_op.type, type_object_op.src, stmt.line))
662662

663663
key = builder.load_str(lvalue.name)
664-
builder.call_c(dict_set_item_op, [non_ext.anns, key, typ], stmt.line)
664+
builder.primitive_op(dict_set_item_op, [non_ext.anns, key, typ], stmt.line)
665665

666666

667667
def add_non_ext_class_attr(

mypyc/irbuild/expression.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
Integer,
6161
LoadAddress,
6262
LoadLiteral,
63+
PrimitiveDescription,
6364
RaiseStandardError,
6465
Register,
6566
TupleGet,
@@ -99,7 +100,7 @@
99100
from mypyc.primitives.generic_ops import iter_op
100101
from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op
101102
from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op
102-
from mypyc.primitives.registry import CFunctionDescription, builtin_names
103+
from mypyc.primitives.registry import builtin_names
103104
from mypyc.primitives.set_ops import set_add_op, set_in_op, set_update_op
104105
from mypyc.primitives.str_ops import str_slice_op
105106
from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op
@@ -182,7 +183,7 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value:
182183
# AST doesn't include a Var node for the module. We
183184
# instead load the module separately on each access.
184185
mod_dict = builder.call_c(get_module_dict_op, [], expr.line)
185-
obj = builder.call_c(
186+
obj = builder.primitive_op(
186187
dict_get_item_op, [mod_dict, builder.load_str(expr.node.fullname)], expr.line
187188
)
188189
return obj
@@ -979,8 +980,8 @@ def _visit_display(
979980
builder: IRBuilder,
980981
items: list[Expression],
981982
constructor_op: Callable[[list[Value], int], Value],
982-
append_op: CFunctionDescription,
983-
extend_op: CFunctionDescription,
983+
append_op: PrimitiveDescription,
984+
extend_op: PrimitiveDescription,
984985
line: int,
985986
is_list: bool,
986987
) -> Value:
@@ -1001,7 +1002,7 @@ def _visit_display(
10011002
if result is None:
10021003
result = constructor_op(initial_items, line)
10031004

1004-
builder.call_c(extend_op if starred else append_op, [result, value], line)
1005+
builder.primitive_op(extend_op if starred else append_op, [result, value], line)
10051006

10061007
if result is None:
10071008
result = constructor_op(initial_items, line)
@@ -1030,7 +1031,7 @@ def transform_dictionary_comprehension(builder: IRBuilder, o: DictionaryComprehe
10301031
def gen_inner_stmts() -> None:
10311032
k = builder.accept(o.key)
10321033
v = builder.accept(o.value)
1033-
builder.call_c(dict_set_item_op, [builder.read(d), k, v], o.line)
1034+
builder.primitive_op(dict_set_item_op, [builder.read(d), k, v], o.line)
10341035

10351036
comprehension_helper(builder, loop_params, gen_inner_stmts, o.line)
10361037
return builder.read(d)

mypyc/irbuild/for_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def translate_list_comprehension(builder: IRBuilder, gen: GeneratorExpr) -> Valu
251251

252252
def gen_inner_stmts() -> None:
253253
e = builder.accept(gen.left_expr)
254-
builder.call_c(list_append_op, [builder.read(list_ops), e], gen.line)
254+
builder.primitive_op(list_append_op, [builder.read(list_ops), e], gen.line)
255255

256256
comprehension_helper(builder, loop_params, gen_inner_stmts, gen.line)
257257
return builder.read(list_ops)
@@ -286,7 +286,7 @@ def translate_set_comprehension(builder: IRBuilder, gen: GeneratorExpr) -> Value
286286

287287
def gen_inner_stmts() -> None:
288288
e = builder.accept(gen.left_expr)
289-
builder.call_c(set_add_op, [builder.read(set_ops), e], gen.line)
289+
builder.primitive_op(set_add_op, [builder.read(set_ops), e], gen.line)
290290

291291
comprehension_helper(builder, loop_params, gen_inner_stmts, gen.line)
292292
return builder.read(set_ops)

mypyc/irbuild/function.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def transform_decorator(builder: IRBuilder, dec: Decorator) -> None:
133133

134134
if decorated_func is not None:
135135
# Set the callable object representing the decorated function as a global.
136-
builder.call_c(
136+
builder.primitive_op(
137137
dict_set_item_op,
138138
[builder.load_globals_dict(), builder.load_str(dec.func.name), decorated_func],
139139
decorated_func.line,
@@ -849,7 +849,7 @@ def generate_singledispatch_dispatch_function(
849849
dispatch_func_obj, "dispatch_cache", dict_rprimitive, line
850850
)
851851
call_find_impl, use_cache, call_func = BasicBlock(), BasicBlock(), BasicBlock()
852-
get_result = builder.call_c(dict_get_method_with_none, [dispatch_cache, arg_type], line)
852+
get_result = builder.primitive_op(dict_get_method_with_none, [dispatch_cache, arg_type], line)
853853
is_not_none = builder.translate_is_op(get_result, builder.none_object(), "is not", line)
854854
impl_to_use = Register(object_rprimitive)
855855
builder.add_bool_branch(is_not_none, use_cache, call_find_impl)
@@ -862,7 +862,7 @@ def generate_singledispatch_dispatch_function(
862862
find_impl = builder.load_module_attr_by_fullname("functools._find_impl", line)
863863
registry = load_singledispatch_registry(builder, dispatch_func_obj, line)
864864
uncached_impl = builder.py_call(find_impl, [arg_type, registry], line)
865-
builder.call_c(dict_set_item_op, [dispatch_cache, arg_type, uncached_impl], line)
865+
builder.primitive_op(dict_set_item_op, [dispatch_cache, arg_type, uncached_impl], line)
866866
builder.assign(impl_to_use, uncached_impl, line)
867867
builder.goto(call_func)
868868

@@ -1039,7 +1039,7 @@ def maybe_insert_into_registry_dict(builder: IRBuilder, fitem: FuncDef) -> None:
10391039
registry = load_singledispatch_registry(builder, dispatch_func_obj, line)
10401040
for typ in types:
10411041
loaded_type = load_type(builder, typ, line)
1042-
builder.call_c(dict_set_item_op, [registry, loaded_type, to_insert], line)
1042+
builder.primitive_op(dict_set_item_op, [registry, loaded_type, to_insert], line)
10431043
dispatch_cache = builder.builder.get_attr(
10441044
dispatch_func_obj, "dispatch_cache", dict_rprimitive, line
10451045
)

mypyc/irbuild/ll_builder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,7 @@ def _construct_varargs(
762762
if kind == ARG_STAR:
763763
if star_result is None:
764764
star_result = self.new_list_op(star_values, line)
765-
self.call_c(list_extend_op, [star_result, value], line)
765+
self.primitive_op(list_extend_op, [star_result, value], line)
766766
elif kind == ARG_STAR2:
767767
if star2_result is None:
768768
star2_result = self._create_dict(star2_keys, star2_values, line)
@@ -1939,7 +1939,7 @@ def primitive_op(
19391939
desc.priority,
19401940
is_pure=desc.is_pure,
19411941
)
1942-
return self.call_c(c_desc, args, line, result_type)
1942+
return self.call_c(c_desc, args, line, result_type=result_type)
19431943

19441944
# This primitive gets transformed in a lowering pass to
19451945
# lower-level IR ops using a custom transform function.
@@ -2005,7 +2005,7 @@ def matching_primitive_op(
20052005
else:
20062006
matching = desc
20072007
if matching:
2008-
return self.primitive_op(matching, args, line=line)
2008+
return self.primitive_op(matching, args, line=line, result_type=result_type)
20092009
return None
20102010

20112011
def int_op(self, type: RType, lhs: Value, rhs: Value, op: int, line: int = -1) -> Value:
@@ -2346,11 +2346,11 @@ def translate_special_method_call(
23462346
23472347
Return None if no translation found; otherwise return the target register.
23482348
"""
2349-
call_c_ops_candidates = method_call_ops.get(name, [])
2350-
call_c_op = self.matching_call_c(
2351-
call_c_ops_candidates, [base_reg] + args, line, result_type, can_borrow=can_borrow
2349+
primitive_ops_candidates = method_call_ops.get(name, [])
2350+
primitive_op = self.matching_primitive_op(
2351+
primitive_ops_candidates, [base_reg] + args, line, result_type, can_borrow=can_borrow
23522352
)
2353-
return call_c_op
2353+
return primitive_op
23542354

23552355
def translate_eq_cmp(self, lreg: Value, rreg: Value, expr_op: str, line: int) -> Value | None:
23562356
"""Add a equality comparison operation.

mypyc/primitives/registry.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ class LoadAddressDescription(NamedTuple):
7070
src: str # name of the target to load
7171

7272

73-
# CallC op for method call (such as 'str.join')
74-
method_call_ops: dict[str, list[CFunctionDescription]] = {}
73+
# Primitive ops for method call (such as 'str.join')
74+
method_call_ops: dict[str, list[PrimitiveDescription]] = {}
7575

7676
# Primitive ops for top level function call (such as 'builtins.list')
7777
function_ops: dict[str, list[PrimitiveDescription]] = {}
@@ -99,7 +99,7 @@ def method_op(
9999
is_borrowed: bool = False,
100100
priority: int = 1,
101101
is_pure: bool = False,
102-
) -> CFunctionDescription:
102+
) -> PrimitiveDescription:
103103
"""Define a c function call op that replaces a method call.
104104
105105
This will be automatically generated by matching against the AST.
@@ -129,7 +129,7 @@ def method_op(
129129
if extra_int_constants is None:
130130
extra_int_constants = []
131131
ops = method_call_ops.setdefault(name, [])
132-
desc = CFunctionDescription(
132+
desc = PrimitiveDescription(
133133
name,
134134
arg_types,
135135
return_type,

mypyc/test/test_cheader.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from mypyc.ir.ops import PrimitiveDescription
1111
from mypyc.primitives import registry
12-
from mypyc.primitives.registry import CFunctionDescription
1312

1413

1514
class TestHeaderInclusion(unittest.TestCase):
@@ -26,14 +25,8 @@ def check_name(name: str) -> None:
2625
rf"\b{name}\b", header
2726
), f'"{name}" is used in mypyc.primitives but not declared in CPy.h'
2827

29-
for old_values in [registry.method_call_ops.values()]:
30-
for old_ops in old_values:
31-
if isinstance(old_ops, CFunctionDescription):
32-
old_ops = [old_ops]
33-
for old_op in old_ops:
34-
check_name(old_op.c_function_name)
35-
3628
for values in [
29+
registry.method_call_ops.values(),
3730
registry.binary_ops.values(),
3831
registry.unary_ops.values(),
3932
registry.function_ops.values(),

mypyc/test/test_emitfunc.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def test_dec_ref_tuple_nested(self) -> None:
303303
def test_list_get_item(self) -> None:
304304
self.assert_emit(
305305
CallC(
306-
list_get_item_op.c_function_name,
306+
str(list_get_item_op.c_function_name),
307307
[self.m, self.k],
308308
list_get_item_op.return_type,
309309
list_get_item_op.steals,
@@ -317,7 +317,7 @@ def test_list_get_item(self) -> None:
317317
def test_list_set_item(self) -> None:
318318
self.assert_emit(
319319
CallC(
320-
list_set_item_op.c_function_name,
320+
str(list_set_item_op.c_function_name),
321321
[self.l, self.n, self.o],
322322
list_set_item_op.return_type,
323323
list_set_item_op.steals,
@@ -353,7 +353,7 @@ def test_unbox_i64(self) -> None:
353353
def test_list_append(self) -> None:
354354
self.assert_emit(
355355
CallC(
356-
list_append_op.c_function_name,
356+
str(list_append_op.c_function_name),
357357
[self.l, self.o],
358358
list_append_op.return_type,
359359
list_append_op.steals,
@@ -493,7 +493,7 @@ def test_set_attr_init_with_bitmap(self) -> None:
493493
def test_dict_get_item(self) -> None:
494494
self.assert_emit(
495495
CallC(
496-
dict_get_item_op.c_function_name,
496+
str(dict_get_item_op.c_function_name),
497497
[self.d, self.o2],
498498
dict_get_item_op.return_type,
499499
dict_get_item_op.steals,
@@ -507,7 +507,7 @@ def test_dict_get_item(self) -> None:
507507
def test_dict_set_item(self) -> None:
508508
self.assert_emit(
509509
CallC(
510-
dict_set_item_op.c_function_name,
510+
str(dict_set_item_op.c_function_name),
511511
[self.d, self.o, self.o2],
512512
dict_set_item_op.return_type,
513513
dict_set_item_op.steals,
@@ -521,7 +521,7 @@ def test_dict_set_item(self) -> None:
521521
def test_dict_update(self) -> None:
522522
self.assert_emit(
523523
CallC(
524-
dict_update_op.c_function_name,
524+
str(dict_update_op.c_function_name),
525525
[self.d, self.o],
526526
dict_update_op.return_type,
527527
dict_update_op.steals,

0 commit comments

Comments
 (0)