Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
ret: list[Statement] = []
current_overload: list[OverloadPart] = []
current_overload_name: str | None = None
seen_unconditional_func_def = False
last_unconditional_func_def: str | None = None
last_if_stmt: IfStmt | None = None
last_if_overload: Decorator | FuncDef | OverloadedFuncDef | None = None
last_if_stmt_overload_name: str | None = None
Expand All @@ -641,7 +641,7 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
if_overload_name: str | None = None
if_block_with_overload: Block | None = None
if_unknown_truth_value: IfStmt | None = None
if isinstance(stmt, IfStmt) and seen_unconditional_func_def is False:
if isinstance(stmt, IfStmt):
# Check IfStmt block to determine if function overloads can be merged
if_overload_name = self._check_ifstmt_for_overloads(stmt, current_overload_name)
if if_overload_name is not None:
Expand Down Expand Up @@ -669,11 +669,18 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
last_if_unknown_truth_value = None
current_overload.append(stmt)
if isinstance(stmt, FuncDef):
seen_unconditional_func_def = True
# This is, strictly speaking, wrong: there might be a decorated
# implementation. However, it only affects the error message we show:
# ideally it's "already defined", but "implementation must come last"
# is also reasonable.
# TODO: can we get rid of this completely and just always emit
# "implementation must come last" instead?
last_unconditional_func_def = stmt.name
elif (
current_overload_name is not None
and isinstance(stmt, IfStmt)
and if_overload_name == current_overload_name
and last_unconditional_func_def != current_overload_name
):
# IfStmt only contains stmts relevant to current_overload.
# Check if stmts are reachable and add them to current_overload,
Expand Down Expand Up @@ -729,7 +736,7 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
# most of mypy/mypyc assumes that all the functions in an OverloadedFuncDef are
# related, but multiple underscore functions next to each other aren't necessarily
# related
seen_unconditional_func_def = False
last_unconditional_func_def = None
if isinstance(stmt, Decorator) and not unnamed_function(stmt.name):
current_overload = [stmt]
current_overload_name = stmt.name
Expand Down
34 changes: 34 additions & 0 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -6310,6 +6310,40 @@ reveal_type(f12(A())) # N: Revealed type is "__main__.A"

[typing fixtures/typing-medium.pyi]

[case testAdjacentConditionalOverloads]
# flags: --always-true true_alias
from typing import overload

true_alias = True

if true_alias:
@overload
def ham(v: str) -> list[str]: ...

@overload
def ham(v: int) -> list[int]: ...

def ham(v: "int | str") -> "list[str] | list[int]":
return []

if true_alias:
@overload
def spam(v: str) -> str: ...

@overload
def spam(v: int) -> int: ...

def spam(v: "int | str") -> "str | int":
return ""

reveal_type(ham) # N: Revealed type is "Overload(def (v: builtins.str) -> builtins.list[builtins.str], def (v: builtins.int) -> builtins.list[builtins.int])"
reveal_type(spam) # N: Revealed type is "Overload(def (v: builtins.str) -> builtins.str, def (v: builtins.int) -> builtins.int)"

reveal_type(ham("")) # N: Revealed type is "builtins.list[builtins.str]"
reveal_type(ham(0)) # N: Revealed type is "builtins.list[builtins.int]"
reveal_type(spam("")) # N: Revealed type is "builtins.str"
reveal_type(spam(0)) # N: Revealed type is "builtins.int"

[case testOverloadIfUnconditionalFuncDef]
# flags: --always-true True --always-false False
from typing import overload
Expand Down