Skip to content

Commit 2b458bc

Browse files
committed
feat: dict literals wip
1 parent dce8e1c commit 2b458bc

File tree

4 files changed

+213
-5
lines changed

4 files changed

+213
-5
lines changed

mypyc/codegen/literals.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
from typing import Final, Union
44
from typing_extensions import TypeGuard
55

6-
# Supported Python literal types. All tuple / frozenset items must have supported
6+
# Supported Python literal types. All tuple / frozenset / dict items must have supported
77
# literal types as well, but we can't represent the type precisely.
88
LiteralValue = Union[
9-
str, bytes, int, bool, float, complex, tuple[object, ...], frozenset[object], None
9+
str, bytes, int, bool, float, complex, tuple[object, ...], frozenset[object], dict[object, object], None
1010
]
1111

1212

1313
def _is_literal_value(obj: object) -> TypeGuard[LiteralValue]:
14-
return isinstance(obj, (str, bytes, int, float, complex, tuple, frozenset, type(None)))
14+
return isinstance(obj, (str, bytes, int, float, complex, tuple, frozenset, dict, type(None)))
1515

1616

1717
# Some literals are singletons and handled specially (None, False and True)
@@ -30,6 +30,7 @@ def __init__(self) -> None:
3030
self.complex_literals: dict[complex, int] = {}
3131
self.tuple_literals: dict[tuple[object, ...], int] = {}
3232
self.frozenset_literals: dict[frozenset[object], int] = {}
33+
self.dict_literals: dict[tuple[tuple[object, object], ...], int] = {}
3334

3435
def record_literal(self, value: LiteralValue) -> None:
3536
"""Ensure that the literal value is available in generated code."""
@@ -70,6 +71,17 @@ def record_literal(self, value: LiteralValue) -> None:
7071
assert _is_literal_value(item)
7172
self.record_literal(item)
7273
frozenset_literals[value] = len(frozenset_literals)
74+
elif isinstance(value, dict):
75+
# Represent dicts as a tuple of sorted (key, value) pairs for uniqueness
76+
items = tuple(sorted(value.items()))
77+
dict_literals = self.dict_literals
78+
if items not in dict_literals:
79+
for k, v in items:
80+
assert _is_literal_value(k)
81+
assert _is_literal_value(v)
82+
self.record_literal(k)
83+
self.record_literal(v)
84+
dict_literals[items] = len(dict_literals)
7385
else:
7486
assert False, "invalid literal: %r" % value
7587

@@ -104,6 +116,10 @@ def literal_index(self, value: LiteralValue) -> int:
104116
n += len(self.tuple_literals)
105117
if isinstance(value, frozenset):
106118
return n + self.frozenset_literals[value]
119+
n += len(self.frozenset_literals)
120+
if isinstance(value, dict):
121+
items = tuple(sorted(value.items()))
122+
return n + self.dict_literals[items]
107123
assert False, "invalid literal: %r" % value
108124

109125
def num_literals(self) -> int:
@@ -117,6 +133,7 @@ def num_literals(self) -> int:
117133
+ len(self.complex_literals)
118134
+ len(self.tuple_literals)
119135
+ len(self.frozenset_literals)
136+
+ len(self.dict_literals)
120137
)
121138

122139
# The following methods return the C encodings of literal values
@@ -143,6 +160,36 @@ def encoded_tuple_values(self) -> list[str]:
143160
def encoded_frozenset_values(self) -> list[str]:
144161
return self._encode_collection_values(self.frozenset_literals)
145162

163+
def encoded_dict_values(self) -> list[str]:
164+
"""Encode dict values into a C array.
165+
166+
The format of the result is like this:
167+
168+
<number of dicts>
169+
<length of the first dict>
170+
<literal index of first key>
171+
<literal index of first value>
172+
...
173+
<literal index of last key>
174+
<literal index of last value>
175+
<length of the second dict>
176+
...
177+
"""
178+
values = self.dict_literals
179+
value_by_index = {index: value for value, index in values.items()}
180+
result = []
181+
count = len(values)
182+
result.append(str(count))
183+
for i in range(count):
184+
items = value_by_index[i]
185+
result.append(str(len(items)))
186+
for k, v in items:
187+
index_k = self.literal_index(k)
188+
index_v = self.literal_index(v)
189+
result.append(str(index_k))
190+
result.append(str(index_v))
191+
return result
192+
146193
def _encode_collection_values(
147194
self, values: dict[tuple[object, ...], int] | dict[frozenset[object], int]
148195
) -> list[str]:

mypyc/irbuild/expression.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,11 @@
5959
BasicBlock,
6060
Call,
6161
ComparisonOp,
62+
InitStatic,
6263
Integer,
6364
LoadAddress,
6465
LoadLiteral,
66+
LoadStatic,
6567
PrimitiveDescription,
6668
RaiseStandardError,
6769
Register,
@@ -73,6 +75,7 @@
7375
RInstance,
7476
RTuple,
7577
bool_rprimitive,
78+
dict_rprimitive,
7679
int_rprimitive,
7780
is_fixed_width_rtype,
7881
is_int_rprimitive,
@@ -100,7 +103,7 @@
100103
)
101104
from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization
102105
from mypyc.primitives.bytes_ops import bytes_slice_op
103-
from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, exact_dict_set_item_op
106+
from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, dict_set_item_op, dict_template_copy_op, exact_dict_set_item_op
104107
from mypyc.primitives.generic_ops import iter_op, name_op
105108
from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op
106109
from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op
@@ -1009,8 +1012,53 @@ def _visit_tuple_display(builder: IRBuilder, expr: TupleExpr) -> Value:
10091012
return builder.primitive_op(list_tuple_op, [val_as_list], expr.line)
10101013

10111014

1015+
def dict_literal_values(builder: IRBuilder, items: Sequence[tuple[Expression | None, Expression]], line: int) -> "Value | None":
1016+
"""Try to extract a constant dict from a dict literal.
1017+
1018+
If all keys and values are deeply immutable and constant, build a static dict at module init,
1019+
register it as a static, and return a LoadStatic for it. Otherwise, return None.
1020+
"""
1021+
result = {}
1022+
for key_expr, value_expr in items:
1023+
if key_expr is None:
1024+
# ** unpacking, not a literal
1025+
# TODO: if ** is unpacking a dict literal we can use that, we just need logic
1026+
return None
1027+
key = constant_fold_expr(builder, key_expr)
1028+
if key is None:
1029+
return None
1030+
value = constant_fold_expr(builder, value_expr)
1031+
if value is None:
1032+
return None
1033+
result[key] = value
1034+
1035+
dict_reg = builder.call_c(dict_new_op, [], line)
1036+
for k, v in result.items():
1037+
key = builder.load_literal_value(k)
1038+
value = builder.load_literal_value(v)
1039+
builder.primitive_op(dict_set_item_op, [dict_reg, key, value], line)
1040+
1041+
# Register as a static with a unique name
1042+
static_name = f"__mypyc_dict_template__{abs(hash(tuple(sorted(result.items()))))}"
1043+
1044+
if _static_dicts.get(static_name) is None:
1045+
_static_dicts[static_name] = builder.add(InitStatic(dict_reg, static_name, builder.module_name))
1046+
1047+
return builder.add(LoadStatic(dict_rprimitive, static_name, builder.module_name))
1048+
10121049
def transform_dict_expr(builder: IRBuilder, expr: DictExpr) -> Value:
1013-
"""First accepts all keys and values, then makes a dict out of them."""
1050+
"""First accepts all keys and values, then makes a dict out of them.
1051+
1052+
Optimization: If all keys and values are deeply immutable, emit a static template dict
1053+
and at runtime use PyDict_Copy to return a fresh dict.
1054+
"""
1055+
# Try to constant fold the dict and get a static Value
1056+
template = dict_literal_values(builder, expr.items, expr.line)
1057+
if template is not None:
1058+
# At runtime, return PyDict_Copy(template)
1059+
return builder.call_c(dict_template_copy_op, [template], expr.line)
1060+
1061+
# Fallback: build dict at runtime as before
10141062
key_value_pairs = []
10151063
for key_expr, value_expr in expr.items:
10161064
key = builder.accept(key_expr) if key_expr is not None else None
@@ -1020,6 +1068,8 @@ def transform_dict_expr(builder: IRBuilder, expr: DictExpr) -> Value:
10201068
return builder.builder.make_dict(key_value_pairs, expr.line)
10211069

10221070

1071+
_static_dicts = {}
1072+
10231073
def transform_set_expr(builder: IRBuilder, expr: SetExpr) -> Value:
10241074
return _visit_display(
10251075
builder, expr.items, builder.new_set_op, set_add_op, set_update_op, expr.line, False

mypyc/primitives/dict_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@
6262
priority=2,
6363
)
6464

65+
dict_template_copy_op = custom_op(
66+
arg_types=[dict_rprimitive],
67+
return_type=dict_rprimitive,
68+
c_function_name="PyDict_Copy",
69+
error_kind=ERR_MAGIC,
70+
)
71+
6572
# Generic one-argument dict constructor: dict(obj)
6673
dict_copy = function_op(
6774
name="builtins.dict",

mypyc/test-data/run-dicts.test

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,107 @@ class subc(dict[Any, Any]):
368368
[file userdefineddict.py]
369369
class dict:
370370
pass
371+
372+
[case testDictConstantFolding]
373+
# Test constant folding of dict literals with deeply immutable keys/values
374+
def get_headers():
375+
return {"k0": "v0", "k1": "v1"}
376+
377+
def get_numbers():
378+
return {1: 2, 3: 4}
379+
380+
def get_tuple_key():
381+
return {(1, 2): "a", (3, 4): "b"}
382+
383+
def get_frozenset_key():
384+
return {frozenset({1, 2}): "x"}
385+
386+
def get_nested_tuple():
387+
return {(1, (2, 3)): "ok"}
388+
389+
def get_bool_none():
390+
return {True: None, False: 1}
391+
392+
def get_bytes():
393+
return {b"x": b"y"}
394+
395+
def get_mutable_value():
396+
return {"a": []}
397+
398+
def get_mutable_key():
399+
return {(1, []): "bad"}
400+
401+
def get_nested_mutable():
402+
return {(1, (2, [])): "bad"}
403+
404+
def get_unhashable_key():
405+
return {[1, 2]: "bad"}
406+
407+
def get_tuple_of_frozenset():
408+
return {(1, frozenset({2, 3})): "ok"}
409+
410+
def get_tuple_of_tuple():
411+
return {(1, (2, (3, 4))): "ok"}
412+
413+
def get_frozenset_of_tuple():
414+
return {frozenset({(1, 2), (3, 4)}): "ok"}
415+
416+
class C: pass
417+
418+
def get_user_defined():
419+
return {C(): 1}
420+
421+
def test_mutation_independence():
422+
d1 = get_headers()
423+
d2 = get_headers()
424+
d1["k0"] = "changed"
425+
assert d2["k0"] == "v0"
426+
d2["k1"] = "changed2"
427+
assert d1["k1"] == "v1"
428+
429+
def test_ineligible_not_folded():
430+
# These should not be constant folded
431+
assert get_mutable_value()["a"] == []
432+
try:
433+
get_mutable_key()
434+
except TypeError:
435+
pass
436+
else:
437+
assert False
438+
try:
439+
get_nested_mutable()
440+
except TypeError:
441+
pass
442+
else:
443+
assert False
444+
try:
445+
get_unhashable_key()
446+
except TypeError:
447+
pass
448+
else:
449+
assert False
450+
# User-defined class instance as key should not be folded
451+
d = get_user_defined()
452+
assert list(d.values()) == [1]
453+
454+
def test_eligible_folded():
455+
assert get_headers() == {"k0": "v0", "k1": "v1"}
456+
assert get_numbers() == {1: 2, 3: 4}
457+
assert get_tuple_key() == {(1, 2): "a", (3, 4): "b"}
458+
assert get_frozenset_key() == {frozenset({1, 2}): "x"}
459+
assert get_nested_tuple() == {(1, (2, 3)): "ok"}
460+
assert get_bool_none() == {True: None, False: 1}
461+
assert get_bytes() == {b"x": b"y"}
462+
assert get_tuple_of_frozenset() == {(1, frozenset({2, 3})): "ok"}
463+
assert get_tuple_of_tuple() == {(1, (2, (3, 4))): "ok"}
464+
assert get_frozenset_of_tuple() == {frozenset({(1, 2), (3, 4)}): "ok"}
465+
466+
[file driver.py]
467+
from native import (
468+
test_mutation_independence,
469+
test_ineligible_not_folded,
470+
test_eligible_folded,
471+
)
472+
test_mutation_independence()
473+
test_ineligible_not_folded()
474+
test_eligible_folded()

0 commit comments

Comments
 (0)