|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 | 5 |
|
6 | 6 | import logging
|
| 7 | +import warnings as _warnings |
7 | 8 |
|
8 | 9 | from typing import (
|
9 | 10 | Any,
|
@@ -226,6 +227,12 @@ def find_pos(self, stage_id: str):
|
226 | 227 |
|
227 | 228 | raise Exception(f"Stage id {stage_id} not found in pipeline")
|
228 | 229 |
|
| 230 | + def has_stage(self, stage_id: str): |
| 231 | + try: |
| 232 | + return self.find_pos(stage_id) >= 0 |
| 233 | + except: |
| 234 | + return False |
| 235 | + |
229 | 236 | def add_stage_after(self, stage_id: str, func: Callable, *args, **kwargs):
|
230 | 237 | """Adds a stage after the given stage id."""
|
231 | 238 | pos = self.find_pos(stage_id) + 1
|
@@ -271,7 +278,34 @@ def run(self):
|
271 | 278 | raise e
|
272 | 279 |
|
273 | 280 |
|
274 |
| -class TosaPipelineINT(BasePipelineMaker, Generic[T]): |
| 281 | +class TOSAPipelineMaker(BasePipelineMaker, Generic[T]): |
| 282 | + |
| 283 | + @staticmethod |
| 284 | + def is_tosa_ref_model_available(): |
| 285 | + """Checks if the TOSA reference model is available.""" |
| 286 | + # Not all deployments of ET have the TOSA reference model available. |
| 287 | + # Make sure we don't try to use it if it's not available. |
| 288 | + try: |
| 289 | + import tosa_reference_model |
| 290 | + |
| 291 | + # Check if the module has content |
| 292 | + return bool(dir(tosa_reference_model)) |
| 293 | + except ImportError: |
| 294 | + return False |
| 295 | + |
| 296 | + def run(self): |
| 297 | + if ( |
| 298 | + self.has_stage("run_method_and_compare_outputs") |
| 299 | + and not self.is_tosa_ref_model_available() |
| 300 | + ): |
| 301 | + _warnings.warn( |
| 302 | + "Warning: Skipping run_method_and_compare_outputs stage. TOSA reference model is not available." |
| 303 | + ) |
| 304 | + self.pop_stage("run_method_and_compare_outputs") |
| 305 | + super().run() |
| 306 | + |
| 307 | + |
| 308 | +class TosaPipelineINT(TOSAPipelineMaker, Generic[T]): |
275 | 309 | """
|
276 | 310 | Lowers a graph to INT TOSA spec (with quantization) and tests it with the TOSA reference model.
|
277 | 311 |
|
@@ -375,7 +409,7 @@ def __init__(
|
375 | 409 | )
|
376 | 410 |
|
377 | 411 |
|
378 |
| -class TosaPipelineFP(BasePipelineMaker, Generic[T]): |
| 412 | +class TosaPipelineFP(TOSAPipelineMaker, Generic[T]): |
379 | 413 | """
|
380 | 414 | Lowers a graph to FP TOSA spec and tests it with the TOSA reference model.
|
381 | 415 |
|
@@ -629,7 +663,7 @@ def __init__(
|
629 | 663 | )
|
630 | 664 |
|
631 | 665 |
|
632 |
| -class PassPipeline(BasePipelineMaker, Generic[T]): |
| 666 | +class PassPipeline(TOSAPipelineMaker, Generic[T]): |
633 | 667 | """
|
634 | 668 | Runs single passes directly on an edge_program and checks operators before/after.
|
635 | 669 |
|
@@ -719,7 +753,7 @@ def __init__(
|
719 | 753 | self.add_stage(self.tester.run_method_and_compare_outputs)
|
720 | 754 |
|
721 | 755 |
|
722 |
| -class TransformAnnotationPassPipeline(BasePipelineMaker, Generic[T]): |
| 756 | +class TransformAnnotationPassPipeline(TOSAPipelineMaker, Generic[T]): |
723 | 757 | """
|
724 | 758 | Runs transform_for_annotation_pipeline passes directly on an exported program and checks output.
|
725 | 759 |
|
@@ -775,7 +809,7 @@ def __init__(
|
775 | 809 | )
|
776 | 810 |
|
777 | 811 |
|
778 |
| -class OpNotSupportedPipeline(BasePipelineMaker, Generic[T]): |
| 812 | +class OpNotSupportedPipeline(TOSAPipelineMaker, Generic[T]): |
779 | 813 | """
|
780 | 814 | Runs the partitioner on a module and checks that ops are not delegated to test
|
781 | 815 | SupportedTOSAOperatorChecks.
|
|
0 commit comments