3
3
# This source code is licensed under the BSD-style license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
6
- #
7
- # Tests the clone op which copies the data of the input tensor (possibly with new data format)
8
- #
9
6
10
7
from typing import Tuple
11
8
12
- import pytest
13
9
import torch
14
10
15
11
from executorch .backends .arm .test import common
28
24
input_t = Tuple [torch .Tensor ]
29
25
30
26
31
- class Clone (torch .nn .Module ):
32
- """A simple module that clones an input tensor."""
27
+ class CloneFirstArg (torch .nn .Module ):
28
+ def forward (self , x ):
29
+ return x .clone () + x
33
30
34
- def forward (self , x : torch .Tensor ):
35
- return x .clone ()
36
31
32
+ class CloneSecondArg (torch .nn .Module ):
33
+ def forward (self , x ):
34
+ return x * x .clone ()
35
+
36
+
37
+ class CloneOutput (torch .nn .Module ):
38
+ def forward (self , x ):
39
+ return (x / x ).clone ()
40
+
41
+
42
+ class CloneBothArgs (torch .nn .Module ):
43
+ def forward (self , x ):
44
+ return x .clone () + x .clone ()
45
+
46
+
47
+ class CloneAfterOtherOp (torch .nn .Module ):
48
+ def forward (self , x ):
49
+ x = x * 2
50
+ return x .clone () + x
51
+
52
+
53
+ class CloneParallelToOtherOp (torch .nn .Module ):
54
+ def forward (self , x ):
55
+ return x * 2 + x .clone ()
37
56
38
- test_data_suite = {
39
- "ones_1D_10" : lambda : (torch .ones (10 ),),
40
- "ones_1D_50" : lambda : (torch .ones (50 ),),
41
- "rand_1D_20" : lambda : (torch .rand (20 ),),
42
- "rand_2D_10x10" : lambda : (torch .rand (10 , 10 ),),
43
- "rand_3D_5x5x5" : lambda : (torch .rand (5 , 5 , 5 ),),
44
- "rand_4D_2x3x4x5" : lambda : (torch .rand (2 , 3 , 4 , 5 ),),
45
- "large_tensor" : lambda : (torch .rand (1000 ),),
46
- }
47
57
58
+ delegated_clones = {
59
+ "clone_first_arg" : lambda : (CloneFirstArg , (torch .rand (1 , 2 , 3 , 4 ),)),
60
+ "clone_second_arg" : lambda : (CloneSecondArg , (torch .rand (1 , 2 , 3 , 4 ),)),
61
+ "clone_output" : lambda : (CloneOutput , (torch .rand (1 , 2 , 3 , 4 ),)),
62
+ "clone_both_args" : lambda : (CloneBothArgs , (torch .rand (1 , 2 , 3 , 4 ),)),
63
+ "clone_after_other_op" : lambda : (CloneAfterOtherOp , (torch .rand (1 , 2 , 3 , 4 ),)),
64
+ "clone_parallel_to_other_op" : lambda : (
65
+ CloneParallelToOtherOp ,
66
+ (torch .rand (1 , 2 , 3 , 4 ),),
67
+ ),
68
+ }
48
69
49
- @common .parametrize ("test_data" , test_data_suite )
50
- def test_clone_tosa_FP (test_data : Tuple [torch .Tensor ]):
51
70
71
+ @common .parametrize ("input_data" , delegated_clones )
72
+ def test_clone_tosa_FP (input_data ):
73
+ module , input_tensor = input_data ()
52
74
pipeline = TosaPipelineFP [input_t ](
53
- Clone (),
54
- test_data (),
55
- aten_op ,
56
- exir_op ,
75
+ module (),
76
+ input_tensor ,
77
+ [],
57
78
)
58
-
59
79
pipeline .run ()
60
80
61
81
62
- @common .parametrize ("test_data" , test_data_suite )
63
- def test_clone_tosa_INT (test_data ):
82
+ @common .parametrize ("input_data" , delegated_clones )
83
+ def test_clone_tosa_INT (input_data ):
84
+ module , input_tensor = input_data ()
85
+
64
86
pipeline = TosaPipelineINT [input_t ](
65
- Clone (),
66
- test_data () ,
87
+ module (),
88
+ input_tensor ,
67
89
aten_op ,
68
90
exir_op ,
69
91
)
70
92
pipeline .run ()
71
93
72
94
73
- @common .parametrize ("test_data " , test_data_suite )
95
+ @common .parametrize ("input_data " , delegated_clones )
74
96
@common .XfailIfNoCorstone300
75
- @pytest .mark .xfail (
76
- reason = "Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477"
77
- )
78
- def test_clone_u55_INT (test_data ):
97
+ def test_clone_u55_INT (input_data ):
98
+ module , input_tensor = input_data ()
99
+
79
100
pipeline = EthosU55PipelineINT [input_t ](
80
- Clone (),
81
- test_data () ,
101
+ module (),
102
+ input_tensor ,
82
103
aten_op ,
83
104
exir_op ,
84
105
run_on_fvp = True ,
@@ -87,15 +108,14 @@ def test_clone_u55_INT(test_data):
87
108
pipeline .run ()
88
109
89
110
90
- @common .parametrize ("test_data " , test_data_suite )
111
+ @common .parametrize ("input_data " , delegated_clones )
91
112
@common .XfailIfNoCorstone320
92
- @pytest .mark .xfail (
93
- reason = "Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477"
94
- )
95
- def test_clone_u85_INT (test_data ):
113
+ def test_clone_u85_INT (input_data ):
114
+ module , input_tensor = input_data ()
115
+
96
116
pipeline = EthosU85PipelineINT [input_t ](
97
- Clone (),
98
- test_data () ,
117
+ module (),
118
+ input_tensor ,
99
119
aten_op ,
100
120
exir_op ,
101
121
run_on_fvp = True ,
@@ -104,27 +124,23 @@ def test_clone_u85_INT(test_data):
104
124
pipeline .run ()
105
125
106
126
107
- @common .parametrize ("test_data" , test_data_suite )
127
+ @common .parametrize ("test_data" , delegated_clones )
108
128
@common .SkipIfNoModelConverter
109
- @pytest .mark .xfail (
110
- reason = "Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477"
111
- )
112
129
def test_clone_vgf_FP (test_data ):
130
+ module , input_tensor = test_data ()
113
131
pipeline = VgfPipeline [input_t ](
114
- Clone (), test_data () , aten_op , exir_op , tosa_version = "TOSA-1.0+FP"
132
+ module (), input_tensor , aten_op , exir_op , tosa_version = "TOSA-1.0+FP"
115
133
)
116
134
pipeline .run ()
117
135
118
136
119
- @common .parametrize ("test_data" , test_data_suite )
137
+ @common .parametrize ("test_data" , delegated_clones )
120
138
@common .SkipIfNoModelConverter
121
- @pytest .mark .xfail (
122
- reason = "Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477"
123
- )
124
139
def test_clone_vgf_INT (test_data ):
140
+ module , input_tensor = test_data ()
125
141
pipeline = VgfPipeline [input_t ](
126
- Clone (),
127
- test_data () ,
142
+ module (),
143
+ input_tensor ,
128
144
aten_op ,
129
145
exir_op ,
130
146
tosa_version = "TOSA-1.0+INT" ,
0 commit comments