Skip to content

Commit 638f99b

Browse files
committed
[mypyc] Refactor "in" expression IR transform
1 parent 3fcfcb8 commit 638f99b

File tree

1 file changed

+65
-63
lines changed

1 file changed

+65
-63
lines changed

mypyc/irbuild/expression.py

Lines changed: 65 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -701,24 +701,70 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
701701
# x in (...)/[...]
702702
# x not in (...)/[...]
703703
first_op = e.operators[0]
704-
if (
705-
first_op in ["in", "not in"]
706-
and len(e.operators) == 1
707-
and isinstance(e.operands[1], (TupleExpr, ListExpr))
708-
):
709-
items = e.operands[1].items
704+
if first_op in ["in", "not in"] and len(e.operators) == 1:
705+
result = try_specialize_in_expr(builder, first_op, e.operands[0], e.operands[1], e.line)
706+
if result is not None:
707+
return result
708+
709+
if len(e.operators) == 1:
710+
# Special some common simple cases
711+
if first_op in ("is", "is not"):
712+
right_expr = e.operands[1]
713+
if isinstance(right_expr, NameExpr) and right_expr.fullname == "builtins.None":
714+
# Special case 'is None' / 'is not None'.
715+
return translate_is_none(builder, e.operands[0], negated=first_op != "is")
716+
left_expr = e.operands[0]
717+
if is_int_rprimitive(builder.node_type(left_expr)):
718+
right_expr = e.operands[1]
719+
if is_int_rprimitive(builder.node_type(right_expr)):
720+
if first_op in int_borrow_friendly_op:
721+
borrow_left = is_borrow_friendly_expr(builder, right_expr)
722+
left = builder.accept(left_expr, can_borrow=borrow_left)
723+
right = builder.accept(right_expr, can_borrow=True)
724+
return builder.binary_op(left, right, first_op, e.line)
725+
726+
# TODO: Don't produce an expression when used in conditional context
727+
# All of the trickiness here is due to support for chained conditionals
728+
# (`e1 < e2 > e3`, etc). `e1 < e2 > e3` is approximately equivalent to
729+
# `e1 < e2 and e2 > e3` except that `e2` is only evaluated once.
730+
expr_type = builder.node_type(e)
731+
732+
# go(i, prev) generates code for `ei opi e{i+1} op{i+1} ... en`,
733+
# assuming that prev contains the value of `ei`.
734+
def go(i: int, prev: Value) -> Value:
735+
if i == len(e.operators) - 1:
736+
return transform_basic_comparison(
737+
builder, e.operators[i], prev, builder.accept(e.operands[i + 1]), e.line
738+
)
739+
740+
next = builder.accept(e.operands[i + 1])
741+
return builder.builder.shortcircuit_helper(
742+
"and",
743+
expr_type,
744+
lambda: transform_basic_comparison(builder, e.operators[i], prev, next, e.line),
745+
lambda: go(i + 1, next),
746+
e.line,
747+
)
748+
749+
return go(0, builder.accept(e.operands[0]))
750+
751+
752+
def try_specialize_in_expr(
753+
builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int
754+
) -> Value | None:
755+
if isinstance(rhs, (TupleExpr, ListExpr)):
756+
items = rhs.items
710757
n_items = len(items)
711758
# x in y -> x == y[0] or ... or x == y[n]
712759
# x not in y -> x != y[0] and ... and x != y[n]
713760
# 16 is arbitrarily chosen to limit code size
714761
if 1 < n_items < 16:
715-
if e.operators[0] == "in":
762+
if op == "in":
716763
bin_op = "or"
717764
cmp_op = "=="
718765
else:
719766
bin_op = "and"
720767
cmp_op = "!="
721-
lhs = e.operands[0]
722768
mypy_file = builder.graph["builtins"].tree
723769
assert mypy_file is not None
724770
info = mypy_file.names["bool"].node
@@ -738,78 +784,34 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
738784
# x in [y]/(y) -> x == y
739785
# x not in [y]/(y) -> x != y
740786
elif n_items == 1:
741-
if e.operators[0] == "in":
787+
if op == "in":
742788
cmp_op = "=="
743789
else:
744790
cmp_op = "!="
745-
e.operators = [cmp_op]
746-
e.operands[1] = items[0]
791+
left = builder.accept(lhs)
792+
right = builder.accept(items[0])
793+
return transform_basic_comparison(builder, cmp_op, left, right, line)
747794
# x in []/() -> False
748795
# x not in []/() -> True
749796
elif n_items == 0:
750-
if e.operators[0] == "in":
797+
if op == "in":
751798
return builder.false()
752799
else:
753800
return builder.true()
754801

755802
# x in {...}
756803
# x not in {...}
757-
if (
758-
first_op in ("in", "not in")
759-
and len(e.operators) == 1
760-
and isinstance(e.operands[1], SetExpr)
761-
):
762-
set_literal = precompute_set_literal(builder, e.operands[1])
804+
if isinstance(rhs, SetExpr):
805+
set_literal = precompute_set_literal(builder, rhs)
763806
if set_literal is not None:
764-
lhs = e.operands[0]
765807
result = builder.builder.primitive_op(
766-
set_in_op, [builder.accept(lhs), set_literal], e.line, bool_rprimitive
808+
set_in_op, [builder.accept(lhs), set_literal], line, bool_rprimitive
767809
)
768-
if first_op == "not in":
769-
return builder.unary_op(result, "not", e.line)
810+
if op == "not in":
811+
return builder.unary_op(result, "not", line)
770812
return result
771813

772-
if len(e.operators) == 1:
773-
# Special some common simple cases
774-
if first_op in ("is", "is not"):
775-
right_expr = e.operands[1]
776-
if isinstance(right_expr, NameExpr) and right_expr.fullname == "builtins.None":
777-
# Special case 'is None' / 'is not None'.
778-
return translate_is_none(builder, e.operands[0], negated=first_op != "is")
779-
left_expr = e.operands[0]
780-
if is_int_rprimitive(builder.node_type(left_expr)):
781-
right_expr = e.operands[1]
782-
if is_int_rprimitive(builder.node_type(right_expr)):
783-
if first_op in int_borrow_friendly_op:
784-
borrow_left = is_borrow_friendly_expr(builder, right_expr)
785-
left = builder.accept(left_expr, can_borrow=borrow_left)
786-
right = builder.accept(right_expr, can_borrow=True)
787-
return builder.binary_op(left, right, first_op, e.line)
788-
789-
# TODO: Don't produce an expression when used in conditional context
790-
# All of the trickiness here is due to support for chained conditionals
791-
# (`e1 < e2 > e3`, etc). `e1 < e2 > e3` is approximately equivalent to
792-
# `e1 < e2 and e2 > e3` except that `e2` is only evaluated once.
793-
expr_type = builder.node_type(e)
794-
795-
# go(i, prev) generates code for `ei opi e{i+1} op{i+1} ... en`,
796-
# assuming that prev contains the value of `ei`.
797-
def go(i: int, prev: Value) -> Value:
798-
if i == len(e.operators) - 1:
799-
return transform_basic_comparison(
800-
builder, e.operators[i], prev, builder.accept(e.operands[i + 1]), e.line
801-
)
802-
803-
next = builder.accept(e.operands[i + 1])
804-
return builder.builder.shortcircuit_helper(
805-
"and",
806-
expr_type,
807-
lambda: transform_basic_comparison(builder, e.operators[i], prev, next, e.line),
808-
lambda: go(i + 1, next),
809-
e.line,
810-
)
811-
812-
return go(0, builder.accept(e.operands[0]))
814+
return None
813815

814816

815817
def translate_is_none(builder: IRBuilder, expr: Expression, negated: bool) -> Value:

0 commit comments

Comments
 (0)