Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mypyc/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def function_annotations(func_ir: FuncIR, tree: MypyFile) -> dict[int, list[Anno
ann = "Dynamic method call."
elif name in op_hints:
ann = op_hints[name]
elif name in ("CPyDict_GetItem", "CPyDict_SetItem"):
elif name in ("CPyDict_GetItemUnsafe", "CPyDict_SetItem"):
if (
isinstance(op.args[0], LoadStatic)
and isinstance(op.args[1], LoadLiteral)
Expand Down
6 changes: 3 additions & 3 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
)
from mypyc.irbuild.util import bytes_from_str, is_constant
from mypyc.options import CompilerOptions
from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op
from mypyc.primitives.dict_ops import dict_set_item_op, exact_dict_get_item_op
from mypyc.primitives.generic_ops import iter_op, next_op, py_setattr_op
from mypyc.primitives.list_ops import list_get_item_unsafe_op, list_pop_last, to_list
from mypyc.primitives.misc_ops import check_unpack_count_op, get_module_dict_op, import_op
Expand Down Expand Up @@ -472,7 +472,7 @@ def get_module(self, module: str, line: int) -> Value:
# Python 3.7 has a nice 'PyImport_GetModule' function that we can't use :(
mod_dict = self.call_c(get_module_dict_op, [], line)
# Get module object from modules dict.
return self.primitive_op(dict_get_item_op, [mod_dict, self.load_str(module)], line)
return self.primitive_op(exact_dict_get_item_op, [mod_dict, self.load_str(module)], line)

def get_module_attr(self, module: str, attr: str, line: int) -> Value:
"""Look up an attribute of a module without storing it in the local namespace.
Expand Down Expand Up @@ -1406,7 +1406,7 @@ def load_global(self, expr: NameExpr) -> Value:
def load_global_str(self, name: str, line: int) -> Value:
_globals = self.load_globals_dict()
reg = self.load_str(name)
return self.primitive_op(dict_get_item_op, [_globals, reg], line)
return self.primitive_op(exact_dict_get_item_op, [_globals, reg], line)

def load_globals_dict(self) -> Value:
return self.add(LoadStatic(dict_rprimitive, "globals", self.module_name))
Expand Down
4 changes: 2 additions & 2 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
translate_object_setattr,
)
from mypyc.primitives.bytes_ops import bytes_slice_op
from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, exact_dict_set_item_op
from mypyc.primitives.dict_ops import dict_new_op, exact_dict_get_item_op, exact_dict_set_item_op
from mypyc.primitives.generic_ops import iter_op, name_op
from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op
from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op
Expand Down Expand Up @@ -190,7 +190,7 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value:
# instead load the module separately on each access.
mod_dict = builder.call_c(get_module_dict_op, [], expr.line)
obj = builder.primitive_op(
dict_get_item_op, [mod_dict, builder.load_str(expr.node.fullname)], expr.line
exact_dict_get_item_op, [mod_dict, builder.load_str(expr.node.fullname)], expr.line
)
return obj
else:
Expand Down
3 changes: 3 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,9 @@ tuple_T3CIO CPyDict_NextValue(PyObject *dict_or_iter, CPyTagged offset);
tuple_T4CIOO CPyDict_NextItem(PyObject *dict_or_iter, CPyTagged offset);
int CPyMapping_Check(PyObject *obj);

// Unsafe dict operations (assume PyDict_CheckExact(dict) is always true)
PyObject *CPyDict_GetItemUnsafe(PyObject *dict, PyObject *key);

// Check that dictionary didn't change size during iteration.
static inline char CPyDict_CheckSize(PyObject *dict, Py_ssize_t size) {
if (!PyDict_CheckExact(dict)) {
Expand Down
29 changes: 20 additions & 9 deletions mypyc/lib-rt/dict_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,7 @@
// some indirections.
PyObject *CPyDict_GetItem(PyObject *dict, PyObject *key) {
if (PyDict_CheckExact(dict)) {
PyObject *res = PyDict_GetItemWithError(dict, key);
if (!res) {
if (!PyErr_Occurred()) {
PyErr_SetObject(PyExc_KeyError, key);
}
} else {
Py_INCREF(res);
}
return res;
return CPyDict_GetItemUnsafe(dict, key);
} else {
return PyObject_GetItem(dict, key);
}
Expand Down Expand Up @@ -489,3 +481,22 @@ tuple_T4CIOO CPyDict_NextItem(PyObject *dict_or_iter, CPyTagged offset) {
int CPyMapping_Check(PyObject *obj) {
return Py_TYPE(obj)->tp_flags & Py_TPFLAGS_MAPPING;
}

// =======================
// Unsafe dict operations
// =======================

// Unsafe: assumes dict is a true dict (PyDict_CheckExact(dict) is always true)

PyObject *CPyDict_GetItemUnsafe(PyObject *dict, PyObject *key) {
// No type check, direct call
PyObject *res = PyDict_GetItemWithError(dict, key);
if (!res) {
if (!PyErr_Occurred()) {
PyErr_SetObject(PyExc_KeyError, key);
}
} else {
Py_INCREF(res);
}
return res;
}
13 changes: 12 additions & 1 deletion mypyc/primitives/dict_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ERR_NEG_INT,
binary_op,
custom_op,
custom_primitive_op,
function_op,
load_address_op,
method_op,
Expand Down Expand Up @@ -80,7 +81,7 @@
error_kind=ERR_NEVER,
)

# dict[key]
# dict[key] = value
dict_get_item_op = method_op(
name="__getitem__",
arg_types=[dict_rprimitive, object_rprimitive],
Expand All @@ -89,6 +90,16 @@
error_kind=ERR_MAGIC,
)

# dict[key] = value (exact dict only, no subclasses)
# NOTE: this is currently for internal use only, and not used for CallExpr specialization
exact_dict_get_item_op = custom_primitive_op(
name="__getitem__",
arg_types=[dict_rprimitive, object_rprimitive],
return_type=object_rprimitive,
c_function_name="CPyDict_GetItemUnsafe",
error_kind=ERR_MAGIC,
)

# dict[key] = value
dict_set_item_op = method_op(
name="__setitem__",
Expand Down
50 changes: 25 additions & 25 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def f(x):
L0:
r0 = __main__.globals :: static
r1 = 'g'
r2 = CPyDict_GetItem(r0, r1)
r2 = CPyDict_GetItemUnsafe(r0, r1)
r3 = box(int, x)
r4 = [r3]
r5 = load_address r4
Expand All @@ -631,13 +631,13 @@ L0:
r7 = unbox(int, r6)
r8 = __main__.globals :: static
r9 = 'h'
r10 = CPyDict_GetItem(r8, r9)
r10 = CPyDict_GetItemUnsafe(r8, r9)
r11 = PyObject_Vectorcall(r10, 0, 0, 0)
r12 = unbox(int, r11)
r13 = CPyTagged_Add(r7, r12)
r14 = __main__.globals :: static
r15 = 'two'
r16 = CPyDict_GetItem(r14, r15)
r16 = CPyDict_GetItemUnsafe(r14, r15)
r17 = PyObject_Vectorcall(r16, 0, 0, 0)
r18 = unbox(int, r17)
r19 = CPyTagged_Add(r13, r18)
Expand Down Expand Up @@ -1147,7 +1147,7 @@ def call_python_function(x):
L0:
r0 = __main__.globals :: static
r1 = 'f'
r2 = CPyDict_GetItem(r0, r1)
r2 = CPyDict_GetItemUnsafe(r0, r1)
r3 = box(int, x)
r4 = [r3]
r5 = load_address r4
Expand All @@ -1165,7 +1165,7 @@ def return_callable_type():
L0:
r0 = __main__.globals :: static
r1 = 'return_float'
r2 = CPyDict_GetItem(r0, r1)
r2 = CPyDict_GetItemUnsafe(r0, r1)
return r2
def call_callable_type():
r0, f, r1 :: object
Expand Down Expand Up @@ -1436,7 +1436,7 @@ def f():
L0:
r0 = __main__.globals :: static
r1 = 'x'
r2 = CPyDict_GetItem(r0, r1)
r2 = CPyDict_GetItemUnsafe(r0, r1)
r3 = unbox(int, r2)
r4 = builtins :: module
r5 = 'print'
Expand Down Expand Up @@ -1484,7 +1484,7 @@ L2:
r9 = r8 >= 0 :: signed
r10 = __main__.globals :: static
r11 = 'x'
r12 = CPyDict_GetItem(r10, r11)
r12 = CPyDict_GetItemUnsafe(r10, r11)
r13 = unbox(int, r12)
r14 = builtins :: module
r15 = 'print'
Expand Down Expand Up @@ -1680,7 +1680,7 @@ L0:
r0 = (2, 4, 6)
r1 = __main__.globals :: static
r2 = 'f'
r3 = CPyDict_GetItem(r1, r2)
r3 = CPyDict_GetItemUnsafe(r1, r2)
r4 = box(tuple[int, int, int], r0)
r5 = PyObject_CallObject(r3, r4)
r6 = unbox(tuple[int, int, int], r5)
Expand All @@ -1701,7 +1701,7 @@ L0:
r0 = (4, 6)
r1 = __main__.globals :: static
r2 = 'f'
r3 = CPyDict_GetItem(r1, r2)
r3 = CPyDict_GetItemUnsafe(r1, r2)
r4 = PyList_New(1)
r5 = object 1
r6 = list_items r4
Expand Down Expand Up @@ -1749,7 +1749,7 @@ L0:
r6 = CPyDict_Build(3, r0, r3, r1, r4, r2, r5)
r7 = __main__.globals :: static
r8 = 'f'
r9 = CPyDict_GetItem(r7, r8)
r9 = CPyDict_GetItemUnsafe(r7, r8)
r10 = CPyTuple_LoadEmptyTupleConstant()
r11 = PyDict_Copy(r6)
r12 = PyObject_Call(r9, r10, r11)
Expand All @@ -1776,7 +1776,7 @@ L0:
r4 = CPyDict_Build(2, r0, r2, r1, r3)
r5 = __main__.globals :: static
r6 = 'f'
r7 = CPyDict_GetItem(r5, r6)
r7 = CPyDict_GetItemUnsafe(r5, r6)
r8 = PyDict_New()
r9 = CPyDict_UpdateInDisplay(r8, r4)
r10 = r9 >= 0 :: signed
Expand Down Expand Up @@ -2239,7 +2239,7 @@ L2:
r19 = box(tuple[object, object], r18)
r20 = __main__.globals :: static
r21 = 'NamedTuple'
r22 = CPyDict_GetItem(r20, r21)
r22 = CPyDict_GetItemUnsafe(r20, r21)
r23 = [r9, r19]
r24 = load_address r23
r25 = PyObject_Vectorcall(r22, r24, 2, 0)
Expand All @@ -2251,7 +2251,7 @@ L2:
r30 = ''
r31 = __main__.globals :: static
r32 = 'Lol'
r33 = CPyDict_GetItem(r31, r32)
r33 = CPyDict_GetItemUnsafe(r31, r32)
r34 = object 1
r35 = [r34, r30]
r36 = load_address r35
Expand All @@ -2264,7 +2264,7 @@ L2:
r42 = r41 >= 0 :: signed
r43 = __main__.globals :: static
r44 = 'List'
r45 = CPyDict_GetItem(r43, r44)
r45 = CPyDict_GetItemUnsafe(r43, r44)
r46 = load_address PyLong_Type
r47 = PyObject_GetItem(r45, r46)
r48 = __main__.globals :: static
Expand All @@ -2274,10 +2274,10 @@ L2:
r52 = 'Bar'
r53 = __main__.globals :: static
r54 = 'Foo'
r55 = CPyDict_GetItem(r53, r54)
r55 = CPyDict_GetItemUnsafe(r53, r54)
r56 = __main__.globals :: static
r57 = 'NewType'
r58 = CPyDict_GetItem(r56, r57)
r58 = CPyDict_GetItemUnsafe(r56, r57)
r59 = [r52, r55]
r60 = load_address r59
r61 = PyObject_Vectorcall(r58, r60, 2, 0)
Expand Down Expand Up @@ -2610,14 +2610,14 @@ L0:
r1.__mypyc_env__ = r0; r2 = is_error
r3 = __main__.globals :: static
r4 = 'b'
r5 = CPyDict_GetItem(r3, r4)
r5 = CPyDict_GetItemUnsafe(r3, r4)
r6 = [r1]
r7 = load_address r6
r8 = PyObject_Vectorcall(r5, r7, 1, 0)
keep_alive r1
r9 = __main__.globals :: static
r10 = 'a'
r11 = CPyDict_GetItem(r9, r10)
r11 = CPyDict_GetItemUnsafe(r9, r10)
r12 = [r8]
r13 = load_address r12
r14 = PyObject_Vectorcall(r11, r13, 1, 0)
Expand Down Expand Up @@ -2681,17 +2681,17 @@ L2:
typing = r8 :: module
r9 = __main__.globals :: static
r10 = 'c'
r11 = CPyDict_GetItem(r9, r10)
r11 = CPyDict_GetItemUnsafe(r9, r10)
r12 = __main__.globals :: static
r13 = 'b'
r14 = CPyDict_GetItem(r12, r13)
r14 = CPyDict_GetItemUnsafe(r12, r13)
r15 = [r11]
r16 = load_address r15
r17 = PyObject_Vectorcall(r14, r16, 1, 0)
keep_alive r11
r18 = __main__.globals :: static
r19 = 'a'
r20 = CPyDict_GetItem(r18, r19)
r20 = CPyDict_GetItemUnsafe(r18, r19)
r21 = [r17]
r22 = load_address r21
r23 = PyObject_Vectorcall(r20, r22, 1, 0)
Expand Down Expand Up @@ -3320,7 +3320,7 @@ L2:
r6 = 'dataclasses'
r7 = PyImport_GetModuleDict()
r8 = 'dataclasses'
r9 = CPyDict_GetItem(r7, r8)
r9 = CPyDict_GetItemUnsafe(r7, r8)
r10 = CPyDict_SetItem(r0, r6, r9)
r11 = r10 >= 0 :: signed
r12 = __main__.globals :: static
Expand All @@ -3336,7 +3336,7 @@ L4:
r18 = 'enum'
r19 = PyImport_GetModuleDict()
r20 = 'enum'
r21 = CPyDict_GetItem(r19, r20)
r21 = CPyDict_GetItemUnsafe(r19, r20)
r22 = CPyDict_SetItem(r12, r18, r21)
r23 = r22 >= 0 :: signed
return 1
Expand Down Expand Up @@ -3372,12 +3372,12 @@ L2:
r6 = 'p'
r7 = PyImport_GetModuleDict()
r8 = 'p'
r9 = CPyDict_GetItem(r7, r8)
r9 = CPyDict_GetItemUnsafe(r7, r8)
r10 = CPyDict_SetItem(r0, r6, r9)
r11 = r10 >= 0 :: signed
r12 = PyImport_GetModuleDict()
r13 = 'p'
r14 = CPyDict_GetItem(r12, r13)
r14 = CPyDict_GetItemUnsafe(r12, r13)
r15 = 'x'
r16 = CPyObject_GetAttr(r14, r15)
r17 = unbox(int, r16)
Expand Down
8 changes: 4 additions & 4 deletions mypyc/test-data/irbuild-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ L2:
r13 = 'T'
r14 = __main__.globals :: static
r15 = 'TypeVar'
r16 = CPyDict_GetItem(r14, r15)
r16 = CPyDict_GetItemUnsafe(r14, r15)
r17 = [r13]
r18 = load_address r17
r19 = PyObject_Vectorcall(r16, r18, 1, 0)
Expand Down Expand Up @@ -322,10 +322,10 @@ L2:
r50 = __main__.S :: type
r51 = __main__.globals :: static
r52 = 'Generic'
r53 = CPyDict_GetItem(r51, r52)
r53 = CPyDict_GetItemUnsafe(r51, r52)
r54 = __main__.globals :: static
r55 = 'T'
r56 = CPyDict_GetItem(r54, r55)
r56 = CPyDict_GetItemUnsafe(r54, r55)
r57 = PyObject_GetItem(r53, r56)
r58 = PyTuple_Pack(3, r49, r50, r57)
r59 = '__main__'
Expand Down Expand Up @@ -1073,7 +1073,7 @@ L0:
__mypyc_self__.x = 20
r0 = __main__.globals :: static
r1 = 'LOL'
r2 = CPyDict_GetItem(r0, r1)
r2 = CPyDict_GetItemUnsafe(r0, r1)
r3 = cast(str, r2)
__mypyc_self__.y = r3
r4 = box(None, 1)
Expand Down
4 changes: 2 additions & 2 deletions mypyc/test-data/irbuild-set.test
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def not_precomputed_non_final_name(i):
L0:
r0 = __main__.globals :: static
r1 = 'non_const'
r2 = CPyDict_GetItem(r0, r1)
r2 = CPyDict_GetItemUnsafe(r0, r1)
r3 = unbox(int, r2)
r4 = PySet_New(0)
r5 = box(int, r3)
Expand Down Expand Up @@ -780,7 +780,7 @@ def not_precomputed():
L0:
r0 = __main__.globals :: static
r1 = 'non_const'
r2 = CPyDict_GetItem(r0, r1)
r2 = CPyDict_GetItemUnsafe(r0, r1)
r3 = unbox(int, r2)
r4 = PySet_New(0)
r5 = box(int, r3)
Expand Down
Loading