4
4
5
5
from typing import Tuple
6
6
7
- import pytest
8
7
import torch
9
8
10
9
from executorch .backends .arm .test import common
19
18
input_tt = Tuple [torch .Tensor , torch .Tensor ]
20
19
21
20
22
- def make_float_div_inputs (B : int = 4 , T : int = 64 ) -> input_tt :
23
- x = torch .randn (B , T )
24
- # guard against zero in denominator
25
- y = torch .randn (B , T ).abs () + 1e-3
26
- return x , y
27
-
28
-
29
21
class DivTensorModeFloat (torch .nn .Module ):
30
22
"""
31
23
torch.div(x, y, rounding_mode=mode) with
@@ -44,11 +36,24 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
44
36
return torch .div (x , y , rounding_mode = self .mode )
45
37
46
38
47
- @pytest .mark .parametrize ("mode" , [None , "floor" , "trunc" ])
48
- def test_div_tensor_mode_tosa_FP (mode ):
39
+ test_data = {
40
+ "mode_none" : lambda : (None , (torch .randn (4 , 8 ), torch .randn (4 , 8 ).abs () + 1e-3 )),
41
+ "mode_floor" : lambda : (
42
+ "floor" ,
43
+ (torch .randn (4 , 8 ), torch .randn (4 , 8 ).abs () + 1e-3 ),
44
+ ),
45
+ "mode_trunc" : lambda : (
46
+ "trunc" ,
47
+ (torch .randn (4 , 8 ), torch .randn (4 , 8 ).abs () + 1e-3 ),
48
+ ),
49
+ "int_denominator" : lambda : (None , (torch .randn (4 , 8 ), 2 )),
50
+ }
51
+
49
52
53
+ @common .parametrize ("data" , test_data )
54
+ def test_div_tensor_mode_tosa_FP (data ):
55
+ mode , inputs = data ()
50
56
model = DivTensorModeFloat (mode )
51
- inputs = make_float_div_inputs ()
52
57
53
58
pipeline = TosaPipelineFP [input_tt ](
54
59
model ,
@@ -61,11 +66,10 @@ def test_div_tensor_mode_tosa_FP(mode):
61
66
pipeline .run ()
62
67
63
68
64
- @pytest . mark . parametrize ("mode " , [ None , "floor" , "trunc" ] )
65
- def test_div_tensor_mode_tosa_INT (mode ):
66
-
69
+ @common . parametrize ("data " , test_data )
70
+ def test_div_tensor_mode_tosa_INT (data ):
71
+ mode , inputs = data ()
67
72
model = DivTensorModeFloat (mode )
68
- inputs = make_float_div_inputs ()
69
73
70
74
pipeline = TosaPipelineINT [input_tt ](
71
75
model ,
@@ -79,11 +83,12 @@ def test_div_tensor_mode_tosa_INT(mode):
79
83
80
84
81
85
@common .XfailIfNoCorstone300
82
- @pytest .mark .parametrize ("mode" , [None , "floor" ])
83
- def test_div_tensor_mode_u55_INT (mode ):
84
-
86
+ @common .parametrize (
87
+ "data" , test_data , xfails = {"mode_trunc" : "CPU op missing in unittests" }
88
+ )
89
+ def test_div_tensor_mode_u55_INT (data ):
90
+ mode , inputs = data ()
85
91
model = DivTensorModeFloat (mode )
86
- inputs = make_float_div_inputs ()
87
92
88
93
pipeline = EthosU55PipelineINT [input_tt ](
89
94
model ,
@@ -97,11 +102,10 @@ def test_div_tensor_mode_u55_INT(mode):
97
102
98
103
99
104
@common .XfailIfNoCorstone320
100
- @pytest . mark . parametrize ("mode " , [ None , "floor" , "trunc" ] )
101
- def test_div_tensor_mode_u85_INT (mode ):
102
-
105
+ @common . parametrize ("data " , test_data )
106
+ def test_div_tensor_mode_u85_INT (data ):
107
+ mode , inputs = data ()
103
108
model = DivTensorModeFloat (mode )
104
- inputs = make_float_div_inputs ()
105
109
106
110
pipeline = EthosU85PipelineINT [input_tt ](
107
111
model ,
@@ -115,11 +119,10 @@ def test_div_tensor_mode_u85_INT(mode):
115
119
116
120
117
121
@common .SkipIfNoModelConverter
118
- @pytest . mark . parametrize ("mode " , [ None , "floor" , "trunc" ] )
119
- def test_div_tensor_mode_vgf_INT (mode ):
120
-
122
+ @common . parametrize ("data " , test_data )
123
+ def test_div_tensor_mode_vgf_INT (data ):
124
+ mode , inputs = data ()
121
125
model = DivTensorModeFloat (mode )
122
- inputs = make_float_div_inputs ()
123
126
124
127
pipeline = VgfPipeline [input_tt ](
125
128
model ,
@@ -134,11 +137,10 @@ def test_div_tensor_mode_vgf_INT(mode):
134
137
135
138
136
139
@common .SkipIfNoModelConverter
137
- @pytest . mark . parametrize ("mode " , [ None , "floor" , "trunc" ] )
138
- def test_div_tensor_mode_vgf_FP (mode ):
139
-
140
+ @common . parametrize ("data " , test_data )
141
+ def test_div_tensor_mode_vgf_FP (data ):
142
+ mode , inputs = data ()
140
143
model = DivTensorModeFloat (mode )
141
- inputs = make_float_div_inputs ()
142
144
143
145
pipeline = VgfPipeline [input_tt ](
144
146
model ,
0 commit comments