|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | | -from typing import Any |
| 5 | +from typing import Any, Final, Literal, TypedDict, cast |
| 6 | +from typing_extensions import NotRequired |
6 | 7 |
|
7 | 8 | from mypy.nodes import ( |
8 | 9 | ARG_NAMED, |
|
31 | 32 | from mypy.types import FINAL_DECORATOR_NAMES |
32 | 33 | from mypyc.errors import Errors |
33 | 34 |
|
34 | | -DATACLASS_DECORATORS = {"dataclasses.dataclass", "attr.s", "attr.attrs"} |
| 35 | +MYPYC_ATTRS: Final[frozenset[MypycAttr]] = frozenset( |
| 36 | + ["native_class", "allow_interpreted_subclasses", "serializable", "free_list_len"] |
| 37 | +) |
| 38 | + |
| 39 | +DATACLASS_DECORATORS: Final = frozenset(["dataclasses.dataclass", "attr.s", "attr.attrs"]) |
| 40 | + |
| 41 | + |
| 42 | +MypycAttr = Literal[ |
| 43 | + "native_class", "allow_interpreted_subclasses", "serializable", "free_list_len" |
| 44 | +] |
| 45 | + |
| 46 | + |
| 47 | +class MypycAttrs(TypedDict): |
| 48 | + native_class: NotRequired[bool] |
| 49 | + allow_interpreted_subclasses: NotRequired[bool] |
| 50 | + serializable: NotRequired[bool] |
| 51 | + free_list_len: NotRequired[int] |
35 | 52 |
|
36 | 53 |
|
37 | 54 | def is_final_decorator(d: Expression) -> bool: |
@@ -112,21 +129,39 @@ def get_mypyc_attr_call(d: Expression) -> CallExpr | None: |
112 | 129 | return None |
113 | 130 |
|
114 | 131 |
|
115 | | -def get_mypyc_attrs(stmt: ClassDef | Decorator) -> tuple[dict[str, Any], dict[str, int]]: |
| 132 | +def get_mypyc_attrs( |
| 133 | + stmt: ClassDef | Decorator, path: str, errors: Errors |
| 134 | +) -> tuple[MypycAttrs, dict[MypycAttr, int]]: |
116 | 135 | """Collect all the mypyc_attr attributes on a class definition or a function.""" |
117 | | - attrs: dict[str, Any] = {} |
118 | | - lines: dict[str, int] = {} |
| 136 | + attrs: MypycAttrs = {} |
| 137 | + lines: dict[MypycAttr, int] = {} |
| 138 | + |
| 139 | + def set_mypyc_attr(key: str, value: Any, line: int) -> None: |
| 140 | + if key in MYPYC_ATTRS: |
| 141 | + key = cast(MypycAttr, key) |
| 142 | + attrs[key] = value |
| 143 | + lines[key] = line |
| 144 | + else: |
| 145 | + errors.error(f'"{key}" is not a supported "mypyc_attr"', path, line) |
| 146 | + supported_keys = '", "'.join(sorted(MYPYC_ATTRS)) |
| 147 | + errors.note(f'supported keys: "{supported_keys}"', path, line) |
| 148 | + |
119 | 149 | for dec in stmt.decorators: |
120 | | - d = get_mypyc_attr_call(dec) |
121 | | - if d: |
| 150 | + if d := get_mypyc_attr_call(dec): |
| 151 | + line = d.line |
122 | 152 | for name, arg in zip(d.arg_names, d.args): |
123 | 153 | if name is None: |
124 | 154 | if isinstance(arg, StrExpr): |
125 | | - attrs[arg.value] = True |
126 | | - lines[arg.value] = d.line |
| 155 | + set_mypyc_attr(arg.value, True, line) |
| 156 | + else: |
| 157 | + errors.error( |
| 158 | + 'All "mypyc_attr" positional arguments must be string literals.', |
| 159 | + path, |
| 160 | + line, |
| 161 | + ) |
127 | 162 | else: |
128 | | - attrs[name] = get_mypyc_attr_literal(arg) |
129 | | - lines[name] = d.line |
| 163 | + arg_value = get_mypyc_attr_literal(arg) |
| 164 | + set_mypyc_attr(name, arg_value, line) |
130 | 165 |
|
131 | 166 | return attrs, lines |
132 | 167 |
|
|
0 commit comments