diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index a600afff4bc9..2df47680d27c 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -752,35 +752,51 @@ def go(i: int, prev: Value) -> Value: def try_specialize_in_expr( builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int ) -> Value | None: + left: Value | None = None + items: list[Value] | None = None + if isinstance(rhs, (TupleExpr, ListExpr)): - items = rhs.items + left = builder.accept(lhs) + items = [builder.accept(item) for item in rhs.items] + elif isinstance(builder.node_type(rhs), RTuple): + left = builder.accept(lhs) + tuple_val = builder.accept(rhs) + assert isinstance(tuple_val.type, RTuple) + items = [builder.add(TupleGet(tuple_val, i)) for i in range(len(tuple_val.type.types))] + + if items is not None: + assert left is not None n_items = len(items) # x in y -> x == y[0] or ... or x == y[n] # x not in y -> x != y[0] and ... and x != y[n] - # 16 is arbitrarily chosen to limit code size - if 1 < n_items < 16: + if n_items > 1: if op == "in": - bin_op = "or" cmp_op = "==" else: - bin_op = "and" cmp_op = "!=" - mypy_file = builder.graph["builtins"].tree - assert mypy_file is not None - info = mypy_file.names["bool"].node - assert isinstance(info, TypeInfo), info - bool_type = Instance(info, []) - exprs = [] + out = BasicBlock() for item in items: - expr = ComparisonExpr([cmp_op], [lhs, item]) - builder.types[expr] = bool_type - exprs.append(expr) - - or_expr: Expression = exprs.pop(0) - for expr in exprs: - or_expr = OpExpr(bin_op, or_expr, expr) - builder.types[or_expr] = bool_type - return builder.accept(or_expr) + cmp = transform_basic_comparison(builder, cmp_op, left, item, line) + bool_val = builder.builder.bool_value(cmp) + next_block = BasicBlock() + if op == "in": + builder.add_bool_branch(bool_val, out, next_block) + else: + builder.add_bool_branch(bool_val, next_block, out) + builder.activate_block(next_block) + result_reg = Register(bool_rprimitive) + end = BasicBlock() + if op == "in": + values = builder.false(), builder.true() + else: + values = builder.true(), builder.false() + builder.assign(result_reg, values[0], line) + builder.goto(end) + builder.activate_block(out) + builder.assign(result_reg, values[1], line) + builder.goto(end) + builder.activate_block(end) + return result_reg # x in [y]/(y) -> x == y # x not in [y]/(y) -> x != y elif n_items == 1: @@ -788,8 +804,7 @@ def try_specialize_in_expr( cmp_op = "==" else: cmp_op = "!=" - left = builder.accept(lhs) - right = builder.accept(items[0]) + right = items[0] return transform_basic_comparison(builder, cmp_op, left, right, line) # x in []/() -> False # x not in []/() -> True diff --git a/mypyc/test-data/irbuild-tuple.test b/mypyc/test-data/irbuild-tuple.test index 712b9c26355a..00ea7f074a5d 100644 --- a/mypyc/test-data/irbuild-tuple.test +++ b/mypyc/test-data/irbuild-tuple.test @@ -184,31 +184,66 @@ def f(i: int) -> bool: [out] def f(i): i :: int - r0 :: bit - r1 :: bool - r2 :: bit + r0, r1, r2 :: bit r3 :: bool - r4 :: bit L0: r0 = int_eq i, 2 - if r0 goto L1 else goto L2 :: bool + if r0 goto L4 else goto L1 :: bool L1: - r1 = r0 - goto L3 + r1 = int_eq i, 4 + if r1 goto L4 else goto L2 :: bool L2: - r2 = int_eq i, 4 - r1 = r2 + r2 = int_eq i, 6 + if r2 goto L4 else goto L3 :: bool L3: - if r1 goto L4 else goto L5 :: bool + r3 = 0 + goto L5 L4: - r3 = r1 - goto L6 + r3 = 1 L5: - r4 = int_eq i, 6 - r3 = r4 -L6: return r3 +[case testTupleOperatorNotIn] +def x() -> int: + return 1 +def y() -> int: + return 2 +def z() -> int: + return 3 + +def f() -> bool: + return z() not in (x(), y()) +[out] +def x(): +L0: + return 2 +def y(): +L0: + return 4 +def z(): +L0: + return 6 +def f(): + r0, r1, r2 :: int + r3, r4 :: bit + r5 :: bool +L0: + r0 = z() + r1 = x() + r2 = y() + r3 = int_ne r0, r1 + if r3 goto L1 else goto L3 :: bool +L1: + r4 = int_ne r0, r2 + if r4 goto L2 else goto L3 :: bool +L2: + r5 = 1 + goto L4 +L3: + r5 = 0 +L4: + return r5 + [case testTupleOperatorInFinalTuple] from typing import Final @@ -221,9 +256,8 @@ def f(x): x :: int r0 :: tuple[int, int] r1 :: bool - r2, r3 :: object - r4 :: i32 - r5 :: bit + r2, r3 :: int + r4, r5 :: bit r6 :: bool L0: r0 = __main__.tt :: static @@ -232,11 +266,19 @@ L1: r1 = raise NameError('value for final name "tt" was not set') unreachable L2: - r2 = box(int, x) - r3 = box(tuple[int, int], r0) - r4 = PySequence_Contains(r3, r2) - r5 = r4 >= 0 :: signed - r6 = truncate r4: i32 to builtins.bool + r2 = r0[0] + r3 = r0[1] + r4 = int_eq x, r2 + if r4 goto L5 else goto L3 :: bool +L3: + r5 = int_eq x, r3 + if r5 goto L5 else goto L4 :: bool +L4: + r6 = 0 + goto L6 +L5: + r6 = 1 +L6: return r6 [case testTupleBuiltFromList] diff --git a/mypyc/test-data/run-lists.test b/mypyc/test-data/run-lists.test index 54bcc0384604..1569579c1156 100644 --- a/mypyc/test-data/run-lists.test +++ b/mypyc/test-data/run-lists.test @@ -364,7 +364,6 @@ def test_multiply() -> None: assert l1 == [1, 1, 1] [case testOperatorInExpression] - def tuple_in_int0(i: int) -> bool: return i in [] @@ -416,71 +415,68 @@ def list_not_in_str(s: "str") -> bool: def list_in_mixed(i: object): return i in [[], (), "", 0, 0.0, False, 0j, {}, set(), type] -[file driver.py] - -from native import * - -assert not tuple_in_int0(0) -assert not tuple_in_int1(0) -assert tuple_in_int1(1) -assert not tuple_in_int3(0) -assert tuple_in_int3(1) -assert tuple_in_int3(2) -assert tuple_in_int3(3) -assert not tuple_in_int3(4) - -assert tuple_not_in_int0(0) -assert tuple_not_in_int1(0) -assert not tuple_not_in_int1(1) -assert tuple_not_in_int3(0) -assert not tuple_not_in_int3(1) -assert not tuple_not_in_int3(2) -assert not tuple_not_in_int3(3) -assert tuple_not_in_int3(4) - -assert tuple_in_str("foo") -assert tuple_in_str("bar") -assert tuple_in_str("baz") -assert not tuple_in_str("apple") -assert not tuple_in_str("pie") -assert not tuple_in_str("\0") -assert not tuple_in_str("") - -assert not list_in_int0(0) -assert not list_in_int1(0) -assert list_in_int1(1) -assert not list_in_int3(0) -assert list_in_int3(1) -assert list_in_int3(2) -assert list_in_int3(3) -assert not list_in_int3(4) - -assert list_not_in_int0(0) -assert list_not_in_int1(0) -assert not list_not_in_int1(1) -assert list_not_in_int3(0) -assert not list_not_in_int3(1) -assert not list_not_in_int3(2) -assert not list_not_in_int3(3) -assert list_not_in_int3(4) - -assert list_in_str("foo") -assert list_in_str("bar") -assert list_in_str("baz") -assert not list_in_str("apple") -assert not list_in_str("pie") -assert not list_in_str("\0") -assert not list_in_str("") - -assert list_in_mixed(0) -assert list_in_mixed([]) -assert list_in_mixed({}) -assert list_in_mixed(()) -assert list_in_mixed(False) -assert list_in_mixed(0.0) -assert not list_in_mixed([1]) -assert not list_in_mixed(object) -assert list_in_mixed(type) +def test_in_operator_various_cases() -> None: + assert not tuple_in_int0(0) + assert not tuple_in_int1(0) + assert tuple_in_int1(1) + assert not tuple_in_int3(0) + assert tuple_in_int3(1) + assert tuple_in_int3(2) + assert tuple_in_int3(3) + assert not tuple_in_int3(4) + + assert tuple_not_in_int0(0) + assert tuple_not_in_int1(0) + assert not tuple_not_in_int1(1) + assert tuple_not_in_int3(0) + assert not tuple_not_in_int3(1) + assert not tuple_not_in_int3(2) + assert not tuple_not_in_int3(3) + assert tuple_not_in_int3(4) + + assert tuple_in_str("foo") + assert tuple_in_str("bar") + assert tuple_in_str("baz") + assert not tuple_in_str("apple") + assert not tuple_in_str("pie") + assert not tuple_in_str("\0") + assert not tuple_in_str("") + + assert not list_in_int0(0) + assert not list_in_int1(0) + assert list_in_int1(1) + assert not list_in_int3(0) + assert list_in_int3(1) + assert list_in_int3(2) + assert list_in_int3(3) + assert not list_in_int3(4) + + assert list_not_in_int0(0) + assert list_not_in_int1(0) + assert not list_not_in_int1(1) + assert list_not_in_int3(0) + assert not list_not_in_int3(1) + assert not list_not_in_int3(2) + assert not list_not_in_int3(3) + assert list_not_in_int3(4) + + assert list_in_str("foo") + assert list_in_str("bar") + assert list_in_str("baz") + assert not list_in_str("apple") + assert not list_in_str("pie") + assert not list_in_str("\0") + assert not list_in_str("") + + assert list_in_mixed(0) + assert list_in_mixed([]) + assert list_in_mixed({}) + assert list_in_mixed(()) + assert list_in_mixed(False) + assert list_in_mixed(0.0) + assert not list_in_mixed([1]) + assert not list_in_mixed(object) + assert list_in_mixed(type) [case testListBuiltFromGenerator] def test_from_gen() -> None: diff --git a/mypyc/test-data/run-tuples.test b/mypyc/test-data/run-tuples.test index 5d9485288cfb..f5e1733d429b 100644 --- a/mypyc/test-data/run-tuples.test +++ b/mypyc/test-data/run-tuples.test @@ -292,6 +292,89 @@ TUPLE: Final[Tuple[str, ...]] = ('x', 'y') def test_final_boxed_tuple() -> None: t = TUPLE assert t == ('x', 'y') + assert 'x' in TUPLE + assert 'y' in TUPLE + b: object = 'z' in TUPLE + assert not b + assert 'z' not in TUPLE + b2: object = 'x' not in TUPLE + assert not b2 + b3: object = 'y' not in TUPLE + assert not b3 + +TUP2: Final = ('x', 'y') +TUP1: Final = ('x',) +TUP0: Final = () + +def test_final_tuple_in() -> None: + assert 'x' + str() in TUP2 + assert 'y' + str() in TUP2 + b: object = 'z' + str() in TUP2 + assert not b + + assert 'x' + str() in TUP1 + b2: object = 'y' in TUP1 + assert not b2 + + b3: object = 'x' in TUP0 + assert not b3 + +def test_final_tuple_not_in() -> None: + assert 'z' + str() not in TUP2 + b: object = 'x' + str() not in TUP2 + assert not b + b2: object = 'y' + str() not in TUP2 + assert not b2 + + assert 'y' + str() not in TUP1 + b3: object = 'x' not in TUP1 + assert not b2 + + assert 'x' not in TUP0 + +log = [] + +def f_a() -> str: + log.append('f_a') + return 'a' + +def f_a2() -> str: + log.append('f_a2') + return 'a' + +def f_b() -> str: + log.append('f_b') + return 'b' + +def f_c() -> str: + log.append('f_c') + return 'c' + +def test_tuple_in_order_of_evaluation() -> None: + log.clear() + assert f_a() in (f_b(), f_a2()) + assert log ==["f_a", "f_b", "f_a2"] + + log.clear() + assert f_a() not in (f_b(), f_c()) + assert log ==["f_a", "f_b", "f_c"] + + log.clear() + assert f_a() in (f_b(), f_a2(), f_c()) + assert log ==["f_a", "f_b", "f_a2", "f_c"] + +def f_t() -> tuple[str, ...]: + log.append('f_t') + return ('x', 'a') + +def test_tuple_in_non_specialized() -> None: + log.clear() + assert f_a() in f_t() + assert log == ["f_a", "f_t"] + + log.clear() + assert f_b() not in f_t() + assert log == ["f_b", "f_t"] def test_add() -> None: res = (1, 2, 3, 4)