Skip to content

Commit 9e87d0f

Browse files
superbobrycopybara-github
authored andcommitted
Minor improvements to the dataclasses overlay
* Relaxed type checking for dataclasses with unknown/dynamic base classes. The set of fields for such dataclasses is unknown, so we cannot precisely check instantiation and `dataclasses.replace` calls. * Allowed inheriting `__match_args__`. * Ensured that `kw_only=` fields are excluded from `__match_args__`. PiperOrigin-RevId: 826512868
1 parent 3879e23 commit 9e87d0f

File tree

4 files changed

+132
-4
lines changed

4 files changed

+132
-4
lines changed

pytype/abstract/_classes.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,12 +429,19 @@ def __init__(
429429
self._populate_decorator_metadata()
430430
if "__dataclass_fields__" in self.metadata:
431431
self.match_args = tuple(
432-
attr.name for attr in self.metadata["__dataclass_fields__"]
432+
attr.name
433+
for attr in self.metadata["__dataclass_fields__"]
434+
if not attr.kw_only
433435
)
434436
elif self.load_lazy_attribute("__match_args__"):
435437
self.match_args = self._convert_str_tuple("__match_args__") or ()
436438
else:
437-
self.match_args = ()
439+
for base in self.mro[1:]:
440+
if isinstance(base, class_mixin.Class) and hasattr(base, "match_args"):
441+
self.match_args = base.match_args
442+
break
443+
else:
444+
self.match_args = ()
438445

439446
@classmethod
440447
def make(

pytype/overlays/classgen.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,21 @@ def make_init(self, node, cls, attrs, init_method_name="__init__"):
119119
else:
120120
pos_params.append(param)
121121

122+
# If the class has unknown bases or is dynamic, we can't know all possible
123+
# fields, so accept arbitrary positional and keyword arguments.
124+
has_unknown_fields = (
125+
self.ctx.convert.unsolvable in cls.mro or cls.is_dynamic
126+
)
127+
122128
return overlay_utils.make_method(
123-
self.ctx, node, init_method_name, pos_params, 0, kwonly_params
129+
self.ctx,
130+
node,
131+
init_method_name,
132+
pos_params,
133+
posonly_count=0,
134+
kwonly_params=kwonly_params,
135+
varargs=Param("args") if has_unknown_fields else None,
136+
kwargs=Param("kwargs") if has_unknown_fields else None,
124137
)
125138

126139
def call(self, node, func, args, alias_map=None):

pytype/overlays/dataclass_overlay.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,13 +293,18 @@ def _match_args_sequentially(self, node, args, alias_map, match_all_views):
293293
fields = obj.cls.metadata["__dataclass_fields__"]
294294
# 0 or more fields can be replaced, so we give every field a default.
295295
default = self.ctx.new_unsolvable(node)
296+
# If the class has unknown bases or is dynamic, we can't know all possible
297+
# fields, so we accept arbitrary keyword arguments via kwargs_name.
298+
has_unknown_fields = (
299+
self.ctx.convert.unsolvable in obj.cls.mro or obj.cls.is_dynamic
300+
)
296301
replace = abstract.SimpleFunction.build(
297302
name=self.name,
298303
param_names=("obj",),
299304
posonly_count=1,
300305
varargs_name=None,
301306
kwonly_params=tuple(f.name for f in fields),
302-
kwargs_name=None,
307+
kwargs_name="kwargs" if has_unknown_fields else None,
303308
defaults={f.name: default for f in fields},
304309
annotations={f.name: f.typ for f in fields},
305310
ctx=self.ctx,

pytype/tests/test_dataclasses.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,33 @@ def __init__(self, a: bool) -> None: ...
147147
""",
148148
)
149149

150+
def test_init_unknown_base(self):
151+
self.CheckWithErrors("""
152+
import dataclasses
153+
from foo import Base # pytype: disable=import-error
154+
@dataclasses.dataclass
155+
class A(Base):
156+
x: int
157+
A(x=42)
158+
A(x="wrong") # wrong-arg-types
159+
A(x=42, y="from_base")
160+
A(42, "from_base")
161+
""")
162+
163+
def test_init_dynamic_base(self):
164+
self.CheckWithErrors("""
165+
import dataclasses
166+
class Base:
167+
_HAS_DYNAMIC_ATTRIBUTES = True
168+
@dataclasses.dataclass
169+
class A(Base):
170+
x: int
171+
A(x=42)
172+
A(x="wrong") # wrong-arg-types
173+
A(x=42, y="from_base")
174+
A(42, "from_base")
175+
""")
176+
150177
def test_field(self):
151178
ty = self.Infer("""
152179
from typing import List
@@ -968,6 +995,31 @@ class C:
968995
errors, {"e": ["Expected", "str", "Actual", "int"]}
969996
)
970997

998+
def test_replace_unknown_base(self):
999+
self.CheckWithErrors("""
1000+
import dataclasses
1001+
from foo import Base # pytype: disable=import-error
1002+
@dataclasses.dataclass
1003+
class A(Base):
1004+
x: int
1005+
a = A(x=42)
1006+
dataclasses.replace(a, x="wrong") # wrong-arg-types
1007+
dataclasses.replace(a, y="from_base")
1008+
""")
1009+
1010+
def test_replace_dynamic_base(self):
1011+
self.CheckWithErrors("""
1012+
import dataclasses
1013+
class Base:
1014+
_HAS_DYNAMIC_ATTRIBUTES = True
1015+
@dataclasses.dataclass
1016+
class A(Base):
1017+
x: int
1018+
a = A(x=42)
1019+
dataclasses.replace(a, x="wrong") # wrong-arg-types
1020+
dataclasses.replace(a, y="from_base")
1021+
""")
1022+
9711023

9721024
class TestPyiDataclass(test_base.BaseTest):
9731025
"""Tests for @dataclasses in pyi files."""
@@ -1718,6 +1770,57 @@ def f(x, y):
17181770
print("not matched")
17191771
""")
17201772

1773+
def test_inheritance(self):
1774+
with self.DepTree([(
1775+
"foo.pyi",
1776+
"""
1777+
import dataclasses
1778+
@dataclasses.dataclass
1779+
class Point:
1780+
x: float
1781+
y: float
1782+
1783+
class OtherPoint(Point):
1784+
...
1785+
""",
1786+
)]):
1787+
self.Check("""
1788+
import foo
1789+
def f(x, y):
1790+
p = foo.OtherPoint(x, y)
1791+
match p:
1792+
case foo.OtherPoint(x, y):
1793+
print(f"({x}, {y})")
1794+
case _:
1795+
print("not matched")
1796+
""")
1797+
1798+
def test_kw_only(self):
1799+
with self.DepTree([(
1800+
"foo.pyi",
1801+
"""
1802+
import dataclasses
1803+
@dataclasses.dataclass
1804+
class Point:
1805+
x: float
1806+
_: dataclasses.KW_ONLY
1807+
y: float
1808+
1809+
class PointWithKwOnly(Point):
1810+
...
1811+
""",
1812+
)]):
1813+
self.CheckWithErrors("""
1814+
import foo
1815+
def f(x, y):
1816+
p = foo.PointWithKwOnly(x, y=y)
1817+
match p:
1818+
case foo.PointWithKwOnly(x, y): # match-error
1819+
print(f"({x}, {y})")
1820+
case _:
1821+
print("not matched")
1822+
""")
1823+
17211824

17221825
if __name__ == "__main__":
17231826
test_base.main()

0 commit comments

Comments
 (0)