@@ -5983,121 +5983,10 @@ def find_isinstance_check_helper(
59835983 ),
59845984 )
59855985 elif isinstance (node , ComparisonExpr ):
5986- # Step 1: Obtain the types of each operand and whether or not we can
5987- # narrow their types. (For example, we shouldn't try narrowing the
5988- # types of literal string or enum expressions).
5989-
5990- operands = [collapse_walrus (x ) for x in node .operands ]
5991- operand_types = []
5992- narrowable_operand_index_to_hash = {}
5993- for i , expr in enumerate (operands ):
5994- if not self .has_type (expr ):
5995- return {}, {}
5996- expr_type = self .lookup_type (expr )
5997- operand_types .append (expr_type )
5998-
5999- if (
6000- literal (expr ) == LITERAL_TYPE
6001- and not is_literal_none (expr )
6002- and not self .is_literal_enum (expr )
6003- ):
6004- h = literal_hash (expr )
6005- if h is not None :
6006- narrowable_operand_index_to_hash [i ] = h
6007-
6008- # Step 2: Group operands chained by either the 'is' or '==' operands
6009- # together. For all other operands, we keep them in groups of size 2.
6010- # So the expression:
6011- #
6012- # x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8
6013- #
6014- # ...is converted into the simplified operator list:
6015- #
6016- # [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]),
6017- # ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])]
6018- #
6019- # We group identity/equality expressions so we can propagate information
6020- # we discover about one operand across the entire chain. We don't bother
6021- # handling 'is not' and '!=' chains in a special way: those are very rare
6022- # in practice.
6023-
6024- simplified_operator_list = group_comparison_operands (
6025- node .pairwise (), narrowable_operand_index_to_hash , {"==" , "is" }
6026- )
6027-
6028- # Step 3: Analyze each group and infer more precise type maps for each
6029- # assignable operand, if possible. We combine these type maps together
6030- # in the final step.
6031-
6032- partial_type_maps = []
6033- for operator , expr_indices in simplified_operator_list :
6034- if operator in {"is" , "is not" , "==" , "!=" }:
6035- if_map , else_map = self .equality_type_narrowing_helper (
6036- node ,
6037- operator ,
6038- operands ,
6039- operand_types ,
6040- expr_indices ,
6041- narrowable_operand_index_to_hash ,
6042- )
6043- elif operator in {"in" , "not in" }:
6044- assert len (expr_indices ) == 2
6045- left_index , right_index = expr_indices
6046- item_type = operand_types [left_index ]
6047- iterable_type = operand_types [right_index ]
6048-
6049- if_map , else_map = {}, {}
6050-
6051- if left_index in narrowable_operand_index_to_hash :
6052- # We only try and narrow away 'None' for now
6053- if is_overlapping_none (item_type ):
6054- collection_item_type = get_proper_type (
6055- builtin_item_type (iterable_type )
6056- )
6057- if (
6058- collection_item_type is not None
6059- and not is_overlapping_none (collection_item_type )
6060- and not (
6061- isinstance (collection_item_type , Instance )
6062- and collection_item_type .type .fullname == "builtins.object"
6063- )
6064- and is_overlapping_erased_types (item_type , collection_item_type )
6065- ):
6066- if_map [operands [left_index ]] = remove_optional (item_type )
6067-
6068- if right_index in narrowable_operand_index_to_hash :
6069- if_type , else_type = self .conditional_types_for_iterable (
6070- item_type , iterable_type
6071- )
6072- expr = operands [right_index ]
6073- if if_type is None :
6074- if_map = None
6075- else :
6076- if_map [expr ] = if_type
6077- if else_type is None :
6078- else_map = None
6079- else :
6080- else_map [expr ] = else_type
6081-
6082- else :
6083- if_map = {}
6084- else_map = {}
6085-
6086- if operator in {"is not" , "!=" , "not in" }:
6087- if_map , else_map = else_map , if_map
6088-
6089- partial_type_maps .append ((if_map , else_map ))
6090-
6091- # If we have found non-trivial restrictions from the regular comparisons,
6092- # then return soon. Otherwise try to infer restrictions involving `len(x)`.
6093- # TODO: support regular and len() narrowing in the same chain.
6094- if any (m != ({}, {}) for m in partial_type_maps ):
6095- return reduce_conditional_maps (partial_type_maps )
6096- else :
6097- # Use meet for `and` maps to get correct results for chained checks
6098- # like `if 1 < len(x) < 4: ...`
6099- return reduce_conditional_maps (self .find_tuple_len_narrowing (node ), use_meet = True )
5986+ return self .comparison_type_narrowing_helper (node )
61005987 elif isinstance (node , AssignmentExpr ):
5988+ if_map : dict [Expression , Type ] | None
5989+ else_map : dict [Expression , Type ] | None
61015990 if_map = {}
61025991 else_map = {}
61035992
@@ -6184,6 +6073,121 @@ def find_isinstance_check_helper(
61846073 else_map = {node : else_type } if not isinstance (else_type , UninhabitedType ) else None
61856074 return if_map , else_map
61866075
6076+ def comparison_type_narrowing_helper (self , node : ComparisonExpr ) -> tuple [TypeMap , TypeMap ]:
6077+ """Infer type narrowing from a comparison expression."""
6078+ # Step 1: Obtain the types of each operand and whether or not we can
6079+ # narrow their types. (For example, we shouldn't try narrowing the
6080+ # types of literal string or enum expressions).
6081+
6082+ operands = [collapse_walrus (x ) for x in node .operands ]
6083+ operand_types = []
6084+ narrowable_operand_index_to_hash = {}
6085+ for i , expr in enumerate (operands ):
6086+ if not self .has_type (expr ):
6087+ return {}, {}
6088+ expr_type = self .lookup_type (expr )
6089+ operand_types .append (expr_type )
6090+
6091+ if (
6092+ literal (expr ) == LITERAL_TYPE
6093+ and not is_literal_none (expr )
6094+ and not self .is_literal_enum (expr )
6095+ ):
6096+ h = literal_hash (expr )
6097+ if h is not None :
6098+ narrowable_operand_index_to_hash [i ] = h
6099+
6100+ # Step 2: Group operands chained by either the 'is' or '==' operands
6101+ # together. For all other operands, we keep them in groups of size 2.
6102+ # So the expression:
6103+ #
6104+ # x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8
6105+ #
6106+ # ...is converted into the simplified operator list:
6107+ #
6108+ # [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]),
6109+ # ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])]
6110+ #
6111+ # We group identity/equality expressions so we can propagate information
6112+ # we discover about one operand across the entire chain. We don't bother
6113+ # handling 'is not' and '!=' chains in a special way: those are very rare
6114+ # in practice.
6115+
6116+ simplified_operator_list = group_comparison_operands (
6117+ node .pairwise (), narrowable_operand_index_to_hash , {"==" , "is" }
6118+ )
6119+
6120+ # Step 3: Analyze each group and infer more precise type maps for each
6121+ # assignable operand, if possible. We combine these type maps together
6122+ # in the final step.
6123+
6124+ partial_type_maps = []
6125+ for operator , expr_indices in simplified_operator_list :
6126+ if operator in {"is" , "is not" , "==" , "!=" }:
6127+ if_map , else_map = self .equality_type_narrowing_helper (
6128+ node ,
6129+ operator ,
6130+ operands ,
6131+ operand_types ,
6132+ expr_indices ,
6133+ narrowable_operand_index_to_hash ,
6134+ )
6135+ elif operator in {"in" , "not in" }:
6136+ assert len (expr_indices ) == 2
6137+ left_index , right_index = expr_indices
6138+ item_type = operand_types [left_index ]
6139+ iterable_type = operand_types [right_index ]
6140+
6141+ if_map , else_map = {}, {}
6142+
6143+ if left_index in narrowable_operand_index_to_hash :
6144+ # We only try and narrow away 'None' for now
6145+ if is_overlapping_none (item_type ):
6146+ collection_item_type = get_proper_type (builtin_item_type (iterable_type ))
6147+ if (
6148+ collection_item_type is not None
6149+ and not is_overlapping_none (collection_item_type )
6150+ and not (
6151+ isinstance (collection_item_type , Instance )
6152+ and collection_item_type .type .fullname == "builtins.object"
6153+ )
6154+ and is_overlapping_erased_types (item_type , collection_item_type )
6155+ ):
6156+ if_map [operands [left_index ]] = remove_optional (item_type )
6157+
6158+ if right_index in narrowable_operand_index_to_hash :
6159+ if_type , else_type = self .conditional_types_for_iterable (
6160+ item_type , iterable_type
6161+ )
6162+ expr = operands [right_index ]
6163+ if if_type is None :
6164+ if_map = None
6165+ else :
6166+ if_map [expr ] = if_type
6167+ if else_type is None :
6168+ else_map = None
6169+ else :
6170+ else_map [expr ] = else_type
6171+
6172+ else :
6173+ if_map = {}
6174+ else_map = {}
6175+
6176+ if operator in {"is not" , "!=" , "not in" }:
6177+ if_map , else_map = else_map , if_map
6178+
6179+ partial_type_maps .append ((if_map , else_map ))
6180+
6181+ # If we have found non-trivial restrictions from the regular comparisons,
6182+ # then return soon. Otherwise try to infer restrictions involving `len(x)`.
6183+ # TODO: support regular and len() narrowing in the same chain.
6184+ if any (m != ({}, {}) for m in partial_type_maps ):
6185+ return reduce_conditional_maps (partial_type_maps )
6186+ else :
6187+ # Use meet for `and` maps to get correct results for chained checks
6188+ # like `if 1 < len(x) < 4: ...`
6189+ return reduce_conditional_maps (self .find_tuple_len_narrowing (node ), use_meet = True )
6190+
61876191 def equality_type_narrowing_helper (
61886192 self ,
61896193 node : ComparisonExpr ,
0 commit comments