|
2 | 2 |
|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 |
| -from typing import Any |
| 5 | +from typing import Any, Final, Literal, TypedDict, cast |
| 6 | +from typing_extensions import NotRequired |
6 | 7 |
|
7 | 8 | from mypy.nodes import (
|
8 | 9 | ARG_NAMED,
|
|
31 | 32 | from mypy.types import FINAL_DECORATOR_NAMES
|
32 | 33 | from mypyc.errors import Errors
|
33 | 34 |
|
34 |
| -DATACLASS_DECORATORS = {"dataclasses.dataclass", "attr.s", "attr.attrs"} |
| 35 | +MYPYC_ATTRS: Final[frozenset[MypycAttr]] = frozenset( |
| 36 | + ["native_class", "allow_interpreted_subclasses", "serializable", "free_list_len"] |
| 37 | +) |
| 38 | + |
| 39 | +DATACLASS_DECORATORS: Final = frozenset(["dataclasses.dataclass", "attr.s", "attr.attrs"]) |
| 40 | + |
| 41 | + |
| 42 | +MypycAttr = Literal[ |
| 43 | + "native_class", "allow_interpreted_subclasses", "serializable", "free_list_len" |
| 44 | +] |
| 45 | + |
| 46 | + |
| 47 | +class MypycAttrs(TypedDict): |
| 48 | + native_class: NotRequired[bool] |
| 49 | + allow_interpreted_subclasses: NotRequired[bool] |
| 50 | + serializable: NotRequired[bool] |
| 51 | + free_list_len: NotRequired[int] |
35 | 52 |
|
36 | 53 |
|
37 | 54 | def is_final_decorator(d: Expression) -> bool:
|
@@ -112,21 +129,39 @@ def get_mypyc_attr_call(d: Expression) -> CallExpr | None:
|
112 | 129 | return None
|
113 | 130 |
|
114 | 131 |
|
115 |
| -def get_mypyc_attrs(stmt: ClassDef | Decorator) -> tuple[dict[str, Any], dict[str, int]]: |
| 132 | +def get_mypyc_attrs( |
| 133 | + stmt: ClassDef | Decorator, path: str, errors: Errors |
| 134 | +) -> tuple[MypycAttrs, dict[MypycAttr, int]]: |
116 | 135 | """Collect all the mypyc_attr attributes on a class definition or a function."""
|
117 |
| - attrs: dict[str, Any] = {} |
118 |
| - lines: dict[str, int] = {} |
| 136 | + attrs: MypycAttrs = {} |
| 137 | + lines: dict[MypycAttr, int] = {} |
| 138 | + |
| 139 | + def set_mypyc_attr(key: str, value: Any, line: int) -> None: |
| 140 | + if key in MYPYC_ATTRS: |
| 141 | + key = cast(MypycAttr, key) |
| 142 | + attrs[key] = value |
| 143 | + lines[key] = line |
| 144 | + else: |
| 145 | + errors.error(f'"{key}" is not a supported "mypyc_attr"', path, line) |
| 146 | + supported_keys = '", "'.join(sorted(MYPYC_ATTRS)) |
| 147 | + errors.note(f'supported keys: "{supported_keys}"', path, line) |
| 148 | + |
119 | 149 | for dec in stmt.decorators:
|
120 |
| - d = get_mypyc_attr_call(dec) |
121 |
| - if d: |
| 150 | + if d := get_mypyc_attr_call(dec): |
| 151 | + line = d.line |
122 | 152 | for name, arg in zip(d.arg_names, d.args):
|
123 | 153 | if name is None:
|
124 | 154 | if isinstance(arg, StrExpr):
|
125 |
| - attrs[arg.value] = True |
126 |
| - lines[arg.value] = d.line |
| 155 | + set_mypyc_attr(arg.value, True, line) |
| 156 | + else: |
| 157 | + errors.error( |
| 158 | + 'All "mypyc_attr" positional arguments must be string literals.', |
| 159 | + path, |
| 160 | + line, |
| 161 | + ) |
127 | 162 | else:
|
128 |
| - attrs[name] = get_mypyc_attr_literal(arg) |
129 |
| - lines[name] = d.line |
| 163 | + arg_value = get_mypyc_attr_literal(arg) |
| 164 | + set_mypyc_attr(name, arg_value, line) |
130 | 165 |
|
131 | 166 | return attrs, lines
|
132 | 167 |
|
|
0 commit comments