|
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 | import itertools |
6 | 6 |
|
7 | | -from typing import Tuple |
| 7 | +from typing import Any, Callable, Tuple |
8 | 8 |
|
9 | 9 | import torch |
10 | 10 | from executorch.backends.arm.quantizer import is_annotated |
|
18 | 18 |
|
19 | 19 |
|
20 | 20 | class SingleOpModel(torch.nn.Module): |
21 | | - def __init__(self, op, example_input, **op_kwargs) -> None: |
| 21 | + def __init__( |
| 22 | + self, |
| 23 | + op: Callable[..., torch.Tensor], |
| 24 | + example_input: Tuple[Any, ...], |
| 25 | + **op_kwargs: Any, |
| 26 | + ) -> None: |
22 | 27 | super().__init__() |
23 | | - self.op = op |
24 | | - self._example_input = example_input |
25 | | - self.op_kwargs = op_kwargs |
| 28 | + self.op: Callable[..., torch.Tensor] = op |
| 29 | + self._example_input: Tuple[Any, ...] = example_input |
| 30 | + self.op_kwargs: dict[str, Any] = dict(op_kwargs) |
26 | 31 |
|
27 | | - def forward(self, x): |
| 32 | + def forward(self, x: Any) -> torch.Tensor: |
28 | 33 | return self.op(x, **self.op_kwargs) |
29 | 34 |
|
30 | | - def example_inputs(self): |
| 35 | + def example_inputs(self) -> Tuple[Any, ...]: |
31 | 36 | return self._example_input |
32 | 37 |
|
33 | 38 |
|
34 | | -def check_annotation(model): |
| 39 | +def check_annotation(model: SingleOpModel) -> None: |
35 | 40 | pipeline = TosaPipelineINT[input_t1](model, model.example_inputs(), [], []) |
36 | 41 | pipeline.pop_stage("check_count.exir") |
37 | 42 | pipeline.pop_stage("run_method_and_compare_outputs") |
|
0 commit comments