@@ -127,29 +127,52 @@ def test_numerical_diff_prints(self):
127127 self .fail ()
128128
129129
130- class TestDumpOperatorsAndDtypes (unittest .TestCase ):
131- def test_dump_ops_and_dtypes (self ):
132- model = Linear (20 , 30 )
133- (
134- ArmTester (
135- model ,
136- example_inputs = model .get_inputs (),
137- compile_spec = common .get_tosa_compile_spec (),
138- )
139- .quantize ()
140- .dump_dtype_distribution ()
141- .dump_operator_distribution ()
142- .export ()
143- .dump_dtype_distribution ()
144- .dump_operator_distribution ()
145- .to_edge ()
146- .dump_dtype_distribution ()
147- .dump_operator_distribution ()
148- .partition ()
149- .dump_dtype_distribution ()
150- .dump_operator_distribution ()
130+ def test_dump_ops_and_dtypes ():
131+ model = Linear (20 , 30 )
132+ (
133+ ArmTester (
134+ model ,
135+ example_inputs = model .get_inputs (),
136+ compile_spec = common .get_tosa_compile_spec (),
137+ )
138+ .quantize ()
139+ .dump_dtype_distribution ()
140+ .dump_operator_distribution ()
141+ .export ()
142+ .dump_dtype_distribution ()
143+ .dump_operator_distribution ()
144+ .to_edge ()
145+ .dump_dtype_distribution ()
146+ .dump_operator_distribution ()
147+ .partition ()
148+ .dump_dtype_distribution ()
149+ .dump_operator_distribution ()
150+ )
151+ # Just test that there are no execptions.
152+
153+
154+ def test_dump_ops_and_dtypes_parseable ():
155+ model = Linear (20 , 30 )
156+ (
157+ ArmTester (
158+ model ,
159+ example_inputs = model .get_inputs (),
160+ compile_spec = common .get_tosa_compile_spec (),
151161 )
152- # Just test that there are no execeptions.
162+ .quantize ()
163+ .dump_dtype_distribution (print_table = False )
164+ .dump_operator_distribution (print_table = False )
165+ .export ()
166+ .dump_dtype_distribution (print_table = False )
167+ .dump_operator_distribution (print_table = False )
168+ .to_edge ()
169+ .dump_dtype_distribution (print_table = False )
170+ .dump_operator_distribution (print_table = False )
171+ .partition ()
172+ .dump_dtype_distribution (print_table = False )
173+ .dump_operator_distribution (print_table = False )
174+ )
175+ # Just test that there are no execptions.
153176
154177
155178class TestCollateTosaTests (unittest .TestCase ):
@@ -186,3 +209,41 @@ def test_collate_tosa_BI_tests(self):
186209
187210 os .environ .pop ("TOSA_TESTCASES_BASE_PATH" )
188211 shutil .rmtree ("test_collate_tosa_tests" , ignore_errors = True )
212+
213+
214+ def test_dump_tosa_ops (caplog ):
215+ caplog .set_level (logging .INFO )
216+ model = Linear (20 , 30 )
217+ (
218+ ArmTester (
219+ model ,
220+ example_inputs = model .get_inputs (),
221+ compile_spec = common .get_tosa_compile_spec (),
222+ )
223+ .quantize ()
224+ .export ()
225+ .to_edge ()
226+ .partition ()
227+ .dump_operator_distribution ()
228+ )
229+ assert "TOSA operators:" in caplog .text
230+
231+
232+ def test_fail_dump_tosa_ops (caplog ):
233+ caplog .set_level (logging .INFO )
234+
235+ class Add (torch .nn .Module ):
236+ def forward (self , x ):
237+ return x + x
238+
239+ model = Add ()
240+ compile_spec = common .get_u55_compile_spec ()
241+ (
242+ ArmTester (model , example_inputs = (torch .ones (5 ),), compile_spec = compile_spec )
243+ .quantize ()
244+ .export ()
245+ .to_edge ()
246+ .partition ()
247+ .dump_operator_distribution ()
248+ )
249+ assert "Can not get operator distribution for Vela command stream." in caplog .text
0 commit comments