Skip to content

Commit 5902da5

Browse files
committed
[mypyc] Speed up "in" against tuple
Also make the semantics closer to Python.
1 parent abf61fb commit 5902da5

File tree

2 files changed

+125
-37
lines changed

2 files changed

+125
-37
lines changed

mypyc/irbuild/expression.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -752,44 +752,60 @@ def go(i: int, prev: Value) -> Value:
752752
def try_specialize_in_expr(
753753
builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int
754754
) -> Value | None:
755-
if isinstance(rhs, (TupleExpr, ListExpr)):
756-
items = rhs.items
755+
left: Value | None = None
756+
items: list[Value] | None = None
757+
758+
# 16 is arbitrarily chosen to limit code size
759+
if isinstance(rhs, (TupleExpr, ListExpr)) and len(rhs.items) < 16:
760+
left = builder.accept(lhs)
761+
items = [builder.accept(item) for item in rhs.items]
762+
elif isinstance(builder.node_type(rhs), RTuple):
763+
left = builder.accept(lhs)
764+
tuple_val = builder.accept(rhs)
765+
assert isinstance(tuple_val.type, RTuple)
766+
items = [builder.add(TupleGet(tuple_val, i)) for i in range(len(tuple_val.type.types))]
767+
768+
if items is not None:
769+
assert left is not None
757770
n_items = len(items)
758771
# x in y -> x == y[0] or ... or x == y[n]
759772
# x not in y -> x != y[0] and ... and x != y[n]
760-
# 16 is arbitrarily chosen to limit code size
761-
if 1 < n_items < 16:
773+
if n_items > 1:
762774
if op == "in":
763-
bin_op = "or"
764775
cmp_op = "=="
765776
else:
766-
bin_op = "and"
767777
cmp_op = "!="
768-
mypy_file = builder.graph["builtins"].tree
769-
assert mypy_file is not None
770-
info = mypy_file.names["bool"].node
771-
assert isinstance(info, TypeInfo), info
772-
bool_type = Instance(info, [])
773-
exprs = []
778+
out = BasicBlock()
774779
for item in items:
775-
expr = ComparisonExpr([cmp_op], [lhs, item])
776-
builder.types[expr] = bool_type
777-
exprs.append(expr)
778-
779-
or_expr: Expression = exprs.pop(0)
780-
for expr in exprs:
781-
or_expr = OpExpr(bin_op, or_expr, expr)
782-
builder.types[or_expr] = bool_type
783-
return builder.accept(or_expr)
780+
x = transform_basic_comparison(builder, cmp_op, left, item, line)
781+
b = builder.builder.bool_value(x)
782+
nxt = BasicBlock()
783+
if op == "in":
784+
builder.add_bool_branch(b, out, nxt)
785+
else:
786+
builder.add_bool_branch(b, nxt, out)
787+
builder.activate_block(nxt)
788+
r = Register(bool_rprimitive)
789+
end = BasicBlock()
790+
if op == "in":
791+
values = builder.false(), builder.true()
792+
else:
793+
values = builder.true(), builder.false()
794+
builder.assign(r, values[0], line)
795+
builder.goto(end)
796+
builder.activate_block(out)
797+
builder.assign(r, values[1], line)
798+
builder.goto(end)
799+
builder.activate_block(end)
800+
return r
784801
# x in [y]/(y) -> x == y
785802
# x not in [y]/(y) -> x != y
786803
elif n_items == 1:
787804
if op == "in":
788805
cmp_op = "=="
789806
else:
790807
cmp_op = "!="
791-
left = builder.accept(lhs)
792-
right = builder.accept(items[0])
808+
right = items[0]
793809
return transform_basic_comparison(builder, cmp_op, left, right, line)
794810
# x in []/() -> False
795811
# x not in []/() -> True

mypyc/test-data/irbuild-tuple.test

Lines changed: 86 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -184,30 +184,102 @@ def f(i: int) -> bool:
184184
[out]
185185
def f(i):
186186
i :: int
187-
r0 :: bit
188-
r1 :: bool
189-
r2 :: bit
187+
r0, r1, r2 :: bit
190188
r3 :: bool
191-
r4 :: bit
192189
L0:
193190
r0 = int_eq i, 2
194-
if r0 goto L1 else goto L2 :: bool
191+
if r0 goto L4 else goto L1 :: bool
195192
L1:
196-
r1 = r0
197-
goto L3
193+
r1 = int_eq i, 4
194+
if r1 goto L4 else goto L2 :: bool
195+
L2:
196+
r2 = int_eq i, 6
197+
if r2 goto L4 else goto L3 :: bool
198+
L3:
199+
r3 = 0
200+
goto L5
201+
L4:
202+
r3 = 1
203+
L5:
204+
return r3
205+
206+
[case testTupleOperatorNotIn]
207+
def x() -> int:
208+
return 1
209+
def y() -> int:
210+
return 2
211+
def z() -> int:
212+
return 3
213+
214+
def f() -> bool:
215+
return z() not in (x(), y())
216+
[out]
217+
def x():
218+
L0:
219+
return 2
220+
def y():
221+
L0:
222+
return 4
223+
def z():
224+
L0:
225+
return 6
226+
def f():
227+
r0, r1, r2 :: int
228+
r3, r4 :: bit
229+
r5 :: bool
230+
L0:
231+
r0 = z()
232+
r1 = x()
233+
r2 = y()
234+
r3 = int_ne r0, r1
235+
if r3 goto L1 else goto L3 :: bool
236+
L1:
237+
r4 = int_ne r0, r2
238+
if r4 goto L2 else goto L3 :: bool
239+
L2:
240+
r5 = 1
241+
goto L4
242+
L3:
243+
r5 = 0
244+
L4:
245+
return r5
246+
247+
[case testTupleOperatorInFinalTuple]
248+
from typing import Final
249+
250+
tt: Final = (1, 2)
251+
252+
def f(x: int) -> bool:
253+
return x in tt
254+
[out]
255+
def f(x):
256+
x :: int
257+
r0 :: tuple[int, int]
258+
r1 :: bool
259+
r2, r3 :: int
260+
r4, r5 :: bit
261+
r6 :: bool
262+
L0:
263+
r0 = __main__.tt :: static
264+
if is_error(r0) goto L1 else goto L2
265+
L1:
266+
r1 = raise NameError('value for final name "tt" was not set')
267+
unreachable
198268
L2:
199-
r2 = int_eq i, 4
200-
r1 = r2
269+
r2 = r0[0]
270+
r3 = r0[1]
271+
r4 = int_eq x, r2
272+
if r4 goto L5 else goto L3 :: bool
201273
L3:
202-
if r1 goto L4 else goto L5 :: bool
274+
r5 = int_eq x, r3
275+
if r5 goto L5 else goto L4 :: bool
203276
L4:
204-
r3 = r1
277+
r6 = 0
205278
goto L6
206279
L5:
207-
r4 = int_eq i, 6
208-
r3 = r4
280+
r6 = 1
209281
L6:
210-
return r3
282+
return r6
211283

212284
[case testTupleOperatorInFinalTuple]
213285
from typing import Final

0 commit comments

Comments
 (0)