diff --git a/mypyc/annotate.py b/mypyc/annotate.py index bc282fc3ea6c..59975819829f 100644 --- a/mypyc/annotate.py +++ b/mypyc/annotate.py @@ -216,7 +216,7 @@ def function_annotations(func_ir: FuncIR, tree: MypyFile) -> dict[int, list[Anno ann = "Dynamic method call." elif name in op_hints: ann = op_hints[name] - elif name in ("CPyDict_GetItem", "CPyDict_SetItem"): + elif name in ("CPyDict_GetItemUnsafe", "CPyDict_SetItem"): if ( isinstance(op.args[0], LoadStatic) and isinstance(op.args[1], LoadLiteral) diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 63930123135f..1c580b1f8936 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -126,7 +126,7 @@ ) from mypyc.irbuild.util import bytes_from_str, is_constant from mypyc.options import CompilerOptions -from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op +from mypyc.primitives.dict_ops import dict_set_item_op, exact_dict_get_item_op from mypyc.primitives.generic_ops import iter_op, next_op, py_setattr_op from mypyc.primitives.list_ops import list_get_item_unsafe_op, list_pop_last, to_list from mypyc.primitives.misc_ops import check_unpack_count_op, get_module_dict_op, import_op @@ -472,7 +472,7 @@ def get_module(self, module: str, line: int) -> Value: # Python 3.7 has a nice 'PyImport_GetModule' function that we can't use :( mod_dict = self.call_c(get_module_dict_op, [], line) # Get module object from modules dict. - return self.primitive_op(dict_get_item_op, [mod_dict, self.load_str(module)], line) + return self.primitive_op(exact_dict_get_item_op, [mod_dict, self.load_str(module)], line) def get_module_attr(self, module: str, attr: str, line: int) -> Value: """Look up an attribute of a module without storing it in the local namespace. @@ -1406,7 +1406,7 @@ def load_global(self, expr: NameExpr) -> Value: def load_global_str(self, name: str, line: int) -> Value: _globals = self.load_globals_dict() reg = self.load_str(name) - return self.primitive_op(dict_get_item_op, [_globals, reg], line) + return self.primitive_op(exact_dict_get_item_op, [_globals, reg], line) def load_globals_dict(self) -> Value: return self.add(LoadStatic(dict_rprimitive, "globals", self.module_name)) diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 59ecc4ac2c5c..562c9f325bd5 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -104,7 +104,7 @@ translate_object_setattr, ) from mypyc.primitives.bytes_ops import bytes_slice_op -from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, exact_dict_set_item_op +from mypyc.primitives.dict_ops import dict_new_op, exact_dict_get_item_op, exact_dict_set_item_op from mypyc.primitives.generic_ops import iter_op, name_op from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op @@ -190,7 +190,7 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value: # instead load the module separately on each access. mod_dict = builder.call_c(get_module_dict_op, [], expr.line) obj = builder.primitive_op( - dict_get_item_op, [mod_dict, builder.load_str(expr.node.fullname)], expr.line + exact_dict_get_item_op, [mod_dict, builder.load_str(expr.node.fullname)], expr.line ) return obj else: diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index e9dfd8de3683..6287beff9c9e 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -711,6 +711,9 @@ tuple_T3CIO CPyDict_NextValue(PyObject *dict_or_iter, CPyTagged offset); tuple_T4CIOO CPyDict_NextItem(PyObject *dict_or_iter, CPyTagged offset); int CPyMapping_Check(PyObject *obj); +// Unsafe dict operations (assume PyDict_CheckExact(dict) is always true) +PyObject *CPyDict_GetItemUnsafe(PyObject *dict, PyObject *key); + // Check that dictionary didn't change size during iteration. static inline char CPyDict_CheckSize(PyObject *dict, Py_ssize_t size) { if (!PyDict_CheckExact(dict)) { diff --git a/mypyc/lib-rt/dict_ops.c b/mypyc/lib-rt/dict_ops.c index b102aba57307..9c22d34c677c 100644 --- a/mypyc/lib-rt/dict_ops.c +++ b/mypyc/lib-rt/dict_ops.c @@ -15,15 +15,7 @@ // some indirections. PyObject *CPyDict_GetItem(PyObject *dict, PyObject *key) { if (PyDict_CheckExact(dict)) { - PyObject *res = PyDict_GetItemWithError(dict, key); - if (!res) { - if (!PyErr_Occurred()) { - PyErr_SetObject(PyExc_KeyError, key); - } - } else { - Py_INCREF(res); - } - return res; + return CPyDict_GetItemUnsafe(dict, key); } else { return PyObject_GetItem(dict, key); } @@ -489,3 +481,22 @@ tuple_T4CIOO CPyDict_NextItem(PyObject *dict_or_iter, CPyTagged offset) { int CPyMapping_Check(PyObject *obj) { return Py_TYPE(obj)->tp_flags & Py_TPFLAGS_MAPPING; } + +// ======================= +// Unsafe dict operations +// ======================= + +// Unsafe: assumes dict is a true dict (PyDict_CheckExact(dict) is always true) + +PyObject *CPyDict_GetItemUnsafe(PyObject *dict, PyObject *key) { + // No type check, direct call + PyObject *res = PyDict_GetItemWithError(dict, key); + if (!res) { + if (!PyErr_Occurred()) { + PyErr_SetObject(PyExc_KeyError, key); + } + } else { + Py_INCREF(res); + } + return res; +} diff --git a/mypyc/primitives/dict_ops.py b/mypyc/primitives/dict_ops.py index f98bcc8ac2ec..c5106a6ef9ea 100644 --- a/mypyc/primitives/dict_ops.py +++ b/mypyc/primitives/dict_ops.py @@ -19,6 +19,7 @@ ERR_NEG_INT, binary_op, custom_op, + custom_primitive_op, function_op, load_address_op, method_op, @@ -80,7 +81,7 @@ error_kind=ERR_NEVER, ) -# dict[key] +# dict[key] = value dict_get_item_op = method_op( name="__getitem__", arg_types=[dict_rprimitive, object_rprimitive], @@ -89,6 +90,16 @@ error_kind=ERR_MAGIC, ) +# dict[key] = value (exact dict only, no subclasses) +# NOTE: this is currently for internal use only, and not used for CallExpr specialization +exact_dict_get_item_op = custom_primitive_op( + name="__getitem__", + arg_types=[dict_rprimitive, object_rprimitive], + return_type=object_rprimitive, + c_function_name="CPyDict_GetItemUnsafe", + error_kind=ERR_MAGIC, +) + # dict[key] = value dict_set_item_op = method_op( name="__setitem__", diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index 612f3266fd79..c42b9736d6df 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -622,7 +622,7 @@ def f(x): L0: r0 = __main__.globals :: static r1 = 'g' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) r3 = box(int, x) r4 = [r3] r5 = load_address r4 @@ -631,13 +631,13 @@ L0: r7 = unbox(int, r6) r8 = __main__.globals :: static r9 = 'h' - r10 = CPyDict_GetItem(r8, r9) + r10 = CPyDict_GetItemUnsafe(r8, r9) r11 = PyObject_Vectorcall(r10, 0, 0, 0) r12 = unbox(int, r11) r13 = CPyTagged_Add(r7, r12) r14 = __main__.globals :: static r15 = 'two' - r16 = CPyDict_GetItem(r14, r15) + r16 = CPyDict_GetItemUnsafe(r14, r15) r17 = PyObject_Vectorcall(r16, 0, 0, 0) r18 = unbox(int, r17) r19 = CPyTagged_Add(r13, r18) @@ -1147,7 +1147,7 @@ def call_python_function(x): L0: r0 = __main__.globals :: static r1 = 'f' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) r3 = box(int, x) r4 = [r3] r5 = load_address r4 @@ -1165,7 +1165,7 @@ def return_callable_type(): L0: r0 = __main__.globals :: static r1 = 'return_float' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) return r2 def call_callable_type(): r0, f, r1 :: object @@ -1436,7 +1436,7 @@ def f(): L0: r0 = __main__.globals :: static r1 = 'x' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) r3 = unbox(int, r2) r4 = builtins :: module r5 = 'print' @@ -1484,7 +1484,7 @@ L2: r9 = r8 >= 0 :: signed r10 = __main__.globals :: static r11 = 'x' - r12 = CPyDict_GetItem(r10, r11) + r12 = CPyDict_GetItemUnsafe(r10, r11) r13 = unbox(int, r12) r14 = builtins :: module r15 = 'print' @@ -1680,7 +1680,7 @@ L0: r0 = (2, 4, 6) r1 = __main__.globals :: static r2 = 'f' - r3 = CPyDict_GetItem(r1, r2) + r3 = CPyDict_GetItemUnsafe(r1, r2) r4 = box(tuple[int, int, int], r0) r5 = PyObject_CallObject(r3, r4) r6 = unbox(tuple[int, int, int], r5) @@ -1701,7 +1701,7 @@ L0: r0 = (4, 6) r1 = __main__.globals :: static r2 = 'f' - r3 = CPyDict_GetItem(r1, r2) + r3 = CPyDict_GetItemUnsafe(r1, r2) r4 = PyList_New(1) r5 = object 1 r6 = list_items r4 @@ -1749,7 +1749,7 @@ L0: r6 = CPyDict_Build(3, r0, r3, r1, r4, r2, r5) r7 = __main__.globals :: static r8 = 'f' - r9 = CPyDict_GetItem(r7, r8) + r9 = CPyDict_GetItemUnsafe(r7, r8) r10 = CPyTuple_LoadEmptyTupleConstant() r11 = PyDict_Copy(r6) r12 = PyObject_Call(r9, r10, r11) @@ -1776,7 +1776,7 @@ L0: r4 = CPyDict_Build(2, r0, r2, r1, r3) r5 = __main__.globals :: static r6 = 'f' - r7 = CPyDict_GetItem(r5, r6) + r7 = CPyDict_GetItemUnsafe(r5, r6) r8 = PyDict_New() r9 = CPyDict_UpdateInDisplay(r8, r4) r10 = r9 >= 0 :: signed @@ -2239,7 +2239,7 @@ L2: r19 = box(tuple[object, object], r18) r20 = __main__.globals :: static r21 = 'NamedTuple' - r22 = CPyDict_GetItem(r20, r21) + r22 = CPyDict_GetItemUnsafe(r20, r21) r23 = [r9, r19] r24 = load_address r23 r25 = PyObject_Vectorcall(r22, r24, 2, 0) @@ -2251,7 +2251,7 @@ L2: r30 = '' r31 = __main__.globals :: static r32 = 'Lol' - r33 = CPyDict_GetItem(r31, r32) + r33 = CPyDict_GetItemUnsafe(r31, r32) r34 = object 1 r35 = [r34, r30] r36 = load_address r35 @@ -2264,7 +2264,7 @@ L2: r42 = r41 >= 0 :: signed r43 = __main__.globals :: static r44 = 'List' - r45 = CPyDict_GetItem(r43, r44) + r45 = CPyDict_GetItemUnsafe(r43, r44) r46 = load_address PyLong_Type r47 = PyObject_GetItem(r45, r46) r48 = __main__.globals :: static @@ -2274,10 +2274,10 @@ L2: r52 = 'Bar' r53 = __main__.globals :: static r54 = 'Foo' - r55 = CPyDict_GetItem(r53, r54) + r55 = CPyDict_GetItemUnsafe(r53, r54) r56 = __main__.globals :: static r57 = 'NewType' - r58 = CPyDict_GetItem(r56, r57) + r58 = CPyDict_GetItemUnsafe(r56, r57) r59 = [r52, r55] r60 = load_address r59 r61 = PyObject_Vectorcall(r58, r60, 2, 0) @@ -2610,14 +2610,14 @@ L0: r1.__mypyc_env__ = r0; r2 = is_error r3 = __main__.globals :: static r4 = 'b' - r5 = CPyDict_GetItem(r3, r4) + r5 = CPyDict_GetItemUnsafe(r3, r4) r6 = [r1] r7 = load_address r6 r8 = PyObject_Vectorcall(r5, r7, 1, 0) keep_alive r1 r9 = __main__.globals :: static r10 = 'a' - r11 = CPyDict_GetItem(r9, r10) + r11 = CPyDict_GetItemUnsafe(r9, r10) r12 = [r8] r13 = load_address r12 r14 = PyObject_Vectorcall(r11, r13, 1, 0) @@ -2681,17 +2681,17 @@ L2: typing = r8 :: module r9 = __main__.globals :: static r10 = 'c' - r11 = CPyDict_GetItem(r9, r10) + r11 = CPyDict_GetItemUnsafe(r9, r10) r12 = __main__.globals :: static r13 = 'b' - r14 = CPyDict_GetItem(r12, r13) + r14 = CPyDict_GetItemUnsafe(r12, r13) r15 = [r11] r16 = load_address r15 r17 = PyObject_Vectorcall(r14, r16, 1, 0) keep_alive r11 r18 = __main__.globals :: static r19 = 'a' - r20 = CPyDict_GetItem(r18, r19) + r20 = CPyDict_GetItemUnsafe(r18, r19) r21 = [r17] r22 = load_address r21 r23 = PyObject_Vectorcall(r20, r22, 1, 0) @@ -3320,7 +3320,7 @@ L2: r6 = 'dataclasses' r7 = PyImport_GetModuleDict() r8 = 'dataclasses' - r9 = CPyDict_GetItem(r7, r8) + r9 = CPyDict_GetItemUnsafe(r7, r8) r10 = CPyDict_SetItem(r0, r6, r9) r11 = r10 >= 0 :: signed r12 = __main__.globals :: static @@ -3336,7 +3336,7 @@ L4: r18 = 'enum' r19 = PyImport_GetModuleDict() r20 = 'enum' - r21 = CPyDict_GetItem(r19, r20) + r21 = CPyDict_GetItemUnsafe(r19, r20) r22 = CPyDict_SetItem(r12, r18, r21) r23 = r22 >= 0 :: signed return 1 @@ -3372,12 +3372,12 @@ L2: r6 = 'p' r7 = PyImport_GetModuleDict() r8 = 'p' - r9 = CPyDict_GetItem(r7, r8) + r9 = CPyDict_GetItemUnsafe(r7, r8) r10 = CPyDict_SetItem(r0, r6, r9) r11 = r10 >= 0 :: signed r12 = PyImport_GetModuleDict() r13 = 'p' - r14 = CPyDict_GetItem(r12, r13) + r14 = CPyDict_GetItemUnsafe(r12, r13) r15 = 'x' r16 = CPyObject_GetAttr(r14, r15) r17 = unbox(int, r16) diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index a2d3b23ccfd9..40d356e3a138 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -282,7 +282,7 @@ L2: r13 = 'T' r14 = __main__.globals :: static r15 = 'TypeVar' - r16 = CPyDict_GetItem(r14, r15) + r16 = CPyDict_GetItemUnsafe(r14, r15) r17 = [r13] r18 = load_address r17 r19 = PyObject_Vectorcall(r16, r18, 1, 0) @@ -322,10 +322,10 @@ L2: r50 = __main__.S :: type r51 = __main__.globals :: static r52 = 'Generic' - r53 = CPyDict_GetItem(r51, r52) + r53 = CPyDict_GetItemUnsafe(r51, r52) r54 = __main__.globals :: static r55 = 'T' - r56 = CPyDict_GetItem(r54, r55) + r56 = CPyDict_GetItemUnsafe(r54, r55) r57 = PyObject_GetItem(r53, r56) r58 = PyTuple_Pack(3, r49, r50, r57) r59 = '__main__' @@ -1073,7 +1073,7 @@ L0: __mypyc_self__.x = 20 r0 = __main__.globals :: static r1 = 'LOL' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) r3 = cast(str, r2) __mypyc_self__.y = r3 r4 = box(None, 1) diff --git a/mypyc/test-data/irbuild-set.test b/mypyc/test-data/irbuild-set.test index 5586a2bf4cfb..505ff28cec4a 100644 --- a/mypyc/test-data/irbuild-set.test +++ b/mypyc/test-data/irbuild-set.test @@ -661,7 +661,7 @@ def not_precomputed_non_final_name(i): L0: r0 = __main__.globals :: static r1 = 'non_const' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) r3 = unbox(int, r2) r4 = PySet_New(0) r5 = box(int, r3) @@ -780,7 +780,7 @@ def not_precomputed(): L0: r0 = __main__.globals :: static r1 = 'non_const' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) r3 = unbox(int, r2) r4 = PySet_New(0) r5 = box(int, r3) diff --git a/mypyc/test-data/irbuild-singledispatch.test b/mypyc/test-data/irbuild-singledispatch.test index 1060ee63c57d..205b1431f341 100644 --- a/mypyc/test-data/irbuild-singledispatch.test +++ b/mypyc/test-data/irbuild-singledispatch.test @@ -131,7 +131,7 @@ def f(arg): L0: r0 = __main__.globals :: static r1 = 'f' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) r3 = f_obj.__call__(r2, arg) return r3 def g(arg): @@ -262,7 +262,7 @@ def f(x): L0: r0 = __main__.globals :: static r1 = 'f' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) r3 = f_obj.__call__(r2, x) return r3 def test():