diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index e2335c07b87..0340710bee4 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -217,13 +217,6 @@ def is_vgf(compile_spec: List[CompileSpec]) -> bool: return False -def get_tosa_spec(compile_spec: List[CompileSpec]) -> TosaSpecification: - for spec in compile_spec: - if spec.key == "tosa_spec": - return TosaSpecification.create_from_string(spec.value.decode()) - raise ValueError("Could not find TOSA version in CompileSpec") - - def get_intermediate_path(compile_spec: List[CompileSpec]) -> Optional[str]: for spec in compile_spec: if spec.key == "debug_artifact_path": diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 28bb70be2b1..4518feeb403 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -20,12 +20,11 @@ from executorch.backends.arm._passes import ArmPassManager from executorch.backends.arm.quantizer import QuantizationConfig -from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.arm.tosa_specification import get_tosa_spec, TosaSpecification from .arm_quantizer_utils import is_annotated, mark_node_as_annotated from .quantization_annotator import annotate_graph from executorch.backends.arm.arm_backend import ( - get_tosa_spec, is_ethosu, is_vgf, ) # usort: skip diff --git a/backends/arm/test/misc/test_tosa_spec.py b/backends/arm/test/misc/test_tosa_spec.py index 66f7dcf0745..a2f5f7d85ee 100644 --- a/backends/arm/test/misc/test_tosa_spec.py +++ b/backends/arm/test/misc/test_tosa_spec.py @@ -5,9 +5,11 @@ import unittest -from executorch.backends.arm.arm_backend import get_tosa_spec - -from executorch.backends.arm.tosa_specification import Tosa_1_00, TosaSpecification +from executorch.backends.arm.tosa_specification import ( + get_tosa_spec, + Tosa_1_00, + TosaSpecification, +) from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized # type: ignore[import-untyped] diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 421ec0adc61..c56ce3542b6 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -8,7 +8,6 @@ from typing import Tuple import torch -from executorch.backends.arm.arm_backend import get_tosa_spec from executorch.backends.arm.quantizer import arm_quantizer from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( @@ -18,7 +17,7 @@ TosaPipelineINT, VgfPipeline, ) -from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.arm.tosa_specification import get_tosa_spec, TosaSpecification from executorch.backends.xnnpack.test.tester import Quantize from torchao.quantization.pt2e import HistogramObserver from torchao.quantization.pt2e.quantizer import QuantizationSpec diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index bd06e817d8f..e3336f1a684 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -18,10 +18,13 @@ import numpy as np import torch -from executorch.backends.arm.arm_backend import get_tosa_spec, is_tosa +from executorch.backends.arm.arm_backend import is_tosa from executorch.backends.arm.test.conftest import is_option_enabled -from executorch.backends.arm.tosa_specification import Tosa_1_00, TosaSpecification - +from executorch.backends.arm.tosa_specification import ( + get_tosa_spec, + Tosa_1_00, + TosaSpecification, +) from executorch.exir import ExecutorchProgramManager, ExportedProgram from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.lowered_backend_module import LoweredBackendModule diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index b848af2d25c..f71a99a0398 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -34,7 +34,6 @@ from executorch.backends.arm.arm_backend import ( get_intermediate_path, - get_tosa_spec, is_ethosu, is_tosa, is_vgf, @@ -62,7 +61,7 @@ ) from executorch.backends.arm.tosa_mapping import extract_tensor_meta from executorch.backends.arm.tosa_partitioner import TOSAPartitioner -from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.arm.tosa_specification import get_tosa_spec, TosaSpecification from executorch.backends.arm.vgf_partitioner import VgfPartitioner diff --git a/backends/arm/tosa_backend.py b/backends/arm/tosa_backend.py index d2d80cd885d..7062d68b944 100644 --- a/backends/arm/tosa_backend.py +++ b/backends/arm/tosa_backend.py @@ -14,8 +14,8 @@ from typing import cast, final, List import serializer.tosa_serializer as ts # type: ignore -from executorch.backends.arm.arm_backend import get_tosa_spec from executorch.backends.arm.operators.node_visitor import get_node_visitors +from executorch.backends.arm.tosa_specification import get_tosa_spec from executorch.backends.arm._passes import ( ArmPassManager, ) # usort: skip diff --git a/backends/arm/tosa_partitioner.py b/backends/arm/tosa_partitioner.py index 8c923568265..ad960036fcf 100644 --- a/backends/arm/tosa_partitioner.py +++ b/backends/arm/tosa_partitioner.py @@ -11,7 +11,6 @@ import torch from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.backends.arm.arm_backend import ( - get_tosa_spec, is_tosa, ) # usort: skip from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor @@ -19,6 +18,7 @@ tosa_support_factory, ) from executorch.backends.arm.tosa_backend import TOSABackend +from executorch.backends.arm.tosa_specification import get_tosa_spec from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( DelegationSpec, diff --git a/backends/arm/tosa_specification.py b/backends/arm/tosa_specification.py index 5f16605aa56..6bb22da7e79 100644 --- a/backends/arm/tosa_specification.py +++ b/backends/arm/tosa_specification.py @@ -15,6 +15,10 @@ import re from typing import List +from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found] + CompileSpec, +) + from packaging.version import Version @@ -188,3 +192,10 @@ def get_context_spec() -> TosaSpecification: return TosaLoweringContext.tosa_spec_var.get() except LookupError: raise RuntimeError("Function must be executed within a TosaLoweringContext") + + +def get_tosa_spec(compile_spec: List[CompileSpec]) -> TosaSpecification: + for spec in compile_spec: + if spec.key == "tosa_spec": + return TosaSpecification.create_from_string(spec.value.decode()) + raise ValueError("Could not find TOSA version in CompileSpec") diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index d6a1eab3205..daa35d3c6f9 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -19,7 +19,6 @@ from examples.devtools.scripts.export_bundled_program import save_bundled_program from executorch.backends.arm.arm_backend import ( ArmCompileSpecBuilder, - get_tosa_spec, is_ethosu, is_tosa, is_vgf, @@ -32,7 +31,7 @@ VgfQuantizer, ) from executorch.backends.arm.tosa_partitioner import TOSAPartitioner -from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.arm.tosa_specification import get_tosa_spec, TosaSpecification from executorch.backends.arm.util.arm_model_evaluator import ( GenericModelEvaluator,