44
55from typing import Tuple
66
7- import pytest
87import torch
98
109from executorch .backends .arm .test import common
1918input_tt = Tuple [torch .Tensor , torch .Tensor ]
2019
2120
22- def make_float_div_inputs (B : int = 4 , T : int = 64 ) -> input_tt :
23- x = torch .randn (B , T )
24- # guard against zero in denominator
25- y = torch .randn (B , T ).abs () + 1e-3
26- return x , y
27-
28-
2921class DivTensorModeFloat (torch .nn .Module ):
3022 """
3123 torch.div(x, y, rounding_mode=mode) with
@@ -44,11 +36,24 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
4436 return torch .div (x , y , rounding_mode = self .mode )
4537
4638
47- @pytest .mark .parametrize ("mode" , [None , "floor" , "trunc" ])
48- def test_div_tensor_mode_tosa_FP (mode ):
39+ test_data = {
40+ "mode_none" : lambda : (None , (torch .randn (4 , 8 ), torch .randn (4 , 8 ).abs () + 1e-3 )),
41+ "mode_floor" : lambda : (
42+ "floor" ,
43+ (torch .randn (4 , 8 ), torch .randn (4 , 8 ).abs () + 1e-3 ),
44+ ),
45+ "mode_trunc" : lambda : (
46+ "trunc" ,
47+ (torch .randn (4 , 8 ), torch .randn (4 , 8 ).abs () + 1e-3 ),
48+ ),
49+ "int_denominator" : lambda : (None , (torch .randn (4 , 8 ), 2 )),
50+ }
51+
4952
53+ @common .parametrize ("data" , test_data )
54+ def test_div_tensor_mode_tosa_FP (data ):
55+ mode , inputs = data ()
5056 model = DivTensorModeFloat (mode )
51- inputs = make_float_div_inputs ()
5257
5358 pipeline = TosaPipelineFP [input_tt ](
5459 model ,
@@ -61,11 +66,10 @@ def test_div_tensor_mode_tosa_FP(mode):
6166 pipeline .run ()
6267
6368
64- @pytest . mark . parametrize ("mode " , [ None , "floor" , "trunc" ] )
65- def test_div_tensor_mode_tosa_INT (mode ):
66-
69+ @common . parametrize ("data " , test_data )
70+ def test_div_tensor_mode_tosa_INT (data ):
71+ mode , inputs = data ()
6772 model = DivTensorModeFloat (mode )
68- inputs = make_float_div_inputs ()
6973
7074 pipeline = TosaPipelineINT [input_tt ](
7175 model ,
@@ -79,11 +83,12 @@ def test_div_tensor_mode_tosa_INT(mode):
7983
8084
8185@common .XfailIfNoCorstone300
82- @pytest .mark .parametrize ("mode" , [None , "floor" ])
83- def test_div_tensor_mode_u55_INT (mode ):
84-
86+ @common .parametrize (
87+ "data" , test_data , xfails = {"mode_trunc" : "CPU op missing in unittests" }
88+ )
89+ def test_div_tensor_mode_u55_INT (data ):
90+ mode , inputs = data ()
8591 model = DivTensorModeFloat (mode )
86- inputs = make_float_div_inputs ()
8792
8893 pipeline = EthosU55PipelineINT [input_tt ](
8994 model ,
@@ -97,11 +102,10 @@ def test_div_tensor_mode_u55_INT(mode):
97102
98103
99104@common .XfailIfNoCorstone320
100- @pytest . mark . parametrize ("mode " , [ None , "floor" , "trunc" ] )
101- def test_div_tensor_mode_u85_INT (mode ):
102-
105+ @common . parametrize ("data " , test_data )
106+ def test_div_tensor_mode_u85_INT (data ):
107+ mode , inputs = data ()
103108 model = DivTensorModeFloat (mode )
104- inputs = make_float_div_inputs ()
105109
106110 pipeline = EthosU85PipelineINT [input_tt ](
107111 model ,
@@ -115,11 +119,10 @@ def test_div_tensor_mode_u85_INT(mode):
115119
116120
117121@common .SkipIfNoModelConverter
118- @pytest . mark . parametrize ("mode " , [ None , "floor" , "trunc" ] )
119- def test_div_tensor_mode_vgf_INT (mode ):
120-
122+ @common . parametrize ("data " , test_data )
123+ def test_div_tensor_mode_vgf_INT (data ):
124+ mode , inputs = data ()
121125 model = DivTensorModeFloat (mode )
122- inputs = make_float_div_inputs ()
123126
124127 pipeline = VgfPipeline [input_tt ](
125128 model ,
@@ -134,11 +137,10 @@ def test_div_tensor_mode_vgf_INT(mode):
134137
135138
136139@common .SkipIfNoModelConverter
137- @pytest . mark . parametrize ("mode " , [ None , "floor" , "trunc" ] )
138- def test_div_tensor_mode_vgf_FP (mode ):
139-
140+ @common . parametrize ("data " , test_data )
141+ def test_div_tensor_mode_vgf_FP (data ):
142+ mode , inputs = data ()
140143 model = DivTensorModeFloat (mode )
141- inputs = make_float_div_inputs ()
142144
143145 pipeline = VgfPipeline [input_tt ](
144146 model ,
0 commit comments