@@ -258,6 +258,8 @@ class TosaPipelineBI(BasePipelineMaker, Generic[T]):
258258 exir_ops: Exir dialect ops expected to be found in the graph after to_edge.
259259 if not using use_edge_to_transform_and_lower.
260260
261+ run_on_tosa_ref_model: Set to true to test the tosa file on the TOSA reference model.
262+
261263 tosa_version: A string for identifying the TOSA version, see common.get_tosa_compile_spec for
262264 options.
263265 use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module.
@@ -270,6 +272,7 @@ def __init__(
270272 test_data : T ,
271273 aten_op : str | List [str ],
272274 exir_op : Optional [str | List [str ]] = None ,
275+ run_on_tosa_ref_model : bool = True ,
273276 tosa_version : str = "TOSA-0.80+BI" ,
274277 symmetric_io_quantization : bool = False ,
275278 use_to_edge_transform_and_lower : bool = True ,
@@ -324,13 +327,14 @@ def __init__(
324327 suffix = "quant_nodes" ,
325328 )
326329
327- self .add_stage (
328- self .tester .run_method_and_compare_outputs ,
329- atol = atol ,
330- rtol = rtol ,
331- qtol = qtol ,
332- inputs = self .test_data ,
333- )
330+ if run_on_tosa_ref_model :
331+ self .add_stage (
332+ self .tester .run_method_and_compare_outputs ,
333+ atol = atol ,
334+ rtol = rtol ,
335+ qtol = qtol ,
336+ inputs = self .test_data ,
337+ )
334338
335339
336340class TosaPipelineMI (BasePipelineMaker , Generic [T ]):
@@ -345,6 +349,8 @@ class TosaPipelineMI(BasePipelineMaker, Generic[T]):
345349 exir_ops: Exir dialect ops expected to be found in the graph after to_edge.
346350 if not using use_edge_to_transform_and_lower.
347351
352+ run_on_tosa_ref_model: Set to true to test the tosa file on the TOSA reference model.
353+
348354 tosa_version: A string for identifying the TOSA version, see common.get_tosa_compile_spec for
349355 options.
350356 use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module.
@@ -357,6 +363,7 @@ def __init__(
357363 test_data : T ,
358364 aten_op : str | List [str ],
359365 exir_op : Optional [str | List [str ]] = None ,
366+ run_on_tosa_ref_model : bool = True ,
360367 tosa_version : str = "TOSA-0.80+MI" ,
361368 use_to_edge_transform_and_lower : bool = True ,
362369 custom_path : str = None ,
@@ -385,13 +392,14 @@ def __init__(
385392 suffix = "quant_nodes" ,
386393 )
387394
388- self .add_stage (
389- self .tester .run_method_and_compare_outputs ,
390- atol = atol ,
391- rtol = rtol ,
392- qtol = qtol ,
393- inputs = self .test_data ,
394- )
395+ if run_on_tosa_ref_model :
396+ self .add_stage (
397+ self .tester .run_method_and_compare_outputs ,
398+ atol = atol ,
399+ rtol = rtol ,
400+ qtol = qtol ,
401+ inputs = self .test_data ,
402+ )
395403
396404
397405class EthosU55PipelineBI (BasePipelineMaker , Generic [T ]):
0 commit comments