diff --git a/mypyc/analysis/ircheck.py b/mypyc/analysis/ircheck.py index 6980c9cee419..4885fb34d9ea 100644 --- a/mypyc/analysis/ircheck.py +++ b/mypyc/analysis/ircheck.py @@ -285,6 +285,16 @@ def check_frozenset_items_valid_literals(self, op: LoadLiteral, s: frozenset[obj else: self.fail(op, f"Invalid type for item of frozenset literal: {type(x)})") + def check_dict_items_valid_literals(self, op: LoadLiteral, d: dict[object, object]) -> None: + valid_types = (str, bytes, bool, int, float, complex) + for k, v in d.items(): + # Acceptable key types: str, bytes, bool, int, float, complex + if not isinstance(k, valid_types): + self.fail(op, f"Invalid type for key of dict literal: {type(k)})") + # Acceptable value types: str, bytes, bool, int, float, complex + if not isinstance(v, valid_types): + self.fail(op, f"Invalid type for value of dict literal: {type(v)})") + def visit_load_literal(self, op: LoadLiteral) -> None: expected_type = None if op.value is None: @@ -309,6 +319,9 @@ def visit_load_literal(self, op: LoadLiteral) -> None: # it's a set (when it's really a frozenset). expected_type = "builtins.set" self.check_frozenset_items_valid_literals(op, op.value) + elif isinstance(op.value, dict): + expected_type = "builtins.dict" + self.check_dict_items_valid_literals(op, op.value) assert expected_type is not None, "Missed a case for LoadLiteral check" diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index ca5db52ab7da..8b7806ff7470 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -708,6 +708,9 @@ def generate_literal_tables(self) -> None: # Descriptions of frozenset literals init_frozenset = c_array_initializer(literals.encoded_frozenset_values()) self.declare_global("const int []", "CPyLit_FrozenSet", initializer=init_frozenset) + # Descriptions of dict literals + init_dict = c_array_initializer(literals.encoded_dict_values()) + self.declare_global("const int []", "CPyLit_Dict", initializer=init_dict) def generate_export_table(self, decl_emitter: Emitter, code_emitter: Emitter) -> None: """Generate the declaration and definition of the group's export struct. @@ -926,7 +929,7 @@ def generate_globals_init(self, emitter: Emitter) -> None: for symbol, fixup in self.simple_inits: emitter.emit_line(f"{symbol} = {fixup};") - values = "CPyLit_Str, CPyLit_Bytes, CPyLit_Int, CPyLit_Float, CPyLit_Complex, CPyLit_Tuple, CPyLit_FrozenSet" + values = "CPyLit_Str, CPyLit_Bytes, CPyLit_Int, CPyLit_Float, CPyLit_Complex, CPyLit_Tuple, CPyLit_FrozenSet, CPyLit_Dict" emitter.emit_lines( f"if (CPyStatics_Initialize(CPyStatics, {values}) < 0) {{", "return -1;", "}" ) diff --git a/mypyc/codegen/literals.py b/mypyc/codegen/literals.py index 4cd41e0f4d32..79da68728d9a 100644 --- a/mypyc/codegen/literals.py +++ b/mypyc/codegen/literals.py @@ -3,15 +3,24 @@ from typing import Final, Union from typing_extensions import TypeGuard -# Supported Python literal types. All tuple / frozenset items must have supported +# Supported Python literal types. All tuple / frozenset / dict items must have supported # literal types as well, but we can't represent the type precisely. LiteralValue = Union[ - str, bytes, int, bool, float, complex, tuple[object, ...], frozenset[object], None + str, + bytes, + int, + bool, + float, + complex, + tuple[object, ...], + frozenset[object], + dict[object, object], + None, ] def _is_literal_value(obj: object) -> TypeGuard[LiteralValue]: - return isinstance(obj, (str, bytes, int, float, complex, tuple, frozenset, type(None))) + return isinstance(obj, (str, bytes, int, float, complex, tuple, frozenset, dict, type(None))) # Some literals are singletons and handled specially (None, False and True) @@ -30,6 +39,7 @@ def __init__(self) -> None: self.complex_literals: dict[complex, int] = {} self.tuple_literals: dict[tuple[object, ...], int] = {} self.frozenset_literals: dict[frozenset[object], int] = {} + self.dict_literals: dict[tuple[tuple[object, object], ...], int] = {} def record_literal(self, value: LiteralValue) -> None: """Ensure that the literal value is available in generated code.""" @@ -70,6 +80,16 @@ def record_literal(self, value: LiteralValue) -> None: assert _is_literal_value(item) self.record_literal(item) frozenset_literals[value] = len(frozenset_literals) + elif isinstance(value, dict): + items = self.make_dict_literal_key(value) # type: ignore [arg-type] + dict_literals = self.dict_literals + if items not in dict_literals: + for k, v in items: + assert _is_literal_value(k) + assert _is_literal_value(v) + self.record_literal(k) + self.record_literal(v) + dict_literals[items] = len(dict_literals) else: assert False, "invalid literal: %r" % value @@ -104,8 +124,18 @@ def literal_index(self, value: LiteralValue) -> int: n += len(self.tuple_literals) if isinstance(value, frozenset): return n + self.frozenset_literals[value] + n += len(self.frozenset_literals) + if isinstance(value, dict): + key = self.make_dict_literal_key(value) # type: ignore [arg-type] + return n + self.dict_literals[key] assert False, "invalid literal: %r" % value + def make_dict_literal_key( + self, value: dict[LiteralValue, LiteralValue] + ) -> tuple[tuple[LiteralValue, LiteralValue], ...]: + """Make a unique key for a literal dict.""" + return tuple(value.items()) + def num_literals(self) -> int: # The first three are for None, True and False return ( @@ -117,6 +147,7 @@ def num_literals(self) -> int: + len(self.complex_literals) + len(self.tuple_literals) + len(self.frozenset_literals) + + len(self.dict_literals) ) # The following methods return the C encodings of literal values @@ -143,6 +174,36 @@ def encoded_tuple_values(self) -> list[str]: def encoded_frozenset_values(self) -> list[str]: return self._encode_collection_values(self.frozenset_literals) + def encoded_dict_values(self) -> list[str]: + """Encode dict values into a C array. + + The format of the result is like this: + + + + + + ... + + + + ... + """ + values = self.dict_literals + value_by_index = {index: value for value, index in values.items()} + result = [] + count = len(values) + result.append(str(count)) + for i in range(count): + items = value_by_index[i] + result.append(str(len(items))) + for k, v in items: + index_k = self.literal_index(k) # type: ignore [arg-type] + index_v = self.literal_index(v) # type: ignore [arg-type] + result.append(str(index_k)) + result.append(str(index_v)) + return result + def _encode_collection_values( self, values: dict[tuple[object, ...], int] | dict[frozenset[object], int] ) -> list[str]: diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index 76c1e07a79d5..27aa19ef13ef 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -812,6 +812,10 @@ class LoadLiteral(RegisterOp): Tuple / frozenset literals must contain only valid literal values as items. + Dict literals must contain only literal keys and literal values. + Due to their mutability, dict literals will be copied from the main template + at each use. + NOTE: You can use this to load boxed (Python) int objects. Use Integer to load unboxed, tagged integers or fixed-width, low-level integers. diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 59ecc4ac2c5c..3ed31ab3e29d 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -8,7 +8,8 @@ import math from collections.abc import Sequence -from typing import Callable +from functools import partial +from typing import Callable, Union from mypy.nodes import ( ARG_NAMED, @@ -72,6 +73,7 @@ RInstance, RTuple, bool_rprimitive, + dict_rprimitive, int_rprimitive, is_fixed_width_rtype, is_int_rprimitive, @@ -83,7 +85,7 @@ ) from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional from mypyc.irbuild.builder import IRBuilder, int_borrow_friendly_op -from mypyc.irbuild.constant_fold import constant_fold_expr +from mypyc.irbuild.constant_fold import ConstantValue, constant_fold_expr from mypyc.irbuild.for_helpers import ( comprehension_helper, raise_error_if_contains_unreachable_names, @@ -104,7 +106,12 @@ translate_object_setattr, ) from mypyc.primitives.bytes_ops import bytes_slice_op -from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, exact_dict_set_item_op +from mypyc.primitives.dict_ops import ( + dict_get_item_op, + dict_new_op, + dict_template_copy_op, + exact_dict_set_item_op, +) from mypyc.primitives.generic_ops import iter_op, name_op from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op @@ -113,6 +120,8 @@ from mypyc.primitives.str_ops import str_slice_op from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op +ConstantValueTuple = Union[ConstantValue, tuple["ConstantValueTuple", ...]] + # Name and attribute references @@ -999,8 +1008,61 @@ def _visit_tuple_display(builder: IRBuilder, expr: TupleExpr) -> Value: return builder.primitive_op(list_tuple_op, [val_as_list], expr.line) +def dict_literal_values( + builder: IRBuilder, items: Sequence[tuple[Expression | None, Expression]], line: int +) -> dict[ConstantValueTuple, ConstantValueTuple] | None: + """Try to extract a constant dict from a dict literal, recursively staticizing nested dicts. + + If all keys and values are deeply immutable and constant (including nested dicts as values), + return the Python dict value. Otherwise, return None. + """ + + def constant_fold_expr_or_tuple( + builder: IRBuilder, expr: Expression + ) -> ConstantValueTuple | None: + value = constant_fold_expr(builder, expr) + if value is not None: + return value + if not isinstance(expr, TupleExpr): + return None + folded = tuple( + const + for const in map(partial(constant_fold_expr_or_tuple, builder), expr.items) + if const is not None + ) + return folded if len(folded) == len(expr.items) else None + + result = {} + for key_expr, value_expr in items: + if key_expr is None: + # ** unpacking, not a literal + # TODO: if ** is unpacking a dict literal we can use that, we just need logic + return None + key = constant_fold_expr_or_tuple(builder, key_expr) + if key is None: + return None + value = constant_fold_expr_or_tuple(builder, value_expr) + if value is None: + return None + result[key] = value + + return result or None + + def transform_dict_expr(builder: IRBuilder, expr: DictExpr) -> Value: - """First accepts all keys and values, then makes a dict out of them.""" + """First accepts all keys and values, then makes a dict out of them. + + Optimization: If all keys and values are deeply immutable, emit a static template dict + and at runtime use PyDict_Copy to return a fresh dict. + """ + # Try to constant fold the dict and get a static Value + static_dict = dict_literal_values(builder, expr.items, expr.line) + if static_dict is not None: + # Register the static dict and return a copy at runtime + static_val = builder.add(LoadLiteral(static_dict, dict_rprimitive)) # type: ignore [arg-type] + return builder.call_c(dict_template_copy_op, [static_val], expr.line) + + # If that fails, build dict at runtime key_value_pairs = [] for key_expr, value_expr in expr.items: key = builder.accept(key_expr) if key_expr is not None else None diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index e9dfd8de3683..afe42944a882 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -931,7 +931,8 @@ int CPyStatics_Initialize(PyObject **statics, const double *floats, const double *complex_numbers, const int *tuples, - const int *frozensets); + const int *frozensets, + const int *dicts); PyObject *CPy_Super(PyObject *builtins, PyObject *self); PyObject *CPy_CallReverseOpMethod(PyObject *left, PyObject *right, const char *op, _Py_Identifier *method); diff --git a/mypyc/lib-rt/misc_ops.c b/mypyc/lib-rt/misc_ops.c index ca09c347b4ff..64aa6478f5e1 100644 --- a/mypyc/lib-rt/misc_ops.c +++ b/mypyc/lib-rt/misc_ops.c @@ -615,7 +615,8 @@ int CPyStatics_Initialize(PyObject **statics, const double *floats, const double *complex_numbers, const int *tuples, - const int *frozensets) { + const int *frozensets, + const int *dicts) { PyObject **result = statics; // Start with some hard-coded values *result++ = Py_None; @@ -733,6 +734,31 @@ int CPyStatics_Initialize(PyObject **statics, *result++ = obj; } } + if (dicts) { + int num = *dicts++; + while (num-- > 0) { + int num_items = *dicts++; + PyObject *obj = PyDict_New(); + if (obj == NULL) { + return -1; + } + for (int i = 0; i < num_items; i++) { + PyObject *key = statics[*dicts++]; + PyObject *value = statics[*dicts++]; + Py_INCREF(key); + Py_INCREF(value); + if (PyDict_SetItem(obj, key, value) == -1) { + Py_DECREF(key); + Py_DECREF(value); + Py_DECREF(obj); + return -1; + } + Py_DECREF(key); + Py_DECREF(value); + } + *result++ = obj; + } + } return 0; } diff --git a/mypyc/primitives/dict_ops.py b/mypyc/primitives/dict_ops.py index f98bcc8ac2ec..5a3f065eee1e 100644 --- a/mypyc/primitives/dict_ops.py +++ b/mypyc/primitives/dict_ops.py @@ -62,6 +62,13 @@ priority=2, ) +dict_template_copy_op = custom_op( + arg_types=[dict_rprimitive], + return_type=dict_rprimitive, + c_function_name="PyDict_Copy", + error_kind=ERR_MAGIC, +) + # Generic one-argument dict constructor: dict(obj) dict_copy = function_op( name="builtins.dict", diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index 612f3266fd79..3519c4ac1d1f 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -1730,61 +1730,49 @@ L0: r0 = (a, b, c) return r0 def g(): - r0, r1, r2 :: str - r3, r4, r5 :: object - r6, r7 :: dict - r8 :: str - r9 :: object - r10 :: tuple - r11 :: dict - r12 :: object - r13 :: tuple[int, int, int] -L0: - r0 = 'a' - r1 = 'b' - r2 = 'c' - r3 = object 1 - r4 = object 2 - r5 = object 3 - r6 = CPyDict_Build(3, r0, r3, r1, r4, r2, r5) - r7 = __main__.globals :: static - r8 = 'f' - r9 = CPyDict_GetItem(r7, r8) - r10 = CPyTuple_LoadEmptyTupleConstant() - r11 = PyDict_Copy(r6) - r12 = PyObject_Call(r9, r10, r11) - r13 = unbox(tuple[int, int, int], r12) - return r13 -def h(): - r0, r1 :: str - r2, r3 :: object - r4, r5 :: dict - r6 :: str + r0, r1, r2 :: dict + r3 :: str + r4 :: object + r5 :: tuple + r6 :: dict r7 :: object - r8 :: dict - r9 :: i32 - r10 :: bit - r11 :: object - r12 :: tuple - r13 :: object - r14 :: tuple[int, int, int] + r8 :: tuple[int, int, int] +L0: + r0 = {'a': 1, 'b': 2, 'c': 3} + r1 = PyDict_Copy(r0) + r2 = __main__.globals :: static + r3 = 'f' + r4 = CPyDict_GetItem(r2, r3) + r5 = CPyTuple_LoadEmptyTupleConstant() + r6 = PyDict_Copy(r1) + r7 = PyObject_Call(r4, r5, r6) + r8 = unbox(tuple[int, int, int], r7) + return r8 +def h(): + r0, r1, r2 :: dict + r3 :: str + r4 :: object + r5 :: dict + r6 :: i32 + r7 :: bit + r8 :: object + r9 :: tuple + r10 :: object + r11 :: tuple[int, int, int] L0: - r0 = 'b' - r1 = 'c' - r2 = object 2 - r3 = object 3 - r4 = CPyDict_Build(2, r0, r2, r1, r3) - r5 = __main__.globals :: static - r6 = 'f' - r7 = CPyDict_GetItem(r5, r6) - r8 = PyDict_New() - r9 = CPyDict_UpdateInDisplay(r8, r4) - r10 = r9 >= 0 :: signed - r11 = object 1 - r12 = PyTuple_Pack(1, r11) - r13 = PyObject_Call(r7, r12, r8) - r14 = unbox(tuple[int, int, int], r13) - return r14 + r0 = {'b': 2, 'c': 3} + r1 = PyDict_Copy(r0) + r2 = __main__.globals :: static + r3 = 'f' + r4 = CPyDict_GetItem(r2, r3) + r5 = PyDict_New() + r6 = CPyDict_UpdateInDisplay(r5, r1) + r7 = r6 >= 0 :: signed + r8 = object 1 + r9 = PyTuple_Pack(1, r8) + r10 = PyObject_Call(r4, r9, r5) + r11 = unbox(tuple[int, int, int], r10) + return r11 [case testFunctionCallWithDefaultArgs] def f(x: int, y: int = 3, z: str = "test") -> None: diff --git a/mypyc/test-data/irbuild-dict.test b/mypyc/test-data/irbuild-dict.test index e0c014f07813..f8e200903aca 100644 --- a/mypyc/test-data/irbuild-dict.test +++ b/mypyc/test-data/irbuild-dict.test @@ -545,10 +545,9 @@ def f4(d, flag): flag :: bool r0 :: str r1 :: object - r2, r3 :: str - r4 :: object - r5 :: dict - r6, r7 :: object + r2 :: str + r3, r4 :: dict + r5, r6 :: object L0: if flag goto L1 else goto L2 :: bool L1: @@ -557,11 +556,25 @@ L1: return r1 L2: r2 = 'a' - r3 = 'c' - r4 = object 1 - r5 = CPyDict_Build(1, r3, r4) - r6 = CPyDict_SetDefault(d, r2, r5) - return r6 + r3 = {'c': 1} + r4 = PyDict_Copy(r3) + r5 = CPyDict_SetDefault(d, r2, r4) + return r5 L3: - r7 = box(None, 1) - return r7 + r6 = box(None, 1) + return r6 + +[case testNestedDictLiteral] +def f() -> None: + d = {"a": {2: 3.4}} +[out] +def f(): + r0 :: str + r1, r2, r3, d :: dict +L0: + r0 = 'a' + r1 = {2: 3.4} + r2 = PyDict_Copy(r1) + r3 = CPyDict_Build(1, r0, r2) + d = r3 + return 1 diff --git a/mypyc/test-data/irbuild-set.test b/mypyc/test-data/irbuild-set.test index 5586a2bf4cfb..ed174cf0d937 100644 --- a/mypyc/test-data/irbuild-set.test +++ b/mypyc/test-data/irbuild-set.test @@ -159,56 +159,49 @@ L5: b = r1 return 1 def test3(): - r0, r1, r2 :: str - r3, r4, r5 :: object - r6, tmp_dict :: dict - r7 :: set - r8 :: short_int - r9 :: native_int - r10 :: object - r11 :: tuple[bool, short_int, object] - r12 :: short_int - r13 :: bool - r14 :: object - r15, x, r16 :: int - r17 :: object - r18 :: i32 - r19, r20, r21 :: bit + r0, r1, tmp_dict :: dict + r2 :: set + r3 :: short_int + r4 :: native_int + r5 :: object + r6 :: tuple[bool, short_int, object] + r7 :: short_int + r8 :: bool + r9 :: object + r10, x, r11 :: int + r12 :: object + r13 :: i32 + r14, r15, r16 :: bit c :: set L0: - r0 = '1' - r1 = '3' - r2 = '5' - r3 = object 1 - r4 = object 3 - r5 = object 5 - r6 = CPyDict_Build(3, r3, r0, r4, r1, r5, r2) - tmp_dict = r6 - r7 = PySet_New(0) - r8 = 0 - r9 = PyDict_Size(tmp_dict) - r10 = CPyDict_GetKeysIter(tmp_dict) + r0 = {1: '1', 3: '3', 5: '5'} + r1 = PyDict_Copy(r0) + tmp_dict = r1 + r2 = PySet_New(0) + r3 = 0 + r4 = PyDict_Size(tmp_dict) + r5 = CPyDict_GetKeysIter(tmp_dict) L1: - r11 = CPyDict_NextKey(r10, r8) - r12 = r11[1] - r8 = r12 - r13 = r11[0] - if r13 goto L2 else goto L4 :: bool + r6 = CPyDict_NextKey(r5, r3) + r7 = r6[1] + r3 = r7 + r8 = r6[0] + if r8 goto L2 else goto L4 :: bool L2: - r14 = r11[2] - r15 = unbox(int, r14) - x = r15 - r16 = f(x) - r17 = box(int, r16) - r18 = PySet_Add(r7, r17) - r19 = r18 >= 0 :: signed + r9 = r6[2] + r10 = unbox(int, r9) + x = r10 + r11 = f(x) + r12 = box(int, r11) + r13 = PySet_Add(r2, r12) + r14 = r13 >= 0 :: signed L3: - r20 = CPyDict_CheckSize(tmp_dict, r9) + r15 = CPyDict_CheckSize(tmp_dict, r4) goto L1 L4: - r21 = CPy_NoErrOccurred() + r16 = CPy_NoErrOccurred() L5: - c = r7 + c = r2 return 1 def test4(): r0 :: set diff --git a/mypyc/test-data/irbuild-statements.test b/mypyc/test-data/irbuild-statements.test index 48b8e0e318b8..7f6d128835ef 100644 --- a/mypyc/test-data/irbuild-statements.test +++ b/mypyc/test-data/irbuild-statements.test @@ -757,49 +757,35 @@ def delDictMultiple() -> None: del d["one"], d["four"] [out] def delDict(): - r0, r1 :: str - r2, r3 :: object - r4, d :: dict - r5 :: str - r6 :: i32 - r7 :: bit + r0, r1, d :: dict + r2 :: str + r3 :: i32 + r4 :: bit L0: - r0 = 'one' - r1 = 'two' - r2 = object 1 - r3 = object 2 - r4 = CPyDict_Build(2, r0, r2, r1, r3) - d = r4 - r5 = 'one' - r6 = PyObject_DelItem(d, r5) - r7 = r6 >= 0 :: signed + r0 = {'one': 1, 'two': 2} + r1 = PyDict_Copy(r0) + d = r1 + r2 = 'one' + r3 = PyObject_DelItem(d, r2) + r4 = r3 >= 0 :: signed return 1 def delDictMultiple(): - r0, r1, r2, r3 :: str - r4, r5, r6, r7 :: object - r8, d :: dict - r9, r10 :: str - r11 :: i32 - r12 :: bit - r13 :: i32 - r14 :: bit + r0, r1, d :: dict + r2, r3 :: str + r4 :: i32 + r5 :: bit + r6 :: i32 + r7 :: bit L0: - r0 = 'one' - r1 = 'two' - r2 = 'three' + r0 = {'one': 1, 'two': 2, 'three': 3, 'four': 4} + r1 = PyDict_Copy(r0) + d = r1 + r2 = 'one' r3 = 'four' - r4 = object 1 - r5 = object 2 - r6 = object 3 - r7 = object 4 - r8 = CPyDict_Build(4, r0, r4, r1, r5, r2, r6, r3, r7) - d = r8 - r9 = 'one' - r10 = 'four' - r11 = PyObject_DelItem(d, r9) - r12 = r11 >= 0 :: signed - r13 = PyObject_DelItem(d, r10) - r14 = r13 >= 0 :: signed + r4 = PyObject_DelItem(d, r2) + r5 = r4 >= 0 :: signed + r6 = PyObject_DelItem(d, r3) + r7 = r6 >= 0 :: signed return 1 [case testDelAttribute] diff --git a/mypyc/test-data/run-dicts.test b/mypyc/test-data/run-dicts.test index 2b75b32c906e..a40966cc4a94 100644 --- a/mypyc/test-data/run-dicts.test +++ b/mypyc/test-data/run-dicts.test @@ -368,3 +368,115 @@ class subc(dict[Any, Any]): [file userdefineddict.py] class dict: pass + +[case testDictConstantFolding] +# Test constant folding of dict literals with deeply immutable keys/values +def get_headers(): + return {"k0": "v0", "k1": "v1"} + +def get_numbers(): + return {1: 2, 3: 4} + +def get_tuple_key(): + return {(1, 2): "a", (3, 4): "b"} + +def get_frozenset_key(): + return {frozenset({1, 2}): "x"} + +def get_nested_tuple(): + return {(1, (2, 3)): "ok"} + +def get_bool_none(): + return {True: None, False: 1} + +def get_bytes(): + return {b"x": b"y"} + +def get_mutable_value(): + return {"a": []} + +def get_mutable_key(): + return {(1, []): "bad"} + +def get_nested_mutable(): + return {(1, (2, [])): "bad"} + +def get_unhashable_key(): + return {[1, 2]: "bad"} + +def get_tuple_of_frozenset(): + return {(1, frozenset({2, 3})): "ok"} + +def get_tuple_of_tuple(): + return {(1, (2, (3, 4))): "ok"} + +def get_frozenset_of_tuple(): + return {frozenset({(1, 2), (3, 4)}): "ok"} + +class C: pass + +def get_user_defined(): + return {C(): 1} + +def test_mutation_independence(): + d1 = get_headers() + d2 = get_headers() + d1["k0"] = "changed" + assert d2["k0"] == "v0" + d2["k1"] = "changed2" + assert d1["k1"] == "v1" +def test_ineligible_not_folded(): + # These should not be constant folded + assert get_mutable_value()["a"] == [] + try: + get_mutable_key() + except TypeError: + pass + else: + assert False + try: + get_nested_mutable() + except TypeError: + pass + else: + assert False + try: + get_unhashable_key() + except TypeError: + pass + else: + assert False + # User-defined class instance as key should not be folded + d = get_user_defined() + assert list(d.values()) == [1] +def test_eligible_folded(): + assert get_headers() == {"k0": "v0", "k1": "v1"} + assert get_numbers() == {1: 2, 3: 4} + assert get_tuple_key() == {(1, 2): "a", (3, 4): "b"} + assert get_frozenset_key() == {frozenset({1, 2}): "x"} + assert get_nested_tuple() == {(1, (2, 3)): "ok"} + assert get_bool_none() == {True: None, False: 1} + assert get_bytes() == {b"x": b"y"} + assert get_tuple_of_frozenset() == {(1, frozenset({2, 3})): "ok"} + assert get_tuple_of_tuple() == {(1, (2, (3, 4))): "ok"} + assert get_frozenset_of_tuple() == {frozenset({(1, 2), (3, 4)}): "ok"} + +[case testDictLiteralIsImmutable] +def get_flat(): + return {"x": 1, "y": 2} +def get_nested(): + return {"a": {"b": 1}} +def test_shallow_mutation_independence(): + d1 = get_flat() + d2 = get_flat() + d1["x"] = 99 + assert d2["x"] == 1 + d2["y"] = 42 + assert d1["y"] == 2 +def test_deep_mutation_independence(): + d1 = get_nested() + d2 = get_nested() + d1["a"]["b"] = 99 + assert d2["a"]["b"] == 1 + d2["a"]["b"] = 42 + assert d1["a"]["b"] == 99