Skip to content

Commit cf6f328

Browse files
committed
Arm backend: use tosa_ref_model only if installed
Not making these tests XFail because there is still value in running these tests w/o validating output since we do go through a lot of other checks in the AoT flow when generating a PTE. Tosa ref model being not installed is not the common case anyway. Added explicit warnings (which should show through pytest) as a reminder for the comparison is being skipped.
1 parent fafcf13 commit cf6f328

File tree

3 files changed

+79
-25
lines changed

3 files changed

+79
-25
lines changed

backends/arm/test/models/test_conformer.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from typing import Tuple
77

8+
import conftest
89
import pytest
910

1011
import torch
@@ -56,6 +57,7 @@ def test_conformer_tosa_FP():
5657
TestConformer.model_example_inputs,
5758
aten_op=TestConformer.aten_ops,
5859
exir_op=[],
60+
run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"),
5961
use_to_edge_transform_and_lower=True,
6062
)
6163
pipeline.run()
@@ -67,17 +69,24 @@ def test_conformer_tosa_INT():
6769
TestConformer.model_example_inputs,
6870
aten_op=TestConformer.aten_ops,
6971
exir_op=[],
72+
run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"),
7073
use_to_edge_transform_and_lower=True,
7174
)
7275
pipeline.pop_stage("check_count.exir")
73-
pipeline.change_args(
74-
"run_method_and_compare_outputs",
75-
get_test_inputs(
76-
TestConformer.dim, TestConformer.lengths, TestConformer.num_examples
77-
),
78-
rtol=1.0,
79-
atol=3.0,
80-
)
76+
77+
try:
78+
if pipeline.find_pos("run_method_and_compare_outputs") >=0:
79+
pipeline.change_args(
80+
"run_method_and_compare_outputs",
81+
get_test_inputs(
82+
TestConformer.dim, TestConformer.lengths, TestConformer.num_examples
83+
),
84+
rtol=1.0,
85+
atol=3.0,
86+
)
87+
except Exception as e:
88+
# tosa_ref_model must not be available
89+
assert not conftest.is_option_enabled("tosa_ref_model"), f"TOSA reference model should be disabled, but error occurred: {e}"
8190
pipeline.run()
8291

8392

backends/arm/test/models/test_lstm_arm.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from executorch.backends.arm.test.tester.test_pipeline import (
1212
EthosU55PipelineINT,
1313
EthosU85PipelineINT,
14+
is_tosa_ref_model_available,
1415
TosaPipelineFP,
1516
TosaPipelineINT,
1617
VgfPipeline,
@@ -51,7 +52,16 @@ def test_lstm_tosa_FP():
5152
exir_op=[],
5253
use_to_edge_transform_and_lower=True,
5354
)
54-
pipeline.change_args("run_method_and_compare_outputs", get_test_inputs(), atol=3e-1)
55+
try:
56+
if pipeline.find_pos("run_method_and_compare_outputs") >= 0:
57+
pipeline.change_args(
58+
"run_method_and_compare_outputs", get_test_inputs(), atol=3e-1
59+
)
60+
except Exception as e:
61+
# tosa_ref_model must not be available
62+
assert (
63+
is_tosa_ref_model_available() == False
64+
), "Expected TOSA reference model to be disabled, but error occurred: {e}"
5565
pipeline.run()
5666

5767

@@ -63,9 +73,16 @@ def test_lstm_tosa_INT():
6373
exir_op=[],
6474
use_to_edge_transform_and_lower=True,
6575
)
66-
pipeline.change_args(
67-
"run_method_and_compare_outputs", get_test_inputs(), atol=3e-1, qtol=1.0
68-
)
76+
try:
77+
if pipeline.find_pos("run_method_and_compare_outputs") >= 0:
78+
pipeline.change_args(
79+
"run_method_and_compare_outputs", get_test_inputs(), atol=3e-1, qtol=1.0
80+
)
81+
except Exception as e:
82+
# tosa_ref_model must not be available
83+
assert (
84+
is_tosa_ref_model_available() == False
85+
), "Expected TOSA reference model to be disabled, but error occurred: {e}"
6986
pipeline.run()
7087

7188

backends/arm/test/tester/test_pipeline.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7+
import warnings as _warnings
78

89
from typing import (
910
Any,
@@ -41,6 +42,20 @@
4142
""" Generic type used for test data in the pipeline. Depends on which type the operator expects."""
4243

4344

45+
def is_tosa_ref_model_available():
46+
"""Checks if the TOSA reference model is available."""
47+
# Not all deployments of ET have the TOSA reference model available.
48+
# Make sure we don't try to use it if it's not available.
49+
try:
50+
import tosa_tools.tosa_ref_model as tosa_reference_model
51+
52+
if not dir(tosa_reference_model):
53+
return False
54+
except ImportError:
55+
return False
56+
return True
57+
58+
4459
class BasePipelineMaker(Generic[T]):
4560
"""
4661
The BasePiplineMaker defines a list of stages to be applied to a torch.nn.module for lowering it
@@ -283,8 +298,6 @@ class TosaPipelineINT(BasePipelineMaker, Generic[T]):
283298
exir_ops: Exir dialect ops expected to be found in the graph after to_edge.
284299
if not using use_edge_to_transform_and_lower.
285300
286-
run_on_tosa_ref_model: Set to true to test the tosa file on the TOSA reference model.
287-
288301
tosa_version: A string for identifying the TOSA version, see common.get_tosa_compile_spec for
289302
options.
290303
use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module.
@@ -297,7 +310,6 @@ def __init__(
297310
test_data: T,
298311
aten_op: str | List[str],
299312
exir_op: Optional[str | List[str]] = None,
300-
run_on_tosa_ref_model: bool = True,
301313
symmetric_io_quantization: bool = False,
302314
per_channel_quantization: bool = True,
303315
use_to_edge_transform_and_lower: bool = True,
@@ -360,14 +372,18 @@ def __init__(
360372
suffix="quant_nodes",
361373
)
362374

363-
if run_on_tosa_ref_model:
375+
if is_tosa_ref_model_available():
364376
self.add_stage(
365377
self.tester.run_method_and_compare_outputs,
366378
atol=atol,
367379
rtol=rtol,
368380
qtol=qtol,
369381
inputs=self.test_data,
370382
)
383+
else:
384+
_warnings.warn(
385+
"Warning: Skipping run_method_and_compare_outputs stage. tosa reference model is not available."
386+
)
371387

372388

373389
class TosaPipelineFP(BasePipelineMaker, Generic[T]):
@@ -382,8 +398,6 @@ class TosaPipelineFP(BasePipelineMaker, Generic[T]):
382398
exir_ops: Exir dialect ops expected to be found in the graph after to_edge.
383399
if not using use_edge_to_transform_and_lower.
384400
385-
run_on_tosa_ref_model: Set to true to test the tosa file on the TOSA reference model.
386-
387401
tosa_version: A string for identifying the TOSA version, see common.get_tosa_compile_spec for
388402
options.
389403
use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module.
@@ -435,14 +449,18 @@ def __init__(
435449
suffix="quant_nodes",
436450
)
437451

438-
if run_on_tosa_ref_model:
452+
if is_tosa_ref_model_available():
439453
self.add_stage(
440454
self.tester.run_method_and_compare_outputs,
441455
atol=atol,
442456
rtol=rtol,
443457
qtol=qtol,
444458
inputs=self.test_data,
445459
)
460+
else:
461+
_warnings.warn(
462+
"Warning: Skipping run_method_and_compare_outputs stage. tosa reference model is not available"
463+
)
446464

447465

448466
class EthosU55PipelineINT(BasePipelineMaker, Generic[T]):
@@ -701,7 +719,12 @@ def __init__(
701719
self.add_stage(self.tester.check_count, ops_after_pass, suffix="after")
702720
if ops_not_after_pass:
703721
self.add_stage(self.tester.check_not, ops_not_after_pass, suffix="after")
704-
self.add_stage(self.tester.run_method_and_compare_outputs)
722+
if is_tosa_ref_model_available():
723+
self.add_stage(self.tester.run_method_and_compare_outputs)
724+
else:
725+
_warnings.warn(
726+
"Warning: Skipping run_method_and_compare_outputs stage. Tosa reference model is not available."
727+
)
705728

706729

707730
class TransformAnnotationPassPipeline(BasePipelineMaker, Generic[T]):
@@ -748,11 +771,16 @@ def __init__(
748771
self.pop_stage("to_executorch")
749772
self.pop_stage("to_edge_transform_and_lower")
750773
self.pop_stage("check.aten")
751-
self.add_stage(
752-
self.tester.run_method_and_compare_outputs,
753-
inputs=test_data,
754-
run_eager_mode=True,
755-
)
774+
if is_tosa_ref_model_available():
775+
self.add_stage(
776+
self.tester.run_method_and_compare_outputs,
777+
inputs=test_data,
778+
run_eager_mode=True,
779+
)
780+
else:
781+
_warnings.warn(
782+
"Warning: Skipping run_method_and_compare_outputs stage. Tosa reference model is not available."
783+
)
756784

757785

758786
class OpNotSupportedPipeline(BasePipelineMaker, Generic[T]):

0 commit comments

Comments
 (0)