44# LICENSE file in the root directory of this source tree.
55
66import logging
7+ from typing import Tuple
78
89import torch
910from executorch .backends .arm .test import common
10- from executorch .backends .arm .test .tester .arm_tester import ArmTester
11- from executorch .backends .arm .tosa_partitioner import TOSAPartitioner
11+ from executorch .backends .arm .test .tester .test_pipeline import TosaPipelineMI
1212from executorch .exir .backend .operator_support import (
1313 DontPartition ,
1414 DontPartitionModule ,
1515 DontPartitionName ,
1616)
1717from executorch .exir .dialects ._ops import ops as exir_ops
1818
19+ input_t1 = Tuple [torch .Tensor , torch .Tensor ] # Input x, y
20+
1921
2022class CustomPartitioning (torch .nn .Module ):
21- inputs = (torch .randn (10 , 4 , 5 ), torch .randn (10 , 4 , 5 ))
23+ inputs = {
24+ "randn" : (torch .randn (10 , 4 , 5 ), torch .randn (10 , 4 , 5 )),
25+ }
2226
2327 def forward (self , x : torch .Tensor , y : torch .Tensor ):
2428 z = x + y
@@ -27,7 +31,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
2731
2832
2933class NestedModule (torch .nn .Module ):
30- inputs = (torch .randn (10 , 4 , 5 ), torch .randn (10 , 4 , 5 ))
34+ inputs = {
35+ "randn" : (torch .randn (10 , 4 , 5 ), torch .randn (10 , 4 , 5 )),
36+ }
3137
3238 def __init__ (self ):
3339 super ().__init__ ()
@@ -39,192 +45,139 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
3945 return self .nested (a , b )
4046
4147
42- def test_single_reject (caplog ):
48+ @common .parametrize ("test_data" , CustomPartitioning .inputs )
49+ def test_single_reject (caplog , test_data : input_t1 ):
4350 caplog .set_level (logging .INFO )
4451
4552 module = CustomPartitioning ()
46- inputs = module .inputs
47- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" )
53+ pipeline = TosaPipelineMI [input_t1 ](module , test_data , [], exir_op = [])
4854 check = DontPartition (exir_ops .edge .aten .sigmoid .default )
49- partitioner = TOSAPartitioner (compile_spec , additional_checks = [check ])
50- (
51- ArmTester (
52- module ,
53- example_inputs = inputs ,
54- compile_spec = compile_spec ,
55- )
56- .export ()
57- .to_edge_transform_and_lower (partitioners = [partitioner ])
58- .check (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
59- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 2 })
60- .to_executorch ()
61- .run_method_and_compare_outputs (inputs = inputs )
55+ pipeline .change_args ("to_edge_transform_and_lower" , additional_checks = [check ])
56+ pipeline .change_args (
57+ "check_count.exir" , {"torch.ops.higher_order.executorch_call_delegate" : 2 }
58+ )
59+ pipeline .change_args (
60+ "check_count.exir" ,
61+ {"executorch_exir_dialects_edge__ops_aten_sigmoid_default" : 1 },
6262 )
63+ pipeline .run ()
6364 assert check .has_rejected_node ()
6465 assert "Rejected by DontPartition" in caplog .text
6566
6667
67- def test_multiple_reject ():
68+ @common .parametrize ("test_data" , CustomPartitioning .inputs )
69+ def test_multiple_reject (test_data : input_t1 ):
6870 module = CustomPartitioning ()
69- inputs = module .inputs
70- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" )
71+ pipeline = TosaPipelineMI [input_t1 ](module , test_data , [], exir_op = [])
7172 check = DontPartition (
7273 exir_ops .edge .aten .sigmoid .default , exir_ops .edge .aten .mul .Tensor
7374 )
74- partitioner = TOSAPartitioner (compile_spec , additional_checks = [check ])
75- (
76- ArmTester (
77- module ,
78- example_inputs = inputs ,
79- compile_spec = compile_spec ,
80- )
81- .export ()
82- .to_edge_transform_and_lower (partitioners = [partitioner ])
83- .check (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
84- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
85- .to_executorch ()
86- .run_method_and_compare_outputs (inputs = inputs )
75+ pipeline .change_args ("to_edge_transform_and_lower" , additional_checks = [check ])
76+ pipeline .change_args (
77+ "check_count.exir" , {"torch.ops.higher_order.executorch_call_delegate" : 2 }
78+ )
79+ pipeline .change_args (
80+ "check_count.exir" ,
81+ {"executorch_exir_dialects_edge__ops_aten_sigmoid_default" : 1 },
8782 )
83+ pipeline .run ()
8884 assert check .has_rejected_node ()
8985
9086
91- def test_torch_op_reject (caplog ):
87+ @common .parametrize ("test_data" , CustomPartitioning .inputs )
88+ def test_torch_op_reject (caplog , test_data : input_t1 ):
9289 caplog .set_level (logging .INFO )
9390
9491 module = CustomPartitioning ()
95- inputs = module .inputs
96- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" )
9792 check = DontPartition (torch .ops .aten .sigmoid .default )
98- partitioner = TOSAPartitioner (compile_spec , additional_checks = [check ])
99- (
100- ArmTester (
101- module ,
102- example_inputs = inputs ,
103- compile_spec = compile_spec ,
104- )
105- .export ()
106- .to_edge_transform_and_lower (partitioners = [partitioner ])
107- .check (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
108- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 2 })
109- .to_executorch ()
110- .run_method_and_compare_outputs (inputs = inputs )
93+ pipeline = TosaPipelineMI [input_t1 ](module , test_data , [], exir_op = [])
94+ pipeline .change_args ("to_edge_transform_and_lower" , additional_checks = [check ])
95+ pipeline .change_args (
96+ "check_count.exir" , {"torch.ops.higher_order.executorch_call_delegate" : 2 }
11197 )
98+ pipeline .change_args (
99+ "check_count.exir" ,
100+ {"executorch_exir_dialects_edge__ops_aten_sigmoid_default" : 1 },
101+ )
102+ pipeline .run ()
112103 assert check .has_rejected_node ()
113104 assert "Rejected by DontPartition" in caplog .text
114105
115106
116- def test_string_op_reject ():
107+ @common .parametrize ("test_data" , CustomPartitioning .inputs )
108+ def test_string_op_reject (test_data : input_t1 ):
117109 module = CustomPartitioning ()
118- inputs = module .inputs
119- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" )
120110 check = DontPartition ("aten.sigmoid.default" )
121- partitioner = TOSAPartitioner (compile_spec , additional_checks = [check ])
122- (
123- ArmTester (
124- module ,
125- example_inputs = inputs ,
126- compile_spec = compile_spec ,
127- )
128- .export ()
129- .to_edge_transform_and_lower (partitioners = [partitioner ])
130- .check (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
131- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 2 })
132- .to_executorch ()
133- .run_method_and_compare_outputs (inputs = inputs )
111+ pipeline = TosaPipelineMI [input_t1 ](module , test_data , [], exir_op = [])
112+ pipeline .change_args ("to_edge_transform_and_lower" , additional_checks = [check ])
113+ pipeline .change_args (
114+ "check_count.exir" , {"torch.ops.higher_order.executorch_call_delegate" : 2 }
134115 )
135-
116+ pipeline .change_args (
117+ "check_count.exir" ,
118+ {"executorch_exir_dialects_edge__ops_aten_sigmoid_default" : 1 },
119+ )
120+ pipeline .run ()
136121 assert check .has_rejected_node ()
137122
138123
139- def test_name_reject (caplog ):
124+ @common .parametrize ("test_data" , CustomPartitioning .inputs )
125+ def test_name_reject (caplog , test_data : input_t1 ):
140126 caplog .set_level (logging .INFO )
141127
142128 module = CustomPartitioning ()
143- inputs = module .inputs
144- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" )
145129 check = DontPartitionName ("mul" , "sigmoid" , exact = False )
146- partitioner = TOSAPartitioner (compile_spec , additional_checks = [check ])
147- (
148- ArmTester (
149- module ,
150- example_inputs = inputs ,
151- compile_spec = compile_spec ,
152- )
153- .export ()
154- .to_edge_transform_and_lower (partitioners = [partitioner ])
155- .check (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
156- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
157- .to_executorch ()
158- .run_method_and_compare_outputs (inputs = inputs )
130+ pipeline = TosaPipelineMI [input_t1 ](module , test_data , [], exir_op = [])
131+ pipeline .change_args ("to_edge_transform_and_lower" , additional_checks = [check ])
132+ pipeline .change_args (
133+ "check_count.exir" ,
134+ {"executorch_exir_dialects_edge__ops_aten_sigmoid_default" : 1 },
159135 )
136+ pipeline .run ()
160137 assert check .has_rejected_node ()
161138 assert "Rejected by DontPartitionName" in caplog .text
162139
163140
164- def test_module_reject ():
141+ @common .parametrize ("test_data" , CustomPartitioning .inputs )
142+ def test_module_reject (test_data : input_t1 ):
165143 module = NestedModule ()
166- inputs = module .inputs
167- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" )
168144 check = DontPartitionModule (module_name = "CustomPartitioning" )
169- partitioner = TOSAPartitioner (compile_spec , additional_checks = [check ])
170- (
171- ArmTester (
172- module ,
173- example_inputs = inputs ,
174- compile_spec = compile_spec ,
175- )
176- .export ()
177- .to_edge_transform_and_lower (partitioners = [partitioner ])
178- .check (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
179- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
180- .to_executorch ()
181- .run_method_and_compare_outputs (inputs = inputs )
145+ pipeline = TosaPipelineMI [input_t1 ](module , test_data , [], exir_op = [])
146+ pipeline .change_args ("to_edge_transform_and_lower" , additional_checks = [check ])
147+ pipeline .change_args (
148+ "check_count.exir" ,
149+ {"executorch_exir_dialects_edge__ops_aten_sigmoid_default" : 1 },
182150 )
151+ pipeline .run ()
183152 assert check .has_rejected_node ()
184153
185154
186- def test_inexact_module_reject (caplog ):
155+ @common .parametrize ("test_data" , CustomPartitioning .inputs )
156+ def test_inexact_module_reject (caplog , test_data : input_t1 ):
187157 caplog .set_level (logging .INFO )
188158
189159 module = NestedModule ()
190- inputs = module .inputs
191- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" )
192160 check = DontPartitionModule (module_name = "Custom" , exact = False )
193- partitioner = TOSAPartitioner (compile_spec , additional_checks = [check ])
194- (
195- ArmTester (
196- module ,
197- example_inputs = inputs ,
198- compile_spec = compile_spec ,
199- )
200- .export ()
201- .to_edge_transform_and_lower (partitioners = [partitioner ])
202- .check (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
203- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
204- .to_executorch ()
205- .run_method_and_compare_outputs (inputs = inputs )
161+ pipeline = TosaPipelineMI [input_t1 ](module , test_data , [], exir_op = [])
162+ pipeline .change_args ("to_edge_transform_and_lower" , additional_checks = [check ])
163+ pipeline .change_args (
164+ "check_count.exir" ,
165+ {"executorch_exir_dialects_edge__ops_aten_sigmoid_default" : 1 },
206166 )
167+ pipeline .run ()
207168 assert check .has_rejected_node ()
208169 assert "Rejected by DontPartitionModule" in caplog .text
209170
210171
211- def test_module_instance_reject ():
172+ @common .parametrize ("test_data" , CustomPartitioning .inputs )
173+ def test_module_instance_reject (test_data : input_t1 ):
212174 module = NestedModule ()
213- inputs = module .inputs
214- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" )
215175 check = DontPartitionModule (instance_name = "nested" )
216- partitioner = TOSAPartitioner (compile_spec , additional_checks = [check ])
217- (
218- ArmTester (
219- module ,
220- example_inputs = inputs ,
221- compile_spec = compile_spec ,
222- )
223- .export ()
224- .to_edge_transform_and_lower (partitioners = [partitioner ])
225- .check (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
226- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
227- .to_executorch ()
228- .run_method_and_compare_outputs (inputs = inputs )
176+ pipeline = TosaPipelineMI [input_t1 ](module , test_data , [], exir_op = [])
177+ pipeline .change_args ("to_edge_transform_and_lower" , additional_checks = [check ])
178+ pipeline .change_args (
179+ "check_count.exir" ,
180+ {"executorch_exir_dialects_edge__ops_aten_sigmoid_default" : 1 },
229181 )
182+ pipeline .run ()
230183 assert check .has_rejected_node ()
0 commit comments