Skip to content

Commit 365d4c1

Browse files
authored
Arm backend: use tosa_ref_model only if installed
Differential Revision: D79887501 Pull Request resolved: #13221
1 parent 02a203e commit 365d4c1

File tree

1 file changed

+39
-5
lines changed

1 file changed

+39
-5
lines changed

backends/arm/test/tester/test_pipeline.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7+
import warnings as _warnings
78

89
from typing import (
910
Any,
@@ -226,6 +227,12 @@ def find_pos(self, stage_id: str):
226227

227228
raise Exception(f"Stage id {stage_id} not found in pipeline")
228229

230+
def has_stage(self, stage_id: str):
231+
try:
232+
return self.find_pos(stage_id) >= 0
233+
except:
234+
return False
235+
229236
def add_stage_after(self, stage_id: str, func: Callable, *args, **kwargs):
230237
"""Adds a stage after the given stage id."""
231238
pos = self.find_pos(stage_id) + 1
@@ -271,7 +278,34 @@ def run(self):
271278
raise e
272279

273280

274-
class TosaPipelineINT(BasePipelineMaker, Generic[T]):
281+
class TOSAPipelineMaker(BasePipelineMaker, Generic[T]):
282+
283+
@staticmethod
284+
def is_tosa_ref_model_available():
285+
"""Checks if the TOSA reference model is available."""
286+
# Not all deployments of ET have the TOSA reference model available.
287+
# Make sure we don't try to use it if it's not available.
288+
try:
289+
import tosa_reference_model
290+
291+
# Check if the module has content
292+
return bool(dir(tosa_reference_model))
293+
except ImportError:
294+
return False
295+
296+
def run(self):
297+
if (
298+
self.has_stage("run_method_and_compare_outputs")
299+
and not self.is_tosa_ref_model_available()
300+
):
301+
_warnings.warn(
302+
"Warning: Skipping run_method_and_compare_outputs stage. TOSA reference model is not available."
303+
)
304+
self.pop_stage("run_method_and_compare_outputs")
305+
super().run()
306+
307+
308+
class TosaPipelineINT(TOSAPipelineMaker, Generic[T]):
275309
"""
276310
Lowers a graph to INT TOSA spec (with quantization) and tests it with the TOSA reference model.
277311
@@ -375,7 +409,7 @@ def __init__(
375409
)
376410

377411

378-
class TosaPipelineFP(BasePipelineMaker, Generic[T]):
412+
class TosaPipelineFP(TOSAPipelineMaker, Generic[T]):
379413
"""
380414
Lowers a graph to FP TOSA spec and tests it with the TOSA reference model.
381415
@@ -629,7 +663,7 @@ def __init__(
629663
)
630664

631665

632-
class PassPipeline(BasePipelineMaker, Generic[T]):
666+
class PassPipeline(TOSAPipelineMaker, Generic[T]):
633667
"""
634668
Runs single passes directly on an edge_program and checks operators before/after.
635669
@@ -719,7 +753,7 @@ def __init__(
719753
self.add_stage(self.tester.run_method_and_compare_outputs)
720754

721755

722-
class TransformAnnotationPassPipeline(BasePipelineMaker, Generic[T]):
756+
class TransformAnnotationPassPipeline(TOSAPipelineMaker, Generic[T]):
723757
"""
724758
Runs transform_for_annotation_pipeline passes directly on an exported program and checks output.
725759
@@ -775,7 +809,7 @@ def __init__(
775809
)
776810

777811

778-
class OpNotSupportedPipeline(BasePipelineMaker, Generic[T]):
812+
class OpNotSupportedPipeline(TOSAPipelineMaker, Generic[T]):
779813
"""
780814
Runs the partitioner on a module and checks that ops are not delegated to test
781815
SupportedTOSAOperatorChecks.

0 commit comments

Comments
 (0)