@@ -6413,33 +6413,42 @@ def equality_type_narrowing_helper(
64136413 should_narrow_by_identity = True
64146414 else :
64156415
6416- def is_exactly_literal_type (t : Type ) -> bool :
6417- return isinstance (get_proper_type (t ), LiteralType )
6418-
64196416 def has_no_custom_eq_checks (t : Type ) -> bool :
64206417 return not custom_special_method (
64216418 t , "__eq__" , check_all = False
64226419 ) and not custom_special_method (t , "__ne__" , check_all = False )
64236420
6424- is_valid_target = is_exactly_literal_type
6425- coerce_only_in_literal_context = True
6426-
64276421 expr_types = [operand_types [i ] for i in expr_indices ]
64286422 should_narrow_by_identity = all (
64296423 map (has_no_custom_eq_checks , expr_types )
64306424 ) and not is_ambiguous_mix_of_enums (expr_types )
64316425
6432- if_map : TypeMap = {}
6433- else_map : TypeMap = {}
6434- if should_narrow_by_identity :
6435- if_map , else_map = self .refine_identity_comparison_expression (
6436- operands ,
6437- operand_types ,
6438- expr_indices ,
6439- narrowable_operand_index_to_hash .keys (),
6440- is_valid_target ,
6441- coerce_only_in_literal_context ,
6442- )
6426+ def is_exactly_literal_type_possibly_except_enum (t : Type ) -> bool :
6427+ p_t = get_proper_type (t )
6428+ if isinstance (p_t , LiteralType ):
6429+ if should_narrow_by_identity :
6430+ return True
6431+ else :
6432+ return not p_t .fallback .type .is_enum
6433+ else :
6434+ return False
6435+
6436+ is_valid_target = is_exactly_literal_type_possibly_except_enum
6437+ coerce_only_in_literal_context = True
6438+
6439+ if_map , else_map = self .refine_identity_comparison_expression (
6440+ operands ,
6441+ operand_types ,
6442+ expr_indices ,
6443+ narrowable_operand_index_to_hash .keys (),
6444+ is_valid_target ,
6445+ coerce_only_in_literal_context ,
6446+ )
6447+ if not should_narrow_by_identity :
6448+ # refine_identity_comparison_expression narrows against a single literal
6449+ # -- and we know that literal will only go to the positive branch.
6450+ # This means that the negative branch narrowing is actually correct.
6451+ if_map = {}
64436452
64446453 if if_map == {} and else_map == {}:
64456454 if_map , else_map = self .refine_away_none_in_comparison (
0 commit comments