33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6- #
7- # Tests the clone op which copies the data of the input tensor (possibly with new data format)
8- #
96
107from typing import Tuple
118
12- import pytest
139import torch
1410
1511from executorch .backends .arm .test import common
2824input_t = Tuple [torch .Tensor ]
2925
3026
31- class Clone (torch .nn .Module ):
32- """A simple module that clones an input tensor."""
27+ class CloneFirstArg (torch .nn .Module ):
28+ def forward (self , x ):
29+ return x .clone () + x
3330
34- def forward (self , x : torch .Tensor ):
35- return x .clone ()
3631
32+ class CloneSecondArg (torch .nn .Module ):
33+ def forward (self , x ):
34+ return x * x .clone ()
35+
36+
37+ class CloneOutput (torch .nn .Module ):
38+ def forward (self , x ):
39+ return (x / x ).clone ()
40+
41+
42+ class CloneBothArgs (torch .nn .Module ):
43+ def forward (self , x ):
44+ return x .clone () + x .clone ()
45+
46+
47+ class CloneAfterOtherOp (torch .nn .Module ):
48+ def forward (self , x ):
49+ x = x * 2
50+ return x .clone () + x
51+
52+
53+ class CloneParallelToOtherOp (torch .nn .Module ):
54+ def forward (self , x ):
55+ return x * 2 + x .clone ()
3756
38- test_data_suite = {
39- "ones_1D_10" : lambda : (torch .ones (10 ),),
40- "ones_1D_50" : lambda : (torch .ones (50 ),),
41- "rand_1D_20" : lambda : (torch .rand (20 ),),
42- "rand_2D_10x10" : lambda : (torch .rand (10 , 10 ),),
43- "rand_3D_5x5x5" : lambda : (torch .rand (5 , 5 , 5 ),),
44- "rand_4D_2x3x4x5" : lambda : (torch .rand (2 , 3 , 4 , 5 ),),
45- "large_tensor" : lambda : (torch .rand (1000 ),),
46- }
4757
58+ delegated_clones = {
59+ "clone_first_arg" : lambda : (CloneFirstArg , (torch .rand (1 , 2 , 3 , 4 ),)),
60+ "clone_second_arg" : lambda : (CloneSecondArg , (torch .rand (1 , 2 , 3 , 4 ),)),
61+ "clone_output" : lambda : (CloneOutput , (torch .rand (1 , 2 , 3 , 4 ),)),
62+ "clone_both_args" : lambda : (CloneBothArgs , (torch .rand (1 , 2 , 3 , 4 ),)),
63+ "clone_after_other_op" : lambda : (CloneAfterOtherOp , (torch .rand (1 , 2 , 3 , 4 ),)),
64+ "clone_parallel_to_other_op" : lambda : (
65+ CloneParallelToOtherOp ,
66+ (torch .rand (1 , 2 , 3 , 4 ),),
67+ ),
68+ }
4869
49- @common .parametrize ("test_data" , test_data_suite )
50- def test_clone_tosa_FP (test_data : Tuple [torch .Tensor ]):
5170
71+ @common .parametrize ("input_data" , delegated_clones )
72+ def test_clone_tosa_FP (input_data ):
73+ module , input_tensor = input_data ()
5274 pipeline = TosaPipelineFP [input_t ](
53- Clone (),
54- test_data (),
55- aten_op ,
56- exir_op ,
75+ module (),
76+ input_tensor ,
77+ [],
5778 )
58-
5979 pipeline .run ()
6080
6181
62- @common .parametrize ("test_data" , test_data_suite )
63- def test_clone_tosa_INT (test_data ):
82+ @common .parametrize ("input_data" , delegated_clones )
83+ def test_clone_tosa_INT (input_data ):
84+ module , input_tensor = input_data ()
85+
6486 pipeline = TosaPipelineINT [input_t ](
65- Clone (),
66- test_data () ,
87+ module (),
88+ input_tensor ,
6789 aten_op ,
6890 exir_op ,
6991 )
7092 pipeline .run ()
7193
7294
73- @common .parametrize ("test_data " , test_data_suite )
95+ @common .parametrize ("input_data " , delegated_clones )
7496@common .XfailIfNoCorstone300
75- @pytest .mark .xfail (
76- reason = "Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477"
77- )
78- def test_clone_u55_INT (test_data ):
97+ def test_clone_u55_INT (input_data ):
98+ module , input_tensor = input_data ()
99+
79100 pipeline = EthosU55PipelineINT [input_t ](
80- Clone (),
81- test_data () ,
101+ module (),
102+ input_tensor ,
82103 aten_op ,
83104 exir_op ,
84105 run_on_fvp = True ,
@@ -87,15 +108,14 @@ def test_clone_u55_INT(test_data):
87108 pipeline .run ()
88109
89110
90- @common .parametrize ("test_data " , test_data_suite )
111+ @common .parametrize ("input_data " , delegated_clones )
91112@common .XfailIfNoCorstone320
92- @pytest .mark .xfail (
93- reason = "Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477"
94- )
95- def test_clone_u85_INT (test_data ):
113+ def test_clone_u85_INT (input_data ):
114+ module , input_tensor = input_data ()
115+
96116 pipeline = EthosU85PipelineINT [input_t ](
97- Clone (),
98- test_data () ,
117+ module (),
118+ input_tensor ,
99119 aten_op ,
100120 exir_op ,
101121 run_on_fvp = True ,
@@ -104,27 +124,23 @@ def test_clone_u85_INT(test_data):
104124 pipeline .run ()
105125
106126
107- @common .parametrize ("test_data" , test_data_suite )
127+ @common .parametrize ("test_data" , delegated_clones )
108128@common .SkipIfNoModelConverter
109- @pytest .mark .xfail (
110- reason = "Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477"
111- )
112129def test_clone_vgf_FP (test_data ):
130+ module , input_tensor = test_data ()
113131 pipeline = VgfPipeline [input_t ](
114- Clone (), test_data () , aten_op , exir_op , tosa_version = "TOSA-1.0+FP"
132+ module (), input_tensor , aten_op , exir_op , tosa_version = "TOSA-1.0+FP"
115133 )
116134 pipeline .run ()
117135
118136
119- @common .parametrize ("test_data" , test_data_suite )
137+ @common .parametrize ("test_data" , delegated_clones )
120138@common .SkipIfNoModelConverter
121- @pytest .mark .xfail (
122- reason = "Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477"
123- )
124139def test_clone_vgf_INT (test_data ):
140+ module , input_tensor = test_data ()
125141 pipeline = VgfPipeline [input_t ](
126- Clone (),
127- test_data () ,
142+ module (),
143+ input_tensor ,
128144 aten_op ,
129145 exir_op ,
130146 tosa_version = "TOSA-1.0+INT" ,
0 commit comments