44# LICENSE file in the root directory of this source tree.
55
66import logging
7+ import warnings as _warnings
78
89from typing import (
910 Any ,
4142""" Generic type used for test data in the pipeline. Depends on which type the operator expects."""
4243
4344
45+ def is_tosa_ref_model_available ():
46+ """Checks if the TOSA reference model is available."""
47+ # Not all deployments of ET have the TOSA reference model available.
48+ # Make sure we don't try to use it if it's not available.
49+ try :
50+ import tosa_tools .tosa_ref_model as tosa_reference_model
51+
52+ if not dir (tosa_reference_model ):
53+ return False
54+ except ImportError :
55+ return False
56+ return True
57+
58+
4459class BasePipelineMaker (Generic [T ]):
4560 """
4661 The BasePiplineMaker defines a list of stages to be applied to a torch.nn.module for lowering it
@@ -283,8 +298,6 @@ class TosaPipelineINT(BasePipelineMaker, Generic[T]):
283298 exir_ops: Exir dialect ops expected to be found in the graph after to_edge.
284299 if not using use_edge_to_transform_and_lower.
285300
286- run_on_tosa_ref_model: Set to true to test the tosa file on the TOSA reference model.
287-
288301 tosa_version: A string for identifying the TOSA version, see common.get_tosa_compile_spec for
289302 options.
290303 use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module.
@@ -297,7 +310,6 @@ def __init__(
297310 test_data : T ,
298311 aten_op : str | List [str ],
299312 exir_op : Optional [str | List [str ]] = None ,
300- run_on_tosa_ref_model : bool = True ,
301313 symmetric_io_quantization : bool = False ,
302314 per_channel_quantization : bool = True ,
303315 use_to_edge_transform_and_lower : bool = True ,
@@ -360,14 +372,18 @@ def __init__(
360372 suffix = "quant_nodes" ,
361373 )
362374
363- if run_on_tosa_ref_model :
375+ if is_tosa_ref_model_available () :
364376 self .add_stage (
365377 self .tester .run_method_and_compare_outputs ,
366378 atol = atol ,
367379 rtol = rtol ,
368380 qtol = qtol ,
369381 inputs = self .test_data ,
370382 )
383+ else :
384+ _warnings .warn (
385+ "Warning: Skipping run_method_and_compare_outputs stage. tosa reference model is not available."
386+ )
371387
372388
373389class TosaPipelineFP (BasePipelineMaker , Generic [T ]):
@@ -382,8 +398,6 @@ class TosaPipelineFP(BasePipelineMaker, Generic[T]):
382398 exir_ops: Exir dialect ops expected to be found in the graph after to_edge.
383399 if not using use_edge_to_transform_and_lower.
384400
385- run_on_tosa_ref_model: Set to true to test the tosa file on the TOSA reference model.
386-
387401 tosa_version: A string for identifying the TOSA version, see common.get_tosa_compile_spec for
388402 options.
389403 use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module.
@@ -435,14 +449,18 @@ def __init__(
435449 suffix = "quant_nodes" ,
436450 )
437451
438- if run_on_tosa_ref_model :
452+ if is_tosa_ref_model_available () :
439453 self .add_stage (
440454 self .tester .run_method_and_compare_outputs ,
441455 atol = atol ,
442456 rtol = rtol ,
443457 qtol = qtol ,
444458 inputs = self .test_data ,
445459 )
460+ else :
461+ _warnings .warn (
462+ "Warning: Skipping run_method_and_compare_outputs stage. tosa reference model is not available"
463+ )
446464
447465
448466class EthosU55PipelineINT (BasePipelineMaker , Generic [T ]):
@@ -701,7 +719,12 @@ def __init__(
701719 self .add_stage (self .tester .check_count , ops_after_pass , suffix = "after" )
702720 if ops_not_after_pass :
703721 self .add_stage (self .tester .check_not , ops_not_after_pass , suffix = "after" )
704- self .add_stage (self .tester .run_method_and_compare_outputs )
722+ if is_tosa_ref_model_available ():
723+ self .add_stage (self .tester .run_method_and_compare_outputs )
724+ else :
725+ _warnings .warn (
726+ "Warning: Skipping run_method_and_compare_outputs stage. Tosa reference model is not available."
727+ )
705728
706729
707730class TransformAnnotationPassPipeline (BasePipelineMaker , Generic [T ]):
@@ -748,11 +771,16 @@ def __init__(
748771 self .pop_stage ("to_executorch" )
749772 self .pop_stage ("to_edge_transform_and_lower" )
750773 self .pop_stage ("check.aten" )
751- self .add_stage (
752- self .tester .run_method_and_compare_outputs ,
753- inputs = test_data ,
754- run_eager_mode = True ,
755- )
774+ if is_tosa_ref_model_available ():
775+ self .add_stage (
776+ self .tester .run_method_and_compare_outputs ,
777+ inputs = test_data ,
778+ run_eager_mode = True ,
779+ )
780+ else :
781+ _warnings .warn (
782+ "Warning: Skipping run_method_and_compare_outputs stage. Tosa reference model is not available."
783+ )
756784
757785
758786class OpNotSupportedPipeline (BasePipelineMaker , Generic [T ]):
0 commit comments