@@ -7994,20 +7994,47 @@ def conditional_types(
7994
7994
) -> tuple [Type | None , Type | None ]:
7995
7995
"""Takes in the current type and a proposed type of an expression.
7996
7996
7997
- Returns a 2-tuple: The first element is the proposed type, if the expression
7998
- can be the proposed type. The second element is the type it would hold
7999
- if it was not the proposed type, if any. UninhabitedType means unreachable.
8000
- None means no new information can be inferred. If default is set it is returned
8001
- instead."""
7997
+ Returns a 2-tuple:
7998
+ The first element is the proposed type, if the expression can be the proposed type.
7999
+ The second element is the type it would hold if it was not the proposed type, if any.
8000
+ UninhabitedType means unreachable.
8001
+ None means no new information can be inferred.
8002
+ If default is set it is returned instead.
8003
+ """
8004
+ if proposed_type_ranges and len (proposed_type_ranges ) == 1 :
8005
+ # expand e.g. bool -> Literal[True] | Literal[False]
8006
+ target = proposed_type_ranges [0 ].item
8007
+ target = get_proper_type (target )
8008
+ if isinstance (target , LiteralType ) and (
8009
+ target .is_enum_literal () or isinstance (target .value , bool )
8010
+ ):
8011
+ enum_name = target .fallback .type .fullname
8012
+ current_type = try_expanding_sum_type_to_union (current_type , enum_name )
8013
+
8014
+ current_type = get_proper_type (current_type )
8015
+ if isinstance (current_type , UnionType ) and (default == current_type ):
8016
+ # factorize over union types
8017
+ # if we try to narrow A|B to C, we instead narrow A to C and B to C, and
8018
+ # return the union of the results
8019
+ result : list [tuple [Type | None , Type | None ]] = [
8020
+ conditional_types (
8021
+ union_item ,
8022
+ proposed_type_ranges ,
8023
+ default = union_item ,
8024
+ consider_runtime_isinstance = consider_runtime_isinstance ,
8025
+ )
8026
+ for union_item in get_proper_types (current_type .items )
8027
+ ]
8028
+ # separate list of tuples into two lists
8029
+ yes_types , no_types = zip (* result )
8030
+ yes_type = make_simplified_union ([t for t in yes_types if t is not None ])
8031
+ no_type = restrict_subtype_away (
8032
+ current_type , yes_type , consider_runtime_isinstance = consider_runtime_isinstance
8033
+ )
8034
+
8035
+ return yes_type , no_type
8036
+
8002
8037
if proposed_type_ranges :
8003
- if len (proposed_type_ranges ) == 1 :
8004
- target = proposed_type_ranges [0 ].item
8005
- target = get_proper_type (target )
8006
- if isinstance (target , LiteralType ) and (
8007
- target .is_enum_literal () or isinstance (target .value , bool )
8008
- ):
8009
- enum_name = target .fallback .type .fullname
8010
- current_type = try_expanding_sum_type_to_union (current_type , enum_name )
8011
8038
proposed_items = [type_range .item for type_range in proposed_type_ranges ]
8012
8039
proposed_type = make_simplified_union (proposed_items )
8013
8040
if isinstance (proposed_type , AnyType ):
0 commit comments