@@ -701,24 +701,70 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
701701 # x in (...)/[...]
702702 # x not in (...)/[...]
703703 first_op = e .operators [0 ]
704- if (
705- first_op in ["in" , "not in" ]
706- and len (e .operators ) == 1
707- and isinstance (e .operands [1 ], (TupleExpr , ListExpr ))
708- ):
709- items = e .operands [1 ].items
704+ if first_op in ["in" , "not in" ] and len (e .operators ) == 1 :
705+ result = try_specialize_in_expr (builder , first_op , e .operands [0 ], e .operands [1 ], e .line )
706+ if result is not None :
707+ return result
708+
709+ if len (e .operators ) == 1 :
710+ # Special some common simple cases
711+ if first_op in ("is" , "is not" ):
712+ right_expr = e .operands [1 ]
713+ if isinstance (right_expr , NameExpr ) and right_expr .fullname == "builtins.None" :
714+ # Special case 'is None' / 'is not None'.
715+ return translate_is_none (builder , e .operands [0 ], negated = first_op != "is" )
716+ left_expr = e .operands [0 ]
717+ if is_int_rprimitive (builder .node_type (left_expr )):
718+ right_expr = e .operands [1 ]
719+ if is_int_rprimitive (builder .node_type (right_expr )):
720+ if first_op in int_borrow_friendly_op :
721+ borrow_left = is_borrow_friendly_expr (builder , right_expr )
722+ left = builder .accept (left_expr , can_borrow = borrow_left )
723+ right = builder .accept (right_expr , can_borrow = True )
724+ return builder .binary_op (left , right , first_op , e .line )
725+
726+ # TODO: Don't produce an expression when used in conditional context
727+ # All of the trickiness here is due to support for chained conditionals
728+ # (`e1 < e2 > e3`, etc). `e1 < e2 > e3` is approximately equivalent to
729+ # `e1 < e2 and e2 > e3` except that `e2` is only evaluated once.
730+ expr_type = builder .node_type (e )
731+
732+ # go(i, prev) generates code for `ei opi e{i+1} op{i+1} ... en`,
733+ # assuming that prev contains the value of `ei`.
734+ def go (i : int , prev : Value ) -> Value :
735+ if i == len (e .operators ) - 1 :
736+ return transform_basic_comparison (
737+ builder , e .operators [i ], prev , builder .accept (e .operands [i + 1 ]), e .line
738+ )
739+
740+ next = builder .accept (e .operands [i + 1 ])
741+ return builder .builder .shortcircuit_helper (
742+ "and" ,
743+ expr_type ,
744+ lambda : transform_basic_comparison (builder , e .operators [i ], prev , next , e .line ),
745+ lambda : go (i + 1 , next ),
746+ e .line ,
747+ )
748+
749+ return go (0 , builder .accept (e .operands [0 ]))
750+
751+
752+ def try_specialize_in_expr (
753+ builder : IRBuilder , op : str , lhs : Expression , rhs : Expression , line : int
754+ ) -> Value | None :
755+ if isinstance (rhs , (TupleExpr , ListExpr )):
756+ items = rhs .items
710757 n_items = len (items )
711758 # x in y -> x == y[0] or ... or x == y[n]
712759 # x not in y -> x != y[0] and ... and x != y[n]
713760 # 16 is arbitrarily chosen to limit code size
714761 if 1 < n_items < 16 :
715- if e . operators [ 0 ] == "in" :
762+ if op == "in" :
716763 bin_op = "or"
717764 cmp_op = "=="
718765 else :
719766 bin_op = "and"
720767 cmp_op = "!="
721- lhs = e .operands [0 ]
722768 mypy_file = builder .graph ["builtins" ].tree
723769 assert mypy_file is not None
724770 info = mypy_file .names ["bool" ].node
@@ -738,78 +784,34 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
738784 # x in [y]/(y) -> x == y
739785 # x not in [y]/(y) -> x != y
740786 elif n_items == 1 :
741- if e . operators [ 0 ] == "in" :
787+ if op == "in" :
742788 cmp_op = "=="
743789 else :
744790 cmp_op = "!="
745- e .operators = [cmp_op ]
746- e .operands [1 ] = items [0 ]
791+ left = builder .accept (lhs )
792+ right = builder .accept (items [0 ])
793+ return transform_basic_comparison (builder , cmp_op , left , right , line )
747794 # x in []/() -> False
748795 # x not in []/() -> True
749796 elif n_items == 0 :
750- if e . operators [ 0 ] == "in" :
797+ if op == "in" :
751798 return builder .false ()
752799 else :
753800 return builder .true ()
754801
755802 # x in {...}
756803 # x not in {...}
757- if (
758- first_op in ("in" , "not in" )
759- and len (e .operators ) == 1
760- and isinstance (e .operands [1 ], SetExpr )
761- ):
762- set_literal = precompute_set_literal (builder , e .operands [1 ])
804+ if isinstance (rhs , SetExpr ):
805+ set_literal = precompute_set_literal (builder , rhs )
763806 if set_literal is not None :
764- lhs = e .operands [0 ]
765807 result = builder .builder .primitive_op (
766- set_in_op , [builder .accept (lhs ), set_literal ], e . line , bool_rprimitive
808+ set_in_op , [builder .accept (lhs ), set_literal ], line , bool_rprimitive
767809 )
768- if first_op == "not in" :
769- return builder .unary_op (result , "not" , e . line )
810+ if op == "not in" :
811+ return builder .unary_op (result , "not" , line )
770812 return result
771813
772- if len (e .operators ) == 1 :
773- # Special some common simple cases
774- if first_op in ("is" , "is not" ):
775- right_expr = e .operands [1 ]
776- if isinstance (right_expr , NameExpr ) and right_expr .fullname == "builtins.None" :
777- # Special case 'is None' / 'is not None'.
778- return translate_is_none (builder , e .operands [0 ], negated = first_op != "is" )
779- left_expr = e .operands [0 ]
780- if is_int_rprimitive (builder .node_type (left_expr )):
781- right_expr = e .operands [1 ]
782- if is_int_rprimitive (builder .node_type (right_expr )):
783- if first_op in int_borrow_friendly_op :
784- borrow_left = is_borrow_friendly_expr (builder , right_expr )
785- left = builder .accept (left_expr , can_borrow = borrow_left )
786- right = builder .accept (right_expr , can_borrow = True )
787- return builder .binary_op (left , right , first_op , e .line )
788-
789- # TODO: Don't produce an expression when used in conditional context
790- # All of the trickiness here is due to support for chained conditionals
791- # (`e1 < e2 > e3`, etc). `e1 < e2 > e3` is approximately equivalent to
792- # `e1 < e2 and e2 > e3` except that `e2` is only evaluated once.
793- expr_type = builder .node_type (e )
794-
795- # go(i, prev) generates code for `ei opi e{i+1} op{i+1} ... en`,
796- # assuming that prev contains the value of `ei`.
797- def go (i : int , prev : Value ) -> Value :
798- if i == len (e .operators ) - 1 :
799- return transform_basic_comparison (
800- builder , e .operators [i ], prev , builder .accept (e .operands [i + 1 ]), e .line
801- )
802-
803- next = builder .accept (e .operands [i + 1 ])
804- return builder .builder .shortcircuit_helper (
805- "and" ,
806- expr_type ,
807- lambda : transform_basic_comparison (builder , e .operators [i ], prev , next , e .line ),
808- lambda : go (i + 1 , next ),
809- e .line ,
810- )
811-
812- return go (0 , builder .accept (e .operands [0 ]))
814+ return None
813815
814816
815817def translate_is_none (builder : IRBuilder , expr : Expression , negated : bool ) -> Value :
0 commit comments