Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
ret: list[Statement] = []
current_overload: list[OverloadPart] = []
current_overload_name: str | None = None
seen_unconditional_func_def = False
last_unconditional_func_def: str | None = None
last_if_stmt: IfStmt | None = None
last_if_overload: Decorator | FuncDef | OverloadedFuncDef | None = None
last_if_stmt_overload_name: str | None = None
Expand All @@ -641,7 +641,7 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
if_overload_name: str | None = None
if_block_with_overload: Block | None = None
if_unknown_truth_value: IfStmt | None = None
if isinstance(stmt, IfStmt) and seen_unconditional_func_def is False:
if isinstance(stmt, IfStmt):
# Check IfStmt block to determine if function overloads can be merged
if_overload_name = self._check_ifstmt_for_overloads(stmt, current_overload_name)
if if_overload_name is not None:
Expand Down Expand Up @@ -669,11 +669,18 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
last_if_unknown_truth_value = None
current_overload.append(stmt)
if isinstance(stmt, FuncDef):
seen_unconditional_func_def = True
# This is, strictly speaking, wrong: there might be a decorated
# implementation. However, it only affects the error message we show:
# ideally it's "already defined", but "implementation must come last"
# is also reasonable.
# TODO: can we get rid of this completely and just always emit
# "implementation must come last" instead?
last_unconditional_func_def = stmt.name
elif (
current_overload_name is not None
and isinstance(stmt, IfStmt)
and if_overload_name == current_overload_name
and last_unconditional_func_def != current_overload_name
):
# IfStmt only contains stmts relevant to current_overload.
# Check if stmts are reachable and add them to current_overload,
Expand Down Expand Up @@ -729,7 +736,7 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
# most of mypy/mypyc assumes that all the functions in an OverloadedFuncDef are
# related, but multiple underscore functions next to each other aren't necessarily
# related
seen_unconditional_func_def = False
last_unconditional_func_def = None
if isinstance(stmt, Decorator) and not unnamed_function(stmt.name):
current_overload = [stmt]
current_overload_name = stmt.name
Expand Down
25 changes: 25 additions & 0 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -6310,6 +6310,31 @@ reveal_type(f12(A())) # N: Revealed type is "__main__.A"

[typing fixtures/typing-medium.pyi]

[case testAdjacentConditionalOverloads]
# flags: --always-true True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use some other name than True, since this overlaps with the normal True, which is confusing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, but I'm not sure it's actually better - all surrounding tests use --always-true True, so now this one is inconsistent which might be surprising or misleading for future readers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I didn't notice the others. This may be still useful in case future readers use the new approach as an example.

from typing import overload

if True:
@overload
def ham(v: str) -> list[str]: ...

@overload
def ham(v: int) -> list[int]: ...

def ham(v: "int | str") -> "list[str] | list[int]":
return []

if True:
@overload
def spam(v: str) -> list[str]: ...

@overload
def spam(v: int) -> list[int]: ...

def spam(v: "int | str") -> "list[str] | list[int]":
return []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above. Actually this looks identical to ham. Was this intentional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it was (or rather didn't matter) as the only thing I was testing here was that two adjacent conditional overloads are interpreted correctly. But your other comment raises a good point, modified accordingly.



[case testOverloadIfUnconditionalFuncDef]
# flags: --always-true True --always-false False
from typing import overload
Expand Down