Skip to content

Commit 135e3b6

Browse files
authored
Arm backend: Fix mypy warnings in test/passes (#15488)
Signed-off-by: [email protected]
1 parent 8374421 commit 135e3b6

28 files changed

+406
-241
lines changed

backends/arm/test/passes/test_broadcast_args_pass.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,27 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import operator
7-
from typing import Tuple
7+
from typing import Callable, Tuple
88

99
import torch
1010
from executorch.backends.arm._passes import BroadcastArgsPass
1111

1212
from executorch.backends.arm.test import common
1313
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1414

15-
input_t = Tuple[torch.Tensor] # Input x
15+
input_t = Tuple[torch.Tensor, torch.Tensor]
1616

1717

1818
class NeedsMultipleBroadcastsModel(torch.nn.Module):
1919
test_data = (torch.rand(1, 10), torch.rand(10, 1))
2020

21-
def __init__(self, op: operator):
21+
def __init__(
22+
self, op: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
23+
) -> None:
2224
self.op = op
2325
super().__init__()
2426

25-
def forward(self, x: torch.Tensor, y: torch.Tensor):
27+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2628
return self.op(x, y)
2729

2830

backends/arm/test/passes/test_cast_int64_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Int64Model(torch.nn.Module):
2121
"rand": (torch.rand(4),),
2222
}
2323

24-
def forward(self, x: torch.Tensor):
24+
def forward(self, x: torch.Tensor) -> torch.Tensor:
2525
return x + 3
2626

2727

backends/arm/test/passes/test_convert_expand_copy_to_repeat.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@ class Expand(torch.nn.Module):
2020
Basic expand model using torch.Tensor.expand function
2121
"""
2222

23-
def __init__(self):
24-
super(Expand, self).__init__()
23+
def __init__(self) -> None:
24+
super().__init__()
2525

26-
def forward(self, x):
26+
def forward(self, x: torch.Tensor) -> torch.Tensor:
2727
return x.expand(3, 4)
2828

2929
def get_inputs(self) -> input_t:
3030
return (torch.rand(3, 1),)
3131

3232

33-
def test_expand_to_repeat_tosa_INT():
33+
def test_expand_to_repeat_tosa_INT() -> None:
3434
module = Expand()
3535
pipeline = PassPipeline[input_t](
3636
module,

backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import Tuple, Union
6+
from typing import Callable, ClassVar, Dict, Tuple, Union
77

88
import pytest
99

@@ -22,18 +22,21 @@
2222
input_t1 = Tuple[torch.Tensor] # Input x
2323
input_t2 = Tuple[torch.Tensor, torch.Tensor] # Input x, y
2424

25+
Scalar = Union[bool, float, int]
26+
ArangeNoneParam = Tuple[Callable[[], input_t1], Tuple[Scalar, Scalar, Scalar]]
27+
FullNoneParam = Tuple[Callable[[], input_t1], Tuple[Tuple[int, ...], Scalar]]
28+
2529

2630
#####################################################
2731
## Test arange(dtype=int64) -> arange(dtype=int32) ##
2832
#####################################################
2933

3034

3135
class ArangeDefaultIncrementViewLessThan(torch.nn.Module):
32-
33-
def forward(self, x: torch.Tensor):
36+
def forward(self, x: torch.Tensor) -> torch.Tensor:
3437
return (torch.arange(10, dtype=torch.int64) + 1).view(-1, 1) < x
3538

36-
test_data = {
39+
test_data: ClassVar[Dict[str, input_t1]] = {
3740
"randint": (
3841
torch.randint(
3942
0,
@@ -46,7 +49,9 @@ def forward(self, x: torch.Tensor):
4649

4750

4851
@common.parametrize("test_data", ArangeDefaultIncrementViewLessThan.test_data)
49-
def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1):
52+
def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_FP(
53+
test_data: input_t1,
54+
) -> None:
5055
module = ArangeDefaultIncrementViewLessThan()
5156
aten_ops_checks = [
5257
"torch.ops.aten.lt.Tensor",
@@ -67,7 +72,9 @@ def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_FP(test_data: inp
6772

6873

6974
@common.parametrize("test_data", ArangeDefaultIncrementViewLessThan.test_data)
70-
def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1):
75+
def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT(
76+
test_data: input_t1,
77+
) -> None:
7178
module = ArangeDefaultIncrementViewLessThan()
7279
aten_ops_checks = [
7380
"torch.ops.aten.lt.Tensor",
@@ -88,11 +95,10 @@ def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT(test_data: in
8895

8996

9097
class ArangeStartIncrementViewLessThan(torch.nn.Module):
91-
92-
def forward(self, x: torch.Tensor):
98+
def forward(self, x: torch.Tensor) -> torch.Tensor:
9399
return (torch.arange(0, 10, dtype=torch.int64) + 1).view(-1, 1) < x
94100

95-
test_data = {
101+
test_data: ClassVar[Dict[str, input_t1]] = {
96102
"randint": (
97103
torch.randint(
98104
0,
@@ -105,7 +111,9 @@ def forward(self, x: torch.Tensor):
105111

106112

107113
@common.parametrize("test_data", ArangeStartIncrementViewLessThan.test_data)
108-
def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1):
114+
def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_FP(
115+
test_data: input_t1,
116+
) -> None:
109117
module = ArangeStartIncrementViewLessThan()
110118
aten_ops_checks = [
111119
"torch.ops.aten.lt.Tensor",
@@ -126,7 +134,9 @@ def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_FP(test_data: input
126134

127135

128136
@common.parametrize("test_data", ArangeStartIncrementViewLessThan.test_data)
129-
def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1):
137+
def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT(
138+
test_data: input_t1,
139+
) -> None:
130140
module = ArangeStartIncrementViewLessThan()
131141
aten_ops_checks = [
132142
"torch.ops.aten.lt.Tensor",
@@ -147,11 +157,10 @@ def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT(test_data: inpu
147157

148158

149159
class ArangeStartStepIncrementViewLessThan(torch.nn.Module):
150-
151-
def forward(self, x: torch.Tensor):
160+
def forward(self, x: torch.Tensor) -> torch.Tensor:
152161
return (torch.arange(0, 10, 2, dtype=torch.int64) + 1).view(-1, 1) < x
153162

154-
test_data = {
163+
test_data: ClassVar[Dict[str, input_t1]] = {
155164
"randint": (
156165
torch.randint(
157166
0,
@@ -166,7 +175,7 @@ def forward(self, x: torch.Tensor):
166175
@common.parametrize("test_data", ArangeStartStepIncrementViewLessThan.test_data)
167176
def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_FP(
168177
test_data: input_t1,
169-
):
178+
) -> None:
170179
module = ArangeStartStepIncrementViewLessThan()
171180
aten_ops_checks = [
172181
"torch.ops.aten.lt.Tensor",
@@ -189,7 +198,7 @@ def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_FP(
189198
@common.parametrize("test_data", ArangeStartStepIncrementViewLessThan.test_data)
190199
def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_INT(
191200
test_data: input_t1,
192-
):
201+
) -> None:
193202
module = ArangeStartStepIncrementViewLessThan()
194203
aten_ops_checks = [
195204
"torch.ops.aten.lt.Tensor",
@@ -225,7 +234,7 @@ def __init__(self, start: float, stop: float, step: float):
225234
def forward(self, x: torch.Tensor) -> torch.Tensor:
226235
return torch.arange(*self.args) + x
227236

228-
test_data = {
237+
test_data: ClassVar[Dict[str, ArangeNoneParam]] = {
229238
"int64": (lambda: (torch.randn(10, 1),), (0, 10, 1)),
230239
"float32_start": (lambda: (torch.randn(10, 1),), (0.0, 10, 1)),
231240
"float32_stop": (lambda: (torch.randn(10, 1),), (0, 10.0, 1)),
@@ -238,23 +247,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
238247

239248

240249
@common.parametrize("test_data", ArangeAddDtypeNone.test_data)
241-
def test_arange_dtype_none_tosa_FP(test_data):
242-
input_data, init_data = test_data
250+
def test_arange_dtype_none_tosa_FP(test_data: ArangeNoneParam) -> None:
251+
input_factory, init_data = test_data
243252
pipeline = TosaPipelineFP[input_t1](
244253
ArangeAddDtypeNone(*init_data),
245-
input_data(),
254+
input_factory(),
246255
ArangeAddDtypeNone.aten_op,
247256
ArangeAddDtypeNone.exir_op,
248257
)
249258
pipeline.run()
250259

251260

252261
@common.parametrize("test_data", ArangeAddDtypeNone.test_data)
253-
def test_arange_dtype_none_tosa_INT(test_data):
254-
input_data, init_data = test_data
262+
def test_arange_dtype_none_tosa_INT(test_data: ArangeNoneParam) -> None:
263+
input_factory, init_data = test_data
255264
pipeline = TosaPipelineINT[input_t1](
256265
ArangeAddDtypeNone(*init_data),
257-
input_data(),
266+
input_factory(),
258267
ArangeAddDtypeNone.aten_op,
259268
ArangeAddDtypeNone.exir_op,
260269
)
@@ -268,8 +277,7 @@ def test_arange_dtype_none_tosa_INT(test_data):
268277

269278

270279
class FullIncrementViewMulXLessThanY(torch.nn.Module):
271-
272-
def forward(self, x: torch.Tensor, y: torch.Tensor):
280+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
273281
return (
274282
(
275283
torch.full(
@@ -286,7 +294,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
286294
* x
287295
) < y
288296

289-
test_data = {
297+
test_data: ClassVar[Dict[str, input_t2]] = {
290298
"randint": (
291299
torch.randint(
292300
0,
@@ -305,7 +313,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
305313

306314

307315
@common.parametrize("test_data", FullIncrementViewMulXLessThanY.test_data)
308-
def test_convert_full_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1):
316+
def test_convert_full_int64_dtype_to_int32_pass_tosa_FP(
317+
test_data: input_t2,
318+
) -> None:
309319
"""
310320
There are four int64 placeholders in the original graph:
311321
1. _lifted_tensor_constant0: 1
@@ -347,7 +357,9 @@ def test_convert_full_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1):
347357

348358

349359
@common.parametrize("test_data", FullIncrementViewMulXLessThanY.test_data)
350-
def test_convert_full_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1):
360+
def test_convert_full_int64_dtype_to_int32_pass_tosa_INT(
361+
test_data: input_t2,
362+
) -> None:
351363
"""
352364
For INT profile, _lifted_tensor_constant0 is still int64 after applying ConvertInt64ConstOpsToInt32Pass().
353365
And an int64->int32 cast is inserted at the beginning of the graph.
@@ -380,8 +392,7 @@ def test_convert_full_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1):
380392

381393

382394
class RejectFullIncrementViewMulXLessThanY(torch.nn.Module):
383-
384-
def forward(self, x: torch.Tensor, y: torch.Tensor):
395+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
385396
return (
386397
(
387398
torch.full(
@@ -398,7 +409,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
398409
* x
399410
) < y
400411

401-
test_data = {
412+
test_data: ClassVar[Dict[str, input_t2]] = {
402413
"randint": (
403414
torch.randint(
404415
0,
@@ -420,7 +431,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
420431
@pytest.mark.xfail(
421432
reason="MLETORCH-1254: Add operator support check for aten.arange and aten.full"
422433
)
423-
def test_reject_convert_full_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1):
434+
def test_reject_convert_full_int64_dtype_to_int32_pass_tosa_FP(
435+
test_data: input_t2,
436+
) -> None:
424437
module = RejectFullIncrementViewMulXLessThanY()
425438
aten_ops_checks = [
426439
"torch.ops.aten.full.default",
@@ -469,23 +482,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
469482

470483

471484
@common.parametrize("test_data", AddConstFullDtypeNone.test_data)
472-
def test_full_dtype_none_tosa_FP(test_data):
473-
input_data, init_data = test_data
485+
def test_full_dtype_none_tosa_FP(test_data: FullNoneParam) -> None:
486+
input_factory, init_data = test_data
474487
pipeline = TosaPipelineFP[input_t1](
475488
AddConstFullDtypeNone(*init_data),
476-
input_data(),
489+
input_factory(),
477490
aten_op=[],
478491
exir_op=AddConstFullDtypeNone.exir_op,
479492
)
480493
pipeline.run()
481494

482495

483496
@common.parametrize("test_data", AddConstFullDtypeNone.test_data_bool)
484-
def test_full_dtype_none_tosa_FP_bool(test_data):
485-
input_data, init_data = test_data
497+
def test_full_dtype_none_tosa_FP_bool(test_data: FullNoneParam) -> None:
498+
input_factory, init_data = test_data
486499
pipeline = TosaPipelineFP[input_t1](
487500
AddConstFullDtypeNone(*init_data),
488-
input_data(),
501+
input_factory(),
489502
aten_op=[],
490503
exir_op=AddConstFullDtypeNone.exir_op,
491504
)
@@ -501,9 +514,10 @@ def test_full_dtype_none_tosa_FP_bool(test_data):
501514
)
502515
def test_full_dtype_none_tosa_INT(test_data):
503516
input_data, init_data = test_data
517+
input_factory, init_data = test_data
504518
pipeline = TosaPipelineINT[input_t1](
505519
AddConstFullDtypeNone(*init_data),
506-
input_data(),
520+
input_factory(),
507521
aten_op=[],
508522
exir_op=AddConstFullDtypeNone.exir_op,
509523
)

0 commit comments

Comments
 (0)