Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions backends/arm/test/tester/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import logging
import warnings as _warnings

from typing import (
Any,
Expand Down Expand Up @@ -226,6 +227,12 @@ def find_pos(self, stage_id: str):

raise Exception(f"Stage id {stage_id} not found in pipeline")

def has_stage(self, stage_id: str):
try:
return self.find_pos(stage_id) >= 0
except:
return False

def add_stage_after(self, stage_id: str, func: Callable, *args, **kwargs):
"""Adds a stage after the given stage id."""
pos = self.find_pos(stage_id) + 1
Expand Down Expand Up @@ -271,7 +278,34 @@ def run(self):
raise e


class TosaPipelineINT(BasePipelineMaker, Generic[T]):
class TOSAPipelineMaker(BasePipelineMaker, Generic[T]):

@staticmethod
def is_tosa_ref_model_available():
"""Checks if the TOSA reference model is available."""
# Not all deployments of ET have the TOSA reference model available.
# Make sure we don't try to use it if it's not available.
try:
import tosa_reference_model

# Check if the module has content
return bool(dir(tosa_reference_model))
except ImportError:
return False

def run(self):
if (
self.has_stage("run_method_and_compare_outputs")
and not self.is_tosa_ref_model_available()
):
_warnings.warn(
"Warning: Skipping run_method_and_compare_outputs stage. TOSA reference model is not available."
)
self.pop_stage("run_method_and_compare_outputs")
super().run()


class TosaPipelineINT(TOSAPipelineMaker, Generic[T]):
"""
Lowers a graph to INT TOSA spec (with quantization) and tests it with the TOSA reference model.

Expand Down Expand Up @@ -375,7 +409,7 @@ def __init__(
)


class TosaPipelineFP(BasePipelineMaker, Generic[T]):
class TosaPipelineFP(TOSAPipelineMaker, Generic[T]):
"""
Lowers a graph to FP TOSA spec and tests it with the TOSA reference model.

Expand Down Expand Up @@ -629,7 +663,7 @@ def __init__(
)


class PassPipeline(BasePipelineMaker, Generic[T]):
class PassPipeline(TOSAPipelineMaker, Generic[T]):
"""
Runs single passes directly on an edge_program and checks operators before/after.

Expand Down Expand Up @@ -719,7 +753,7 @@ def __init__(
self.add_stage(self.tester.run_method_and_compare_outputs)


class TransformAnnotationPassPipeline(BasePipelineMaker, Generic[T]):
class TransformAnnotationPassPipeline(TOSAPipelineMaker, Generic[T]):
"""
Runs transform_for_annotation_pipeline passes directly on an exported program and checks output.

Expand Down Expand Up @@ -775,7 +809,7 @@ def __init__(
)


class OpNotSupportedPipeline(BasePipelineMaker, Generic[T]):
class OpNotSupportedPipeline(TOSAPipelineMaker, Generic[T]):
"""
Runs the partitioner on a module and checks that ops are not delegated to test
SupportedTOSAOperatorChecks.
Expand Down
Loading