Skip to content

Commit 02e3f33

Browse files
superbobrycopybara-github
authored andcommitted
Another fix to the argument handling in the functools overlay
Turns out * positional arguments aren't necessarily present in `Signature.param_names`; * *args/**kwargs need special handling as well. PiperOrigin-RevId: 823630645
1 parent 2466abd commit 02e3f33

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

pytype/overlays/functools_overlay.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,12 @@ def get_signatures(self) -> Sequence[function.Signature]:
179179
for name, value, _ in sig.iter_args(args):
180180
if value is None:
181181
continue
182-
if sig.param_names.index(name) < sig.posonly_count:
182+
if name == sig.varargs_name or name == sig.kwargs_name:
183+
continue # Nothing to do for packed parameters.
184+
if (
185+
name not in sig.param_names or
186+
sig.param_names.index(name) < sig.posonly_count
187+
):
183188
# The parameter is positional-only, meaning that it cannot be
184189
# overwritten via a keyword argument. Remove it.
185190
bound_param_names.add(name)

pytype/tests/test_attr2.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,33 @@ class Foo:
244244
assert_type(foo.x, int)
245245
""")
246246

247+
def test_partial_with_star_args_as_converter(self):
248+
self.Check("""
249+
import attr
250+
import functools
251+
def f(*args: str) -> str:
252+
return "".join(args)
253+
@attr.s
254+
class Foo:
255+
x = attr.ib(converter=functools.partial(f, "foo", "bar"))
256+
foo = Foo(x=0)
257+
assert_type(foo.x, str)
258+
""")
259+
260+
def test_partial_as_converter_with_factory(self):
261+
# This is a smoke test for signature construction in the functools overlay.
262+
self.Check("""
263+
import collections
264+
import functools
265+
import attr
266+
@attr.s(auto_attribs=True)
267+
class Foo(object):
268+
x = attr.ib(
269+
factory=dict,
270+
converter=functools.partial(collections.defaultdict, lambda: 0),
271+
)
272+
""")
273+
247274
def test_partial_overloaded_as_converter(self):
248275
self.Check("""
249276
import attr

0 commit comments

Comments
 (0)