Skip to content

Commit 6df3f57

Browse files
committed
Improve support for functools.partial of overloaded callable protocol
Resolves #18637 Mypy's behaviour here is not correct (see test case), but this PR makes mypy's behaviour match what it used to be before we added the functools.partial plugin
1 parent 0451880 commit 6df3f57

File tree

2 files changed

+65
-41
lines changed

2 files changed

+65
-41
lines changed

mypy/checker.py

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -701,50 +701,57 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
701701
def extract_callable_type(self, inner_type: Type | None, ctx: Context) -> CallableType | None:
702702
"""Get type as seen by an overload item caller."""
703703
inner_type = get_proper_type(inner_type)
704-
outer_type: CallableType | None = None
705-
if inner_type is not None and not isinstance(inner_type, AnyType):
706-
if isinstance(inner_type, TypeVarLikeType):
707-
inner_type = get_proper_type(inner_type.upper_bound)
708-
if isinstance(inner_type, TypeType):
709-
inner_type = get_proper_type(
710-
self.expr_checker.analyze_type_type_callee(inner_type.item, ctx)
711-
)
704+
outer_type: FunctionLike | None = None
705+
if inner_type is None or isinstance(inner_type, AnyType):
706+
return None
707+
if isinstance(inner_type, TypeVarLikeType):
708+
inner_type = get_proper_type(inner_type.upper_bound)
709+
if isinstance(inner_type, TypeType):
710+
inner_type = get_proper_type(
711+
self.expr_checker.analyze_type_type_callee(inner_type.item, ctx)
712+
)
712713

713-
if isinstance(inner_type, CallableType):
714-
outer_type = inner_type
715-
elif isinstance(inner_type, Instance):
716-
inner_call = get_proper_type(
717-
analyze_member_access(
718-
name="__call__",
719-
typ=inner_type,
720-
context=ctx,
721-
is_lvalue=False,
722-
is_super=False,
723-
is_operator=True,
724-
msg=self.msg,
725-
original_type=inner_type,
726-
chk=self,
727-
)
714+
if isinstance(inner_type, FunctionLike):
715+
outer_type = inner_type
716+
elif isinstance(inner_type, Instance):
717+
inner_call = get_proper_type(
718+
analyze_member_access(
719+
name="__call__",
720+
typ=inner_type,
721+
context=ctx,
722+
is_lvalue=False,
723+
is_super=False,
724+
is_operator=True,
725+
msg=self.msg,
726+
original_type=inner_type,
727+
chk=self,
728728
)
729-
if isinstance(inner_call, CallableType):
730-
outer_type = inner_call
731-
elif isinstance(inner_type, UnionType):
732-
union_type = make_simplified_union(inner_type.items)
733-
if isinstance(union_type, UnionType):
734-
items = []
735-
for item in union_type.items:
736-
callable_item = self.extract_callable_type(item, ctx)
737-
if callable_item is None:
738-
break
739-
items.append(callable_item)
740-
else:
741-
joined_type = get_proper_type(join.join_type_list(items))
742-
if isinstance(joined_type, CallableType):
743-
outer_type = joined_type
729+
)
730+
if isinstance(inner_call, FunctionLike):
731+
outer_type = inner_call
732+
elif isinstance(inner_type, UnionType):
733+
union_type = make_simplified_union(inner_type.items)
734+
if isinstance(union_type, UnionType):
735+
items = []
736+
for item in union_type.items:
737+
callable_item = self.extract_callable_type(item, ctx)
738+
if callable_item is None:
739+
break
740+
items.append(callable_item)
744741
else:
745-
return self.extract_callable_type(union_type, ctx)
746-
if outer_type is None:
747-
self.msg.not_callable(inner_type, ctx)
742+
joined_type = get_proper_type(join.join_type_list(items))
743+
if isinstance(joined_type, FunctionLike):
744+
outer_type = joined_type
745+
else:
746+
return self.extract_callable_type(union_type, ctx)
747+
748+
if outer_type is None:
749+
self.msg.not_callable(inner_type, ctx)
750+
return None
751+
if isinstance(outer_type, Overloaded):
752+
return None
753+
754+
assert isinstance(outer_type, CallableType)
748755
return outer_type
749756

750757
def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:

test-data/unit/check-functools.test

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,3 +640,20 @@ hp = partial(h, 1)
640640
reveal_type(hp(1)) # N: Revealed type is "builtins.int"
641641
hp("a") # E: Argument 1 to "h" has incompatible type "str"; expected "int"
642642
[builtins fixtures/tuple.pyi]
643+
644+
[case testFunctoolsPartialOverloadedCallableProtocol]
645+
from functools import partial
646+
from typing import Callable, Protocol, overload
647+
648+
class P(Protocol):
649+
@overload
650+
def __call__(self, x: int) -> int: ...
651+
@overload
652+
def __call__(self, x: str) -> str: ...
653+
654+
def f(x: P):
655+
reveal_type(partial(x, 1)()) # N: Revealed type is "builtins.int"
656+
657+
# TODO: but this is incorrect, predating the functools.partial plugin
658+
reveal_type(partial(x, "a")()) # N: Revealed type is "builtins.int"
659+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)