66
77import logging
88import os
9+ import shutil
910import tempfile
1011import unittest
1112
@@ -126,8 +127,62 @@ def test_numerical_diff_prints(self):
126127 self .fail ()
127128
128129
129- class TestDumpOperatorsAndDtypes (unittest .TestCase ):
130- def test_dump_ops_and_dtypes (self ):
130+ def test_dump_ops_and_dtypes ():
131+ model = Linear (20 , 30 )
132+ (
133+ ArmTester (
134+ model ,
135+ example_inputs = model .get_inputs (),
136+ compile_spec = common .get_tosa_compile_spec (),
137+ )
138+ .quantize ()
139+ .dump_dtype_distribution ()
140+ .dump_operator_distribution ()
141+ .export ()
142+ .dump_dtype_distribution ()
143+ .dump_operator_distribution ()
144+ .to_edge ()
145+ .dump_dtype_distribution ()
146+ .dump_operator_distribution ()
147+ .partition ()
148+ .dump_dtype_distribution ()
149+ .dump_operator_distribution ()
150+ )
151+ # Just test that there are no execptions.
152+
153+
154+ def test_dump_ops_and_dtypes_parseable ():
155+ model = Linear (20 , 30 )
156+ (
157+ ArmTester (
158+ model ,
159+ example_inputs = model .get_inputs (),
160+ compile_spec = common .get_tosa_compile_spec (),
161+ )
162+ .quantize ()
163+ .dump_dtype_distribution (print_table = False )
164+ .dump_operator_distribution (print_table = False )
165+ .export ()
166+ .dump_dtype_distribution (print_table = False )
167+ .dump_operator_distribution (print_table = False )
168+ .to_edge ()
169+ .dump_dtype_distribution (print_table = False )
170+ .dump_operator_distribution (print_table = False )
171+ .partition ()
172+ .dump_dtype_distribution (print_table = False )
173+ .dump_operator_distribution (print_table = False )
174+ )
175+ # Just test that there are no execptions.
176+
177+
178+ class TestCollateTosaTests (unittest .TestCase ):
179+ """Tests the collation of TOSA tests through setting the environment variable TOSA_TESTCASE_BASE_PATH."""
180+
181+ def test_collate_tosa_BI_tests (self ):
182+ # Set the environment variable to trigger the collation of TOSA tests
183+ os .environ ["TOSA_TESTCASES_BASE_PATH" ] = "test_collate_tosa_tests"
184+ # Clear out the directory
185+
131186 model = Linear (20 , 30 )
132187 (
133188 ArmTester (
@@ -136,16 +191,59 @@ def test_dump_ops_and_dtypes(self):
136191 compile_spec = common .get_tosa_compile_spec (),
137192 )
138193 .quantize ()
139- .dump_dtype_distribution ()
140- .dump_operator_distribution ()
141194 .export ()
142- .dump_dtype_distribution ()
143- .dump_operator_distribution ()
144195 .to_edge ()
145- .dump_dtype_distribution ()
146- .dump_operator_distribution ()
147196 .partition ()
148- .dump_dtype_distribution ()
149- .dump_operator_distribution ()
197+ .to_executorch ()
198+ )
199+ # test that the output directory is created and contains the expected files
200+ assert os .path .exists (
201+ "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests"
202+ )
203+ assert os .path .exists (
204+ "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag8.tosa"
205+ )
206+ assert os .path .exists (
207+ "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag8.json"
208+ )
209+
210+ os .environ .pop ("TOSA_TESTCASES_BASE_PATH" )
211+ shutil .rmtree ("test_collate_tosa_tests" , ignore_errors = True )
212+
213+
214+ def test_dump_tosa_ops (caplog ):
215+ caplog .set_level (logging .INFO )
216+ model = Linear (20 , 30 )
217+ (
218+ ArmTester (
219+ model ,
220+ example_inputs = model .get_inputs (),
221+ compile_spec = common .get_tosa_compile_spec (),
150222 )
151- # Just test that there are no execeptions.
223+ .quantize ()
224+ .export ()
225+ .to_edge ()
226+ .partition ()
227+ .dump_operator_distribution ()
228+ )
229+ assert "TOSA operators:" in caplog .text
230+
231+
232+ def test_fail_dump_tosa_ops (caplog ):
233+ caplog .set_level (logging .INFO )
234+
235+ class Add (torch .nn .Module ):
236+ def forward (self , x ):
237+ return x + x
238+
239+ model = Add ()
240+ compile_spec = common .get_u55_compile_spec ()
241+ (
242+ ArmTester (model , example_inputs = (torch .ones (5 ),), compile_spec = compile_spec )
243+ .quantize ()
244+ .export ()
245+ .to_edge ()
246+ .partition ()
247+ .dump_operator_distribution ()
248+ )
249+ assert "Can not get operator distribution for Vela command stream." in caplog .text
0 commit comments