Skip to content

Commit abf61fb

Browse files
authored
[mypyc] Refactor building IR for "in" against a tuple (#19679)
This covers `x in (a, b)` and `x in [a, b]`. Also add an irbuild test for a use case that is currently inefficient. This is in preparation for improving IR building for "in" against tuple.
1 parent 3fcfcb8 commit abf61fb

File tree

2 files changed

+95
-63
lines changed

2 files changed

+95
-63
lines changed

mypyc/irbuild/expression.py

Lines changed: 65 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -701,24 +701,70 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
701701
# x in (...)/[...]
702702
# x not in (...)/[...]
703703
first_op = e.operators[0]
704-
if (
705-
first_op in ["in", "not in"]
706-
and len(e.operators) == 1
707-
and isinstance(e.operands[1], (TupleExpr, ListExpr))
708-
):
709-
items = e.operands[1].items
704+
if first_op in ["in", "not in"] and len(e.operators) == 1:
705+
result = try_specialize_in_expr(builder, first_op, e.operands[0], e.operands[1], e.line)
706+
if result is not None:
707+
return result
708+
709+
if len(e.operators) == 1:
710+
# Special some common simple cases
711+
if first_op in ("is", "is not"):
712+
right_expr = e.operands[1]
713+
if isinstance(right_expr, NameExpr) and right_expr.fullname == "builtins.None":
714+
# Special case 'is None' / 'is not None'.
715+
return translate_is_none(builder, e.operands[0], negated=first_op != "is")
716+
left_expr = e.operands[0]
717+
if is_int_rprimitive(builder.node_type(left_expr)):
718+
right_expr = e.operands[1]
719+
if is_int_rprimitive(builder.node_type(right_expr)):
720+
if first_op in int_borrow_friendly_op:
721+
borrow_left = is_borrow_friendly_expr(builder, right_expr)
722+
left = builder.accept(left_expr, can_borrow=borrow_left)
723+
right = builder.accept(right_expr, can_borrow=True)
724+
return builder.binary_op(left, right, first_op, e.line)
725+
726+
# TODO: Don't produce an expression when used in conditional context
727+
# All of the trickiness here is due to support for chained conditionals
728+
# (`e1 < e2 > e3`, etc). `e1 < e2 > e3` is approximately equivalent to
729+
# `e1 < e2 and e2 > e3` except that `e2` is only evaluated once.
730+
expr_type = builder.node_type(e)
731+
732+
# go(i, prev) generates code for `ei opi e{i+1} op{i+1} ... en`,
733+
# assuming that prev contains the value of `ei`.
734+
def go(i: int, prev: Value) -> Value:
735+
if i == len(e.operators) - 1:
736+
return transform_basic_comparison(
737+
builder, e.operators[i], prev, builder.accept(e.operands[i + 1]), e.line
738+
)
739+
740+
next = builder.accept(e.operands[i + 1])
741+
return builder.builder.shortcircuit_helper(
742+
"and",
743+
expr_type,
744+
lambda: transform_basic_comparison(builder, e.operators[i], prev, next, e.line),
745+
lambda: go(i + 1, next),
746+
e.line,
747+
)
748+
749+
return go(0, builder.accept(e.operands[0]))
750+
751+
752+
def try_specialize_in_expr(
753+
builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int
754+
) -> Value | None:
755+
if isinstance(rhs, (TupleExpr, ListExpr)):
756+
items = rhs.items
710757
n_items = len(items)
711758
# x in y -> x == y[0] or ... or x == y[n]
712759
# x not in y -> x != y[0] and ... and x != y[n]
713760
# 16 is arbitrarily chosen to limit code size
714761
if 1 < n_items < 16:
715-
if e.operators[0] == "in":
762+
if op == "in":
716763
bin_op = "or"
717764
cmp_op = "=="
718765
else:
719766
bin_op = "and"
720767
cmp_op = "!="
721-
lhs = e.operands[0]
722768
mypy_file = builder.graph["builtins"].tree
723769
assert mypy_file is not None
724770
info = mypy_file.names["bool"].node
@@ -738,78 +784,34 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
738784
# x in [y]/(y) -> x == y
739785
# x not in [y]/(y) -> x != y
740786
elif n_items == 1:
741-
if e.operators[0] == "in":
787+
if op == "in":
742788
cmp_op = "=="
743789
else:
744790
cmp_op = "!="
745-
e.operators = [cmp_op]
746-
e.operands[1] = items[0]
791+
left = builder.accept(lhs)
792+
right = builder.accept(items[0])
793+
return transform_basic_comparison(builder, cmp_op, left, right, line)
747794
# x in []/() -> False
748795
# x not in []/() -> True
749796
elif n_items == 0:
750-
if e.operators[0] == "in":
797+
if op == "in":
751798
return builder.false()
752799
else:
753800
return builder.true()
754801

755802
# x in {...}
756803
# x not in {...}
757-
if (
758-
first_op in ("in", "not in")
759-
and len(e.operators) == 1
760-
and isinstance(e.operands[1], SetExpr)
761-
):
762-
set_literal = precompute_set_literal(builder, e.operands[1])
804+
if isinstance(rhs, SetExpr):
805+
set_literal = precompute_set_literal(builder, rhs)
763806
if set_literal is not None:
764-
lhs = e.operands[0]
765807
result = builder.builder.primitive_op(
766-
set_in_op, [builder.accept(lhs), set_literal], e.line, bool_rprimitive
808+
set_in_op, [builder.accept(lhs), set_literal], line, bool_rprimitive
767809
)
768-
if first_op == "not in":
769-
return builder.unary_op(result, "not", e.line)
810+
if op == "not in":
811+
return builder.unary_op(result, "not", line)
770812
return result
771813

772-
if len(e.operators) == 1:
773-
# Special some common simple cases
774-
if first_op in ("is", "is not"):
775-
right_expr = e.operands[1]
776-
if isinstance(right_expr, NameExpr) and right_expr.fullname == "builtins.None":
777-
# Special case 'is None' / 'is not None'.
778-
return translate_is_none(builder, e.operands[0], negated=first_op != "is")
779-
left_expr = e.operands[0]
780-
if is_int_rprimitive(builder.node_type(left_expr)):
781-
right_expr = e.operands[1]
782-
if is_int_rprimitive(builder.node_type(right_expr)):
783-
if first_op in int_borrow_friendly_op:
784-
borrow_left = is_borrow_friendly_expr(builder, right_expr)
785-
left = builder.accept(left_expr, can_borrow=borrow_left)
786-
right = builder.accept(right_expr, can_borrow=True)
787-
return builder.binary_op(left, right, first_op, e.line)
788-
789-
# TODO: Don't produce an expression when used in conditional context
790-
# All of the trickiness here is due to support for chained conditionals
791-
# (`e1 < e2 > e3`, etc). `e1 < e2 > e3` is approximately equivalent to
792-
# `e1 < e2 and e2 > e3` except that `e2` is only evaluated once.
793-
expr_type = builder.node_type(e)
794-
795-
# go(i, prev) generates code for `ei opi e{i+1} op{i+1} ... en`,
796-
# assuming that prev contains the value of `ei`.
797-
def go(i: int, prev: Value) -> Value:
798-
if i == len(e.operators) - 1:
799-
return transform_basic_comparison(
800-
builder, e.operators[i], prev, builder.accept(e.operands[i + 1]), e.line
801-
)
802-
803-
next = builder.accept(e.operands[i + 1])
804-
return builder.builder.shortcircuit_helper(
805-
"and",
806-
expr_type,
807-
lambda: transform_basic_comparison(builder, e.operators[i], prev, next, e.line),
808-
lambda: go(i + 1, next),
809-
e.line,
810-
)
811-
812-
return go(0, builder.accept(e.operands[0]))
814+
return None
813815

814816

815817
def translate_is_none(builder: IRBuilder, expr: Expression, negated: bool) -> Value:

mypyc/test-data/irbuild-tuple.test

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,36 @@ L5:
209209
L6:
210210
return r3
211211

212+
[case testTupleOperatorInFinalTuple]
213+
from typing import Final
214+
215+
tt: Final = (1, 2)
216+
217+
def f(x: int) -> bool:
218+
return x in tt
219+
[out]
220+
def f(x):
221+
x :: int
222+
r0 :: tuple[int, int]
223+
r1 :: bool
224+
r2, r3 :: object
225+
r4 :: i32
226+
r5 :: bit
227+
r6 :: bool
228+
L0:
229+
r0 = __main__.tt :: static
230+
if is_error(r0) goto L1 else goto L2
231+
L1:
232+
r1 = raise NameError('value for final name "tt" was not set')
233+
unreachable
234+
L2:
235+
r2 = box(int, x)
236+
r3 = box(tuple[int, int], r0)
237+
r4 = PySequence_Contains(r3, r2)
238+
r5 = r4 >= 0 :: signed
239+
r6 = truncate r4: i32 to builtins.bool
240+
return r6
241+
212242
[case testTupleBuiltFromList]
213243
def f(val: int) -> bool:
214244
return val % 2 == 0

0 commit comments

Comments
 (0)