|
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
6 | 6 | import logging |
| 7 | +import warnings as _warnings |
7 | 8 |
|
8 | 9 | from typing import ( |
9 | 10 | Any, |
@@ -226,6 +227,12 @@ def find_pos(self, stage_id: str): |
226 | 227 |
|
227 | 228 | raise Exception(f"Stage id {stage_id} not found in pipeline") |
228 | 229 |
|
| 230 | + def has_stage(self, stage_id: str): |
| 231 | + try: |
| 232 | + return self.find_pos(stage_id) >= 0 |
| 233 | + except: |
| 234 | + return False |
| 235 | + |
229 | 236 | def add_stage_after(self, stage_id: str, func: Callable, *args, **kwargs): |
230 | 237 | """Adds a stage after the given stage id.""" |
231 | 238 | pos = self.find_pos(stage_id) + 1 |
@@ -271,7 +278,34 @@ def run(self): |
271 | 278 | raise e |
272 | 279 |
|
273 | 280 |
|
274 | | -class TosaPipelineINT(BasePipelineMaker, Generic[T]): |
| 281 | +class TOSAPipelineMaker(BasePipelineMaker, Generic[T]): |
| 282 | + |
| 283 | + @staticmethod |
| 284 | + def is_tosa_ref_model_available(): |
| 285 | + """Checks if the TOSA reference model is available.""" |
| 286 | + # Not all deployments of ET have the TOSA reference model available. |
| 287 | + # Make sure we don't try to use it if it's not available. |
| 288 | + try: |
| 289 | + import tosa_reference_model |
| 290 | + |
| 291 | + # Check if the module has content |
| 292 | + return bool(dir(tosa_reference_model)) |
| 293 | + except ImportError: |
| 294 | + return False |
| 295 | + |
| 296 | + def run(self): |
| 297 | + if ( |
| 298 | + self.has_stage("run_method_and_compare_outputs") |
| 299 | + and not self.is_tosa_ref_model_available() |
| 300 | + ): |
| 301 | + _warnings.warn( |
| 302 | + "Warning: Skipping run_method_and_compare_outputs stage. TOSA reference model is not available." |
| 303 | + ) |
| 304 | + self.pop_stage("run_method_and_compare_outputs") |
| 305 | + super().run() |
| 306 | + |
| 307 | + |
| 308 | +class TosaPipelineINT(TOSAPipelineMaker, Generic[T]): |
275 | 309 | """ |
276 | 310 | Lowers a graph to INT TOSA spec (with quantization) and tests it with the TOSA reference model. |
277 | 311 |
|
@@ -375,7 +409,7 @@ def __init__( |
375 | 409 | ) |
376 | 410 |
|
377 | 411 |
|
378 | | -class TosaPipelineFP(BasePipelineMaker, Generic[T]): |
| 412 | +class TosaPipelineFP(TOSAPipelineMaker, Generic[T]): |
379 | 413 | """ |
380 | 414 | Lowers a graph to FP TOSA spec and tests it with the TOSA reference model. |
381 | 415 |
|
@@ -629,7 +663,7 @@ def __init__( |
629 | 663 | ) |
630 | 664 |
|
631 | 665 |
|
632 | | -class PassPipeline(BasePipelineMaker, Generic[T]): |
| 666 | +class PassPipeline(TOSAPipelineMaker, Generic[T]): |
633 | 667 | """ |
634 | 668 | Runs single passes directly on an edge_program and checks operators before/after. |
635 | 669 |
|
@@ -719,7 +753,7 @@ def __init__( |
719 | 753 | self.add_stage(self.tester.run_method_and_compare_outputs) |
720 | 754 |
|
721 | 755 |
|
722 | | -class TransformAnnotationPassPipeline(BasePipelineMaker, Generic[T]): |
| 756 | +class TransformAnnotationPassPipeline(TOSAPipelineMaker, Generic[T]): |
723 | 757 | """ |
724 | 758 | Runs transform_for_annotation_pipeline passes directly on an exported program and checks output. |
725 | 759 |
|
@@ -775,7 +809,7 @@ def __init__( |
775 | 809 | ) |
776 | 810 |
|
777 | 811 |
|
778 | | -class OpNotSupportedPipeline(BasePipelineMaker, Generic[T]): |
| 812 | +class OpNotSupportedPipeline(TOSAPipelineMaker, Generic[T]): |
779 | 813 | """ |
780 | 814 | Runs the partitioner on a module and checks that ops are not delegated to test |
781 | 815 | SupportedTOSAOperatorChecks. |
|
0 commit comments