@@ -50,17 +50,19 @@ def forward(self, x):
5050
5151
5252def  _tosa_FP_pipeline (module : torch .nn .Module , test_data : input_t1 , dump_file = None ):
53- 
54-     pipeline  =  TosaPipelineFP [input_t1 ](module , test_data , [], [])
53+     aten_ops : list [str ] =  []
54+     exir_ops : list [str ] =  []
55+     pipeline  =  TosaPipelineFP [input_t1 ](module , test_data , aten_ops , exir_ops )
5556    pipeline .dump_artifact ("to_edge_transform_and_lower" )
5657    pipeline .dump_artifact ("to_edge_transform_and_lower" , suffix = dump_file )
5758    pipeline .pop_stage ("run_method_and_compare_outputs" )
5859    pipeline .run ()
5960
6061
6162def  _tosa_INT_pipeline (module : torch .nn .Module , test_data : input_t1 , dump_file = None ):
62- 
63-     pipeline  =  TosaPipelineINT [input_t1 ](module , test_data , [], [])
63+     aten_ops : list [str ] =  []
64+     exir_ops : list [str ] =  []
65+     pipeline  =  TosaPipelineINT [input_t1 ](module , test_data , aten_ops , exir_ops )
6466    pipeline .dump_artifact ("to_edge_transform_and_lower" )
6567    pipeline .dump_artifact ("to_edge_transform_and_lower" , suffix = dump_file )
6668    pipeline .pop_stage ("run_method_and_compare_outputs" )
@@ -105,11 +107,13 @@ def test_INT_artifact(test_data: input_t1):
105107
106108@common .parametrize ("test_data" , Linear .inputs ) 
107109def  test_numerical_diff_print (test_data : input_t1 ):
110+     aten_ops : list [str ] =  []
111+     exir_ops : list [str ] =  []
108112    pipeline  =  TosaPipelineINT [input_t1 ](
109113        Linear (),
110114        test_data ,
111-         [] ,
112-         [] ,
115+         aten_ops ,
116+         exir_ops ,
113117        custom_path = "diff_print_test" ,
114118    )
115119    pipeline .pop_stage ("run_method_and_compare_outputs" )
@@ -131,7 +135,9 @@ def test_numerical_diff_print(test_data: input_t1):
131135
132136@common .parametrize ("test_data" , Linear .inputs ) 
133137def  test_dump_ops_and_dtypes (test_data : input_t1 ):
134-     pipeline  =  TosaPipelineINT [input_t1 ](Linear (), test_data , [], [])
138+     aten_ops : list [str ] =  []
139+     exir_ops : list [str ] =  []
140+     pipeline  =  TosaPipelineINT [input_t1 ](Linear (), test_data , aten_ops , exir_ops )
135141    pipeline .pop_stage ("run_method_and_compare_outputs" )
136142    pipeline .add_stage_after ("quantize" , pipeline .tester .dump_dtype_distribution )
137143    pipeline .add_stage_after ("quantize" , pipeline .tester .dump_operator_distribution )
@@ -149,7 +155,9 @@ def test_dump_ops_and_dtypes(test_data: input_t1):
149155
150156@common .parametrize ("test_data" , Linear .inputs ) 
151157def  test_dump_ops_and_dtypes_parseable (test_data : input_t1 ):
152-     pipeline  =  TosaPipelineINT [input_t1 ](Linear (), test_data , [], [])
158+     aten_ops : list [str ] =  []
159+     exir_ops : list [str ] =  []
160+     pipeline  =  TosaPipelineINT [input_t1 ](Linear (), test_data , aten_ops , exir_ops )
153161    pipeline .pop_stage ("run_method_and_compare_outputs" )
154162    pipeline .add_stage_after ("quantize" , pipeline .tester .dump_dtype_distribution , False )
155163    pipeline .add_stage_after (
@@ -177,7 +185,9 @@ def test_collate_tosa_INT_tests(test_data: input_t1):
177185    # Set the environment variable to trigger the collation of TOSA tests 
178186    os .environ ["TOSA_TESTCASES_BASE_PATH" ] =  "test_collate_tosa_tests" 
179187    # Clear out the directory 
180-     pipeline  =  TosaPipelineINT [input_t1 ](Linear (), test_data , [], [])
188+     aten_ops : list [str ] =  []
189+     exir_ops : list [str ] =  []
190+     pipeline  =  TosaPipelineINT [input_t1 ](Linear (), test_data , aten_ops , exir_ops )
181191    pipeline .pop_stage ("run_method_and_compare_outputs" )
182192    pipeline .run ()
183193
@@ -197,11 +207,13 @@ def test_collate_tosa_INT_tests(test_data: input_t1):
197207@common .parametrize ("test_data" , Linear .inputs ) 
198208def  test_dump_tosa_debug_json (test_data : input_t1 ):
199209    with  tempfile .TemporaryDirectory () as  tmpdir :
210+         aten_ops : list [str ] =  []
211+         exir_ops : list [str ] =  []
200212        pipeline  =  TosaPipelineINT [input_t1 ](
201213            module = Linear (),
202214            test_data = test_data ,
203-             aten_op = [] ,
204-             exir_op = [] ,
215+             aten_op = aten_ops ,
216+             exir_op = exir_ops ,
205217            custom_path = tmpdir ,
206218            tosa_debug_mode = ArmCompileSpec .DebugMode .JSON ,
207219        )
@@ -228,11 +240,13 @@ def test_dump_tosa_debug_json(test_data: input_t1):
228240@common .parametrize ("test_data" , Linear .inputs ) 
229241def  test_dump_tosa_debug_tosa (test_data : input_t1 ):
230242    with  tempfile .TemporaryDirectory () as  tmpdir :
243+         aten_ops : list [str ] =  []
244+         exir_ops : list [str ] =  []
231245        pipeline  =  TosaPipelineINT [input_t1 ](
232246            module = Linear (),
233247            test_data = test_data ,
234-             aten_op = [] ,
235-             exir_op = [] ,
248+             aten_op = aten_ops ,
249+             exir_op = exir_ops ,
236250            custom_path = tmpdir ,
237251            tosa_debug_mode = ArmCompileSpec .DebugMode .TOSA ,
238252        )
@@ -248,7 +262,9 @@ def test_dump_tosa_debug_tosa(test_data: input_t1):
248262
249263@common .parametrize ("test_data" , Linear .inputs ) 
250264def  test_dump_tosa_ops (caplog , test_data : input_t1 ):
251-     pipeline  =  TosaPipelineINT [input_t1 ](Linear (), test_data , [], [])
265+     aten_ops : list [str ] =  []
266+     exir_ops : list [str ] =  []
267+     pipeline  =  TosaPipelineINT [input_t1 ](Linear (), test_data , aten_ops , exir_ops )
252268    pipeline .pop_stage ("run_method_and_compare_outputs" )
253269    pipeline .dump_operator_distribution ("to_edge_transform_and_lower" )
254270    pipeline .run ()
@@ -267,8 +283,10 @@ def forward(self, x):
267283@common .parametrize ("test_data" , Add .inputs ) 
268284@common .XfailIfNoCorstone300  
269285def  test_fail_dump_tosa_ops (caplog , test_data : input_t1 ):
286+     aten_ops : list [str ] =  []
287+     exir_ops : list [str ] =  []
270288    pipeline  =  EthosU55PipelineINT [input_t1 ](
271-         Add (), test_data , [], [] , use_to_edge_transform_and_lower = True 
289+         Add (), test_data , aten_ops ,  exir_ops , use_to_edge_transform_and_lower = True 
272290    )
273291    pipeline .dump_operator_distribution ("to_edge_transform_and_lower" )
274292    pipeline .run ()
0 commit comments