Skip to content

Commit a8631d1

Browse files
perheldagrima1304
authored andcommitted
Arm backend: Fix mypy warnings in test/misc (pytorch#15354)
Signed-off-by: [email protected]
1 parent 66ba13f commit a8631d1

File tree

8 files changed

+126
-69
lines changed

8 files changed

+126
-69
lines changed

backends/arm/test/misc/test_debug_feats.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,19 @@ def forward(self, x):
5050

5151

5252
def _tosa_FP_pipeline(module: torch.nn.Module, test_data: input_t1, dump_file=None):
53-
54-
pipeline = TosaPipelineFP[input_t1](module, test_data, [], [])
53+
aten_ops: list[str] = []
54+
exir_ops: list[str] = []
55+
pipeline = TosaPipelineFP[input_t1](module, test_data, aten_ops, exir_ops)
5556
pipeline.dump_artifact("to_edge_transform_and_lower")
5657
pipeline.dump_artifact("to_edge_transform_and_lower", suffix=dump_file)
5758
pipeline.pop_stage("run_method_and_compare_outputs")
5859
pipeline.run()
5960

6061

6162
def _tosa_INT_pipeline(module: torch.nn.Module, test_data: input_t1, dump_file=None):
62-
63-
pipeline = TosaPipelineINT[input_t1](module, test_data, [], [])
63+
aten_ops: list[str] = []
64+
exir_ops: list[str] = []
65+
pipeline = TosaPipelineINT[input_t1](module, test_data, aten_ops, exir_ops)
6466
pipeline.dump_artifact("to_edge_transform_and_lower")
6567
pipeline.dump_artifact("to_edge_transform_and_lower", suffix=dump_file)
6668
pipeline.pop_stage("run_method_and_compare_outputs")
@@ -105,11 +107,13 @@ def test_INT_artifact(test_data: input_t1):
105107

106108
@common.parametrize("test_data", Linear.inputs)
107109
def test_numerical_diff_print(test_data: input_t1):
110+
aten_ops: list[str] = []
111+
exir_ops: list[str] = []
108112
pipeline = TosaPipelineINT[input_t1](
109113
Linear(),
110114
test_data,
111-
[],
112-
[],
115+
aten_ops,
116+
exir_ops,
113117
custom_path="diff_print_test",
114118
)
115119
pipeline.pop_stage("run_method_and_compare_outputs")
@@ -131,7 +135,9 @@ def test_numerical_diff_print(test_data: input_t1):
131135

132136
@common.parametrize("test_data", Linear.inputs)
133137
def test_dump_ops_and_dtypes(test_data: input_t1):
134-
pipeline = TosaPipelineINT[input_t1](Linear(), test_data, [], [])
138+
aten_ops: list[str] = []
139+
exir_ops: list[str] = []
140+
pipeline = TosaPipelineINT[input_t1](Linear(), test_data, aten_ops, exir_ops)
135141
pipeline.pop_stage("run_method_and_compare_outputs")
136142
pipeline.add_stage_after("quantize", pipeline.tester.dump_dtype_distribution)
137143
pipeline.add_stage_after("quantize", pipeline.tester.dump_operator_distribution)
@@ -149,7 +155,9 @@ def test_dump_ops_and_dtypes(test_data: input_t1):
149155

150156
@common.parametrize("test_data", Linear.inputs)
151157
def test_dump_ops_and_dtypes_parseable(test_data: input_t1):
152-
pipeline = TosaPipelineINT[input_t1](Linear(), test_data, [], [])
158+
aten_ops: list[str] = []
159+
exir_ops: list[str] = []
160+
pipeline = TosaPipelineINT[input_t1](Linear(), test_data, aten_ops, exir_ops)
153161
pipeline.pop_stage("run_method_and_compare_outputs")
154162
pipeline.add_stage_after("quantize", pipeline.tester.dump_dtype_distribution, False)
155163
pipeline.add_stage_after(
@@ -177,7 +185,9 @@ def test_collate_tosa_INT_tests(test_data: input_t1):
177185
# Set the environment variable to trigger the collation of TOSA tests
178186
os.environ["TOSA_TESTCASES_BASE_PATH"] = "test_collate_tosa_tests"
179187
# Clear out the directory
180-
pipeline = TosaPipelineINT[input_t1](Linear(), test_data, [], [])
188+
aten_ops: list[str] = []
189+
exir_ops: list[str] = []
190+
pipeline = TosaPipelineINT[input_t1](Linear(), test_data, aten_ops, exir_ops)
181191
pipeline.pop_stage("run_method_and_compare_outputs")
182192
pipeline.run()
183193

@@ -197,11 +207,13 @@ def test_collate_tosa_INT_tests(test_data: input_t1):
197207
@common.parametrize("test_data", Linear.inputs)
198208
def test_dump_tosa_debug_json(test_data: input_t1):
199209
with tempfile.TemporaryDirectory() as tmpdir:
210+
aten_ops: list[str] = []
211+
exir_ops: list[str] = []
200212
pipeline = TosaPipelineINT[input_t1](
201213
module=Linear(),
202214
test_data=test_data,
203-
aten_op=[],
204-
exir_op=[],
215+
aten_op=aten_ops,
216+
exir_op=exir_ops,
205217
custom_path=tmpdir,
206218
tosa_debug_mode=ArmCompileSpec.DebugMode.JSON,
207219
)
@@ -228,11 +240,13 @@ def test_dump_tosa_debug_json(test_data: input_t1):
228240
@common.parametrize("test_data", Linear.inputs)
229241
def test_dump_tosa_debug_tosa(test_data: input_t1):
230242
with tempfile.TemporaryDirectory() as tmpdir:
243+
aten_ops: list[str] = []
244+
exir_ops: list[str] = []
231245
pipeline = TosaPipelineINT[input_t1](
232246
module=Linear(),
233247
test_data=test_data,
234-
aten_op=[],
235-
exir_op=[],
248+
aten_op=aten_ops,
249+
exir_op=exir_ops,
236250
custom_path=tmpdir,
237251
tosa_debug_mode=ArmCompileSpec.DebugMode.TOSA,
238252
)
@@ -248,7 +262,9 @@ def test_dump_tosa_debug_tosa(test_data: input_t1):
248262

249263
@common.parametrize("test_data", Linear.inputs)
250264
def test_dump_tosa_ops(caplog, test_data: input_t1):
251-
pipeline = TosaPipelineINT[input_t1](Linear(), test_data, [], [])
265+
aten_ops: list[str] = []
266+
exir_ops: list[str] = []
267+
pipeline = TosaPipelineINT[input_t1](Linear(), test_data, aten_ops, exir_ops)
252268
pipeline.pop_stage("run_method_and_compare_outputs")
253269
pipeline.dump_operator_distribution("to_edge_transform_and_lower")
254270
pipeline.run()
@@ -267,8 +283,10 @@ def forward(self, x):
267283
@common.parametrize("test_data", Add.inputs)
268284
@common.XfailIfNoCorstone300
269285
def test_fail_dump_tosa_ops(caplog, test_data: input_t1):
286+
aten_ops: list[str] = []
287+
exir_ops: list[str] = []
270288
pipeline = EthosU55PipelineINT[input_t1](
271-
Add(), test_data, [], [], use_to_edge_transform_and_lower=True
289+
Add(), test_data, aten_ops, exir_ops, use_to_edge_transform_and_lower=True
272290
)
273291
pipeline.dump_operator_distribution("to_edge_transform_and_lower")
274292
pipeline.run()

backends/arm/test/misc/test_debug_hook.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55

66
from dataclasses import dataclass
77
from types import SimpleNamespace
8+
from typing import cast
89

910
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
1011
from executorch.backends.arm.debug.schema import DebugHook, DebugSchema
1112
from executorch.backends.arm.test import common
1213

14+
from torch.fx import Node
15+
1316

1417
@dataclass
1518
class DebugHookTestCase:
@@ -95,9 +98,9 @@ def create_mock_node_3():
9598
return fx_node_mock
9699

97100

98-
def _compare_tosa_and_schema(debug_event: DebugSchema, tosa_op):
101+
def _compare_tosa_and_schema(debug_event: DebugSchema, tosa_op: str) -> None:
99102
tosa_info = debug_event.tosa_info
100-
103+
assert tosa_info is not None
101104
assert tosa_info.node_name == tosa_op
102105

103106
# The mapping between op_ids to operator names could change
@@ -159,7 +162,7 @@ def _compare_node_and_schema(debug_event: DebugSchema, mocked_node):
159162
@common.parametrize("test_data", TESTCASES)
160163
def test_debug_hook_add_json(test_data: DebugHookTestCase):
161164
hook = DebugHook(ArmCompileSpec.DebugMode.JSON)
162-
hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id)
165+
hook.add(cast(Node, test_data.mock_node), test_data.tosa_op, test_data.op_id)
163166

164167
debug_events = hook._debug_events
165168
assert len(debug_events) == test_data.expected_events
@@ -172,7 +175,7 @@ def test_debug_hook_add_json(test_data: DebugHookTestCase):
172175
@common.parametrize("test_data", TESTCASES)
173176
def test_debug_hook_add_tosa(test_data: DebugHookTestCase):
174177
hook = DebugHook(ArmCompileSpec.DebugMode.TOSA)
175-
hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id)
178+
hook.add(cast(Node, test_data.mock_node), test_data.tosa_op, test_data.op_id)
176179

177180
debug_events = hook._debug_events
178181
assert len(debug_events) == test_data.expected_events

backends/arm/test/misc/test_dim_order.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,28 +96,32 @@ def forward(self, x):
9696

9797

9898
@common.parametrize("module", test_modules)
99-
def test_dim_order_tosa_FP(module):
100-
pipeline = TosaPipelineFP[input_t1](module(), module.inputs, [])
99+
def test_dim_order_tosa_FP(module) -> None:
100+
aten_ops: list[str] = []
101+
pipeline = TosaPipelineFP[input_t1](module(), module.inputs, aten_ops)
101102
pipeline.run()
102103

103104

104105
@common.parametrize("module", test_modules)
105-
def test_dim_order_tosa_INT(module):
106+
def test_dim_order_tosa_INT(module) -> None:
107+
aten_ops: list[str] = []
106108
pipeline = TosaPipelineINT[input_t1](
107-
module(), module.inputs, [], symmetric_io_quantization=True
109+
module(), module.inputs, aten_ops, symmetric_io_quantization=True
108110
)
109111
pipeline.run()
110112

111113

112114
@common.XfailIfNoCorstone300
113115
@common.parametrize("module", test_modules)
114-
def test_dim_order_u55_INT(module):
115-
pipeline = EthosU55PipelineINT[input_t1](module(), module.inputs, [])
116+
def test_dim_order_u55_INT(module) -> None:
117+
aten_ops: list[str] = []
118+
pipeline = EthosU55PipelineINT[input_t1](module(), module.inputs, aten_ops)
116119
pipeline.run()
117120

118121

119122
@common.XfailIfNoCorstone320
120123
@common.parametrize("module", test_modules)
121-
def test_dim_order_u85_INT(module):
122-
pipeline = EthosU85PipelineINT[input_t1](module(), module.inputs, [])
124+
def test_dim_order_u85_INT(module) -> None:
125+
aten_ops: list[str] = []
126+
pipeline = EthosU85PipelineINT[input_t1](module(), module.inputs, aten_ops)
123127
pipeline.run()

backends/arm/test/misc/test_lifted_tensor.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import operator
7-
from typing import Tuple, Union
7+
from collections.abc import Callable
8+
from typing import Union
89

910
import torch
1011
from executorch.backends.arm.test import common
@@ -15,12 +16,22 @@
1516
from executorch.backends.test.harness.stages import StageType
1617

1718

18-
input_t1 = Tuple[torch.Tensor]
19+
LiftedTensorInputs = tuple[torch.Tensor, int]
20+
LiftedTensorCase = tuple[
21+
Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
22+
LiftedTensorInputs,
23+
]
24+
LiftedScalarTensorInputs = tuple[torch.Tensor, ...]
25+
LiftedScalarTensorCase = tuple[
26+
Callable[[torch.Tensor, Union[float, int, torch.Tensor]], torch.Tensor],
27+
LiftedScalarTensorInputs,
28+
Union[float, int, torch.Tensor],
29+
]
1930

2031

2132
class LiftedTensor(torch.nn.Module):
2233

23-
test_data = {
34+
test_data: dict[str, LiftedTensorCase] = {
2435
# test_name: (operator, test_data, length)
2536
"add": (operator.add, (torch.randn(2, 2), 2)),
2637
"truediv": (operator.truediv, (torch.ones(2, 2), 2)),
@@ -39,7 +50,7 @@ def forward(self, x: torch.Tensor, length) -> torch.Tensor:
3950

4051

4152
class LiftedScalarTensor(torch.nn.Module):
42-
test_data = {
53+
test_data: dict[str, LiftedScalarTensorCase] = {
4354
# test_name: (operator, test_data)
4455
"add": (operator.add, (torch.randn(2, 2),), 1.0),
4556
"truediv": (operator.truediv, (torch.randn(4, 2),), 1.0),
@@ -60,14 +71,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6071

6172

6273
@common.parametrize("test_data", LiftedTensor.test_data)
63-
def test_partition_lifted_tensor_tosa_FP(test_data: input_t1):
64-
op = test_data[0]
65-
data = test_data[1:]
74+
def test_partition_lifted_tensor_tosa_FP(test_data: LiftedTensorCase) -> None:
75+
op, inputs = test_data
6676
module = LiftedTensor(op)
67-
pipeline = TosaPipelineFP[input_t1](
77+
aten_ops: list[str] = []
78+
pipeline = TosaPipelineFP[LiftedTensorInputs](
6879
module,
69-
*data,
70-
[],
80+
inputs,
81+
aten_ops,
7182
exir_op=[],
7283
use_to_edge_transform_and_lower=False,
7384
)
@@ -81,14 +92,14 @@ def test_partition_lifted_tensor_tosa_FP(test_data: input_t1):
8192

8293

8394
@common.parametrize("test_data", LiftedTensor.test_data)
84-
def test_partition_lifted_tensor_tosa_INT(test_data: input_t1):
85-
op = test_data[0]
86-
data = test_data[1:]
95+
def test_partition_lifted_tensor_tosa_INT(test_data: LiftedTensorCase) -> None:
96+
op, inputs = test_data
8797
module = LiftedTensor(op)
88-
pipeline = TosaPipelineINT[input_t1](
98+
aten_ops: list[str] = []
99+
pipeline = TosaPipelineINT[LiftedTensorInputs](
89100
module,
90-
*data,
91-
[],
101+
inputs,
102+
aten_ops,
92103
exir_op=[],
93104
use_to_edge_transform_and_lower=False,
94105
)
@@ -102,29 +113,33 @@ def test_partition_lifted_tensor_tosa_INT(test_data: input_t1):
102113

103114

104115
@common.parametrize("test_data", LiftedScalarTensor.test_data)
105-
def test_partition_lifted_scalar_tensor_tosa_FP(test_data: input_t1):
106-
op = test_data[0]
107-
data = test_data[1:]
108-
module = LiftedScalarTensor(op, data[-1])
109-
pipeline = TosaPipelineFP[input_t1](
116+
def test_partition_lifted_scalar_tensor_tosa_FP(
117+
test_data: LiftedScalarTensorCase,
118+
) -> None:
119+
op, tensor_inputs, scalar_arg = test_data
120+
module = LiftedScalarTensor(op, scalar_arg)
121+
aten_ops: list[str] = []
122+
pipeline = TosaPipelineFP[LiftedScalarTensorInputs](
110123
module,
111-
data[0],
112-
[],
124+
tensor_inputs,
125+
aten_ops,
113126
exir_op=[],
114127
use_to_edge_transform_and_lower=False,
115128
)
116129
pipeline.run()
117130

118131

119132
@common.parametrize("test_data", LiftedScalarTensor.test_data)
120-
def test_partition_lifted_scalar_tensor_tosa_INT(test_data: input_t1):
121-
op = test_data[0]
122-
data = test_data[1:]
123-
module = LiftedScalarTensor(op, data[-1])
124-
pipeline = TosaPipelineINT[input_t1](
133+
def test_partition_lifted_scalar_tensor_tosa_INT(
134+
test_data: LiftedScalarTensorCase,
135+
) -> None:
136+
op, tensor_inputs, scalar_arg = test_data
137+
module = LiftedScalarTensor(op, scalar_arg)
138+
aten_ops: list[str] = []
139+
pipeline = TosaPipelineINT[LiftedScalarTensorInputs](
125140
module,
126-
data[0],
127-
[],
141+
tensor_inputs,
142+
aten_ops,
128143
exir_op=[],
129144
use_to_edge_transform_and_lower=False,
130145
)

backends/arm/test/misc/test_multiple_delegates.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
2929

3030
@common.parametrize("test_data", MultipleDelegatesModule.inputs)
3131
def test_tosa_FP_pipeline(test_data: input_t1):
32-
pipeline = TosaPipelineFP[input_t1](MultipleDelegatesModule(), test_data, [], [])
32+
aten_ops: list[str] = []
33+
exir_ops: list[str] = []
34+
pipeline = TosaPipelineFP[input_t1](
35+
MultipleDelegatesModule(), test_data, aten_ops, exir_ops
36+
)
3337
pipeline.change_args(
3438
"check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2}
3539
)
@@ -38,8 +42,10 @@ def test_tosa_FP_pipeline(test_data: input_t1):
3842

3943
@common.parametrize("test_data", MultipleDelegatesModule.inputs)
4044
def test_tosa_INT_pipeline(test_data: input_t1):
45+
aten_ops: list[str] = []
46+
exir_ops: list[str] = []
4147
pipeline = TosaPipelineINT[input_t1](
42-
MultipleDelegatesModule(), test_data, [], [], qtol=1
48+
MultipleDelegatesModule(), test_data, aten_ops, exir_ops, qtol=1
4349
)
4450
pipeline.change_args(
4551
"check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2}

0 commit comments

Comments
 (0)