diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index f652449f5289..b490c2a52e57 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Callable, Optional +from typing import Callable, Final, Optional from mypy.nodes import ( ARG_NAMED, @@ -89,7 +89,7 @@ dict_setdefault_spec_init_op, dict_values_op, ) -from mypyc.primitives.list_ops import new_list_set_item_op +from mypyc.primitives.list_ops import isinstance_list, new_list_set_item_op from mypyc.primitives.str_ops import ( str_encode_ascii_strict, str_encode_latin1_strict, @@ -546,6 +546,9 @@ def gen_inner_stmts() -> None: return retval +isinstance_primitives: Final = {"builtins.list": isinstance_list} + + @specialize_function("builtins.isinstance") def translate_isinstance(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: """Special case for builtins.isinstance. @@ -554,11 +557,10 @@ def translate_isinstance(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> there is no need to coerce something to a new type before checking what type it is, and the coercion could lead to bugs. """ - if ( - len(expr.args) == 2 - and expr.arg_kinds == [ARG_POS, ARG_POS] - and isinstance(expr.args[1], (RefExpr, TupleExpr)) - ): + if not (len(expr.args) == 2 and expr.arg_kinds == [ARG_POS, ARG_POS]): + return None + + if isinstance(expr.args[1], (RefExpr, TupleExpr)): builder.types[expr.args[0]] = AnyType(TypeOfAny.from_error) irs = builder.flatten_classes(expr.args[1]) @@ -569,6 +571,15 @@ def translate_isinstance(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> ) obj = builder.accept(expr.args[0], can_borrow=can_borrow) return builder.builder.isinstance_helper(obj, irs, expr.line) + + if isinstance(expr.args[1], RefExpr): + node = expr.args[1].node + if node: + desc = isinstance_primitives.get(node.fullname) + if desc: + obj = builder.accept(expr.args[0]) + return builder.primitive_op(desc, [obj], expr.line) + return None diff --git a/mypyc/primitives/list_ops.py b/mypyc/primitives/list_ops.py index d0e0af9f987f..7442e31c9118 100644 --- a/mypyc/primitives/list_ops.py +++ b/mypyc/primitives/list_ops.py @@ -55,6 +55,15 @@ extra_int_constants=[(0, int_rprimitive)], ) +# isinstance(obj, list) +isinstance_list = function_op( + name="builtins.isinstance", + arg_types=[object_rprimitive], + return_type=bit_rprimitive, + c_function_name="PyList_Check", + error_kind=ERR_NEVER, +) + new_list_op = custom_op( arg_types=[c_pyssize_t_rprimitive], return_type=list_rprimitive, diff --git a/mypyc/test-data/irbuild-lists.test b/mypyc/test-data/irbuild-lists.test index 72caa5fad8d8..efd38870974d 100644 --- a/mypyc/test-data/irbuild-lists.test +++ b/mypyc/test-data/irbuild-lists.test @@ -498,29 +498,23 @@ def nested_union(a: Union[List[str], List[Optional[str]]]) -> None: [out] def narrow(a): a :: union[list, int] - r0 :: object - r1 :: i32 - r2 :: bit - r3 :: bool - r4 :: list - r5 :: native_int - r6 :: short_int - r7 :: int + r0 :: bit + r1 :: list + r2 :: native_int + r3 :: short_int + r4 :: int L0: - r0 = load_address PyList_Type - r1 = PyObject_IsInstance(a, r0) - r2 = r1 >= 0 :: signed - r3 = truncate r1: i32 to builtins.bool - if r3 goto L1 else goto L2 :: bool + r0 = PyList_Check(a) + if r0 goto L1 else goto L2 :: bool L1: - r4 = borrow cast(list, a) - r5 = var_object_size r4 - r6 = r5 << 1 + r1 = borrow cast(list, a) + r2 = var_object_size r1 + r3 = r2 << 1 keep_alive a - return r6 + return r3 L2: - r7 = unbox(int, a) - return r7 + r4 = unbox(int, a) + return r4 def loop(a): a :: list r0 :: short_int diff --git a/mypyc/test-data/run-lists.test b/mypyc/test-data/run-lists.test index 85e0926027c5..ee1bd27e6352 100644 --- a/mypyc/test-data/run-lists.test +++ b/mypyc/test-data/run-lists.test @@ -466,7 +466,7 @@ assert not list_in_mixed(object) assert list_in_mixed(type) [case testListBuiltFromGenerator] -def test() -> None: +def test_from_gen() -> None: source_a = ["a", "b", "c"] a = list(x + "f2" for x in source_a) assert a == ["af2", "bf2", "cf2"] @@ -486,12 +486,6 @@ def test() -> None: f = list("str:" + x for x in source_str) assert f == ["str:a", "str:b", "str:c", "str:d"] -[case testNextBug] -from typing import List, Optional - -def test(x: List[int]) -> None: - res = next((i for i in x), None) - [case testListGetItemWithBorrow] from typing import List @@ -537,3 +531,35 @@ def test_sorted() -> None: assert sorted((2, 1, 3)) == res assert sorted({2, 1, 3}) == res assert sorted({2: "", 1: "", 3: ""}) == res + +[case testIsInstance] +from copysubclass import subc +def test_built_in() -> None: + assert isinstance([], list) + assert isinstance([1,2,3], list) + assert isinstance(['a','b'], list) + assert isinstance(subc(), list) + assert isinstance(subc([1,2,3]), list) + assert isinstance(subc(['a','b']), list) + + assert not isinstance({}, list) + assert not isinstance((), list) + assert not isinstance((1,2,3), list) + assert not isinstance(('a','b'), list) + assert not isinstance(1, list) + assert not isinstance('a', list) + +def test_user_defined() -> None: + from userdefinedlist import list + + assert isinstance(list(), list) + assert not isinstance([list()], list) + +[file copysubclass.py] +from typing import Any +class subc(list[Any]): + pass + +[file userdefinedlist.py] +class list: + pass