Skip to content

Commit 6817b4a

Browse files
committed
Allow adjacent conditionally-defined overloads of different functions
1 parent bd1f51a commit 6817b4a

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

mypy/fastparse.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
629629
ret: list[Statement] = []
630630
current_overload: list[OverloadPart] = []
631631
current_overload_name: str | None = None
632-
seen_unconditional_func_def = False
632+
last_unconditional_func_def: str | None = None
633633
last_if_stmt: IfStmt | None = None
634634
last_if_overload: Decorator | FuncDef | OverloadedFuncDef | None = None
635635
last_if_stmt_overload_name: str | None = None
@@ -639,7 +639,7 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
639639
if_overload_name: str | None = None
640640
if_block_with_overload: Block | None = None
641641
if_unknown_truth_value: IfStmt | None = None
642-
if isinstance(stmt, IfStmt) and seen_unconditional_func_def is False:
642+
if isinstance(stmt, IfStmt):
643643
# Check IfStmt block to determine if function overloads can be merged
644644
if_overload_name = self._check_ifstmt_for_overloads(stmt, current_overload_name)
645645
if if_overload_name is not None:
@@ -667,11 +667,18 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
667667
last_if_unknown_truth_value = None
668668
current_overload.append(stmt)
669669
if isinstance(stmt, FuncDef):
670-
seen_unconditional_func_def = True
670+
# This is, strictly speaking, wrong: there might be a decorated
671+
# implementation. However, it only affects the error message we show:
672+
# ideally it's "already defined", but "implementation must come last"
673+
# is also reasonable.
674+
# TODO: can we get rid of this completely and just always emit
675+
# "implementation must come last" instead?
676+
last_unconditional_func_def = stmt.name
671677
elif (
672678
current_overload_name is not None
673679
and isinstance(stmt, IfStmt)
674680
and if_overload_name == current_overload_name
681+
and last_unconditional_func_def != current_overload_name
675682
):
676683
# IfStmt only contains stmts relevant to current_overload.
677684
# Check if stmts are reachable and add them to current_overload,
@@ -727,7 +734,7 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
727734
# most of mypy/mypyc assumes that all the functions in an OverloadedFuncDef are
728735
# related, but multiple underscore functions next to each other aren't necessarily
729736
# related
730-
seen_unconditional_func_def = False
737+
last_unconditional_func_def = None
731738
if isinstance(stmt, Decorator) and not unnamed_function(stmt.name):
732739
current_overload = [stmt]
733740
current_overload_name = stmt.name

test-data/unit/check-overloading.test

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6310,6 +6310,31 @@ reveal_type(f12(A())) # N: Revealed type is "__main__.A"
63106310

63116311
[typing fixtures/typing-medium.pyi]
63126312

6313+
[case testAdjacentConditionalOverloads]
6314+
# flags: --always-true True
6315+
from typing import overload
6316+
6317+
if True:
6318+
@overload
6319+
def ham(v: str) -> list[str]: ...
6320+
6321+
@overload
6322+
def ham(v: int) -> list[int]: ...
6323+
6324+
def ham(v: "int | str") -> "list[str] | list[int]":
6325+
return []
6326+
6327+
if True:
6328+
@overload
6329+
def spam(v: str) -> list[str]: ...
6330+
6331+
@overload
6332+
def spam(v: int) -> list[int]: ...
6333+
6334+
def spam(v: "int | str") -> "list[str] | list[int]":
6335+
return []
6336+
6337+
63136338
[case testOverloadIfUnconditionalFuncDef]
63146339
# flags: --always-true True --always-false False
63156340
from typing import overload

0 commit comments

Comments
 (0)