@@ -701,24 +701,70 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
701
701
# x in (...)/[...]
702
702
# x not in (...)/[...]
703
703
first_op = e .operators [0 ]
704
- if (
705
- first_op in ["in" , "not in" ]
706
- and len (e .operators ) == 1
707
- and isinstance (e .operands [1 ], (TupleExpr , ListExpr ))
708
- ):
709
- items = e .operands [1 ].items
704
+ if first_op in ["in" , "not in" ] and len (e .operators ) == 1 :
705
+ result = try_specialize_in_expr (builder , first_op , e .operands [0 ], e .operands [1 ], e .line )
706
+ if result is not None :
707
+ return result
708
+
709
+ if len (e .operators ) == 1 :
710
+ # Special some common simple cases
711
+ if first_op in ("is" , "is not" ):
712
+ right_expr = e .operands [1 ]
713
+ if isinstance (right_expr , NameExpr ) and right_expr .fullname == "builtins.None" :
714
+ # Special case 'is None' / 'is not None'.
715
+ return translate_is_none (builder , e .operands [0 ], negated = first_op != "is" )
716
+ left_expr = e .operands [0 ]
717
+ if is_int_rprimitive (builder .node_type (left_expr )):
718
+ right_expr = e .operands [1 ]
719
+ if is_int_rprimitive (builder .node_type (right_expr )):
720
+ if first_op in int_borrow_friendly_op :
721
+ borrow_left = is_borrow_friendly_expr (builder , right_expr )
722
+ left = builder .accept (left_expr , can_borrow = borrow_left )
723
+ right = builder .accept (right_expr , can_borrow = True )
724
+ return builder .binary_op (left , right , first_op , e .line )
725
+
726
+ # TODO: Don't produce an expression when used in conditional context
727
+ # All of the trickiness here is due to support for chained conditionals
728
+ # (`e1 < e2 > e3`, etc). `e1 < e2 > e3` is approximately equivalent to
729
+ # `e1 < e2 and e2 > e3` except that `e2` is only evaluated once.
730
+ expr_type = builder .node_type (e )
731
+
732
+ # go(i, prev) generates code for `ei opi e{i+1} op{i+1} ... en`,
733
+ # assuming that prev contains the value of `ei`.
734
+ def go (i : int , prev : Value ) -> Value :
735
+ if i == len (e .operators ) - 1 :
736
+ return transform_basic_comparison (
737
+ builder , e .operators [i ], prev , builder .accept (e .operands [i + 1 ]), e .line
738
+ )
739
+
740
+ next = builder .accept (e .operands [i + 1 ])
741
+ return builder .builder .shortcircuit_helper (
742
+ "and" ,
743
+ expr_type ,
744
+ lambda : transform_basic_comparison (builder , e .operators [i ], prev , next , e .line ),
745
+ lambda : go (i + 1 , next ),
746
+ e .line ,
747
+ )
748
+
749
+ return go (0 , builder .accept (e .operands [0 ]))
750
+
751
+
752
+ def try_specialize_in_expr (
753
+ builder : IRBuilder , op : str , lhs : Expression , rhs : Expression , line : int
754
+ ) -> Value | None :
755
+ if isinstance (rhs , (TupleExpr , ListExpr )):
756
+ items = rhs .items
710
757
n_items = len (items )
711
758
# x in y -> x == y[0] or ... or x == y[n]
712
759
# x not in y -> x != y[0] and ... and x != y[n]
713
760
# 16 is arbitrarily chosen to limit code size
714
761
if 1 < n_items < 16 :
715
- if e . operators [ 0 ] == "in" :
762
+ if op == "in" :
716
763
bin_op = "or"
717
764
cmp_op = "=="
718
765
else :
719
766
bin_op = "and"
720
767
cmp_op = "!="
721
- lhs = e .operands [0 ]
722
768
mypy_file = builder .graph ["builtins" ].tree
723
769
assert mypy_file is not None
724
770
info = mypy_file .names ["bool" ].node
@@ -738,78 +784,34 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
738
784
# x in [y]/(y) -> x == y
739
785
# x not in [y]/(y) -> x != y
740
786
elif n_items == 1 :
741
- if e . operators [ 0 ] == "in" :
787
+ if op == "in" :
742
788
cmp_op = "=="
743
789
else :
744
790
cmp_op = "!="
745
- e .operators = [cmp_op ]
746
- e .operands [1 ] = items [0 ]
791
+ left = builder .accept (lhs )
792
+ right = builder .accept (items [0 ])
793
+ return transform_basic_comparison (builder , cmp_op , left , right , line )
747
794
# x in []/() -> False
748
795
# x not in []/() -> True
749
796
elif n_items == 0 :
750
- if e . operators [ 0 ] == "in" :
797
+ if op == "in" :
751
798
return builder .false ()
752
799
else :
753
800
return builder .true ()
754
801
755
802
# x in {...}
756
803
# x not in {...}
757
- if (
758
- first_op in ("in" , "not in" )
759
- and len (e .operators ) == 1
760
- and isinstance (e .operands [1 ], SetExpr )
761
- ):
762
- set_literal = precompute_set_literal (builder , e .operands [1 ])
804
+ if isinstance (rhs , SetExpr ):
805
+ set_literal = precompute_set_literal (builder , rhs )
763
806
if set_literal is not None :
764
- lhs = e .operands [0 ]
765
807
result = builder .builder .primitive_op (
766
- set_in_op , [builder .accept (lhs ), set_literal ], e . line , bool_rprimitive
808
+ set_in_op , [builder .accept (lhs ), set_literal ], line , bool_rprimitive
767
809
)
768
- if first_op == "not in" :
769
- return builder .unary_op (result , "not" , e . line )
810
+ if op == "not in" :
811
+ return builder .unary_op (result , "not" , line )
770
812
return result
771
813
772
- if len (e .operators ) == 1 :
773
- # Special some common simple cases
774
- if first_op in ("is" , "is not" ):
775
- right_expr = e .operands [1 ]
776
- if isinstance (right_expr , NameExpr ) and right_expr .fullname == "builtins.None" :
777
- # Special case 'is None' / 'is not None'.
778
- return translate_is_none (builder , e .operands [0 ], negated = first_op != "is" )
779
- left_expr = e .operands [0 ]
780
- if is_int_rprimitive (builder .node_type (left_expr )):
781
- right_expr = e .operands [1 ]
782
- if is_int_rprimitive (builder .node_type (right_expr )):
783
- if first_op in int_borrow_friendly_op :
784
- borrow_left = is_borrow_friendly_expr (builder , right_expr )
785
- left = builder .accept (left_expr , can_borrow = borrow_left )
786
- right = builder .accept (right_expr , can_borrow = True )
787
- return builder .binary_op (left , right , first_op , e .line )
788
-
789
- # TODO: Don't produce an expression when used in conditional context
790
- # All of the trickiness here is due to support for chained conditionals
791
- # (`e1 < e2 > e3`, etc). `e1 < e2 > e3` is approximately equivalent to
792
- # `e1 < e2 and e2 > e3` except that `e2` is only evaluated once.
793
- expr_type = builder .node_type (e )
794
-
795
- # go(i, prev) generates code for `ei opi e{i+1} op{i+1} ... en`,
796
- # assuming that prev contains the value of `ei`.
797
- def go (i : int , prev : Value ) -> Value :
798
- if i == len (e .operators ) - 1 :
799
- return transform_basic_comparison (
800
- builder , e .operators [i ], prev , builder .accept (e .operands [i + 1 ]), e .line
801
- )
802
-
803
- next = builder .accept (e .operands [i + 1 ])
804
- return builder .builder .shortcircuit_helper (
805
- "and" ,
806
- expr_type ,
807
- lambda : transform_basic_comparison (builder , e .operators [i ], prev , next , e .line ),
808
- lambda : go (i + 1 , next ),
809
- e .line ,
810
- )
811
-
812
- return go (0 , builder .accept (e .operands [0 ]))
814
+ return None
813
815
814
816
815
817
def translate_is_none (builder : IRBuilder , expr : Expression , negated : bool ) -> Value :
0 commit comments