diff --git a/mypyc/analysis/ircheck.py b/mypyc/analysis/ircheck.py index 6980c9cee419..0af0c4b0ec28 100644 --- a/mypyc/analysis/ircheck.py +++ b/mypyc/analysis/ircheck.py @@ -197,7 +197,10 @@ def can_coerce_to(src: RType, dest: RType) -> bool: if isinstance(src, RPrimitive): # If either src or dest is a disjoint type, then they must both be. if src.name in disjoint_types and dest.name in disjoint_types: - return src.name == dest.name + return src.name == dest.name or ( + src.name in ("builtins.dict", "builtins.dict[exact]") + and dest.name in ("builtins.dict", "builtins.dict[exact]") + ) return src.size == dest.size if isinstance(src, RInstance): return is_object_rprimitive(dest) diff --git a/mypyc/ir/rtypes.py b/mypyc/ir/rtypes.py index 34824a59cd5c..7cb3d59e2ea9 100644 --- a/mypyc/ir/rtypes.py +++ b/mypyc/ir/rtypes.py @@ -487,7 +487,12 @@ def __hash__(self) -> int: "builtins.list", is_unboxed=False, is_refcounted=True, may_be_immortal=False ) -# Python dict object (or an instance of a subclass of dict). +# Python dict object. +exact_dict_rprimitive: Final = RPrimitive( + "builtins.dict[exact]", is_unboxed=False, is_refcounted=True +) + +# An instance of a subclass of dict. dict_rprimitive: Final = RPrimitive("builtins.dict", is_unboxed=False, is_refcounted=True) # Python set object (or an instance of a subclass of set). @@ -608,7 +613,14 @@ def is_list_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]: def is_dict_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]: - return isinstance(rtype, RPrimitive) and rtype.name == "builtins.dict" + return isinstance(rtype, RPrimitive) and rtype.name in ( + "builtins.dict", + "builtins.dict[exact]", + ) + + +def is_exact_dict_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]: + return isinstance(rtype, RPrimitive) and rtype.name == "builtins.dict[exact]" def is_set_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]: diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index 8e6e450c64dc..63a48b6e2906 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -13,6 +13,7 @@ c_pyssize_t_rprimitive, cstring_rprimitive, dict_rprimitive, + exact_dict_rprimitive, float_rprimitive, int_rprimitive, none_rprimitive, @@ -161,7 +162,7 @@ # Get the sys.modules dictionary get_module_dict_op = custom_op( arg_types=[], - return_type=dict_rprimitive, + return_type=exact_dict_rprimitive, c_function_name="PyImport_GetModuleDict", error_kind=ERR_NEVER, is_borrowed=True, diff --git a/mypyc/rt_subtype.py b/mypyc/rt_subtype.py index 004e56ed75bc..01619158a954 100644 --- a/mypyc/rt_subtype.py +++ b/mypyc/rt_subtype.py @@ -27,6 +27,8 @@ RVoid, is_bit_rprimitive, is_bool_rprimitive, + is_dict_rprimitive, + is_exact_dict_rprimitive, is_int_rprimitive, is_short_int_rprimitive, ) @@ -58,6 +60,8 @@ def visit_rprimitive(self, left: RPrimitive) -> bool: return True if is_bit_rprimitive(left) and is_bool_rprimitive(self.right): return True + if is_exact_dict_rprimitive(left) and is_dict_rprimitive(self.right): + return True return left is self.right def visit_rtuple(self, left: RTuple) -> bool: diff --git a/mypyc/subtype.py b/mypyc/subtype.py index 726a48d7a01d..6feb4b83b5cf 100644 --- a/mypyc/subtype.py +++ b/mypyc/subtype.py @@ -14,6 +14,8 @@ RVoid, is_bit_rprimitive, is_bool_rprimitive, + is_dict_rprimitive, + is_exact_dict_rprimitive, is_fixed_width_rtype, is_int_rprimitive, is_object_rprimitive, @@ -67,6 +69,9 @@ def visit_rprimitive(self, left: RPrimitive) -> bool: elif is_fixed_width_rtype(left): if is_int_rprimitive(right): return True + elif is_exact_dict_rprimitive(left): + if is_dict_rprimitive(right): + return True return left is right def visit_rtuple(self, left: RTuple) -> bool: diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index 612f3266fd79..340d1c230031 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -3290,7 +3290,7 @@ def root(): r4 :: str r5 :: object r6 :: str - r7 :: dict + r7 :: dict[exact] r8 :: str r9 :: object r10 :: i32 @@ -3301,7 +3301,7 @@ def root(): r16 :: str r17 :: object r18 :: str - r19 :: dict + r19 :: dict[exact] r20 :: str r21 :: object r22 :: i32 @@ -3347,12 +3347,12 @@ def submodule(): r4 :: str r5 :: object r6 :: str - r7 :: dict + r7 :: dict[exact] r8 :: str r9 :: object r10 :: i32 r11 :: bit - r12 :: dict + r12 :: dict[exact] r13 :: str r14 :: object r15 :: str