Skip to content

Commit 3879e23

Browse files
superbobrycopybara-github
authored andcommitted
Fixed an edge-case in unpacked argument matching
PiperOrigin-RevId: 824947727
1 parent 02e3f33 commit 3879e23

File tree

3 files changed

+43
-3
lines changed

3 files changed

+43
-3
lines changed

pytype/matcher.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,7 +1516,12 @@ def _match_heterogeneous_tuple_instance(
15161516
# accidentally violate _satisfies_common_superclass.
15171517
new_substs = []
15181518
for instance_param in instance.pyval:
1519-
if copy_params_directly and instance_param.bindings:
1519+
if abstract_utils.is_var_splat(instance_param):
1520+
instance_param = abstract_utils.unwrap_splat(instance_param)
1521+
new_subst = self._match_all_bindings(
1522+
instance_param, class_param, subst, view
1523+
)
1524+
elif copy_params_directly and instance_param.bindings:
15201525
new_subst = {
15211526
class_param.full_name: view[instance_param].AssignToNewVariable(
15221527
self._node
@@ -1528,7 +1533,8 @@ def _match_heterogeneous_tuple_instance(
15281533
)
15291534
if new_subst is None:
15301535
return None
1531-
new_substs.append(new_subst)
1536+
if new_subst is not None:
1537+
new_substs.append(new_subst)
15321538
if new_substs:
15331539
subst = self._merge_substs(subst, new_substs)
15341540
if not instance.pyval:

pytype/output.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,12 @@ def _value_to_parameter_types(self, node, v, instance, template, seen, view):
149149
type_arguments = []
150150
for t in template:
151151
if isinstance(instance, abstract.Tuple):
152+
elem_var = instance.pyval[t]
153+
if abstract_utils.is_var_splat(elem_var):
154+
elem_var = abstract_utils.unwrap_splat(elem_var)
152155
param_values = {
153156
val: view
154-
for val in self._get_values(node, instance.pyval[t], view)
157+
for val in self._get_values(node, elem_var, view)
155158
}
156159
elif instance.has_instance_type_parameter(t):
157160
param_values = {

pytype/tests/test_functions2.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,37 @@ def test_unpack_str(self):
651651
""",
652652
)
653653

654+
def test_unpack_tuple(self):
655+
# The **kwargs unpacking in the wrapper seems to prevent pytype from
656+
# eagerly expanding the splat in the tuple literal.
657+
ty = self.Infer("""
658+
def f(*, xs: tuple[int, ...], **kwargs: object):
659+
def wrapper():
660+
out = f(
661+
xs=(42, *kwargs.pop("xs", ())),
662+
**kwargs,
663+
)()
664+
return wrapper
665+
""")
666+
self.assertTypesMatchPytd(
667+
ty,
668+
"""
669+
from typing import Any, Callable
670+
def f(*, xs: tuple[int, ...], **kwargs: object) -> Callable[[], Any]: ...
671+
""",
672+
)
673+
674+
def test_unpack_tuple_invalid(self):
675+
self.InferWithErrors("""
676+
def f(*, xs: tuple[int, ...], **kwargs: object):
677+
def wrapper():
678+
out = f( # wrong-arg-types
679+
xs=(object(), *kwargs.pop("xs", ())),
680+
**kwargs,
681+
)()
682+
return wrapper
683+
""")
684+
654685
def test_unpack_nonliteral(self):
655686
ty = self.Infer("""
656687
def f(x, **kwargs):

0 commit comments

Comments
 (0)