|
65 | 65 | CallExpr, |
66 | 66 | ClassDef, |
67 | 67 | ComparisonExpr, |
| 68 | + ComplexExpr, |
68 | 69 | Context, |
69 | 70 | ContinueStmt, |
70 | 71 | Decorator, |
71 | 72 | DelStmt, |
| 73 | + DictExpr, |
72 | 74 | EllipsisExpr, |
73 | 75 | Expression, |
74 | 76 | ExpressionStmt, |
|
100 | 102 | RaiseStmt, |
101 | 103 | RefExpr, |
102 | 104 | ReturnStmt, |
| 105 | + SetExpr, |
103 | 106 | StarExpr, |
104 | 107 | Statement, |
105 | 108 | StrExpr, |
@@ -350,6 +353,9 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface): |
350 | 353 | # functions such as open(), etc. |
351 | 354 | plugin: Plugin |
352 | 355 |
|
| 356 | + # A helper state to produce unique temporary names on demand. |
| 357 | + _unique_id: int |
| 358 | + |
353 | 359 | def __init__( |
354 | 360 | self, |
355 | 361 | errors: Errors, |
@@ -413,6 +419,7 @@ def __init__( |
413 | 419 | self, self.msg, self.plugin, per_line_checking_time_ns |
414 | 420 | ) |
415 | 421 | self.pattern_checker = PatternChecker(self, self.msg, self.plugin, options) |
| 422 | + self._unique_id = 0 |
416 | 423 |
|
417 | 424 | @property |
418 | 425 | def type_context(self) -> list[Type | None]: |
@@ -5273,19 +5280,7 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None: |
5273 | 5280 | return |
5274 | 5281 |
|
5275 | 5282 | def visit_match_stmt(self, s: MatchStmt) -> None: |
5276 | | - named_subject: Expression |
5277 | | - if isinstance(s.subject, CallExpr): |
5278 | | - # Create a dummy subject expression to handle cases where a match statement's subject |
5279 | | - # is not a literal value. This lets us correctly narrow types and check exhaustivity |
5280 | | - # This is hack! |
5281 | | - id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else "" |
5282 | | - name = "dummy-match-" + id |
5283 | | - v = Var(name) |
5284 | | - named_subject = NameExpr(name) |
5285 | | - named_subject.node = v |
5286 | | - else: |
5287 | | - named_subject = s.subject |
5288 | | - |
| 5283 | + named_subject = self._make_named_statement_for_match(s.subject) |
5289 | 5284 | with self.binder.frame_context(can_skip=False, fall_through=0): |
5290 | 5285 | subject_type = get_proper_type(self.expr_checker.accept(s.subject)) |
5291 | 5286 |
|
@@ -5362,6 +5357,38 @@ def visit_match_stmt(self, s: MatchStmt) -> None: |
5362 | 5357 | with self.binder.frame_context(can_skip=False, fall_through=2): |
5363 | 5358 | pass |
5364 | 5359 |
|
| 5360 | + def _make_named_statement_for_match(self, subject: Expression) -> Expression: |
| 5361 | + """Construct a fake NameExpr for inference if a match clause is complex.""" |
| 5362 | + expressions_to_preserve = ( |
| 5363 | + # Already named - we should infer type of it as given |
| 5364 | + NameExpr, |
| 5365 | + AssignmentExpr, |
| 5366 | + # Collection literals defined inline - we want to infer types of variables |
| 5367 | + # included there, not exprs as a whole |
| 5368 | + ListExpr, |
| 5369 | + DictExpr, |
| 5370 | + TupleExpr, |
| 5371 | + SetExpr, |
| 5372 | + # Primitive literals - their type is known, no need to name them |
| 5373 | + IntExpr, |
| 5374 | + StrExpr, |
| 5375 | + BytesExpr, |
| 5376 | + FloatExpr, |
| 5377 | + ComplexExpr, |
| 5378 | + EllipsisExpr, |
| 5379 | + ) |
| 5380 | + if isinstance(subject, expressions_to_preserve): |
| 5381 | + return subject |
| 5382 | + else: |
| 5383 | + # Create a dummy subject expression to handle cases where a match statement's subject |
| 5384 | + # is not a literal value. This lets us correctly narrow types and check exhaustivity |
| 5385 | + # This is hack! |
| 5386 | + name = self.new_unique_dummy_name("match") |
| 5387 | + v = Var(name) |
| 5388 | + named_subject = NameExpr(name) |
| 5389 | + named_subject.node = v |
| 5390 | + return named_subject |
| 5391 | + |
5365 | 5392 | def _get_recursive_sub_patterns_map( |
5366 | 5393 | self, expr: Expression, typ: Type |
5367 | 5394 | ) -> dict[Expression, Type]: |
@@ -7715,6 +7742,12 @@ def warn_deprecated_overload_item( |
7715 | 7742 | if candidate == target: |
7716 | 7743 | self.warn_deprecated(item.func, context) |
7717 | 7744 |
|
| 7745 | + def new_unique_dummy_name(self, namespace: str) -> str: |
| 7746 | + """Generate a name that is guaranteed to be unique for this TypeChecker instance.""" |
| 7747 | + name = f"dummy-{namespace}-{self._unique_id}" |
| 7748 | + self._unique_id += 1 |
| 7749 | + return name |
| 7750 | + |
7718 | 7751 |
|
7719 | 7752 | class CollectArgTypeVarTypes(TypeTraverserVisitor): |
7720 | 7753 | """Collects the non-nested argument types in a set.""" |
|
0 commit comments