Skip to content

Commit 9be2dea

Browse files
superbobrycopybara-github
authored andcommitted
Allowed functools.partial to be used as a converter= in attrs.field aka attr.ib
pytype can now do proper checking of such converter declarations. PiperOrigin-RevId: 820705345
1 parent 34342a6 commit 9be2dea

File tree

3 files changed

+66
-11
lines changed

3 files changed

+66
-11
lines changed

pytype/abstract/function.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,22 @@ def get_signatures(func: "_function_base.Function") -> "list[Signature]":
7272
return []
7373
elif isinstance(func.cls, _abstract.CallableClass):
7474
return [Signature.from_callable(func.cls)]
75+
elif isinstance(func, (_abstract.InterpreterClass, _abstract.PyTDClass)):
76+
if isinstance(func, _abstract.PyTDClass) and "__init__" in func:
77+
func.load_lazy_attribute("__init__")
78+
if (init_var := func.members.get("__init__")) and len(init_var.data) == 1:
79+
sigs = []
80+
for sig in get_signatures(init_var.data[0]):
81+
sig = sig.drop_first_parameter() # drop "self"
82+
sigs.append(
83+
sig._replace(annotations=sig.annotations | {"return": func})
84+
)
85+
return sigs
86+
# The class does not have __init__? Bail out!
87+
# TODO(slebedev): Consider handling __new__ and metaclass.__call__ here.
88+
return [Signature.from_any()]
89+
elif hasattr(func, "get_signatures"):
90+
return func.get_signatures()
7591
else:
7692
unwrapped = abstract_utils.maybe_unwrap_decorated_function(func)
7793
if unwrapped:

pytype/overlays/functools_overlay.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import Mapping, Sequence
5+
from collections.abc import Sequence
66
import threading
77
from typing import Any, TYPE_CHECKING
88

@@ -129,8 +129,8 @@ class BoundPartial(abstract.Instance, mixin.HasSlots):
129129
"""An instance of functools.partial."""
130130

131131
underlying: cfg.Variable
132-
args: Sequence[cfg.Variable]
133-
kwargs: Mapping[str, cfg.Variable]
132+
args: tuple[cfg.Variable, ...]
133+
kwargs: dict[str, cfg.Variable]
134134

135135
def __init__(self, ctx, cls, container=None):
136136
super().__init__(cls, ctx, container)
@@ -139,11 +139,19 @@ def __init__(self, ctx, cls, container=None):
139139
"__call__", NativeFunction("__call__", self.call_slot, self.ctx)
140140
)
141141

142-
@property
143-
def func(self) -> cfg.Variable:
144-
# The ``func`` attribute marks this class as a wrapper for
145-
# ``maybe_unwrap_decorated_function``.
146-
return self.underlying
142+
def get_signatures(self) -> Sequence[function.Signature]:
143+
sigs = []
144+
args = function.Args(posargs=self.args, namedargs=self.kwargs)
145+
for data in self.underlying.data:
146+
for sig in function.get_signatures(data):
147+
# Use the partial arguments as defaults in the signature, making them
148+
# optional but overwritable.
149+
defaults = sig.defaults.copy()
150+
for name, value, _ in sig.iter_args(args):
151+
if value is not None:
152+
defaults[name] = value
153+
sigs.append(sig._replace(defaults=defaults))
154+
return sigs
147155

148156
def call_slot(self, node: cfg.CFGNode, *args, **kwargs):
149157
return function.call_function(

pytype/tests/test_attr2.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,40 @@ def f(x: int) -> str:
226226
@attr.s
227227
class Foo:
228228
x = attr.ib(converter=functools.partial(f))
229-
# We don't yet infer the right type for Foo.x in this case, but we at
230-
# least want to check that constructing a Foo doesn't generate errors.
231-
Foo(x=0)
229+
foo = Foo(x=0)
230+
assert_type(foo.x, str)
231+
""")
232+
233+
def test_partial_overloaded_as_converter(self):
234+
self.Check("""
235+
import attr
236+
import functools
237+
from typing import overload
238+
@overload
239+
def f(x: int, y: int) -> int:
240+
return ''
241+
@overload
242+
def f(x: str, y: int) -> str:
243+
return ''
244+
@attr.s
245+
class Foo:
246+
x = attr.ib(converter=functools.partial(f, 42))
247+
foo = Foo(x=0)
248+
assert_type(foo.x, int)
249+
""")
250+
251+
def test_partial_class_as_converter(self):
252+
self.Check("""
253+
import attr
254+
import functools
255+
class C:
256+
def __init__(self, x: int, y: int) -> None:
257+
self.x = x
258+
@attr.s
259+
class Foo:
260+
x = attr.ib(converter=functools.partial(C, 42))
261+
foo = Foo(x=0)
262+
assert_type(foo.x, C)
232263
""")
233264

234265

0 commit comments

Comments
 (0)