diff --git a/mypyc/irbuild/prepare.py b/mypyc/irbuild/prepare.py index 20f2aeef8e6e..e4f43b38b0dc 100644 --- a/mypyc/irbuild/prepare.py +++ b/mypyc/irbuild/prepare.py @@ -357,7 +357,7 @@ def prepare_class_def( ir = mapper.type_to_ir[cdef.info] info = cdef.info - attrs, attrs_lines = get_mypyc_attrs(cdef) + attrs, attrs_lines = get_mypyc_attrs(cdef, path, errors) if attrs.get("allow_interpreted_subclasses") is True: ir.allow_interpreted_subclasses = True if attrs.get("serializable") is True: diff --git a/mypyc/irbuild/util.py b/mypyc/irbuild/util.py index eca2cac7e9db..3028e940f7f9 100644 --- a/mypyc/irbuild/util.py +++ b/mypyc/irbuild/util.py @@ -2,7 +2,8 @@ from __future__ import annotations -from typing import Any +from typing import Any, Final, Literal, TypedDict, cast +from typing_extensions import NotRequired from mypy.nodes import ( ARG_NAMED, @@ -31,7 +32,23 @@ from mypy.types import FINAL_DECORATOR_NAMES from mypyc.errors import Errors -DATACLASS_DECORATORS = {"dataclasses.dataclass", "attr.s", "attr.attrs"} +MYPYC_ATTRS: Final[frozenset[MypycAttr]] = frozenset( + ["native_class", "allow_interpreted_subclasses", "serializable", "free_list_len"] +) + +DATACLASS_DECORATORS: Final = frozenset(["dataclasses.dataclass", "attr.s", "attr.attrs"]) + + +MypycAttr = Literal[ + "native_class", "allow_interpreted_subclasses", "serializable", "free_list_len" +] + + +class MypycAttrs(TypedDict): + native_class: NotRequired[bool] + allow_interpreted_subclasses: NotRequired[bool] + serializable: NotRequired[bool] + free_list_len: NotRequired[int] def is_final_decorator(d: Expression) -> bool: @@ -112,21 +129,39 @@ def get_mypyc_attr_call(d: Expression) -> CallExpr | None: return None -def get_mypyc_attrs(stmt: ClassDef | Decorator) -> tuple[dict[str, Any], dict[str, int]]: +def get_mypyc_attrs( + stmt: ClassDef | Decorator, path: str, errors: Errors +) -> tuple[MypycAttrs, dict[MypycAttr, int]]: """Collect all the mypyc_attr attributes on a class definition or a function.""" - attrs: dict[str, Any] = {} - lines: dict[str, int] = {} + attrs: MypycAttrs = {} + lines: dict[MypycAttr, int] = {} + + def set_mypyc_attr(key: str, value: Any, line: int) -> None: + if key in MYPYC_ATTRS: + key = cast(MypycAttr, key) + attrs[key] = value + lines[key] = line + else: + errors.error(f'"{key}" is not a supported "mypyc_attr"', path, line) + supported_keys = '", "'.join(sorted(MYPYC_ATTRS)) + errors.note(f'supported keys: "{supported_keys}"', path, line) + for dec in stmt.decorators: - d = get_mypyc_attr_call(dec) - if d: + if d := get_mypyc_attr_call(dec): + line = d.line for name, arg in zip(d.arg_names, d.args): if name is None: if isinstance(arg, StrExpr): - attrs[arg.value] = True - lines[arg.value] = d.line + set_mypyc_attr(arg.value, True, line) + else: + errors.error( + 'All "mypyc_attr" positional arguments must be string literals.', + path, + line, + ) else: - attrs[name] = get_mypyc_attr_literal(arg) - lines[name] = d.line + arg_value = get_mypyc_attr_literal(arg) + set_mypyc_attr(name, arg_value, line) return attrs, lines diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index a98b3a7d3dcf..64364cf91ef8 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -2686,3 +2686,18 @@ L0: r6 = PyObject_VectorcallMethod(r3, r5, 9223372036854775812, 0) keep_alive r2, self, key, val return 1 + +[case testInvalidMypycAttr] +from mypy_extensions import mypyc_attr + +@mypyc_attr("allow_interpreted_subclasses", "invalid_arg") # E: "invalid_arg" is not a supported "mypyc_attr" \ + # N: supported keys: "allow_interpreted_subclasses", "free_list_len", "native_class", "serializable" +class InvalidArg: + pass +@mypyc_attr(invalid_kwarg=True) # E: "invalid_kwarg" is not a supported "mypyc_attr" \ + # N: supported keys: "allow_interpreted_subclasses", "free_list_len", "native_class", "serializable" +class InvalidKwarg: + pass +@mypyc_attr(str()) # E: All "mypyc_attr" positional arguments must be string literals. +class InvalidLiteral: + pass