|
241 | 241 | from mypy.typevars import fill_typevars, fill_typevars_with_any, has_no_typevars |
242 | 242 | from mypy.util import is_dunder, is_sunder |
243 | 243 | from mypy.visitor import NodeVisitor |
| 244 | +from mypy.types import LiteralType |
244 | 245 |
|
245 | 246 | T = TypeVar("T") |
246 | 247 |
|
@@ -6517,7 +6518,6 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa |
6517 | 6518 | # Step 1: Obtain the types of each operand and whether or not we can |
6518 | 6519 | # narrow their types. (For example, we shouldn't try narrowing the |
6519 | 6520 | # types of literal string or enum expressions). |
6520 | | - |
6521 | 6521 | operands = [collapse_walrus(x) for x in node.operands] |
6522 | 6522 | operand_types = [] |
6523 | 6523 | narrowable_operand_index_to_hash = {} |
@@ -6581,18 +6581,39 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa |
6581 | 6581 |
|
6582 | 6582 | if left_index in narrowable_operand_index_to_hash: |
6583 | 6583 | # We only try and narrow away 'None' for now |
6584 | | - if is_overlapping_none(item_type): |
6585 | | - collection_item_type = get_proper_type(builtin_item_type(iterable_type)) |
6586 | | - if ( |
6587 | | - collection_item_type is not None |
6588 | | - and not is_overlapping_none(collection_item_type) |
6589 | | - and not ( |
6590 | | - isinstance(collection_item_type, Instance) |
6591 | | - and collection_item_type.type.fullname == "builtins.object" |
6592 | | - ) |
6593 | | - and is_overlapping_erased_types(item_type, collection_item_type) |
6594 | | - ): |
6595 | | - if_map[operands[left_index]] = remove_optional(item_type) |
| 6584 | + if is_overlapping_none(item_type): |
| 6585 | + collection_item_type = get_proper_type(builtin_item_type(iterable_type)) |
| 6586 | + if ( |
| 6587 | + collection_item_type is not None |
| 6588 | + and not is_overlapping_none(collection_item_type) |
| 6589 | + and not ( |
| 6590 | + isinstance(collection_item_type, Instance) |
| 6591 | + and collection_item_type.type.fullname == "builtins.object" |
| 6592 | + ) |
| 6593 | + and is_overlapping_erased_types(item_type, collection_item_type) |
| 6594 | + ): |
| 6595 | + if_map[operands[left_index]] = remove_optional(item_type) |
| 6596 | + if_map[operands[left_index]] = remove_optional(item_type) |
| 6597 | + literal_types = [] |
| 6598 | + if isinstance(get_proper_type(iterable_type), TupleType): |
| 6599 | + # Check if this is an enum instance that can be narrowed |
| 6600 | + tuple_type = get_proper_type(iterable_type) |
| 6601 | + for i, item_type in enumerate(tuple_type.items): |
| 6602 | + if isinstance(item_type, Instance): |
| 6603 | + |
| 6604 | + if item_type.type.is_enum: |
| 6605 | + # Enum values in tuples are represented as Instance types, not LiteralType |
| 6606 | + if hasattr(item_type, 'last_known_value') and item_type.last_known_value: |
| 6607 | + # Use the existing literal representation |
| 6608 | + literal_types.append(item_type.last_known_value) |
| 6609 | + else: |
| 6610 | + # using the instance directly |
| 6611 | + literal_types.append(item_type) |
| 6612 | + # If we found enum literals in the tuple, narrow the left operand |
| 6613 | + if literal_types: |
| 6614 | + union_type = make_simplified_union(literal_types) |
| 6615 | + # Applying type narrowing for the true branch of the 'in' check |
| 6616 | + if_map[operands[left_index]] = union_type |
6596 | 6617 |
|
6597 | 6618 | if right_index in narrowable_operand_index_to_hash: |
6598 | 6619 | if_type, else_type = self.conditional_types_for_iterable( |
|
0 commit comments