33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6- from typing import Tuple , Union
6+ from typing import Callable , ClassVar , Dict , Tuple , Union
77
88import pytest
99
2222input_t1 = Tuple [torch .Tensor ] # Input x
2323input_t2 = Tuple [torch .Tensor , torch .Tensor ] # Input x, y
2424
25+ Scalar = Union [bool , float , int ]
26+ ArangeNoneParam = Tuple [Callable [[], input_t1 ], Tuple [Scalar , Scalar , Scalar ]]
27+ FullNoneParam = Tuple [Callable [[], input_t1 ], Tuple [Tuple [int , ...], Scalar ]]
28+
2529
2630#####################################################
2731## Test arange(dtype=int64) -> arange(dtype=int32) ##
2832#####################################################
2933
3034
3135class ArangeDefaultIncrementViewLessThan (torch .nn .Module ):
32-
33- def forward (self , x : torch .Tensor ):
36+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
3437 return (torch .arange (10 , dtype = torch .int64 ) + 1 ).view (- 1 , 1 ) < x
3538
36- test_data = {
39+ test_data : ClassVar [ Dict [ str , input_t1 ]] = {
3740 "randint" : (
3841 torch .randint (
3942 0 ,
@@ -46,7 +49,9 @@ def forward(self, x: torch.Tensor):
4649
4750
4851@common .parametrize ("test_data" , ArangeDefaultIncrementViewLessThan .test_data )
49- def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_FP (test_data : input_t1 ):
52+ def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_FP (
53+ test_data : input_t1 ,
54+ ) -> None :
5055 module = ArangeDefaultIncrementViewLessThan ()
5156 aten_ops_checks = [
5257 "torch.ops.aten.lt.Tensor" ,
@@ -67,7 +72,9 @@ def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_FP(test_data: inp
6772
6873
6974@common .parametrize ("test_data" , ArangeDefaultIncrementViewLessThan .test_data )
70- def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT (test_data : input_t1 ):
75+ def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT (
76+ test_data : input_t1 ,
77+ ) -> None :
7178 module = ArangeDefaultIncrementViewLessThan ()
7279 aten_ops_checks = [
7380 "torch.ops.aten.lt.Tensor" ,
@@ -88,11 +95,10 @@ def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT(test_data: in
8895
8996
9097class ArangeStartIncrementViewLessThan (torch .nn .Module ):
91-
92- def forward (self , x : torch .Tensor ):
98+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
9399 return (torch .arange (0 , 10 , dtype = torch .int64 ) + 1 ).view (- 1 , 1 ) < x
94100
95- test_data = {
101+ test_data : ClassVar [ Dict [ str , input_t1 ]] = {
96102 "randint" : (
97103 torch .randint (
98104 0 ,
@@ -105,7 +111,9 @@ def forward(self, x: torch.Tensor):
105111
106112
107113@common .parametrize ("test_data" , ArangeStartIncrementViewLessThan .test_data )
108- def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_FP (test_data : input_t1 ):
114+ def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_FP (
115+ test_data : input_t1 ,
116+ ) -> None :
109117 module = ArangeStartIncrementViewLessThan ()
110118 aten_ops_checks = [
111119 "torch.ops.aten.lt.Tensor" ,
@@ -126,7 +134,9 @@ def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_FP(test_data: input
126134
127135
128136@common .parametrize ("test_data" , ArangeStartIncrementViewLessThan .test_data )
129- def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT (test_data : input_t1 ):
137+ def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT (
138+ test_data : input_t1 ,
139+ ) -> None :
130140 module = ArangeStartIncrementViewLessThan ()
131141 aten_ops_checks = [
132142 "torch.ops.aten.lt.Tensor" ,
@@ -147,11 +157,10 @@ def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT(test_data: inpu
147157
148158
149159class ArangeStartStepIncrementViewLessThan (torch .nn .Module ):
150-
151- def forward (self , x : torch .Tensor ):
160+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
152161 return (torch .arange (0 , 10 , 2 , dtype = torch .int64 ) + 1 ).view (- 1 , 1 ) < x
153162
154- test_data = {
163+ test_data : ClassVar [ Dict [ str , input_t1 ]] = {
155164 "randint" : (
156165 torch .randint (
157166 0 ,
@@ -166,7 +175,7 @@ def forward(self, x: torch.Tensor):
166175@common .parametrize ("test_data" , ArangeStartStepIncrementViewLessThan .test_data )
167176def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_FP (
168177 test_data : input_t1 ,
169- ):
178+ ) -> None :
170179 module = ArangeStartStepIncrementViewLessThan ()
171180 aten_ops_checks = [
172181 "torch.ops.aten.lt.Tensor" ,
@@ -189,7 +198,7 @@ def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_FP(
189198@common .parametrize ("test_data" , ArangeStartStepIncrementViewLessThan .test_data )
190199def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_INT (
191200 test_data : input_t1 ,
192- ):
201+ ) -> None :
193202 module = ArangeStartStepIncrementViewLessThan ()
194203 aten_ops_checks = [
195204 "torch.ops.aten.lt.Tensor" ,
@@ -225,7 +234,7 @@ def __init__(self, start: float, stop: float, step: float):
225234 def forward (self , x : torch .Tensor ) -> torch .Tensor :
226235 return torch .arange (* self .args ) + x
227236
228- test_data = {
237+ test_data : ClassVar [ Dict [ str , ArangeNoneParam ]] = {
229238 "int64" : (lambda : (torch .randn (10 , 1 ),), (0 , 10 , 1 )),
230239 "float32_start" : (lambda : (torch .randn (10 , 1 ),), (0.0 , 10 , 1 )),
231240 "float32_stop" : (lambda : (torch .randn (10 , 1 ),), (0 , 10.0 , 1 )),
@@ -238,23 +247,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
238247
239248
240249@common .parametrize ("test_data" , ArangeAddDtypeNone .test_data )
241- def test_arange_dtype_none_tosa_FP (test_data ) :
242- input_data , init_data = test_data
250+ def test_arange_dtype_none_tosa_FP (test_data : ArangeNoneParam ) -> None :
251+ input_factory , init_data = test_data
243252 pipeline = TosaPipelineFP [input_t1 ](
244253 ArangeAddDtypeNone (* init_data ),
245- input_data (),
254+ input_factory (),
246255 ArangeAddDtypeNone .aten_op ,
247256 ArangeAddDtypeNone .exir_op ,
248257 )
249258 pipeline .run ()
250259
251260
252261@common .parametrize ("test_data" , ArangeAddDtypeNone .test_data )
253- def test_arange_dtype_none_tosa_INT (test_data ) :
254- input_data , init_data = test_data
262+ def test_arange_dtype_none_tosa_INT (test_data : ArangeNoneParam ) -> None :
263+ input_factory , init_data = test_data
255264 pipeline = TosaPipelineINT [input_t1 ](
256265 ArangeAddDtypeNone (* init_data ),
257- input_data (),
266+ input_factory (),
258267 ArangeAddDtypeNone .aten_op ,
259268 ArangeAddDtypeNone .exir_op ,
260269 )
@@ -268,8 +277,7 @@ def test_arange_dtype_none_tosa_INT(test_data):
268277
269278
270279class FullIncrementViewMulXLessThanY (torch .nn .Module ):
271-
272- def forward (self , x : torch .Tensor , y : torch .Tensor ):
280+ def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
273281 return (
274282 (
275283 torch .full (
@@ -286,7 +294,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
286294 * x
287295 ) < y
288296
289- test_data = {
297+ test_data : ClassVar [ Dict [ str , input_t2 ]] = {
290298 "randint" : (
291299 torch .randint (
292300 0 ,
@@ -305,7 +313,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
305313
306314
307315@common .parametrize ("test_data" , FullIncrementViewMulXLessThanY .test_data )
308- def test_convert_full_int64_dtype_to_int32_pass_tosa_FP (test_data : input_t1 ):
316+ def test_convert_full_int64_dtype_to_int32_pass_tosa_FP (
317+ test_data : input_t2 ,
318+ ) -> None :
309319 """
310320 There are four int64 placeholders in the original graph:
311321 1. _lifted_tensor_constant0: 1
@@ -347,7 +357,9 @@ def test_convert_full_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1):
347357
348358
349359@common .parametrize ("test_data" , FullIncrementViewMulXLessThanY .test_data )
350- def test_convert_full_int64_dtype_to_int32_pass_tosa_INT (test_data : input_t1 ):
360+ def test_convert_full_int64_dtype_to_int32_pass_tosa_INT (
361+ test_data : input_t2 ,
362+ ) -> None :
351363 """
352364 For INT profile, _lifted_tensor_constant0 is still int64 after applying ConvertInt64ConstOpsToInt32Pass().
353365 And an int64->int32 cast is inserted at the beginning of the graph.
@@ -380,8 +392,7 @@ def test_convert_full_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1):
380392
381393
382394class RejectFullIncrementViewMulXLessThanY (torch .nn .Module ):
383-
384- def forward (self , x : torch .Tensor , y : torch .Tensor ):
395+ def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
385396 return (
386397 (
387398 torch .full (
@@ -398,7 +409,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
398409 * x
399410 ) < y
400411
401- test_data = {
412+ test_data : ClassVar [ Dict [ str , input_t2 ]] = {
402413 "randint" : (
403414 torch .randint (
404415 0 ,
@@ -420,7 +431,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
420431@pytest .mark .xfail (
421432 reason = "MLETORCH-1254: Add operator support check for aten.arange and aten.full"
422433)
423- def test_reject_convert_full_int64_dtype_to_int32_pass_tosa_FP (test_data : input_t1 ):
434+ def test_reject_convert_full_int64_dtype_to_int32_pass_tosa_FP (
435+ test_data : input_t2 ,
436+ ) -> None :
424437 module = RejectFullIncrementViewMulXLessThanY ()
425438 aten_ops_checks = [
426439 "torch.ops.aten.full.default" ,
@@ -469,23 +482,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
469482
470483
471484@common .parametrize ("test_data" , AddConstFullDtypeNone .test_data )
472- def test_full_dtype_none_tosa_FP (test_data ) :
473- input_data , init_data = test_data
485+ def test_full_dtype_none_tosa_FP (test_data : FullNoneParam ) -> None :
486+ input_factory , init_data = test_data
474487 pipeline = TosaPipelineFP [input_t1 ](
475488 AddConstFullDtypeNone (* init_data ),
476- input_data (),
489+ input_factory (),
477490 aten_op = [],
478491 exir_op = AddConstFullDtypeNone .exir_op ,
479492 )
480493 pipeline .run ()
481494
482495
483496@common .parametrize ("test_data" , AddConstFullDtypeNone .test_data_bool )
484- def test_full_dtype_none_tosa_FP_bool (test_data ) :
485- input_data , init_data = test_data
497+ def test_full_dtype_none_tosa_FP_bool (test_data : FullNoneParam ) -> None :
498+ input_factory , init_data = test_data
486499 pipeline = TosaPipelineFP [input_t1 ](
487500 AddConstFullDtypeNone (* init_data ),
488- input_data (),
501+ input_factory (),
489502 aten_op = [],
490503 exir_op = AddConstFullDtypeNone .exir_op ,
491504 )
@@ -501,9 +514,10 @@ def test_full_dtype_none_tosa_FP_bool(test_data):
501514)
502515def test_full_dtype_none_tosa_INT (test_data ):
503516 input_data , init_data = test_data
517+ input_factory , init_data = test_data
504518 pipeline = TosaPipelineINT [input_t1 ](
505519 AddConstFullDtypeNone (* init_data ),
506- input_data (),
520+ input_factory (),
507521 aten_op = [],
508522 exir_op = AddConstFullDtypeNone .exir_op ,
509523 )
0 commit comments