diff --git a/pytype/abstract/_classes.py b/pytype/abstract/_classes.py index ee90484d5..f47266cdb 100644 --- a/pytype/abstract/_classes.py +++ b/pytype/abstract/_classes.py @@ -429,12 +429,19 @@ def __init__( self._populate_decorator_metadata() if "__dataclass_fields__" in self.metadata: self.match_args = tuple( - attr.name for attr in self.metadata["__dataclass_fields__"] + attr.name + for attr in self.metadata["__dataclass_fields__"] + if not attr.kw_only ) elif self.load_lazy_attribute("__match_args__"): self.match_args = self._convert_str_tuple("__match_args__") or () else: - self.match_args = () + for base in self.mro[1:]: + if isinstance(base, class_mixin.Class) and hasattr(base, "match_args"): + self.match_args = base.match_args + break + else: + self.match_args = () @classmethod def make( diff --git a/pytype/overlays/classgen.py b/pytype/overlays/classgen.py index 8bd1afd08..0107b70d1 100644 --- a/pytype/overlays/classgen.py +++ b/pytype/overlays/classgen.py @@ -119,8 +119,21 @@ def make_init(self, node, cls, attrs, init_method_name="__init__"): else: pos_params.append(param) + # If the class has unknown bases or is dynamic, we can't know all possible + # fields, so accept arbitrary positional and keyword arguments. + has_unknown_fields = ( + self.ctx.convert.unsolvable in cls.mro or cls.is_dynamic + ) + return overlay_utils.make_method( - self.ctx, node, init_method_name, pos_params, 0, kwonly_params + self.ctx, + node, + init_method_name, + pos_params, + posonly_count=0, + kwonly_params=kwonly_params, + varargs=Param("args") if has_unknown_fields else None, + kwargs=Param("kwargs") if has_unknown_fields else None, ) def call(self, node, func, args, alias_map=None): diff --git a/pytype/overlays/dataclass_overlay.py b/pytype/overlays/dataclass_overlay.py index 590ce189b..798b47765 100644 --- a/pytype/overlays/dataclass_overlay.py +++ b/pytype/overlays/dataclass_overlay.py @@ -293,13 +293,18 @@ def _match_args_sequentially(self, node, args, alias_map, match_all_views): fields = obj.cls.metadata["__dataclass_fields__"] # 0 or more fields can be replaced, so we give every field a default. default = self.ctx.new_unsolvable(node) + # If the class has unknown bases or is dynamic, we can't know all possible + # fields, so we accept arbitrary keyword arguments via kwargs_name. + has_unknown_fields = ( + self.ctx.convert.unsolvable in obj.cls.mro or obj.cls.is_dynamic + ) replace = abstract.SimpleFunction.build( name=self.name, param_names=("obj",), posonly_count=1, varargs_name=None, kwonly_params=tuple(f.name for f in fields), - kwargs_name=None, + kwargs_name="kwargs" if has_unknown_fields else None, defaults={f.name: default for f in fields}, annotations={f.name: f.typ for f in fields}, ctx=self.ctx, diff --git a/pytype/tests/test_dataclasses.py b/pytype/tests/test_dataclasses.py index 2b941e90f..aa1cae6a6 100644 --- a/pytype/tests/test_dataclasses.py +++ b/pytype/tests/test_dataclasses.py @@ -147,6 +147,33 @@ def __init__(self, a: bool) -> None: ... """, ) + def test_init_unknown_base(self): + self.CheckWithErrors(""" + import dataclasses + from foo import Base # pytype: disable=import-error + @dataclasses.dataclass + class A(Base): + x: int + A(x=42) + A(x="wrong") # wrong-arg-types + A(x=42, y="from_base") + A(42, "from_base") + """) + + def test_init_dynamic_base(self): + self.CheckWithErrors(""" + import dataclasses + class Base: + _HAS_DYNAMIC_ATTRIBUTES = True + @dataclasses.dataclass + class A(Base): + x: int + A(x=42) + A(x="wrong") # wrong-arg-types + A(x=42, y="from_base") + A(42, "from_base") + """) + def test_field(self): ty = self.Infer(""" from typing import List @@ -968,6 +995,31 @@ class C: errors, {"e": ["Expected", "str", "Actual", "int"]} ) + def test_replace_unknown_base(self): + self.CheckWithErrors(""" + import dataclasses + from foo import Base # pytype: disable=import-error + @dataclasses.dataclass + class A(Base): + x: int + a = A(x=42) + dataclasses.replace(a, x="wrong") # wrong-arg-types + dataclasses.replace(a, y="from_base") + """) + + def test_replace_dynamic_base(self): + self.CheckWithErrors(""" + import dataclasses + class Base: + _HAS_DYNAMIC_ATTRIBUTES = True + @dataclasses.dataclass + class A(Base): + x: int + a = A(x=42) + dataclasses.replace(a, x="wrong") # wrong-arg-types + dataclasses.replace(a, y="from_base") + """) + class TestPyiDataclass(test_base.BaseTest): """Tests for @dataclasses in pyi files.""" @@ -1718,6 +1770,57 @@ def f(x, y): print("not matched") """) + def test_inheritance(self): + with self.DepTree([( + "foo.pyi", + """ + import dataclasses + @dataclasses.dataclass + class Point: + x: float + y: float + + class OtherPoint(Point): + ... + """, + )]): + self.Check(""" + import foo + def f(x, y): + p = foo.OtherPoint(x, y) + match p: + case foo.OtherPoint(x, y): + print(f"({x}, {y})") + case _: + print("not matched") + """) + + def test_kw_only(self): + with self.DepTree([( + "foo.pyi", + """ + import dataclasses + @dataclasses.dataclass + class Point: + x: float + _: dataclasses.KW_ONLY + y: float + + class PointWithKwOnly(Point): + ... + """, + )]): + self.CheckWithErrors(""" + import foo + def f(x, y): + p = foo.PointWithKwOnly(x, y=y) + match p: + case foo.PointWithKwOnly(x, y): # match-error + print(f"({x}, {y})") + case _: + print("not matched") + """) + if __name__ == "__main__": test_base.main()