Skip to content

Commit 095df17

Browse files
authored
Allow adjacent conditionally-defined overloads (#19042)
Fixes #19015. Fixes #17521.
1 parent 82e0eb6 commit 095df17

File tree

2 files changed

+45
-4
lines changed

2 files changed

+45
-4
lines changed

mypy/fastparse.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
631631
ret: list[Statement] = []
632632
current_overload: list[OverloadPart] = []
633633
current_overload_name: str | None = None
634-
seen_unconditional_func_def = False
634+
last_unconditional_func_def: str | None = None
635635
last_if_stmt: IfStmt | None = None
636636
last_if_overload: Decorator | FuncDef | OverloadedFuncDef | None = None
637637
last_if_stmt_overload_name: str | None = None
@@ -641,7 +641,7 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
641641
if_overload_name: str | None = None
642642
if_block_with_overload: Block | None = None
643643
if_unknown_truth_value: IfStmt | None = None
644-
if isinstance(stmt, IfStmt) and seen_unconditional_func_def is False:
644+
if isinstance(stmt, IfStmt):
645645
# Check IfStmt block to determine if function overloads can be merged
646646
if_overload_name = self._check_ifstmt_for_overloads(stmt, current_overload_name)
647647
if if_overload_name is not None:
@@ -669,11 +669,18 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
669669
last_if_unknown_truth_value = None
670670
current_overload.append(stmt)
671671
if isinstance(stmt, FuncDef):
672-
seen_unconditional_func_def = True
672+
# This is, strictly speaking, wrong: there might be a decorated
673+
# implementation. However, it only affects the error message we show:
674+
# ideally it's "already defined", but "implementation must come last"
675+
# is also reasonable.
676+
# TODO: can we get rid of this completely and just always emit
677+
# "implementation must come last" instead?
678+
last_unconditional_func_def = stmt.name
673679
elif (
674680
current_overload_name is not None
675681
and isinstance(stmt, IfStmt)
676682
and if_overload_name == current_overload_name
683+
and last_unconditional_func_def != current_overload_name
677684
):
678685
# IfStmt only contains stmts relevant to current_overload.
679686
# Check if stmts are reachable and add them to current_overload,
@@ -729,7 +736,7 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
729736
# most of mypy/mypyc assumes that all the functions in an OverloadedFuncDef are
730737
# related, but multiple underscore functions next to each other aren't necessarily
731738
# related
732-
seen_unconditional_func_def = False
739+
last_unconditional_func_def = None
733740
if isinstance(stmt, Decorator) and not unnamed_function(stmt.name):
734741
current_overload = [stmt]
735742
current_overload_name = stmt.name

test-data/unit/check-overloading.test

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6310,6 +6310,40 @@ reveal_type(f12(A())) # N: Revealed type is "__main__.A"
63106310

63116311
[typing fixtures/typing-medium.pyi]
63126312

6313+
[case testAdjacentConditionalOverloads]
6314+
# flags: --always-true true_alias
6315+
from typing import overload
6316+
6317+
true_alias = True
6318+
6319+
if true_alias:
6320+
@overload
6321+
def ham(v: str) -> list[str]: ...
6322+
6323+
@overload
6324+
def ham(v: int) -> list[int]: ...
6325+
6326+
def ham(v: "int | str") -> "list[str] | list[int]":
6327+
return []
6328+
6329+
if true_alias:
6330+
@overload
6331+
def spam(v: str) -> str: ...
6332+
6333+
@overload
6334+
def spam(v: int) -> int: ...
6335+
6336+
def spam(v: "int | str") -> "str | int":
6337+
return ""
6338+
6339+
reveal_type(ham) # N: Revealed type is "Overload(def (v: builtins.str) -> builtins.list[builtins.str], def (v: builtins.int) -> builtins.list[builtins.int])"
6340+
reveal_type(spam) # N: Revealed type is "Overload(def (v: builtins.str) -> builtins.str, def (v: builtins.int) -> builtins.int)"
6341+
6342+
reveal_type(ham("")) # N: Revealed type is "builtins.list[builtins.str]"
6343+
reveal_type(ham(0)) # N: Revealed type is "builtins.list[builtins.int]"
6344+
reveal_type(spam("")) # N: Revealed type is "builtins.str"
6345+
reveal_type(spam(0)) # N: Revealed type is "builtins.int"
6346+
63136347
[case testOverloadIfUnconditionalFuncDef]
63146348
# flags: --always-true True --always-false False
63156349
from typing import overload

0 commit comments

Comments
 (0)