Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 37 additions & 22 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,44 +752,59 @@ def go(i: int, prev: Value) -> Value:
def try_specialize_in_expr(
builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int
) -> Value | None:
left: Value | None = None
items: list[Value] | None = None

if isinstance(rhs, (TupleExpr, ListExpr)):
items = rhs.items
left = builder.accept(lhs)
items = [builder.accept(item) for item in rhs.items]
elif isinstance(builder.node_type(rhs), RTuple):
left = builder.accept(lhs)
tuple_val = builder.accept(rhs)
assert isinstance(tuple_val.type, RTuple)
items = [builder.add(TupleGet(tuple_val, i)) for i in range(len(tuple_val.type.types))]

if items is not None:
assert left is not None
n_items = len(items)
# x in y -> x == y[0] or ... or x == y[n]
# x not in y -> x != y[0] and ... and x != y[n]
# 16 is arbitrarily chosen to limit code size
if 1 < n_items < 16:
if n_items > 1:
if op == "in":
bin_op = "or"
cmp_op = "=="
else:
bin_op = "and"
cmp_op = "!="
mypy_file = builder.graph["builtins"].tree
assert mypy_file is not None
info = mypy_file.names["bool"].node
assert isinstance(info, TypeInfo), info
bool_type = Instance(info, [])
exprs = []
out = BasicBlock()
for item in items:
expr = ComparisonExpr([cmp_op], [lhs, item])
builder.types[expr] = bool_type
exprs.append(expr)

or_expr: Expression = exprs.pop(0)
for expr in exprs:
or_expr = OpExpr(bin_op, or_expr, expr)
builder.types[or_expr] = bool_type
return builder.accept(or_expr)
cmp = transform_basic_comparison(builder, cmp_op, left, item, line)
bool_val = builder.builder.bool_value(cmp)
next_block = BasicBlock()
if op == "in":
builder.add_bool_branch(bool_val, out, next_block)
else:
builder.add_bool_branch(bool_val, next_block, out)
builder.activate_block(next_block)
result_reg = Register(bool_rprimitive)
end = BasicBlock()
if op == "in":
values = builder.false(), builder.true()
else:
values = builder.true(), builder.false()
builder.assign(result_reg, values[0], line)
builder.goto(end)
builder.activate_block(out)
builder.assign(result_reg, values[1], line)
builder.goto(end)
builder.activate_block(end)
return result_reg
# x in [y]/(y) -> x == y
# x not in [y]/(y) -> x != y
elif n_items == 1:
if op == "in":
cmp_op = "=="
else:
cmp_op = "!="
left = builder.accept(lhs)
right = builder.accept(items[0])
right = items[0]
return transform_basic_comparison(builder, cmp_op, left, right, line)
# x in []/() -> False
# x not in []/() -> True
Expand Down
88 changes: 65 additions & 23 deletions mypyc/test-data/irbuild-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -184,31 +184,66 @@ def f(i: int) -> bool:
[out]
def f(i):
i :: int
r0 :: bit
r1 :: bool
r2 :: bit
r0, r1, r2 :: bit
r3 :: bool
r4 :: bit
L0:
r0 = int_eq i, 2
if r0 goto L1 else goto L2 :: bool
if r0 goto L4 else goto L1 :: bool
L1:
r1 = r0
goto L3
r1 = int_eq i, 4
if r1 goto L4 else goto L2 :: bool
L2:
r2 = int_eq i, 4
r1 = r2
r2 = int_eq i, 6
if r2 goto L4 else goto L3 :: bool
L3:
if r1 goto L4 else goto L5 :: bool
r3 = 0
goto L5
L4:
r3 = r1
goto L6
r3 = 1
L5:
r4 = int_eq i, 6
r3 = r4
L6:
return r3

[case testTupleOperatorNotIn]
def x() -> int:
return 1
def y() -> int:
return 2
def z() -> int:
return 3

def f() -> bool:
return z() not in (x(), y())
[out]
def x():
L0:
return 2
def y():
L0:
return 4
def z():
L0:
return 6
def f():
r0, r1, r2 :: int
r3, r4 :: bit
r5 :: bool
L0:
r0 = z()
r1 = x()
r2 = y()
r3 = int_ne r0, r1
if r3 goto L1 else goto L3 :: bool
L1:
r4 = int_ne r0, r2
if r4 goto L2 else goto L3 :: bool
L2:
r5 = 1
goto L4
L3:
r5 = 0
L4:
return r5

[case testTupleOperatorInFinalTuple]
from typing import Final

Expand All @@ -221,9 +256,8 @@ def f(x):
x :: int
r0 :: tuple[int, int]
r1 :: bool
r2, r3 :: object
r4 :: i32
r5 :: bit
r2, r3 :: int
r4, r5 :: bit
r6 :: bool
L0:
r0 = __main__.tt :: static
Expand All @@ -232,11 +266,19 @@ L1:
r1 = raise NameError('value for final name "tt" was not set')
unreachable
L2:
r2 = box(int, x)
r3 = box(tuple[int, int], r0)
r4 = PySequence_Contains(r3, r2)
r5 = r4 >= 0 :: signed
r6 = truncate r4: i32 to builtins.bool
r2 = r0[0]
r3 = r0[1]
r4 = int_eq x, r2
if r4 goto L5 else goto L3 :: bool
L3:
r5 = int_eq x, r3
if r5 goto L5 else goto L4 :: bool
L4:
r6 = 0
goto L6
L5:
r6 = 1
L6:
return r6

[case testTupleBuiltFromList]
Expand Down
128 changes: 62 additions & 66 deletions mypyc/test-data/run-lists.test
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,6 @@ def test_multiply() -> None:
assert l1 == [1, 1, 1]

[case testOperatorInExpression]

def tuple_in_int0(i: int) -> bool:
return i in []

Expand Down Expand Up @@ -416,71 +415,68 @@ def list_not_in_str(s: "str") -> bool:
def list_in_mixed(i: object):
return i in [[], (), "", 0, 0.0, False, 0j, {}, set(), type]

[file driver.py]

from native import *

assert not tuple_in_int0(0)
assert not tuple_in_int1(0)
assert tuple_in_int1(1)
assert not tuple_in_int3(0)
assert tuple_in_int3(1)
assert tuple_in_int3(2)
assert tuple_in_int3(3)
assert not tuple_in_int3(4)

assert tuple_not_in_int0(0)
assert tuple_not_in_int1(0)
assert not tuple_not_in_int1(1)
assert tuple_not_in_int3(0)
assert not tuple_not_in_int3(1)
assert not tuple_not_in_int3(2)
assert not tuple_not_in_int3(3)
assert tuple_not_in_int3(4)

assert tuple_in_str("foo")
assert tuple_in_str("bar")
assert tuple_in_str("baz")
assert not tuple_in_str("apple")
assert not tuple_in_str("pie")
assert not tuple_in_str("\0")
assert not tuple_in_str("")

assert not list_in_int0(0)
assert not list_in_int1(0)
assert list_in_int1(1)
assert not list_in_int3(0)
assert list_in_int3(1)
assert list_in_int3(2)
assert list_in_int3(3)
assert not list_in_int3(4)

assert list_not_in_int0(0)
assert list_not_in_int1(0)
assert not list_not_in_int1(1)
assert list_not_in_int3(0)
assert not list_not_in_int3(1)
assert not list_not_in_int3(2)
assert not list_not_in_int3(3)
assert list_not_in_int3(4)

assert list_in_str("foo")
assert list_in_str("bar")
assert list_in_str("baz")
assert not list_in_str("apple")
assert not list_in_str("pie")
assert not list_in_str("\0")
assert not list_in_str("")

assert list_in_mixed(0)
assert list_in_mixed([])
assert list_in_mixed({})
assert list_in_mixed(())
assert list_in_mixed(False)
assert list_in_mixed(0.0)
assert not list_in_mixed([1])
assert not list_in_mixed(object)
assert list_in_mixed(type)
def test_in_operator_various_cases() -> None:
assert not tuple_in_int0(0)
assert not tuple_in_int1(0)
assert tuple_in_int1(1)
assert not tuple_in_int3(0)
assert tuple_in_int3(1)
assert tuple_in_int3(2)
assert tuple_in_int3(3)
assert not tuple_in_int3(4)

assert tuple_not_in_int0(0)
assert tuple_not_in_int1(0)
assert not tuple_not_in_int1(1)
assert tuple_not_in_int3(0)
assert not tuple_not_in_int3(1)
assert not tuple_not_in_int3(2)
assert not tuple_not_in_int3(3)
assert tuple_not_in_int3(4)

assert tuple_in_str("foo")
assert tuple_in_str("bar")
assert tuple_in_str("baz")
assert not tuple_in_str("apple")
assert not tuple_in_str("pie")
assert not tuple_in_str("\0")
assert not tuple_in_str("")

assert not list_in_int0(0)
assert not list_in_int1(0)
assert list_in_int1(1)
assert not list_in_int3(0)
assert list_in_int3(1)
assert list_in_int3(2)
assert list_in_int3(3)
assert not list_in_int3(4)

assert list_not_in_int0(0)
assert list_not_in_int1(0)
assert not list_not_in_int1(1)
assert list_not_in_int3(0)
assert not list_not_in_int3(1)
assert not list_not_in_int3(2)
assert not list_not_in_int3(3)
assert list_not_in_int3(4)

assert list_in_str("foo")
assert list_in_str("bar")
assert list_in_str("baz")
assert not list_in_str("apple")
assert not list_in_str("pie")
assert not list_in_str("\0")
assert not list_in_str("")

assert list_in_mixed(0)
assert list_in_mixed([])
assert list_in_mixed({})
assert list_in_mixed(())
assert list_in_mixed(False)
assert list_in_mixed(0.0)
assert not list_in_mixed([1])
assert not list_in_mixed(object)
assert list_in_mixed(type)

[case testListBuiltFromGenerator]
def test_from_gen() -> None:
Expand Down
Loading
Loading