Skip to content

Commit 1f9505c

Browse files
authored
[mypyc] Speed up "in" against final fixed-length tuple (#19682)
Previously the `in` operation here boxed the tuple: ``` TUP: Final = ('x', 'y') ... if s in TUP: ... ``` Now we don't box the tuple and inline the comparisons against each tuple item instead, which is more efficient. Also make the semantics closer to Python and add tests.
1 parent abf61fb commit 1f9505c

File tree

4 files changed

+247
-111
lines changed

4 files changed

+247
-111
lines changed

mypyc/irbuild/expression.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -752,44 +752,59 @@ def go(i: int, prev: Value) -> Value:
752752
def try_specialize_in_expr(
753753
builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int
754754
) -> Value | None:
755+
left: Value | None = None
756+
items: list[Value] | None = None
757+
755758
if isinstance(rhs, (TupleExpr, ListExpr)):
756-
items = rhs.items
759+
left = builder.accept(lhs)
760+
items = [builder.accept(item) for item in rhs.items]
761+
elif isinstance(builder.node_type(rhs), RTuple):
762+
left = builder.accept(lhs)
763+
tuple_val = builder.accept(rhs)
764+
assert isinstance(tuple_val.type, RTuple)
765+
items = [builder.add(TupleGet(tuple_val, i)) for i in range(len(tuple_val.type.types))]
766+
767+
if items is not None:
768+
assert left is not None
757769
n_items = len(items)
758770
# x in y -> x == y[0] or ... or x == y[n]
759771
# x not in y -> x != y[0] and ... and x != y[n]
760-
# 16 is arbitrarily chosen to limit code size
761-
if 1 < n_items < 16:
772+
if n_items > 1:
762773
if op == "in":
763-
bin_op = "or"
764774
cmp_op = "=="
765775
else:
766-
bin_op = "and"
767776
cmp_op = "!="
768-
mypy_file = builder.graph["builtins"].tree
769-
assert mypy_file is not None
770-
info = mypy_file.names["bool"].node
771-
assert isinstance(info, TypeInfo), info
772-
bool_type = Instance(info, [])
773-
exprs = []
777+
out = BasicBlock()
774778
for item in items:
775-
expr = ComparisonExpr([cmp_op], [lhs, item])
776-
builder.types[expr] = bool_type
777-
exprs.append(expr)
778-
779-
or_expr: Expression = exprs.pop(0)
780-
for expr in exprs:
781-
or_expr = OpExpr(bin_op, or_expr, expr)
782-
builder.types[or_expr] = bool_type
783-
return builder.accept(or_expr)
779+
cmp = transform_basic_comparison(builder, cmp_op, left, item, line)
780+
bool_val = builder.builder.bool_value(cmp)
781+
next_block = BasicBlock()
782+
if op == "in":
783+
builder.add_bool_branch(bool_val, out, next_block)
784+
else:
785+
builder.add_bool_branch(bool_val, next_block, out)
786+
builder.activate_block(next_block)
787+
result_reg = Register(bool_rprimitive)
788+
end = BasicBlock()
789+
if op == "in":
790+
values = builder.false(), builder.true()
791+
else:
792+
values = builder.true(), builder.false()
793+
builder.assign(result_reg, values[0], line)
794+
builder.goto(end)
795+
builder.activate_block(out)
796+
builder.assign(result_reg, values[1], line)
797+
builder.goto(end)
798+
builder.activate_block(end)
799+
return result_reg
784800
# x in [y]/(y) -> x == y
785801
# x not in [y]/(y) -> x != y
786802
elif n_items == 1:
787803
if op == "in":
788804
cmp_op = "=="
789805
else:
790806
cmp_op = "!="
791-
left = builder.accept(lhs)
792-
right = builder.accept(items[0])
807+
right = items[0]
793808
return transform_basic_comparison(builder, cmp_op, left, right, line)
794809
# x in []/() -> False
795810
# x not in []/() -> True

mypyc/test-data/irbuild-tuple.test

Lines changed: 65 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -184,31 +184,66 @@ def f(i: int) -> bool:
184184
[out]
185185
def f(i):
186186
i :: int
187-
r0 :: bit
188-
r1 :: bool
189-
r2 :: bit
187+
r0, r1, r2 :: bit
190188
r3 :: bool
191-
r4 :: bit
192189
L0:
193190
r0 = int_eq i, 2
194-
if r0 goto L1 else goto L2 :: bool
191+
if r0 goto L4 else goto L1 :: bool
195192
L1:
196-
r1 = r0
197-
goto L3
193+
r1 = int_eq i, 4
194+
if r1 goto L4 else goto L2 :: bool
198195
L2:
199-
r2 = int_eq i, 4
200-
r1 = r2
196+
r2 = int_eq i, 6
197+
if r2 goto L4 else goto L3 :: bool
201198
L3:
202-
if r1 goto L4 else goto L5 :: bool
199+
r3 = 0
200+
goto L5
203201
L4:
204-
r3 = r1
205-
goto L6
202+
r3 = 1
206203
L5:
207-
r4 = int_eq i, 6
208-
r3 = r4
209-
L6:
210204
return r3
211205

206+
[case testTupleOperatorNotIn]
207+
def x() -> int:
208+
return 1
209+
def y() -> int:
210+
return 2
211+
def z() -> int:
212+
return 3
213+
214+
def f() -> bool:
215+
return z() not in (x(), y())
216+
[out]
217+
def x():
218+
L0:
219+
return 2
220+
def y():
221+
L0:
222+
return 4
223+
def z():
224+
L0:
225+
return 6
226+
def f():
227+
r0, r1, r2 :: int
228+
r3, r4 :: bit
229+
r5 :: bool
230+
L0:
231+
r0 = z()
232+
r1 = x()
233+
r2 = y()
234+
r3 = int_ne r0, r1
235+
if r3 goto L1 else goto L3 :: bool
236+
L1:
237+
r4 = int_ne r0, r2
238+
if r4 goto L2 else goto L3 :: bool
239+
L2:
240+
r5 = 1
241+
goto L4
242+
L3:
243+
r5 = 0
244+
L4:
245+
return r5
246+
212247
[case testTupleOperatorInFinalTuple]
213248
from typing import Final
214249

@@ -221,9 +256,8 @@ def f(x):
221256
x :: int
222257
r0 :: tuple[int, int]
223258
r1 :: bool
224-
r2, r3 :: object
225-
r4 :: i32
226-
r5 :: bit
259+
r2, r3 :: int
260+
r4, r5 :: bit
227261
r6 :: bool
228262
L0:
229263
r0 = __main__.tt :: static
@@ -232,11 +266,19 @@ L1:
232266
r1 = raise NameError('value for final name "tt" was not set')
233267
unreachable
234268
L2:
235-
r2 = box(int, x)
236-
r3 = box(tuple[int, int], r0)
237-
r4 = PySequence_Contains(r3, r2)
238-
r5 = r4 >= 0 :: signed
239-
r6 = truncate r4: i32 to builtins.bool
269+
r2 = r0[0]
270+
r3 = r0[1]
271+
r4 = int_eq x, r2
272+
if r4 goto L5 else goto L3 :: bool
273+
L3:
274+
r5 = int_eq x, r3
275+
if r5 goto L5 else goto L4 :: bool
276+
L4:
277+
r6 = 0
278+
goto L6
279+
L5:
280+
r6 = 1
281+
L6:
240282
return r6
241283

242284
[case testTupleBuiltFromList]

mypyc/test-data/run-lists.test

Lines changed: 62 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,6 @@ def test_multiply() -> None:
364364
assert l1 == [1, 1, 1]
365365

366366
[case testOperatorInExpression]
367-
368367
def tuple_in_int0(i: int) -> bool:
369368
return i in []
370369

@@ -416,71 +415,68 @@ def list_not_in_str(s: "str") -> bool:
416415
def list_in_mixed(i: object):
417416
return i in [[], (), "", 0, 0.0, False, 0j, {}, set(), type]
418417

419-
[file driver.py]
420-
421-
from native import *
422-
423-
assert not tuple_in_int0(0)
424-
assert not tuple_in_int1(0)
425-
assert tuple_in_int1(1)
426-
assert not tuple_in_int3(0)
427-
assert tuple_in_int3(1)
428-
assert tuple_in_int3(2)
429-
assert tuple_in_int3(3)
430-
assert not tuple_in_int3(4)
431-
432-
assert tuple_not_in_int0(0)
433-
assert tuple_not_in_int1(0)
434-
assert not tuple_not_in_int1(1)
435-
assert tuple_not_in_int3(0)
436-
assert not tuple_not_in_int3(1)
437-
assert not tuple_not_in_int3(2)
438-
assert not tuple_not_in_int3(3)
439-
assert tuple_not_in_int3(4)
440-
441-
assert tuple_in_str("foo")
442-
assert tuple_in_str("bar")
443-
assert tuple_in_str("baz")
444-
assert not tuple_in_str("apple")
445-
assert not tuple_in_str("pie")
446-
assert not tuple_in_str("\0")
447-
assert not tuple_in_str("")
448-
449-
assert not list_in_int0(0)
450-
assert not list_in_int1(0)
451-
assert list_in_int1(1)
452-
assert not list_in_int3(0)
453-
assert list_in_int3(1)
454-
assert list_in_int3(2)
455-
assert list_in_int3(3)
456-
assert not list_in_int3(4)
457-
458-
assert list_not_in_int0(0)
459-
assert list_not_in_int1(0)
460-
assert not list_not_in_int1(1)
461-
assert list_not_in_int3(0)
462-
assert not list_not_in_int3(1)
463-
assert not list_not_in_int3(2)
464-
assert not list_not_in_int3(3)
465-
assert list_not_in_int3(4)
466-
467-
assert list_in_str("foo")
468-
assert list_in_str("bar")
469-
assert list_in_str("baz")
470-
assert not list_in_str("apple")
471-
assert not list_in_str("pie")
472-
assert not list_in_str("\0")
473-
assert not list_in_str("")
474-
475-
assert list_in_mixed(0)
476-
assert list_in_mixed([])
477-
assert list_in_mixed({})
478-
assert list_in_mixed(())
479-
assert list_in_mixed(False)
480-
assert list_in_mixed(0.0)
481-
assert not list_in_mixed([1])
482-
assert not list_in_mixed(object)
483-
assert list_in_mixed(type)
418+
def test_in_operator_various_cases() -> None:
419+
assert not tuple_in_int0(0)
420+
assert not tuple_in_int1(0)
421+
assert tuple_in_int1(1)
422+
assert not tuple_in_int3(0)
423+
assert tuple_in_int3(1)
424+
assert tuple_in_int3(2)
425+
assert tuple_in_int3(3)
426+
assert not tuple_in_int3(4)
427+
428+
assert tuple_not_in_int0(0)
429+
assert tuple_not_in_int1(0)
430+
assert not tuple_not_in_int1(1)
431+
assert tuple_not_in_int3(0)
432+
assert not tuple_not_in_int3(1)
433+
assert not tuple_not_in_int3(2)
434+
assert not tuple_not_in_int3(3)
435+
assert tuple_not_in_int3(4)
436+
437+
assert tuple_in_str("foo")
438+
assert tuple_in_str("bar")
439+
assert tuple_in_str("baz")
440+
assert not tuple_in_str("apple")
441+
assert not tuple_in_str("pie")
442+
assert not tuple_in_str("\0")
443+
assert not tuple_in_str("")
444+
445+
assert not list_in_int0(0)
446+
assert not list_in_int1(0)
447+
assert list_in_int1(1)
448+
assert not list_in_int3(0)
449+
assert list_in_int3(1)
450+
assert list_in_int3(2)
451+
assert list_in_int3(3)
452+
assert not list_in_int3(4)
453+
454+
assert list_not_in_int0(0)
455+
assert list_not_in_int1(0)
456+
assert not list_not_in_int1(1)
457+
assert list_not_in_int3(0)
458+
assert not list_not_in_int3(1)
459+
assert not list_not_in_int3(2)
460+
assert not list_not_in_int3(3)
461+
assert list_not_in_int3(4)
462+
463+
assert list_in_str("foo")
464+
assert list_in_str("bar")
465+
assert list_in_str("baz")
466+
assert not list_in_str("apple")
467+
assert not list_in_str("pie")
468+
assert not list_in_str("\0")
469+
assert not list_in_str("")
470+
471+
assert list_in_mixed(0)
472+
assert list_in_mixed([])
473+
assert list_in_mixed({})
474+
assert list_in_mixed(())
475+
assert list_in_mixed(False)
476+
assert list_in_mixed(0.0)
477+
assert not list_in_mixed([1])
478+
assert not list_in_mixed(object)
479+
assert list_in_mixed(type)
484480

485481
[case testListBuiltFromGenerator]
486482
def test_from_gen() -> None:

0 commit comments

Comments
 (0)