From 861e98aad0563f03ab917c229fd3ca685f230a22 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Thu, 11 Sep 2025 15:32:09 -0700 Subject: [PATCH 1/4] Arm backend: Refactor compile spec handling (Try 2) (#14191) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/14191 Redoing https://github.com/pytorch/executorch/pull/14111 with additional fixes Reviewed By: digantdesai Differential Revision: D82171193 --- backends/arm/TARGETS | 60 +++-- backends/arm/arm_backend.py | 245 ------------------ backends/arm/common/arm_compile_spec.py | 195 ++++++++++++++ backends/arm/debug/schema.py | 6 +- backends/arm/ethosu/__init__.py | 6 +- backends/arm/ethosu/compile_spec.py | 101 ++++++++ backends/arm/ethosu/partitioner.py | 18 +- backends/arm/operators/node_visitor.py | 4 +- backends/arm/quantizer/arm_quantizer.py | 52 ++-- backends/arm/runtime/VelaBinStream.cpp | 2 +- backends/arm/runtime/VelaBinStream.h | 4 +- .../arm/scripts/TOSA_minimal_example.ipynb | 25 +- backends/arm/test/TARGETS | 10 +- backends/arm/test/common.py | 193 ++++---------- backends/arm/test/misc/test_compile_spec.py | 50 ++++ backends/arm/test/misc/test_debug_feats.py | 6 +- backends/arm/test/misc/test_debug_hook.py | 6 +- .../test/misc/test_extract_io_params_tosa.py | 22 +- backends/arm/test/misc/test_outputs_order.py | 11 +- backends/arm/test/ops/test_add.py | 10 +- backends/arm/test/runner_utils.py | 37 +-- backends/arm/test/targets.bzl | 15 ++ .../arm/test/tester/analyze_output_utils.py | 3 +- backends/arm/test/tester/arm_tester.py | 70 +++-- backends/arm/test/tester/test_pipeline.py | 15 +- backends/arm/tosa/TARGETS | 16 +- backends/arm/tosa/backend.py | 6 +- backends/arm/tosa/compile_spec.py | 25 ++ backends/arm/tosa/partitioner.py | 20 +- backends/arm/vgf/__init__.py | 6 +- backends/arm/vgf/compile_spec.py | 66 +++++ backends/arm/vgf/partitioner.py | 18 +- examples/arm/aot_arm_compiler.py | 49 ++-- examples/arm/ethos_u_minimal_example.ipynb | 11 +- examples/arm/vgf_minimal_example.ipynb | 19 +- 35 files changed, 735 insertions(+), 667 deletions(-) delete mode 100644 backends/arm/arm_backend.py create mode 100644 backends/arm/common/arm_compile_spec.py create mode 100644 backends/arm/ethosu/compile_spec.py create mode 100644 backends/arm/test/misc/test_compile_spec.py create mode 100644 backends/arm/tosa/compile_spec.py create mode 100644 backends/arm/vgf/compile_spec.py diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS index 0d947d63903..b00e8057df6 100644 --- a/backends/arm/TARGETS +++ b/backends/arm/TARGETS @@ -6,29 +6,6 @@ # @noautodeps load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") -runtime.python_library( - name = "ethosu_partitioner", - srcs = [ - "ethosu/__init__.py", - "ethosu/backend.py", - "ethosu/partitioner.py" - ], - deps = [ - ":arm_vela", - "//executorch/backends/arm/tosa:arm_partitioner", - ] -) -runtime.python_library( - name = "vgf_partitioner", - srcs = [ - "vgf/__init__.py", - "vgf/backend.py", - "vgf/partitioner.py" - ], - deps = [ - "//executorch/backends/arm/tosa:arm_partitioner", - ] -) runtime.python_library( name = "constants", srcs = [ @@ -51,10 +28,11 @@ runtime.python_library( "//executorch/exir:lib", ], ) + runtime.python_library( - name = "arm_backend", + name = "arm_compile_spec", srcs = [ - "arm_backend.py", + "common/arm_compile_spec.py", ], deps = [ "fbsource//third-party/pypi/flatbuffers:flatbuffers", @@ -64,11 +42,43 @@ runtime.python_library( "fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/tosa:tosa", "fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/tosa:tosa", ":process_node", + "//executorch/exir/backend:compile_spec_schema", "//executorch/backends/arm/operators:lib", "//executorch/backends/arm/operators:node_visitor", "//executorch/backends/arm/_passes:passes", ], ) +runtime.python_library( + name = "ethosu", + srcs = [ + "ethosu/__init__.py", + "ethosu/backend.py", + "ethosu/compile_spec.py", + "ethosu/partitioner.py", + ], + deps = [ + ":arm_compile_spec", + ":arm_vela", + "//executorch/backends/arm/tosa:specification", + "//executorch/backends/arm/tosa:partitioner", + ], +) + +runtime.python_library( + name = "vgf", + srcs = [ + "vgf/__init__.py", + "vgf/backend.py", + "vgf/compile_spec.py", + "vgf/partitioner.py", + ], + deps = [ + ":arm_compile_spec", + "//executorch/backends/arm/tosa:specification", + "//executorch/backends/arm/tosa:partitioner", + ], +) + runtime.python_library( name = "process_node", srcs = ["process_node.py"], diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py deleted file mode 100644 index 2e71f91dbb6..00000000000 --- a/backends/arm/arm_backend.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -# -# Main implementation of AoT flow to partition and preprocess for Arm target -# backends. Converts via TOSA as an intermediate form supported by AoT and -# JIT compiler flows. -# -from enum import Enum -from typing import List, Optional - -from executorch.backends.arm.tosa import TosaSpecification - -from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found] - CompileSpec, -) - - -class ArmCompileSpecBuilder: - class DebugMode(Enum): - JSON = 1 - TOSA = 2 - - def __init__(self): - self.compile_spec: List[CompileSpec] = [] - self.compiler_flags = [] - self.output_format = None - self.path_for_intermediates = None - self.tosa_spec = None - self.tosa_debug_mode = None - - def vgf_compile_spec( - self, - tosa_spec: TosaSpecification = None, # type: ignore[assignment] - compiler_flags: Optional[str] = "", - ) -> "ArmCompileSpecBuilder": - """ - Generate compile spec for VGF compatible targets - - Args: - compiler_flags: Extra compiler flags for converter_backend - """ - self.output_format = "vgf" - self.compiler_flags = [ - compiler_flags, - ] - - if tosa_spec is None: - tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP") - - tosa_version = tosa_spec.version # type: ignore[attr-defined] - tosa_profiles = tosa_spec.profiles # type: ignore[attr-defined] - - if tosa_version.major != 1: - raise ValueError( - "Arm backend only supports converter-backend for TOSA version 1. " - f"Invalid TOSA version: {tosa_version}" - ) - - if "FP" not in tosa_profiles and "INT" not in tosa_profiles: - raise ValueError( - "Arm backend only supports converter-backend for FP or INT. " - f"Invalid TOSA profile: {tosa_profiles}" - ) - - if len(tosa_profiles) != 1: - raise ValueError( - "For now Arm backend only supports converter-backend for either FP or INT. " - f"Invalid TOSA profile: {tosa_profiles}" - ) - - self.tosa_spec = tosa_spec - - return self - - def ethosu_compile_spec( - self, - target: str, - system_config: Optional[str] = None, - memory_mode: Optional[str] = None, - extra_flags: Optional[str] = None, - config_ini: Optional[str] = "Arm/vela.ini", - ) -> "ArmCompileSpecBuilder": - """ - Generate compile spec for Ethos-U NPU - - Args: - target: Ethos-U accelerator configuration, e.g. ethos-u55-128 - system_config: System configuration to select from the Vel - configuration file - memory_mode: Memory mode to select from the Vela configuration file - extra_flags: Extra flags for the Vela compiler - config_ini: Vela configuration file(s) in Python ConfigParser .ini - file format - """ - assert ( - self.output_format is None - ), f"Output format already set to f{self.output_format}" - self.output_format = "vela" - self.compiler_flags = [ - f"--accelerator-config={target}", - f"--config={config_ini}", - ] - - # default system config and memory mode - if "ethos-u55" in target: - if system_config is None: - system_config = "Ethos_U55_High_End_Embedded" - if memory_mode is None: - memory_mode = "Shared_Sram" - elif "ethos-u85" in target: - if system_config is None: - system_config = "Ethos_U85_SYS_DRAM_Mid" - if memory_mode is None: - memory_mode = "Sram_Only" - else: - raise RuntimeError(f"Unknown ethos target: {target}") - - if system_config is not None: - self.compiler_flags.append(f"--system-config={system_config}") - if memory_mode is not None: - self.compiler_flags.append(f"--memory-mode={memory_mode}") - if extra_flags is not None: - self.compiler_flags.append(extra_flags) - - # We require raw output and regor, so add these flags if absent. This - # overrides any other output setting. - self.compiler_flags.append("--output-format=raw") - self.compiler_flags.append("--debug-force-regor") - - base_tosa_version = "TOSA-1.0+INT+int16" - if "u55" in target: - # Add the Ethos-U55 extension marker - base_tosa_version += "+u55" - self.tosa_spec = TosaSpecification.create_from_string(base_tosa_version) - - return self - - def tosa_compile_spec( - self, tosa_spec: str | TosaSpecification - ) -> "ArmCompileSpecBuilder": - """ - Generate compile spec for TOSA flatbuffer output - """ - assert ( - self.output_format is None - ), f"Output format already set: {self.output_format}" - self.output_format = "tosa" - if isinstance(tosa_spec, TosaSpecification): - self.tosa_spec = tosa_spec - elif isinstance(tosa_spec, str): - self.tosa_spec = TosaSpecification.create_from_string(tosa_spec) - else: - raise RuntimeError(f"Invalid type for {tosa_spec}!") - return self - - def dump_intermediate_artifacts_to( - self, output_path: str - ) -> "ArmCompileSpecBuilder": - """ - Sets a path for dumping intermediate results during such as tosa and pte. - """ - self.path_for_intermediates = output_path - return self - - def dump_debug_info(self, debug_mode: DebugMode) -> "ArmCompileSpecBuilder": - """ - Dump debugging information into the intermediates path - """ - self.tosa_debug_mode = debug_mode.name - return self - - def build(self) -> List[CompileSpec]: - """ - Generate a list of compile spec objects from the builder - """ - assert self.tosa_spec - - # Always supply a TOSA version - self.compile_spec = [CompileSpec("tosa_spec", str(self.tosa_spec).encode())] - - # Add compile flags, these are backend specific, refer to the backend - # documentation. - self.compile_spec += [ - CompileSpec("compile_flags", " ".join(self.compiler_flags).encode()), - ] - - # encode output format - self.compile_spec.append( - CompileSpec("output_format", self.output_format.encode()) - ) - - if self.path_for_intermediates is not None: - self.compile_spec.append( - CompileSpec("debug_artifact_path", self.path_for_intermediates.encode()) - ) - - if self.tosa_debug_mode is not None: - if not self.path_for_intermediates: - raise ValueError( - "dump_debug_info() must be used in conjunction with dump_intermediate_artifacts_to()" - ) - - self.compile_spec.append( - CompileSpec("dump_debug_info", self.tosa_debug_mode.encode()) - ) - - return self.compile_spec - - -def is_tosa(compile_spec: List[CompileSpec]) -> bool: - has_tosa_output = False - has_tosa_spec = False - for spec in compile_spec: - if spec.key == "output_format": - has_tosa_output = spec.value.decode() == "tosa" - if spec.key == "tosa_spec": - has_tosa_spec = True - - return has_tosa_output and has_tosa_spec - - -def is_ethosu(compile_spec: List[CompileSpec]) -> bool: - for spec in compile_spec: - if spec.key == "output_format": - return spec.value.decode() == "vela" - return False - - -def is_vgf(compile_spec: List[CompileSpec]) -> bool: - for spec in compile_spec: - if spec.key == "output_format": - return spec.value.decode() == "vgf" - return False - - -def get_intermediate_path(compile_spec: List[CompileSpec]) -> Optional[str]: - for spec in compile_spec: - if spec.key == "debug_artifact_path": - return spec.value.decode() - return None diff --git a/backends/arm/common/arm_compile_spec.py b/backends/arm/common/arm_compile_spec.py new file mode 100644 index 00000000000..c6818e2716a --- /dev/null +++ b/backends/arm/common/arm_compile_spec.py @@ -0,0 +1,195 @@ +# Copyright 2023-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +# +# Main implementation of AoT flow to partition and preprocess for Arm target +# backends. Converts via TOSA as an intermediate form supported by AoT and +# JIT compiler flows. +# + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum + +from executorch.backends.arm.tosa import TosaSpecification + +from executorch.exir.backend.compile_spec_schema import CompileSpec + + +@dataclass(init=False) +class ArmCompileSpec(ABC): + class DebugMode(Enum): + JSON = 1 + TOSA = 2 + + tosa_spec: TosaSpecification + compiler_flags: list[str] = field(default_factory=list) + path_for_intermediates: str | None = None + tosa_debug_mode: DebugMode | None = None + + _TOSA_SPEC_KEY = "tosa_spec" + _COMPILE_FLAGS_KEY = "compile_flags" + _OUTPUT_FORMAT_KEY = "output_format" + _DEBUG_ARTIFACT_KEY = "debug_artifact_path" + _DEBUG_MODE_KEY = "dump_debug_info" + + def _set_compile_specs( + self, + tosa_spec: TosaSpecification, + compiler_flags: list[str], + path_for_intermediates: str | None = None, + tosa_debug_mode: DebugMode | None = None, + ): + """Set all values of dataclass directly.""" + self.tosa_spec = tosa_spec + self.compiler_flags = compiler_flags + self.path_for_intermediates = path_for_intermediates + self.tosa_debug_mode = tosa_debug_mode + + @classmethod + def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 + tosa_spec: TosaSpecification | None = None + output_format: str | None = None + compiler_flags: list[str] | None = None + path_for_intermediates: str | None = None + tosa_debug_mode: ArmCompileSpec.DebugMode | None = None + unknown_specs: dict[str, str] = {} + for spec in compile_specs: + key = spec.key + val = spec.value.decode() + if key == ArmCompileSpec._TOSA_SPEC_KEY: + if tosa_spec is not None: + raise ValueError("More than one tosa_spec entry in compile spec.") + tosa_spec = TosaSpecification.create_from_string(val) + elif key == ArmCompileSpec._COMPILE_FLAGS_KEY: + if compiler_flags is not None: + raise ValueError( + "More than one compiler flags entry in compile spec." + ) + compiler_flags = val.split(" ") + elif key == ArmCompileSpec._OUTPUT_FORMAT_KEY: + if output_format is not None: + raise ValueError( + "More than one output format entry in compile spec." + ) + output_format = val + elif key == ArmCompileSpec._DEBUG_ARTIFACT_KEY: + if path_for_intermediates is not None: + raise ValueError( + "More than one debug artifact path entry in compile spec." + ) + path_for_intermediates = val + elif key == ArmCompileSpec._DEBUG_MODE_KEY: + if tosa_debug_mode is not None: + raise ValueError( + "More than one tosa_debug_mode entry in compile spec." + ) + tosa_debug_mode = ArmCompileSpec.DebugMode[val] + else: + unknown_specs[key] = val + + if tosa_spec is None: + raise ValueError("No tosa_spec in compile spec.") + if output_format is None: + raise ValueError("No output_format in compile spec.") + if output_format != cls.get_output_format(): + raise ValueError( + f"Incorrect output format '{output_format}' for {cls.__name__}, expected '{cls.get_output_format()}'" + ) + if compiler_flags is None: + compiler_flags = [] + + # Create new object from class, but bypass __init__ and use _set_compile_specs instead. + compile_spec = cls.__new__(cls) + compile_spec._set_compile_specs( + tosa_spec=tosa_spec, + compiler_flags=compiler_flags, + path_for_intermediates=path_for_intermediates, + tosa_debug_mode=tosa_debug_mode, + ) + cls.from_list_hook(compile_spec, unknown_specs) + compile_spec.validate() + return compile_spec + + @classmethod + def from_list_hook(cls, compile_spec, specs: dict[str, str]): # noqa: B027 + """Allows subclasses to hook into parsing compile spec lists.""" + pass + + @abstractmethod + def validate(self): + """Throws an error if the compile spec is not valid.""" + + def to_list(self): + """Get the ArmCompileSpec in list form.""" + assert self.tosa_spec + + # Always supply a TOSA version + compile_spec = [ + CompileSpec(ArmCompileSpec._TOSA_SPEC_KEY, str(self.tosa_spec).encode()) + ] + + # Add compile flags, these are backend specific, refer to the backend + # documentation. + if len(self.compiler_flags) > 0: + compile_spec += [ + CompileSpec( + ArmCompileSpec._COMPILE_FLAGS_KEY, + " ".join(self.compiler_flags).encode(), + ), + ] + + # Add output format to identify kind of compile spec. + compile_spec.append( + CompileSpec( + ArmCompileSpec._OUTPUT_FORMAT_KEY, self.get_output_format().encode() + ) + ) + + if self.path_for_intermediates is not None: + compile_spec.append( + CompileSpec( + ArmCompileSpec._DEBUG_ARTIFACT_KEY, + self.path_for_intermediates.encode(), + ) + ) + + if self.tosa_debug_mode is not None: + if not self.path_for_intermediates: + raise ValueError( + "dump_debug_info() must be used in conjunction with dump_intermediate_artifacts_to()" + ) + + compile_spec.append( + CompileSpec( + ArmCompileSpec._DEBUG_MODE_KEY, self.tosa_debug_mode.name.encode() + ) + ) + + return compile_spec + + def get_intermediate_path(self) -> str | None: + return self.path_for_intermediates + + def dump_intermediate_artifacts_to(self, output_path: str | None): + """ + Sets a path for dumping intermediate results during such as tosa and pte. + """ + self.path_for_intermediates = output_path + return self + + def dump_debug_info(self, debug_mode: DebugMode | None): + """ + Dump debugging information into the intermediates path + """ + self.tosa_debug_mode = debug_mode + return self + + @classmethod + @abstractmethod + def get_output_format(cls) -> str: + """Returns a constant string that is the output format of the class.""" diff --git a/backends/arm/debug/schema.py b/backends/arm/debug/schema.py index 82f0fd6bf7e..46742a8ce61 100644 --- a/backends/arm/debug/schema.py +++ b/backends/arm/debug/schema.py @@ -13,7 +13,7 @@ import serializer.tosa_serializer as ts # type: ignore import torch -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from torch.fx.traceback import NodeSource @@ -112,7 +112,7 @@ def to_dict(self) -> dict[str, Any]: class DebugHook: - def __init__(self, debug_mode: ArmCompileSpecBuilder.DebugMode) -> None: + def __init__(self, debug_mode: ArmCompileSpec.DebugMode) -> None: self._debug_events: list[DebugSchema] = [] self.__op_id_to_name = {} self.mode = debug_mode @@ -126,7 +126,7 @@ def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> DebugSchema # If the debug data is being embedded into the TOSA flatbuffer # do not collect TOSADebugSchema data, it's redundent - if self.mode != ArmCompileSpecBuilder.DebugMode.TOSA: + if self.mode != ArmCompileSpec.DebugMode.TOSA: tosa_debug_info = TosaDebugSchema( node_name=str(tosa_op), operator_name=self.__op_id_to_name[tosa_op_id], diff --git a/backends/arm/ethosu/__init__.py b/backends/arm/ethosu/__init__.py index f6cc1329dfe..25a91dc5929 100644 --- a/backends/arm/ethosu/__init__.py +++ b/backends/arm/ethosu/__init__.py @@ -6,9 +6,7 @@ # pyre-unsafe from .backend import EthosUBackend # noqa: F401 +from .compile_spec import EthosUCompileSpec # noqa: F401 from .partitioner import EthosUPartitioner # noqa: F401 -__all__ = [ - "EthosUBackend", - "EthosUPartitioner", -] +__all__ = ["EthosUBackend", "EthosUPartitioner", "EthosUCompileSpec"] diff --git a/backends/arm/ethosu/compile_spec.py b/backends/arm/ethosu/compile_spec.py new file mode 100644 index 00000000000..5f3f92fdd0e --- /dev/null +++ b/backends/arm/ethosu/compile_spec.py @@ -0,0 +1,101 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec + +from executorch.backends.arm.tosa import ( # type: ignore[import-not-found] + TosaSpecification, +) + +from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found] + CompileSpec, +) + + +class EthosUCompileSpec(ArmCompileSpec): + + _TARGET_KEY = "target" + + def __init__( + self, + target: str, + system_config: str | None = None, + memory_mode: str | None = None, + extra_flags: list[str] | None = None, + config_ini: str | None = "Arm/vela.ini", + ): + """Generate compile spec for Ethos-U NPU + Args: + target: Ethos-U accelerator configuration, e.g. ethos-u55-128 + system_config: System configuration to select from the Vela + configuration file + memory_mode: Memory mode to select from the Vela configuration file + extra_flags: Extra flags for the Vela compiler + config_ini: Vela configuration file(s) in Python ConfigParser .ini + file format + """ + self.target = target + + # Set vela compiler flags + if config_ini is None: + config_ini = "Arm/vela.ini" + compiler_flags = [] if extra_flags is None else extra_flags + compiler_flags.extend( + [ + f"--accelerator-config={target}", + f"--config={config_ini}", + "--output-format=raw", + "--debug-force-regor", + ] + ) + # default system config and memory mode + if "ethos-u55" in self.target: + if system_config is None: + system_config = "Ethos_U55_High_End_Embedded" + if memory_mode is None: + memory_mode = "Shared_Sram" + elif "ethos-u85" in self.target: + if system_config is None: + system_config = "Ethos_U85_SYS_DRAM_Mid" + if memory_mode is None: + memory_mode = "Sram_Only" + else: + raise RuntimeError(f"Unknown ethos target: {self.target}") + + compiler_flags.append(f"--system-config={system_config}") + compiler_flags.append(f"--memory-mode={memory_mode}") + + # Set TOSA version. + base_tosa_version = "TOSA-1.0+INT+int16" + if "u55" in self.target: + # Add the Ethos-U55 extension marker + base_tosa_version += "+u55" + tosa_spec = TosaSpecification.create_from_string(base_tosa_version) + + self._set_compile_specs(tosa_spec, compiler_flags) + self.validate() + + def to_list(self): + compile_specs = super().to_list() + compile_specs.append(CompileSpec(self._TARGET_KEY, self.target.encode())) + return compile_specs + + @classmethod + def from_list_hook(cls, compile_spec, specs: dict[str, str]): + compile_spec.target = specs.get(cls._TARGET_KEY, None) + + def validate(self): + if len(self.compiler_flags) == 0: + raise ValueError( + "compile_flags are required in the CompileSpec list for EthosUBackend" + ) + if "u55" in self.target and not self.tosa_spec.is_U55_subset: + raise ValueError( + f"Target was {self.target} but tosa spec was not u55 subset." + ) + + @classmethod + def get_output_format(cls) -> str: + return "vela" diff --git a/backends/arm/ethosu/partitioner.py b/backends/arm/ethosu/partitioner.py index d76b29eb1d9..d2fad094c03 100644 --- a/backends/arm/ethosu/partitioner.py +++ b/backends/arm/ethosu/partitioner.py @@ -5,14 +5,10 @@ # pyre-unsafe -from typing import final, List, Optional, Sequence +from typing import final, Optional, Sequence -from executorch.backends.arm.arm_backend import ( - is_ethosu, -) # usort: skip -from executorch.backends.arm.ethosu import EthosUBackend +from executorch.backends.arm.ethosu import EthosUBackend, EthosUCompileSpec from executorch.backends.arm.tosa.partitioner import TOSAPartitioner -from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import DelegationSpec from torch.fx.passes.operator_support import OperatorSupportBase @@ -21,12 +17,12 @@ class EthosUPartitioner(TOSAPartitioner): def __init__( self, - compile_spec: List[CompileSpec], + compile_spec: EthosUCompileSpec, additional_checks: Optional[Sequence[OperatorSupportBase]] = None, ) -> None: - if not is_ethosu(compile_spec): - raise RuntimeError("compile spec is not targeting Ethos-U") - # Override the delegation spec for Ethos-U - self.delegation_spec = DelegationSpec(EthosUBackend.__name__, compile_spec) + self.delegation_spec = DelegationSpec( + EthosUBackend.__name__, compile_spec.to_list() + ) self.additional_checks = additional_checks + self.tosa_spec = compile_spec.tosa_spec diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index 54a81bdaaff..172adbc7c78 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -10,7 +10,7 @@ import torch -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.debug.schema import DebugHook from executorch.backends.arm.tosa.mapping import TosaArg from executorch.backends.arm.tosa.specification import TosaSpecification @@ -59,7 +59,7 @@ def _serialize_operator( tosa_op_id=tosa_op, ) - if self.debug_hook.mode == ArmCompileSpecBuilder.DebugMode.TOSA: + if self.debug_hook.mode == ArmCompileSpec.DebugMode.TOSA: op_location = json.dumps(debug_info.to_dict()) tosa_graph.addOperator( diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index ae7c8255428..e6240a08c8e 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -14,21 +14,17 @@ from __future__ import annotations import functools -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional import torch +from executorch.backends.arm.ethosu import EthosUCompileSpec from executorch.backends.arm.quantizer import QuantizationConfig from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.arm.tosa.specification import get_tosa_spec - -from .arm_quantizer_utils import is_annotated, mark_node_as_annotated -from .quantization_annotator import annotate_graph -from executorch.backends.arm.arm_backend import ( - is_ethosu, - is_vgf, -) # usort: skip -from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.backends.arm.common.arm_compile_spec import ( + ArmCompileSpec, +) # isort: skip +from executorch.backends.arm.vgf import VgfCompileSpec from torch.fx import GraphModule, Node from torchao.quantization.pt2e import ( @@ -49,6 +45,9 @@ Quantizer, ) +from .arm_quantizer_utils import is_annotated, mark_node_as_annotated +from .quantization_annotator import annotate_graph + __all__ = [ "TOSAQuantizer", "EthosUQuantizer", @@ -300,27 +299,16 @@ def not_module_type_or_name_filter(n: Node) -> bool: class TOSAQuantizer(Quantizer): def __init__( - self, compile_spec_or_tosa_spec: Union[TosaSpecification, List[CompileSpec]] + self, compile_spec_or_tosa_spec: TosaSpecification | ArmCompileSpec ) -> None: super().__init__() if isinstance(compile_spec_or_tosa_spec, TosaSpecification): self.tosa_spec = compile_spec_or_tosa_spec self.compile_spec = None - elif isinstance(compile_spec_or_tosa_spec, list): + elif isinstance(compile_spec_or_tosa_spec, ArmCompileSpec): self.compile_spec = compile_spec_or_tosa_spec - # find entry that is 'tosa_spec' - for cs in compile_spec_or_tosa_spec: - if cs.key == "tosa_spec": - spec_val = ( - cs.value.decode() if isinstance(cs.value, bytes) else cs.value - ) - self.tosa_spec = TosaSpecification.create_from_string(spec_val) - break - else: - raise ValueError( - "compile_spec list did not contain a 'tosa_spec' entry" - ) + self.tosa_spec = self.compile_spec.tosa_spec else: raise TypeError( f"TOSAQuantizer constructor expects " @@ -466,18 +454,10 @@ def validate(self, model: GraphModule) -> None: class EthosUQuantizer(TOSAQuantizer): - def __init__(self, compile_spec: list[CompileSpec]) -> None: - if not is_ethosu(compile_spec): - raise RuntimeError("compile spec is not targeting Ethos-U") - - tosa_spec = get_tosa_spec(compile_spec) - super().__init__(tosa_spec) + def __init__(self, compile_spec: EthosUCompileSpec) -> None: + super().__init__(compile_spec) class VgfQuantizer(TOSAQuantizer): - def __init__(self, compile_spec: list[CompileSpec]) -> None: - if not is_vgf(compile_spec): - raise RuntimeError("compile spec is not targeting VGF") - - tosa_spec = get_tosa_spec(compile_spec) - super().__init__(tosa_spec) + def __init__(self, compile_spec: VgfCompileSpec) -> None: + super().__init__(compile_spec) diff --git a/backends/arm/runtime/VelaBinStream.cpp b/backends/arm/runtime/VelaBinStream.cpp index 180219c75b5..c8d568499c9 100644 --- a/backends/arm/runtime/VelaBinStream.cpp +++ b/backends/arm/runtime/VelaBinStream.cpp @@ -6,7 +6,7 @@ */ /* - * Warning: Do not change this without changing arm_backend.py::vela_compile + * Warning: Do not change this without changing arm_vela.py::vela_compile * as that function emits this format and the two need to align. */ diff --git a/backends/arm/runtime/VelaBinStream.h b/backends/arm/runtime/VelaBinStream.h index 04b8b2ada00..7a7ea9b6266 100644 --- a/backends/arm/runtime/VelaBinStream.h +++ b/backends/arm/runtime/VelaBinStream.h @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 Arm Limited and/or its affiliates. + * Copyright 2023-2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -7,7 +7,7 @@ /* * Minimal reading function for vela_bin_stream wire format. This is an - * implementation detail of the arm_backend AoT flow and ArmBackendEthosU + * implementation detail of the arm backend AoT flow and ArmBackendEthosU * and subject to change. * This format captures the command stream, I/O and memory layout data to * enable execution of the command stream on Ethos-U hardware. diff --git a/backends/arm/scripts/TOSA_minimal_example.ipynb b/backends/arm/scripts/TOSA_minimal_example.ipynb index 785affc657b..b79780c6a07 100644 --- a/backends/arm/scripts/TOSA_minimal_example.ipynb +++ b/backends/arm/scripts/TOSA_minimal_example.ipynb @@ -86,10 +86,7 @@ "metadata": {}, "outputs": [], "source": [ - "from executorch.backends.arm.arm_backend import (\n", - " ArmCompileSpecBuilder,\n", - ")\n", - "from executorch.backends.arm.tosa.specification import TosaSpecification\n", + "from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec\n", "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", "from pathlib import Path\n", "\n", @@ -99,11 +96,7 @@ "\n", "# Create a compilation spec describing the target for configuring the quantizer\n", "# Dump intermediate artifacts (in this case TOSA flat buffers) to specified location\n", - "tosa_spec = TosaSpecification.create_from_string(target)\n", - "spec_builder = (ArmCompileSpecBuilder()\n", - " .tosa_compile_spec(tosa_spec)\n", - " .dump_intermediate_artifacts_to(str(cwd_dir / base_name)))\n", - "compile_spec = spec_builder.build()\n", + "compile_spec = TosaCompileSpec(target).dump_intermediate_artifacts_to(str(cwd_dir / base_name))\n", "\n", "_ = graph_module.print_readable()\n", "\n", @@ -130,15 +123,11 @@ "metadata": {}, "outputs": [], "source": [ - "from executorch.backends.arm.arm_backend import (\n", - " ArmCompileSpecBuilder,\n", - " get_tosa_spec,\n", - ")\n", + "from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec\n", "from executorch.backends.arm.quantizer import (\n", " TOSAQuantizer,\n", " get_symmetric_quantization_config,\n", ")\n", - "from executorch.backends.arm.tosa.specification import TosaSpecification\n", "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", "from pathlib import Path\n", "\n", @@ -148,14 +137,10 @@ "\n", "# Create a compilation spec describing the target for configuring the quantizer\n", "# Dump intermediate artifacts (in this case TOSA flat buffers) to specified location\n", - "tosa_spec = TosaSpecification.create_from_string(target)\n", - "spec_builder = (ArmCompileSpecBuilder()\n", - " .tosa_compile_spec(tosa_spec)\n", - " .dump_intermediate_artifacts_to(str(cwd_dir / base_name)))\n", - "compile_spec = spec_builder.build()\n", + "compile_spec = TosaCompileSpec(target).dump_intermediate_artifacts_to(str(cwd_dir / base_name))\n", "\n", "# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n", - "quantizer = TOSAQuantizer(get_tosa_spec(compile_spec))\n", + "quantizer = TOSAQuantizer(compile_spec)\n", "operator_config = get_symmetric_quantization_config()\n", "quantizer.set_global(operator_config)\n", "\n", diff --git a/backends/arm/test/TARGETS b/backends/arm/test/TARGETS index c27d00590f3..8ffad640d5a 100644 --- a/backends/arm/test/TARGETS +++ b/backends/arm/test/TARGETS @@ -19,7 +19,11 @@ runtime.python_library( srcs = ["runner_utils.py"], deps = [ ":conftest", - "//executorch/backends/arm:arm_backend", + "//executorch/backends/arm:arm_compile_spec", + "//executorch/backends/arm:ethosu", + "//executorch/backends/arm/tosa:compile_spec", + "//executorch/backends/arm:vgf", + "//executorch/backends/arm/tosa:specification", "//executorch/exir:lib", "//executorch/exir/backend:compile_spec_schema", ] @@ -41,10 +45,10 @@ runtime.python_library( deps = [ ":common", "//executorch/backends/xnnpack/test/tester:tester", - "//executorch/backends/arm:ethosu_partitioner", + "//executorch/backends/arm:ethosu", "//executorch/backends/arm/quantizer:lib", "//executorch/backends/arm/tosa:mapping", - "//executorch/backends/arm:vgf_partitioner", + "//executorch/backends/arm:vgf", "//executorch/devtools/backend_debug:delegation_info", "//executorch/exir/backend:operator_support", "fbsource//third-party/pypi/tabulate:tabulate", diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 608c273b2ef..963084d6091 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -13,7 +13,7 @@ from typing import Any, Optional import pytest -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder +from executorch.backends.arm.ethosu import EthosUCompileSpec from executorch.backends.arm.test.runner_utils import ( arm_executor_runner_exists, corstone300_installed, @@ -22,7 +22,8 @@ vkml_emulation_layer_installed, ) from executorch.backends.arm.tosa import TosaSpecification -from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.vgf import VgfCompileSpec def get_time_formatted_path(path: str, log_prefix: str) -> str: @@ -64,43 +65,21 @@ def maybe_get_tosa_collate_path() -> str | None: def get_tosa_compile_spec( tosa_spec: str | TosaSpecification, - custom_path: Optional[str] = None, - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, -) -> list[CompileSpec]: - """ - Default compile spec for TOSA tests. - """ - return get_tosa_compile_spec_unbuilt( - tosa_spec, - custom_path, - tosa_debug_mode, - ).build() - - -def get_tosa_compile_spec_unbuilt( - tosa_spec: str | TosaSpecification, - custom_path: Optional[str], - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode], -) -> ArmCompileSpecBuilder: - """Get the ArmCompileSpecBuilder for the default TOSA tests, to modify - the compile spec before calling .build() to finalize it. - """ + custom_path=None, + tosa_debug_mode: TosaCompileSpec.DebugMode | None = None, +) -> TosaCompileSpec: + """Get the compile spec for default TOSA tests.""" if not custom_path: custom_path = maybe_get_tosa_collate_path() - if custom_path is not None: os.makedirs(custom_path, exist_ok=True) - compile_spec_builder = ( - ArmCompileSpecBuilder() - .tosa_compile_spec(tosa_spec) + compile_spec = ( + TosaCompileSpec(tosa_spec) .dump_intermediate_artifacts_to(custom_path) + .dump_debug_info(tosa_debug_mode) ) - - if tosa_debug_mode is not None: - compile_spec_builder.dump_debug_info(tosa_debug_mode) - - return compile_spec_builder + return compile_spec def get_u55_compile_spec( @@ -109,72 +88,10 @@ def get_u55_compile_spec( memory_mode: str = "Shared_Sram", extra_flags: str = "--debug-force-regor --output-format=raw", custom_path: Optional[str] = None, - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, - config: Optional[str] = "Arm/vela.ini", -) -> list[CompileSpec]: - """ - Compile spec for Ethos-U55. - """ - return get_u55_compile_spec_unbuilt( - macs=macs, - system_config=system_config, - memory_mode=memory_mode, - extra_flags=extra_flags, - custom_path=custom_path, - tosa_debug_mode=tosa_debug_mode, - config=config, - ).build() - - -def get_u85_compile_spec( - macs: int = 128, - system_config: str = "Ethos_U85_SYS_DRAM_Mid", - memory_mode: str = "Shared_Sram", - extra_flags: str = "--output-format=raw", - custom_path: Optional[str] = None, - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, - config: Optional[str] = "Arm/vela.ini", -) -> list[CompileSpec]: - """ - Compile spec for Ethos-U85. - """ - return get_u85_compile_spec_unbuilt( # type: ignore[attr-defined] - macs=macs, - system_config=system_config, - memory_mode=memory_mode, - extra_flags=extra_flags, - custom_path=custom_path, - tosa_debug_mode=tosa_debug_mode, - config=config, - ).build() - - -def get_vgf_compile_spec( - tosa_spec: str | TosaSpecification, - compiler_flags: Optional[str] = "", - custom_path: Optional[str] = "", - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, -) -> list[CompileSpec]: - """ - Default compile spec for VGF tests. - """ - return get_vgf_compile_spec_unbuilt( - tosa_spec, compiler_flags, custom_path, tosa_debug_mode - ).build() - - -def get_u55_compile_spec_unbuilt( - macs: int, - system_config: str, - memory_mode: str, - extra_flags: str, - custom_path: Optional[str], - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode], - config: Optional[str], -) -> ArmCompileSpecBuilder: - """Get the ArmCompileSpecBuilder for the Ethos-U55 tests, to modify - the compile spec before calling .build() to finalize it. - """ + config: Optional[str] = None, + tosa_debug_mode: EthosUCompileSpec.DebugMode | None = None, +) -> EthosUCompileSpec: + """Default compile spec for Ethos-U55 tests.""" artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_u55_") if not os.path.exists(artifact_path): os.makedirs(artifact_path, exist_ok=True) @@ -182,67 +99,67 @@ def get_u55_compile_spec_unbuilt( # https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/OPTIONS.md assert macs in [32, 64, 128, 256], "Unsupported MACs value" + if extra_flags is not None: + extra_flags_list = extra_flags.split(" ") + else: + extra_flags_list = [] compile_spec = ( - ArmCompileSpecBuilder() - .ethosu_compile_spec( + EthosUCompileSpec( f"ethos-u55-{macs}", system_config=system_config, memory_mode=memory_mode, - extra_flags=extra_flags, + extra_flags=extra_flags_list, config_ini=config, ) .dump_intermediate_artifacts_to(artifact_path) + .dump_debug_info(tosa_debug_mode) ) - - if tosa_debug_mode is not None: - compile_spec.dump_debug_info(tosa_debug_mode) - return compile_spec -def get_u85_compile_spec_unbuilt( - macs: int, - system_config: str, - memory_mode: str, - extra_flags: str, - custom_path: Optional[str], - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode], - config: Optional[str], -) -> list[CompileSpec]: - """Get the ArmCompileSpecBuilder for the Ethos-U85 tests, to modify - the compile spec before calling .build() to finalize it. - """ +def get_u85_compile_spec( + macs: int = 128, + system_config="Ethos_U85_SYS_DRAM_Mid", + memory_mode="Shared_Sram", + extra_flags="--output-format=raw", + custom_path: Optional[str] = None, + config: Optional[str] = None, + tosa_debug_mode: EthosUCompileSpec.DebugMode | None = None, +) -> EthosUCompileSpec: + """Default compile spec for Ethos-U85 tests.""" + artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_u85_") if not os.path.exists(artifact_path): os.makedirs(artifact_path, exist_ok=True) assert macs in [128, 256, 512, 1024, 2048], "Unsupported MACs value" + if extra_flags is not None: + extra_flags_list = extra_flags.split(" ") + else: + extra_flags_list = [] + compile_spec = ( - ArmCompileSpecBuilder() - .ethosu_compile_spec( + EthosUCompileSpec( f"ethos-u85-{macs}", system_config=system_config, memory_mode=memory_mode, - extra_flags=extra_flags, + extra_flags=extra_flags_list, config_ini=config, ) .dump_intermediate_artifacts_to(artifact_path) + .dump_debug_info(tosa_debug_mode) ) - - if tosa_debug_mode is not None: - compile_spec.dump_debug_info(tosa_debug_mode) - return compile_spec # type: ignore[return-value] -def get_vgf_compile_spec_unbuilt( +def get_vgf_compile_spec( tosa_spec: str | TosaSpecification, - compiler_flags: Optional[str], - custom_path: Optional[str], - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode], -) -> ArmCompileSpecBuilder: - """Get the ArmCompileSpecBuilder for the default VGF tests, to modify + compiler_flags: Optional[str] = "", + custom_path=None, + tosa_debug_mode: VgfCompileSpec.DebugMode | None = None, +) -> VgfCompileSpec: + """Get the ArmCompileSpec for the default VGF tests, to modify the compile spec before calling .build() to finalize it. """ if "FP" in repr(tosa_spec): @@ -255,16 +172,18 @@ def get_vgf_compile_spec_unbuilt( if not os.path.exists(artifact_path): os.makedirs(artifact_path, exist_ok=True) - compile_spec_builder = ( - ArmCompileSpecBuilder() - .vgf_compile_spec(tosa_spec, compiler_flags) + if compiler_flags is not None: + compiler_flags_list = compiler_flags.split(" ") + else: + compiler_flags_list = [] + + compile_spec = ( + VgfCompileSpec(tosa_spec, compiler_flags_list) .dump_intermediate_artifacts_to(artifact_path) + .dump_debug_info(tosa_debug_mode) ) - if tosa_debug_mode is not None: - compile_spec_builder.dump_debug_info(tosa_debug_mode) - - return compile_spec_builder + return compile_spec XfailIfNoCorstone300 = pytest.mark.xfail( diff --git a/backends/arm/test/misc/test_compile_spec.py b/backends/arm/test/misc/test_compile_spec.py new file mode 100644 index 00000000000..a1b42cd22b5 --- /dev/null +++ b/backends/arm/test/misc/test_compile_spec.py @@ -0,0 +1,50 @@ +from executorch.backends.arm.ethosu import EthosUCompileSpec +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.vgf import VgfCompileSpec +from pytest import raises + + +def test_ethos_u_compile_spec(): + compile_spec = ( + EthosUCompileSpec("ethos-u55", extra_flags=["--my-flag"]) + .dump_intermediate_artifacts_to("my_path") + .dump_debug_info(EthosUCompileSpec.DebugMode.TOSA) + ) + spec_list = compile_spec.to_list() + + assert EthosUCompileSpec.from_list(spec_list) == compile_spec + assert "--my-flag" in compile_spec.compiler_flags + assert "--output-format=raw" in compile_spec.compiler_flags + with raises(ValueError, match="Incorrect output format"): + VgfCompileSpec.from_list(spec_list) + + spec_list.pop(0) + with raises(ValueError, match="No tosa_spec in compile spec."): + EthosUCompileSpec.from_list(spec_list) + + +def test_vgf_compile_spec(): + compile_spec = ( + VgfCompileSpec(compiler_flags=["--my-flag"]) + .dump_intermediate_artifacts_to("my_path") + .dump_debug_info(None) + ) + compile_spec2 = VgfCompileSpec( + compiler_flags=["--my-flag2"] + ).dump_intermediate_artifacts_to("my_path") + + spec_list = compile_spec.to_list() + + assert VgfCompileSpec.from_list(spec_list) == compile_spec + assert VgfCompileSpec.from_list(spec_list) != compile_spec2 + with raises(ValueError, match="Incorrect output format"): + EthosUCompileSpec.from_list(spec_list) + + +def test_tosa_compile_spec(): + compile_spec = TosaCompileSpec("TOSA-1.0+INT") + spec_list = compile_spec.to_list() + + assert TosaCompileSpec.from_list(spec_list) == compile_spec + with raises(ValueError, match="Incorrect output format"): + VgfCompileSpec.from_list(spec_list) diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index 3e10a9336f9..3796d3dce4a 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -14,7 +14,7 @@ import pytest import torch -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, @@ -200,7 +200,7 @@ def test_dump_tosa_debug_json(test_data: input_t1): aten_op=[], exir_op=[], custom_path=tmpdir, - tosa_debug_mode=ArmCompileSpecBuilder.DebugMode.JSON, + tosa_debug_mode=ArmCompileSpec.DebugMode.JSON, ) pipeline.pop_stage("run_method_and_compare_outputs") @@ -231,7 +231,7 @@ def test_dump_tosa_debug_tosa(test_data: input_t1): aten_op=[], exir_op=[], custom_path=tmpdir, - tosa_debug_mode=ArmCompileSpecBuilder.DebugMode.TOSA, + tosa_debug_mode=ArmCompileSpec.DebugMode.TOSA, ) pipeline.pop_stage("run_method_and_compare_outputs") diff --git a/backends/arm/test/misc/test_debug_hook.py b/backends/arm/test/misc/test_debug_hook.py index 935f3984403..376c65ff093 100644 --- a/backends/arm/test/misc/test_debug_hook.py +++ b/backends/arm/test/misc/test_debug_hook.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from types import SimpleNamespace -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.debug.schema import DebugHook, DebugSchema from executorch.backends.arm.test import common @@ -158,7 +158,7 @@ def _compare_node_and_schema(debug_event: DebugSchema, mocked_node): @common.parametrize("test_data", TESTCASES) def test_debug_hook_add_json(test_data: DebugHookTestCase): - hook = DebugHook(ArmCompileSpecBuilder.DebugMode.JSON) + hook = DebugHook(ArmCompileSpec.DebugMode.JSON) hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id) debug_events = hook._debug_events @@ -171,7 +171,7 @@ def test_debug_hook_add_json(test_data: DebugHookTestCase): @common.parametrize("test_data", TESTCASES) def test_debug_hook_add_tosa(test_data: DebugHookTestCase): - hook = DebugHook(ArmCompileSpecBuilder.DebugMode.TOSA) + hook = DebugHook(ArmCompileSpec.DebugMode.TOSA) hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id) debug_events = hook._debug_events diff --git a/backends/arm/test/misc/test_extract_io_params_tosa.py b/backends/arm/test/misc/test_extract_io_params_tosa.py index da471b0bb74..90104c54899 100644 --- a/backends/arm/test/misc/test_extract_io_params_tosa.py +++ b/backends/arm/test/misc/test_extract_io_params_tosa.py @@ -7,7 +7,6 @@ import pytest import torch -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.quantizer import VgfQuantizer from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_quantization_config, @@ -15,9 +14,9 @@ ) from executorch.backends.arm.test.common import SkipIfNoModelConverter -from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.tosa.partitioner import TOSAPartitioner -from executorch.backends.arm.vgf import VgfPartitioner +from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner from executorch.exir import to_edge_transform_and_lower from executorch.exir.passes.quantize_io_pass import extract_io_quant_params from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -29,11 +28,11 @@ def forward(self, x, y): @pytest.mark.parametrize( - "builder_method, quantizer_cls, partitioner_cls", + "compile_spec_cls, quantizer_cls, partitioner_cls", [ - ("tosa_compile_spec", TOSAQuantizer, TOSAPartitioner), + (TosaCompileSpec, TOSAQuantizer, TOSAPartitioner), pytest.param( - "vgf_compile_spec", + VgfCompileSpec, VgfQuantizer, VgfPartitioner, marks=SkipIfNoModelConverter, @@ -41,7 +40,11 @@ def forward(self, x, y): ), ], ) -def test_roundtrip_extracts_io_params(builder_method, quantizer_cls, partitioner_cls): +def test_roundtrip_extracts_io_params( + compile_spec_cls: type[TosaCompileSpec] | type[VgfCompileSpec], + quantizer_cls, + partitioner_cls, +): """ Validates that IO quantization parameters round-trip for both flows. """ @@ -51,10 +54,7 @@ def test_roundtrip_extracts_io_params(builder_method, quantizer_cls, partitioner ) mod = SimpleAdd().eval() - base_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") - compile_spec = getattr(ArmCompileSpecBuilder(), builder_method)( - tosa_spec=base_spec - ).build() + compile_spec = compile_spec_cls("TOSA-1.0+INT") quantizer = quantizer_cls(compile_spec) operator_config = get_symmetric_quantization_config(is_qat=True) diff --git a/backends/arm/test/misc/test_outputs_order.py b/backends/arm/test/misc/test_outputs_order.py index 43d35b6d13c..ff02ffc360a 100644 --- a/backends/arm/test/misc/test_outputs_order.py +++ b/backends/arm/test/misc/test_outputs_order.py @@ -9,11 +9,11 @@ import pytest import torch -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_quantization_config, TOSAQuantizer, ) +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.tosa.partitioner import TOSAPartitioner from executorch.backends.arm.tosa.specification import TosaSpecification from executorch.exir import to_edge_transform_and_lower @@ -81,7 +81,7 @@ def test_network_output_order_and_restore(tmp_path, batch_size): model = Network(batch_norm=True).eval() # Prepare spec spec = TosaSpecification.create_from_string("TOSA-1.0+INT") - compile_spec = ArmCompileSpecBuilder().tosa_compile_spec(tosa_spec=spec).build() + compile_spec = TosaCompileSpec(tosa_spec=spec) # Setup quantizer quantizer = TOSAQuantizer(compile_spec) quantizer.set_global( @@ -89,7 +89,7 @@ def test_network_output_order_and_restore(tmp_path, batch_size): ) # Trace the model dummy = torch.randn(batch_size, 1, 28, 28) - fx_mod = torch.export.export_for_training(model, (dummy,)).module() + fx_mod = torch.export.export(model, (dummy,)).module() model = prepare_pt2e(fx_mod, quantizer) model(dummy) model = convert_pt2e(model) @@ -98,10 +98,7 @@ def test_network_output_order_and_restore(tmp_path, batch_size): with tempfile.TemporaryDirectory() as tmpdir: art_dir = Path(tmpdir) part = TOSAPartitioner( - ArmCompileSpecBuilder() - .tosa_compile_spec(spec) - .dump_intermediate_artifacts_to(str(art_dir)) - .build() + TosaCompileSpec(spec).dump_intermediate_artifacts_to(str(art_dir)) ) _ = to_edge_transform_and_lower(aten_gm, partitioner=[part]) # Expect exactly one .tosa file in the artefact dir diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 2eabd302df6..24fdfbb5457 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -5,7 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Tuple import pytest import torch @@ -23,7 +23,6 @@ VgfPipeline, ) from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.arm.tosa.specification import get_tosa_spec from executorch.backends.xnnpack.test.tester import Quantize from torchao.quantization.pt2e import HistogramObserver from torchao.quantization.pt2e.quantizer import QuantizationSpec @@ -103,14 +102,13 @@ def test_add_tensor_tosa_INT(test_data: input_t1): @common.parametrize("test_data", Add.test_data) def test_add_tensor_tosa_INT_i32(test_data: input_t1): pipeline = TosaPipelineINT[input_t1](Add(), test_data(), aten_op, exir_op) - tosa_version = conftest.get_option("tosa_version") + tosa_version = cast(str, conftest.get_option("tosa_version")) tosa_profiles = { "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT"), } # Create a quantizer with int8 quantization on the input and output but int32 on everything else. - quantizer = arm_quantizer.TOSAQuantizer( - get_tosa_spec(common.get_tosa_compile_spec(tosa_profiles[tosa_version])) - ) + quantizer = arm_quantizer.TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_io(arm_quantizer.get_symmetric_quantization_config()) observer_options = {"eps": 2**-16} observer = HistogramObserver.with_args(**observer_options) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index aeb0e3a56bd..1b59b186a2e 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -17,16 +17,14 @@ import numpy as np import torch +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec -from executorch.backends.arm.arm_backend import is_tosa, is_vgf +from executorch.backends.arm.ethosu import EthosUCompileSpec from executorch.backends.arm.test.conftest import is_option_enabled -from executorch.backends.arm.tosa.specification import ( - get_tosa_spec, - Tosa_1_00, - TosaSpecification, -) +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification +from executorch.backends.arm.vgf import VgfCompileSpec from executorch.exir import ExecutorchProgramManager, ExportedProgram -from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.lowered_backend_module import LoweredBackendModule from torch.fx.node import Node @@ -168,14 +166,9 @@ def __init__(self): def _tosa_dispatch(self, lowered_backend_module: LoweredBackendModule, inputs): tosa_buffer = lowered_backend_module.processed_bytes - compile_specs = lowered_backend_module.compile_specs - if not is_tosa(compile_specs): - raise RuntimeError( - "Model needs to be compiled to tosa to run reference model." - ) - tosa_spec = get_tosa_spec(compile_specs) + compile_spec = TosaCompileSpec.from_list(lowered_backend_module.compile_specs) - return run_tosa_graph(tosa_buffer, tosa_spec, inputs) + return run_tosa_graph(tosa_buffer, compile_spec.tosa_spec, inputs) def __exit__(self, exc_type, exc_val, exc_tb): super().__exit__(exc_type, exc_val, exc_tb) @@ -725,14 +718,12 @@ def run_tosa_graph( return [torch.from_numpy(output) for output in outputs_np] -def get_target_board(compile_spec: list[CompileSpec]) -> str | None: - if is_vgf(compile_spec): +def get_target_board(compile_spec: ArmCompileSpec) -> str | None: + if isinstance(compile_spec, VgfCompileSpec): return "vkml_emulation_layer" - for spec in compile_spec: - if spec.key == "compile_flags": - flags = spec.value.decode() - if "u55" in flags: - return "corstone-300" - elif "u85" in flags: - return "corstone-320" + if isinstance(compile_spec, EthosUCompileSpec): + if "u55" in compile_spec.target: + return "corstone-300" + if "u85" in compile_spec.target: + return "corstone-320" return None diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index a6181cf34ce..62bc5aef57a 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -31,6 +31,17 @@ def define_arm_tests(): "quantizer/test_generic_annotater.py", ] + # Misc tests + test_files += [ + "misc/test_compile_spec.py", + "misc/test_tosa_spec.py", + "misc/test_bn_relu_folding_qat.py", + "misc/test_custom_partition.py", + "misc/test_debug_hook.py", + "misc/test_dim_order_guards.py", + "misc/test_outputs_order.py", + ] + TESTS = {} for test_file in test_files: @@ -50,6 +61,10 @@ def define_arm_tests(): deps = [ "//executorch/backends/arm/test:arm_tester", "//executorch/backends/arm/test:conftest", + "//executorch/backends/arm:ethosu", + "//executorch/backends/arm/tosa:compile_spec", + "//executorch/backends/arm/tosa:partitioner", + "//executorch/backends/arm:vgf", "//executorch/exir:lib", "fbsource//third-party/pypi/pytest:pytest", "fbsource//third-party/pypi/parameterized:parameterized", diff --git a/backends/arm/test/tester/analyze_output_utils.py b/backends/arm/test/tester/analyze_output_utils.py index 82d4f5d9837..c707eed8013 100644 --- a/backends/arm/test/tester/analyze_output_utils.py +++ b/backends/arm/test/tester/analyze_output_utils.py @@ -7,7 +7,6 @@ import tempfile import torch -from executorch.backends.arm.arm_backend import get_intermediate_path from executorch.backends.arm.test.runner_utils import ( get_input_quantization_params, get_output_quantization_params, @@ -245,7 +244,7 @@ def dump_error_output( # Capture assertion error and print more info banner = "=" * 40 + "TOSA debug info" + "=" * 40 logger.error(banner) - path_to_tosa_files = get_intermediate_path(tester.compile_spec) + path_to_tosa_files = tester.compile_spec.get_intermediate_path() if path_to_tosa_files is None: path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_") diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index fe17bd3f448..284d4d6d1c4 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -32,13 +32,8 @@ from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager -from executorch.backends.arm.arm_backend import ( - get_intermediate_path, - is_ethosu, - is_tosa, - is_vgf, -) -from executorch.backends.arm.ethosu import EthosUPartitioner +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner from executorch.backends.arm.quantizer import ( EthosUQuantizer, get_symmetric_quantization_config, @@ -59,11 +54,11 @@ print_error_diffs, ) from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.tosa.mapping import extract_tensor_meta from executorch.backends.arm.tosa.partitioner import TOSAPartitioner -from executorch.backends.arm.tosa.specification import get_tosa_spec -from executorch.backends.arm.vgf import VgfPartitioner +from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner from executorch.backends.test.harness.stages import Stage, StageType from executorch.backends.xnnpack.test.tester import Tester @@ -77,7 +72,6 @@ to_edge_transform_and_lower, ) from executorch.exir.backend.backend_api import validation_disabled -from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.operator_support import ( DontPartition, DontPartitionModule, @@ -131,7 +125,7 @@ def get_output_format(lowered_module) -> str | None: to_print = dbg_tosa_fb_to_json(tosa_fb) to_print = pformat(to_print, compact=True, indent=1) output += f"\nTOSA deserialized {node.name}: \n{to_print}\n" - elif output_format == "vela": + elif output_format == EthosUCompileSpec.get_output_format(): vela_cmd_stream = lowered_module.processed_bytes output += f"\nVela command stream {node.name}: \n{vela_cmd_stream}\n" else: @@ -186,7 +180,7 @@ def run( class Serialize(tester.Serialize): - def __init__(self, compile_spec: list[CompileSpec], timeout): + def __init__(self, compile_spec: ArmCompileSpec, timeout): super().__init__() self.timeout = timeout self.executorch_program_manager: ExecutorchProgramManager | None @@ -203,7 +197,7 @@ def run_artifact(self, inputs): "Tried running artifact from Serialize stage without running the stage." ) inputs_flattened, _ = tree_flatten(inputs) - intermediate_path = get_intermediate_path(self.compile_spec) + intermediate_path = self.compile_spec.get_intermediate_path() target_board = get_target_board(self.compile_spec) elf_path = get_elf_path(target_board) @@ -297,7 +291,7 @@ def __init__( self, model: torch.nn.Module, example_inputs: Tuple, - compile_spec: List[CompileSpec], + compile_spec: ArmCompileSpec, tosa_ref_model_path: str | None = None, dynamic_shapes: Optional[Tuple[Any]] = None, constant_methods: Optional[Dict[str, Any]] = None, @@ -331,12 +325,11 @@ def quantize( ): if quantize_stage is None: quantizer = None - if is_tosa(self.compile_spec): - tosa_spec = get_tosa_spec(self.compile_spec) - quantizer = TOSAQuantizer(tosa_spec) - elif is_ethosu(self.compile_spec): + if isinstance(self.compile_spec, TosaCompileSpec): + quantizer = TOSAQuantizer(self.compile_spec) + elif isinstance(self.compile_spec, EthosUCompileSpec): quantizer = EthosUQuantizer(self.compile_spec) - elif is_vgf(self.compile_spec): + elif isinstance(self.compile_spec, VgfCompileSpec): quantizer = VgfQuantizer(self.compile_spec) quantize_stage = tester.Quantize( quantizer, @@ -359,10 +352,12 @@ def to_edge( def partition(self, partition_stage: Optional[Partition] = None): if partition_stage is None: - if is_tosa(self.compile_spec): - arm_partitioner = TOSAPartitioner(compile_spec=self.compile_spec) - elif is_ethosu(self.compile_spec): - arm_partitioner = EthosUPartitioner(compile_spec=self.compile_spec) + if isinstance(self.compile_spec, TosaCompileSpec): + arm_partitioner = TOSAPartitioner(self.compile_spec) + elif isinstance(self.compile_spec, EthosUCompileSpec): + arm_partitioner = EthosUPartitioner(self.compile_spec) + elif isinstance(self.compile_spec, VgfCompileSpec): + arm_partitioner = VgfPartitioner(self.compile_spec) else: raise ValueError("compile spec doesn't target any Arm Partitioner") partition_stage = Partition(arm_partitioner) @@ -380,23 +375,24 @@ def to_edge_transform_and_lower( Union[Sequence[PassType], Dict[str, Sequence[PassType]]] ] = None, ): + if transform_passes is not None: + raise RuntimeError( + "transform passes are given to ArmTester at construction." + ) + if to_edge_and_lower_stage is None: if partitioners is None: - arm_partitioner = None - if is_tosa(self.compile_spec): + if isinstance(self.compile_spec, TosaCompileSpec): arm_partitioner = TOSAPartitioner( - compile_spec=self.compile_spec, - additional_checks=additional_checks, + self.compile_spec, additional_checks ) - elif is_ethosu(self.compile_spec): + elif isinstance(self.compile_spec, EthosUCompileSpec): arm_partitioner = EthosUPartitioner( - compile_spec=self.compile_spec, - additional_checks=additional_checks, + self.compile_spec, additional_checks ) - elif is_vgf(self.compile_spec): + elif isinstance(self.compile_spec, VgfCompileSpec): arm_partitioner = VgfPartitioner( - compile_spec=self.compile_spec, - additional_checks=additional_checks, + self.compile_spec, additional_checks ) else: raise ValueError("compile spec doesn't target any Arm Partitioner") @@ -425,7 +421,7 @@ def serialize( if serialize_stage is None: serialize_stage = Serialize(self.compile_spec, timeout) assert ( - get_intermediate_path(self.compile_spec) is not None + self.compile_spec.get_intermediate_path() is not None ), "Can't dump serialized file when compile specs do not contain an artifact path." return super().serialize(serialize_stage) @@ -621,7 +617,7 @@ def dump_dtype_distribution( to_print = f"{line} {self.cur} Placeholder Dtype Distribution {line}\n" graph = self.get_graph(self.cur) - tosa_spec = get_tosa_spec(self.compile_spec) + tosa_spec = self.compile_spec.tosa_spec dtype_dist_placeholders, dtype_dirst_tensors = _get_dtype_distribution( graph, tosa_spec ) @@ -668,7 +664,7 @@ def run_transform_for_annotation_pipeline( # We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run. artifact = self.get_artifact(stage) if self.cur == StageType.EXPORT: - new_gm = ArmPassManager(get_tosa_spec(self.compile_spec)).transform_for_annotation_pipeline( # type: ignore[arg-type] + new_gm = ArmPassManager(self.compile_spec.tosa_spec).transform_for_annotation_pipeline( # type: ignore[arg-type] graph_module=artifact.graph_module ) else: @@ -784,7 +780,7 @@ def _get_tosa_operator_distribution( [operator["op"] for operator in block["operators"]] ) break - elif spec.value == b"vela": + elif spec.value == EthosUCompileSpec.get_output_format().encode(): return "Can not get operator distribution for Vela command stream." else: return f"Unknown output format '{spec.value}'." diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 102ccd209e9..123c1af44c3 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -21,8 +21,8 @@ ) import torch +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.quantizer import ( EthosUQuantizer, get_symmetric_quantization_config, @@ -37,7 +37,6 @@ ) from executorch.backends.xnnpack.test.tester.tester import Quantize -from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.pass_base import ExportPass from torch._export.pass_base import PassType @@ -104,7 +103,7 @@ def __init__( module: torch.nn.Module, test_data: T, aten_ops: str | List[str], - compile_spec: List[CompileSpec], + compile_spec: ArmCompileSpec, exir_ops: Optional[str | List[str]] = None, use_to_edge_transform_and_lower: bool = True, dynamic_shapes: Optional[Tuple[Any]] = None, @@ -340,7 +339,7 @@ def __init__( per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, + tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, @@ -445,7 +444,7 @@ def __init__( run_on_tosa_ref_model: bool = True, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, + tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 0, @@ -526,7 +525,7 @@ def __init__( per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, + tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, @@ -617,7 +616,7 @@ def __init__( per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, + tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, @@ -930,7 +929,7 @@ def __init__( per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, + tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, diff --git a/backends/arm/tosa/TARGETS b/backends/arm/tosa/TARGETS index 18868054259..df32689bc3e 100644 --- a/backends/arm/tosa/TARGETS +++ b/backends/arm/tosa/TARGETS @@ -61,13 +61,25 @@ runtime.python_library( ) runtime.python_library( - name = "arm_partitioner", + name = "compile_spec", + srcs = [ + "compile_spec.py", + ], + deps = [ + ":tosa", + ":specification", + "//executorch/backends/arm:arm_compile_spec", + ], +) + +runtime.python_library( + name = "partitioner", srcs = [ "backend.py", "partitioner.py", ], deps = [ - "//executorch/backends/arm:arm_backend", + ":compile_spec", "//executorch/backends/arm:constants", "//executorch/backends/arm:process_node", "//executorch/backends/arm/debug:schema", diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index ce2b7a27487..08b0d55aaeb 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -16,7 +16,7 @@ from typing import cast, Dict, final, List, Set import serializer.tosa_serializer as ts # type: ignore -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.common.debug import debug_fail, debug_tosa_dump from executorch.backends.arm.debug.schema import DebugHook from executorch.backends.arm.process_node import ( @@ -132,7 +132,7 @@ def preprocess( # noqa: C901 debug_hook = None if dump_debug_info is not None: - debug_hook = DebugHook(ArmCompileSpecBuilder.DebugMode[dump_debug_info]) + debug_hook = DebugHook(ArmCompileSpec.DebugMode[dump_debug_info]) # TODO: Fix the need to lazily import this. from executorch.backends.arm.operators.node_visitor import get_node_visitors @@ -192,7 +192,7 @@ def _sort_key(t: Node) -> int: ) if debug_hook is not None: - if debug_hook.mode == ArmCompileSpecBuilder.DebugMode.JSON: + if debug_hook.mode == ArmCompileSpec.DebugMode.JSON: json_output = debug_hook.serialize() with open(f"{artifact_path}/debug.json", "w") as f: f.write(json_output) diff --git a/backends/arm/tosa/compile_spec.py b/backends/arm/tosa/compile_spec.py new file mode 100644 index 00000000000..39403c867d7 --- /dev/null +++ b/backends/arm/tosa/compile_spec.py @@ -0,0 +1,25 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.tosa import TosaSpecification + + +class TosaCompileSpec(ArmCompileSpec): + def __init__(self, tosa_spec: TosaSpecification | str): + if isinstance(tosa_spec, str): + tosa_spec = TosaSpecification.create_from_string(tosa_spec) + self._set_compile_specs(tosa_spec, []) + + def validate(self): + if len(self.compiler_flags) != 0: + raise ValueError( + f"TosaCompileSpec can't have compiler flags, got {self.compiler_flags}" + ) + pass + + @classmethod + def get_output_format(cls) -> str: + return "tosa" diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index c0f546fe50a..ab381470968 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -18,8 +18,7 @@ tosa_support_factory, ) from executorch.backends.arm.tosa.backend import TOSABackend -from executorch.backends.arm.tosa.specification import get_tosa_spec -from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.exir.backend.partitioner import ( DelegationSpec, Partitioner, @@ -38,7 +37,7 @@ def is_noop_clone(node: torch.fx.node.Node) -> bool: return node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default -def is_noop_alias_copy(node: torch.fx.node.Node) -> bool: +def is_noop_alias_copy(node: torch.fx.Node) -> bool: return node.target == exir_ops.edge.aten.alias_copy.default @@ -60,15 +59,14 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool: class TOSAPartitioner(Partitioner): def __init__( self, - compile_spec: List[CompileSpec], + compile_spec: TosaCompileSpec, additional_checks: Optional[Sequence[OperatorSupportBase]] = None, ) -> None: - from executorch.backends.arm.arm_backend import is_tosa - - if not is_tosa(compile_spec): - raise RuntimeError("compile spec is not targeting TOSA") - self.delegation_spec = DelegationSpec(TOSABackend.__name__, compile_spec) + self.delegation_spec = DelegationSpec( + TOSABackend.__name__, compile_spec.to_list() + ) self.additional_checks = additional_checks + self.tosa_spec = compile_spec.tosa_spec def partition(self, exported_program: ExportedProgram) -> PartitionResult: # noqa # Run the CapabilityBasedPartitioner to return the largest possible @@ -77,7 +75,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # no logger.info("TOSAPartitioner::partition") partition_tags: dict[str, DelegationSpec] = {} - tosa_spec = get_tosa_spec(self.delegation_spec.compile_specs) + tosa_spec = self.tosa_spec logger.info(f"Partitioning for {self.delegation_spec.backend_id}: {tosa_spec}") @@ -215,7 +213,7 @@ def filter_fn(node: torch.fx.Node) -> bool: torch.ops.aten.logit.default, ] + ops_to_not_decompose_if_quant_op - tosa_spec = get_tosa_spec(self.delegation_spec.compile_specs) + tosa_spec = self.tosa_spec if not tosa_spec.is_U55_subset: # Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d # and upsample_nearest2d decompose into that it will not be possible to diff --git a/backends/arm/vgf/__init__.py b/backends/arm/vgf/__init__.py index 4ab8144cbd6..f4ce8f5d1a4 100644 --- a/backends/arm/vgf/__init__.py +++ b/backends/arm/vgf/__init__.py @@ -6,9 +6,7 @@ # pyre-unsafe from .backend import VgfBackend # noqa: F401 +from .compile_spec import VgfCompileSpec # noqa: F401 from .partitioner import VgfPartitioner # noqa: F401 -__all__ = [ - "VgfBackend", - "VgfPartitioner", -] +__all__ = ["VgfBackend", "VgfPartitioner", "VgfCompileSpec"] diff --git a/backends/arm/vgf/compile_spec.py b/backends/arm/vgf/compile_spec.py new file mode 100644 index 00000000000..452ea5c1956 --- /dev/null +++ b/backends/arm/vgf/compile_spec.py @@ -0,0 +1,66 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.tosa import ( # type: ignore[import-not-found] + TosaSpecification, +) + +# debug functionality +logger = logging.getLogger(__name__) + + +class VgfCompileSpec(ArmCompileSpec): + + def __init__( + self, + tosa_spec: TosaSpecification | str | None = None, + compiler_flags: list[str] | None = None, + ): + """ + Generate compile spec for VGF compatible targets + + Args: + compiler_flags: Extra compiler flags for converter_backend + """ + + if tosa_spec is None: + tosa_spec = "TOSA-1.0+FP" + if isinstance(tosa_spec, str): + tosa_spec = TosaSpecification.create_from_string(tosa_spec) + + if compiler_flags is None: + compiler_flags = [] + self._set_compile_specs(tosa_spec, compiler_flags) + self.validate() + + def validate(self): + """Throws an error if the compile spec is not valid.""" + tosa_version = self.tosa_spec.version # type: ignore[attr-defined] + tosa_profiles = self.tosa_spec.profiles # type: ignore[attr-defined] + + if tosa_version.major != 1: + raise ValueError( + "Arm backend only supports converter-backend for TOSA version 1. " + f"Invalid TOSA version: {tosa_version}" + ) + + if "FP" not in tosa_profiles and "INT" not in tosa_profiles: + raise ValueError( + "Arm backend only supports converter-backend for FP or INT. " + f"Invalid TOSA profile: {tosa_profiles}" + ) + + if len(tosa_profiles) != 1: + raise ValueError( + "For now Arm backend only supports converter-backend for either FP or INT. " + f"Invalid TOSA profile: {tosa_profiles}" + ) + + @classmethod + def get_output_format(cls) -> str: + return "vgf" diff --git a/backends/arm/vgf/partitioner.py b/backends/arm/vgf/partitioner.py index f6dab597487..ea10730e810 100644 --- a/backends/arm/vgf/partitioner.py +++ b/backends/arm/vgf/partitioner.py @@ -5,14 +5,10 @@ # pyre-unsafe -from typing import final, List, Optional, Sequence +from typing import final, Optional, Sequence -from executorch.backends.arm.arm_backend import ( - is_vgf, -) # usort: skip from executorch.backends.arm.tosa.partitioner import TOSAPartitioner -from executorch.backends.arm.vgf import VgfBackend -from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.backends.arm.vgf import VgfBackend, VgfCompileSpec from executorch.exir.backend.partitioner import DelegationSpec from torch.fx.passes.operator_support import OperatorSupportBase @@ -21,12 +17,12 @@ class VgfPartitioner(TOSAPartitioner): def __init__( self, - compile_spec: List[CompileSpec], + compile_spec: VgfCompileSpec, additional_checks: Optional[Sequence[OperatorSupportBase]] = None, ) -> None: - if not is_vgf(compile_spec): - raise RuntimeError("compile spec is not targeting Vgf") - # Override the delegation spec for Vgf - self.delegation_spec = DelegationSpec(VgfBackend.__name__, compile_spec) + self.delegation_spec = DelegationSpec( + VgfBackend.__name__, compile_spec.to_list() + ) self.additional_checks = additional_checks + self.tosa_spec = compile_spec.tosa_spec diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index d7e1b64e3ca..8132751f6f0 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -18,13 +18,7 @@ import torch from examples.devtools.scripts.export_bundled_program import save_bundled_program -from executorch.backends.arm.arm_backend import ( - ArmCompileSpecBuilder, - is_ethosu, - is_tosa, - is_vgf, -) -from executorch.backends.arm.ethosu import EthosUPartitioner +from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner from executorch.backends.arm.quantizer import ( EthosUQuantizer, get_symmetric_quantization_config, @@ -32,15 +26,15 @@ VgfQuantizer, ) from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.tosa.partitioner import TOSAPartitioner -from executorch.backends.arm.tosa.specification import get_tosa_spec from executorch.backends.arm.util.arm_model_evaluator import ( GenericModelEvaluator, MobileNetV2Evaluator, ) -from executorch.backends.arm.vgf import VgfPartitioner +from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner # To use Cortex-M backend from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import ( @@ -60,7 +54,6 @@ ExecutorchBackendConfig, to_edge_transform_and_lower, ) -from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.extension.export_util.utils import save_pte_program from tabulate import tabulate from torch.utils.data import DataLoader @@ -149,7 +142,7 @@ def get_model_and_inputs_from_name( def quantize( model: torch.nn.Module, model_name: str, - compile_specs: list[CompileSpec], + compile_specs: EthosUCompileSpec | VgfCompileSpec | TosaCompileSpec, example_inputs: Tuple[torch.Tensor], evaluator_name: str | None, evaluator_config: Dict[str, Any] | None, @@ -158,11 +151,11 @@ def quantize( logging.info("Quantizing Model...") logging.debug(f"Original model: {model}") quantizer = None - if is_ethosu(compile_specs): + if isinstance(compile_specs, EthosUCompileSpec): quantizer = EthosUQuantizer(compile_specs) - elif is_tosa(compile_specs): - quantizer = TOSAQuantizer(get_tosa_spec(compile_specs)) - elif is_vgf(compile_specs): + elif isinstance(compile_specs, TosaCompileSpec): + quantizer = TOSAQuantizer(compile_specs) + elif isinstance(compile_specs, VgfCompileSpec): quantizer = VgfQuantizer(compile_specs) else: raise RuntimeError("Unsupported compilespecs for quantization!") @@ -393,20 +386,20 @@ def get_compile_spec( memory_mode: Optional[str] = None, quantize: bool = False, config: Optional[str] = None, -) -> list[CompileSpec]: - spec_builder = None +) -> TosaCompileSpec | EthosUCompileSpec | VgfCompileSpec: + compile_spec = None if target.startswith("TOSA"): try: tosa_spec = TosaSpecification.create_from_string(target) - except: + except Exception: tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") - spec_builder = ArmCompileSpecBuilder().tosa_compile_spec(tosa_spec) + compile_spec = TosaCompileSpec(tosa_spec) elif "ethos-u" in target: - spec_builder = ArmCompileSpecBuilder().ethosu_compile_spec( + compile_spec = EthosUCompileSpec( target, system_config=system_config, memory_mode=memory_mode, - extra_flags="--verbose-operators --verbose-cycle-estimate", + extra_flags=["--verbose-operators", "--verbose-cycle-estimate"], config_ini=config, ) elif "vgf" in target: @@ -414,12 +407,14 @@ def get_compile_spec( tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") else: tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP") - spec_builder = ArmCompileSpecBuilder().vgf_compile_spec(tosa_spec) + compile_spec = VgfCompileSpec(tosa_spec) + else: + raise RuntimeError(f"Unkown target {target}") if intermediates is not None: - spec_builder.dump_intermediate_artifacts_to(intermediates) + compile_spec.dump_intermediate_artifacts_to(intermediates) - return spec_builder.build() + return compile_spec def evaluate_model( @@ -749,11 +744,11 @@ def to_edge_TOSA_delegate( ) model = model_int8 - if is_ethosu(compile_spec): + if isinstance(compile_spec, EthosUCompileSpec): partitioner = EthosUPartitioner(compile_spec) - elif is_tosa(compile_spec): + elif isinstance(compile_spec, TosaCompileSpec): partitioner = TOSAPartitioner(compile_spec) - elif is_vgf(compile_spec): + elif isinstance(compile_spec, VgfCompileSpec): partitioner = VgfPartitioner(compile_spec) else: raise RuntimeError(f"Unhandled compile spec: {compile_spec}") diff --git a/examples/arm/ethos_u_minimal_example.ipynb b/examples/arm/ethos_u_minimal_example.ipynb index e63a7d37e58..dc8ea7193aa 100644 --- a/examples/arm/ethos_u_minimal_example.ipynb +++ b/examples/arm/ethos_u_minimal_example.ipynb @@ -80,7 +80,7 @@ "metadata": {}, "outputs": [], "source": [ - "from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder\n", + "from executorch.backends.arm.ethosu import EthosUCompileSpec\n", "from executorch.backends.arm.quantizer import (\n", " EthosUQuantizer,\n", " get_symmetric_quantization_config,\n", @@ -90,13 +90,12 @@ "# Create a compilation spec describing the target for configuring the quantizer\n", "# Some args are used by the Arm Vela graph compiler later in the example. Refer to Arm Vela documentation for an\n", "# explanation of its flags: https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/OPTIONS.md\n", - "spec_builder = ArmCompileSpecBuilder().ethosu_compile_spec(\n", + "compile_spec = EthosUCompileSpec(\n", " target=\"ethos-u55-128\",\n", " system_config=\"Ethos_U55_High_End_Embedded\",\n", " memory_mode=\"Shared_Sram\",\n", - " extra_flags=\"--output-format=raw --debug-force-regor\"\n", + " extra_flags=[\"--output-format=raw\", \"--debug-force-regor\"]\n", " )\n", - "compile_spec = spec_builder.build()\n", "\n", "# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n", "quantizer = EthosUQuantizer(compile_spec)\n", @@ -242,7 +241,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv (3.10.15)", + "display_name": "et_env", "language": "python", "name": "python3" }, @@ -256,7 +255,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/arm/vgf_minimal_example.ipynb b/examples/arm/vgf_minimal_example.ipynb index 35378817a7d..36004f2c7cd 100644 --- a/examples/arm/vgf_minimal_example.ipynb +++ b/examples/arm/vgf_minimal_example.ipynb @@ -82,21 +82,15 @@ "metadata": {}, "outputs": [], "source": [ - "from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder\n", - "from executorch.backends.arm.tosa import ( \n", - " TosaSpecification,\n", - ")\n", + "from executorch.backends.arm.vgf import VgfCompileSpec\n", "\n", "# Create a compilation spec describing the floating point target.\n", - "tosa_spec = TosaSpecification.create_from_string(\"TOSA-1.0+FP\")\n", - "\n", - "spec_builder = ArmCompileSpecBuilder().vgf_compile_spec(tosa_spec)\n", - "compile_spec = spec_builder.build()\n", + "compile_spec = VgfCompileSpec(\"TOSA-1.0+FP\")\n", "\n", "_ = graph_module.print_readable()\n", "\n", "# Create a new exported program using the graph_module\n", - "exported_program = torch.export.export_for_training(graph_module, example_inputs)" + "exported_program = torch.export.export(graph_module, example_inputs)" ] }, { @@ -125,10 +119,7 @@ "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", "\n", "# Create a compilation spec describing the target for configuring the quantizer\n", - "tosa_spec = TosaSpecification.create_from_string(\"TOSA-1.0+INT\")\n", - "\n", - "spec_builder = ArmCompileSpecBuilder().vgf_compile_spec(tosa_spec)\n", - "compile_spec = spec_builder.build()\n", + "compile_spec = VgfCompileSpec(\"TOSA-1.0+INT\")\n", "\n", "# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n", "quantizer = VgfQuantizer(compile_spec)\n", @@ -143,7 +134,7 @@ "_ = quantized_graph_module.print_readable()\n", "\n", "# Create a new exported program using the quantized_graph_module\n", - "quantized_exported_program = torch.export.export_for_training(quantized_graph_module, example_inputs)" + "quantized_exported_program = torch.export.export(quantized_graph_module, example_inputs)" ] }, { From c996232bfbe76c69706b5bd59a3a6144ce44457f Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Wed, 6 Aug 2025 18:15:14 +0200 Subject: [PATCH 2/4] Arm backend: Support channels-last input and output - Insert transposes for input/output iff the incoming/outgoing data is in channels first format. - For testing using tosa_reference_mode, transpose numpy arrays to and from correct data format since numpy doesn't have the concept of dim_order. - Remove checks for channels_first only input. - Remove check for not changing dim_order before to_tosa_memory_format pass since the behaviour of channel last tensors is non-predictable. - Add dim order testing of example networks and mv2 - Add a section to the documentation about memory formats. Signed-off-by: Adrian Lundell Change-Id: I05548b9f3b4671da6faad90a9dd7366fda4498d6 --- .../arm/_passes/to_tosa_memory_format_pass.py | 111 +++++++--------- backends/arm/constants.py | 12 ++ .../to_dim_order_copy_support.py | 1 + backends/arm/process_node.py | 7 - backends/arm/runtime/EthosUBackend.cpp | 9 -- backends/arm/test/misc/test_dim_order.py | 123 ++++++++++++++++++ .../arm/test/misc/test_dim_order_guards.py | 67 ---------- .../arm/test/models/test_mobilenet_v2_arm.py | 17 +++ .../arm/test/models/test_torch_functions.py | 1 - .../test/passes/test_to_tosa_memory_format.py | 10 +- backends/arm/test/runner_utils.py | 108 ++++++++++----- docs/source/backends-arm-ethos-u.md | 9 ++ 12 files changed, 295 insertions(+), 180 deletions(-) create mode 100644 backends/arm/test/misc/test_dim_order.py delete mode 100644 backends/arm/test/misc/test_dim_order_guards.py diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index e4436d638f4..9294b54314c 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -9,13 +9,23 @@ import logging import torch -from executorch.backends.arm._passes import AnnotateOutputDimOrderPass +from executorch.backends.arm._passes.annotate_decomposed_matmul import ( + AnnotateDecomposedMatmulPass, +) from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, - get_output_dim_orders, is_param_node, ) +from executorch.backends.arm.constants import ( + HWCM_ORDER, + NCHW_ORDER, + NHWC_INVERSE_ORDER, + NHWC_ORDER, + NNCHW_ORDER, + NNHWC_INVERSE_ORDER, + NNHWC_ORDER, +) from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -38,12 +48,6 @@ class ToTosaMemoryFormatPass(ExportPass): The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape. """ - NHWC_order = (0, 2, 3, 1) - NHWC_inverse_order = (0, 3, 1, 2) - HWCM_order = (2, 3, 0, 1) - NNHWC_order = (0, 1, 3, 4, 2) - NNHWC_inverse_order = (0, 1, 4, 2, 3) - def __init__(self, exported_program: ExportedProgram) -> None: self.exported_program = exported_program super().__init__() @@ -135,9 +139,9 @@ def insert_input_transpose(node, input_node, graph_module): args=( input_node, list( - ToTosaMemoryFormatPass.NNHWC_inverse_order + NNHWC_INVERSE_ORDER if len(get_first_fake_tensor(input_node).size()) == 5 - else ToTosaMemoryFormatPass.NHWC_inverse_order + else NHWC_INVERSE_ORDER ), ), from_node=node, @@ -157,18 +161,18 @@ def insert_output_transpose(node, graph_module): args=( node, list( - ToTosaMemoryFormatPass.NNHWC_order + NNHWC_ORDER if len(get_first_fake_tensor(node).size()) == 5 - else ToTosaMemoryFormatPass.NHWC_order + else NHWC_ORDER ), ), from_node=node, ) permute_node.meta["tosa_dim_order"] = ( - ToTosaMemoryFormatPass.NNHWC_order + NNHWC_ORDER if len(get_first_fake_tensor(node).size()) == 5 - else ToTosaMemoryFormatPass.NHWC_order + else NHWC_ORDER ) node.meta["tosa_dim_order"] = tuple( range(len(get_first_fake_tensor(node).size())) @@ -218,7 +222,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: # call_function and placeholder allowed due to # index.Tensor being able to come in as both - if node.op not in ["call_function", "placeholder", "output"]: + if node.op != "call_function": continue # Transpose views @@ -240,21 +244,33 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): graph_module, ) - # Transpose inputs - elif _is_input(node, self.exported_program): - input_shape = get_first_fake_tensor(node).size() - if len(input_shape) in (4, 5): - ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module) + output_node = graph_module.graph.output_node() - # Transpose outputs - elif node.op == "output": - output_shape = get_first_fake_tensor(node).size() + # Transpose inputs if they are in (N)NCHW format + inputs = [ + n for n in graph_module.graph.nodes if _is_input(n, self.exported_program) + ] + for input_node in inputs: + input_dim_order = get_first_fake_tensor(input_node).dim_order() + if input_dim_order in (NCHW_ORDER, NNCHW_ORDER): + ToTosaMemoryFormatPass.insert_output_transpose(input_node, graph_module) + + # Transpose outputs if they are in (N)NCHW format + outputs = output_node.args[0] + output_dim_orders = output_node.meta.get("original_dim_orders") + if output_dim_orders is None: + raise RuntimeError( + f"{AnnotateDecomposedMatmulPass.__name__} is required to run at the beginning of the pass pipeline when using {ToTosaMemoryFormatPass.__name__}." + ) - if len(output_shape) in (4, 5): - for input_node in node.all_input_nodes: - ToTosaMemoryFormatPass.insert_input_transpose( - node, input_node, graph_module - ) + for output_node_input, output_dim_order in zip(outputs, output_dim_orders): # type: ignore[arg-type] + if output_dim_order in ( + NCHW_ORDER, + NNCHW_ORDER, + ): + ToTosaMemoryFormatPass.insert_input_transpose( + output_node, output_node_input, graph_module + ) def remove_dim_order_kwargs( self, graph_module: torch.fx.GraphModule, node: torch.fx.Node @@ -277,17 +293,17 @@ def call(self, graph_module: torch.fx.GraphModule): node_data = get_first_fake_tensor(node).data self.remove_dim_order_kwargs(graph_module, node) - # Inputs and outputs are always in (N)NCHW format + # Inputs and outputs may vary in dim_order if _is_input(node, self.exported_program) or node.op == "output": - dim_order = tuple(range(node_data.dim())) + dim_order = node_data.dim_order() elif node_data.dim() == 4: - dim_order = self.NHWC_order + dim_order = NHWC_ORDER if self.is_weight_node_for_depthwise_conv2d(node): # The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to # dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d). - dim_order = self.HWCM_order + dim_order = HWCM_ORDER elif node_data.dim() == 5: - dim_order = self.NNHWC_order + dim_order = NNHWC_ORDER else: dim_order = tuple(range(node_data.dim())) # type: ignore[assignment] @@ -300,32 +316,3 @@ def call(self, graph_module: torch.fx.GraphModule): graph_module = super().call(graph_module).graph_module return PassResult(graph_module, True) - - def requires(self, graph_module) -> None: - """ - This is the only pass which handles dim_orders, so verify that the output dim_orders has not changed since the beginning of the lowering pipeline. - """ - - dim_orders = get_output_dim_orders(graph_module) - original_dim_orders = graph_module.graph.output_node().meta.get( - "original_dim_orders" - ) - output_node = graph_module.graph.output_node() - - if original_dim_orders is None: - raise RuntimeError( - f"{AnnotateOutputDimOrderPass.__name__} must be run in the beginning of the pass pipeline to verify that the dim order has not changed unexpectedly during its run." - ) - - if len(dim_orders) != len(original_dim_orders): - raise RuntimeError( - f"The number of outputs has changed since {AnnotateOutputDimOrderPass.__name__} was run." - ) - - for node, dim_order, original_dim_order in zip( - output_node.args[0], dim_orders, original_dim_orders - ): - if dim_order != original_dim_order: - raise RuntimeError( - f"The dim order of output {node.name} has changed from {original_dim_order} to {dim_order} since {AnnotateOutputDimOrderPass.__name__} was run." - ) diff --git a/backends/arm/constants.py b/backends/arm/constants.py index fd8710d3ead..b9995410b23 100644 --- a/backends/arm/constants.py +++ b/backends/arm/constants.py @@ -29,3 +29,15 @@ DEQUANT_PER_TENSOR_OP_T, ) PER_CHANNEL_QDQ_OPS: Final = (QUANT_PER_CHANNEL_OP, DEQUANT_PER_CHANNEL_OP) + +NHWC_ORDER: Final = (0, 2, 3, 1) +NHWC_INVERSE_ORDER: Final = (0, 3, 1, 2) +NNHWC_ORDER: Final = (0, 1, 3, 4, 2) +NNHWC_INVERSE_ORDER: Final = (0, 1, 4, 2, 3) + +NCHW_ORDER: Final = (0, 1, 2, 3) +NCHW_INVERSE_ORDER: Final = (0, 2, 3, 1) +NNCHW_ORDER: Final = (0, 1, 2, 3, 4) +NNCHW_INVERSE_ORDER: Final = (0, 1, 3, 4, 2) + +HWCM_ORDER: Final = (2, 3, 0, 1) diff --git a/backends/arm/operator_support/to_dim_order_copy_support.py b/backends/arm/operator_support/to_dim_order_copy_support.py index e21f8a68ad6..ced9b7c5afc 100644 --- a/backends/arm/operator_support/to_dim_order_copy_support.py +++ b/backends/arm/operator_support/to_dim_order_copy_support.py @@ -89,6 +89,7 @@ def _merge_supported_types( torch.int32, torch.bfloat16, torch.float16, + torch.float32, ], } ALL_SUPPORTED_TYPES = _merge_supported_types( diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 9ca435c60c5..5093ea32d4c 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -70,13 +70,6 @@ def process_inputs( tosa_spec: TosaSpecification, ): """Serialize an input node""" - # inputs need to be in default dim_order (contiguous memory format) - meta = node.meta["val"] - if meta.dim_order() != tuple(range(meta.dim())): - raise RuntimeError( - f"Arm backend only supports contiguous memory format for inputs. " - f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}" - ) try: tosa_arg = TosaArg(node, tosa_spec) except ValueError as e: diff --git a/backends/arm/runtime/EthosUBackend.cpp b/backends/arm/runtime/EthosUBackend.cpp index bff5ff69284..46424dd97a8 100644 --- a/backends/arm/runtime/EthosUBackend.cpp +++ b/backends/arm/runtime/EthosUBackend.cpp @@ -249,15 +249,6 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { handles.inputs->io[i].elem_size); return Error::InvalidProgram; } - supported = executorch::runtime::is_contiguous_dim_order( - tensor_in.dim_order().data(), tensor_in.dim()); - if (!supported) { - ET_LOG( - Error, - "Input %d expected contiguous dim_order, but got non-contiguous dim_order", - i); - return Error::InvalidProgram; - } // Select a compatible copy routine including checking for input layouts // which require permutation. diff --git a/backends/arm/test/misc/test_dim_order.py b/backends/arm/test/misc/test_dim_order.py new file mode 100644 index 00000000000..6b0b79add99 --- /dev/null +++ b/backends/arm/test/misc/test_dim_order.py @@ -0,0 +1,123 @@ +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, +) + + +input_t1 = Tuple[torch.Tensor] # Input x + + +class ChannelsLastInput(torch.nn.Module): + """ + Test a complex case with (channels last, channels first) input, + and (channels first, channels last) output. + """ + + inputs: input_t1 = ( + torch.arange(1, 25, dtype=torch.float32) + .reshape((1, 2, 3, 4)) + .to(memory_format=torch.channels_last), + torch.arange(1, 25, dtype=torch.float32).reshape((1, 2, 3, 4)), + ) + + def forward(self, x, y): + x = x * x + return y, x + + +class ChannelsFirstOutput(torch.nn.Module): + """ + Test coverting to channels_first inside the delegate. + """ + + inputs: input_t1 = ( + torch.arange(1, 25, dtype=torch.float32) + .reshape((1, 2, 3, 4)) + .to(memory_format=torch.channels_last), + ) + + def forward(self, x): + x = x.clone(memory_format=torch.contiguous_format) * x + return x + + +class ChannelsLastOutput(torch.nn.Module): + """ + Test changing of dim_order inside the delegate. + """ + + inputs: input_t1 = (torch.arange(1, 9, dtype=torch.float32).reshape((1, 2, 2, 2)),) + + def forward(self, x): + x = x * x + x = x.clone(memory_format=torch.channels_last) + return x + + +class ChannelsLastInsidePartition(torch.nn.Module): + """ + Test dim_order changes inside the partiton, but no dim_order changes at input/output. + """ + + inputs: input_t1 = (torch.randn((1, 2, 3, 3)),) + + def __init__(self): + super().__init__() + self.conv2d = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=(3, 3)) + + def forward(self, x): + return ( + self.conv2d(x.clone(memory_format=torch.channels_last)).clone( + memory_format=torch.contiguous_format + ) + * 1 + ) + + +test_modules = { + "channels_last_input": ChannelsLastInput, + "channels_first_output": ChannelsFirstOutput, + "channels_last_output": ChannelsLastOutput, + "channels_last_inside_partition": ChannelsLastInsidePartition, +} + + +@common.parametrize("module", test_modules) +def test_dim_order_tosa_FP(module): + pipeline = TosaPipelineFP[input_t1](module(), module.inputs, []) + pipeline.run() + + +@common.parametrize("module", test_modules) +def test_dim_order_tosa_INT(module): + pipeline = TosaPipelineINT[input_t1]( + module(), module.inputs, [], symmetric_io_quantization=True + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +@common.parametrize("module", test_modules) +def test_dim_order_u55_INT(module): + pipeline = EthosU55PipelineINT[input_t1](module(), module.inputs, []) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("module", test_modules) +def test_dim_order_u85_INT(module): + pipeline = EthosU85PipelineINT[input_t1](module(), module.inputs, []) + pipeline.run() diff --git a/backends/arm/test/misc/test_dim_order_guards.py b/backends/arm/test/misc/test_dim_order_guards.py deleted file mode 100644 index 80a3c014abc..00000000000 --- a/backends/arm/test/misc/test_dim_order_guards.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import Tuple - -import pytest - -import torch -from executorch.backends.arm.test import common - -from executorch.backends.arm.test.tester.test_pipeline import ( - TosaPipelineFP, - TosaPipelineINT, -) - - -input_t1 = Tuple[torch.Tensor] # Input x - - -class Conv2D(torch.nn.Module): - inputs: dict[str, input_t1] = { - "randn": (torch.randn(1, 2, 20, 20).to(memory_format=torch.channels_last),), - } - - def __init__(self): - super().__init__() - self.conv2d = torch.nn.Conv2d(in_channels=2, out_channels=3, kernel_size=(3, 3)) - - def forward(self, x): - return self.conv2d(x) - - -@common.parametrize("test_data", Conv2D.inputs) -def test_tosa_FP_pipeline(test_data: input_t1): - module = Conv2D() - pipeline = TosaPipelineFP[input_t1]( - module, - test_data, - [], - [], - use_to_edge_transform_and_lower=False, - ) - pos = pipeline.find_pos("partition") - pipeline._stages = pipeline._stages[:pos] - pipeline.run() - with pytest.raises(RuntimeError): - pipeline.tester.partition() - - -@common.parametrize("test_data", Conv2D.inputs) -def test_tosa_INT_pipeline(test_data: input_t1): - module = Conv2D() - pipeline = TosaPipelineINT[input_t1]( - module, - test_data, - [], - [], - use_to_edge_transform_and_lower=False, - ) - pos = pipeline.find_pos("partition") - pipeline._stages = pipeline._stages[:pos] - pipeline.run() - with pytest.raises(RuntimeError): - pipeline.tester.partition() diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py index d4e3bbc8e28..84de432155e 100644 --- a/backends/arm/test/models/test_mobilenet_v2_arm.py +++ b/backends/arm/test/models/test_mobilenet_v2_arm.py @@ -46,6 +46,23 @@ def test_mv2_tosa_FP(): pipeline.run() +def test_mv2_tosa_FP_channels_last(): + input_tensor = model_inputs[0].to(memory_format=torch.channels_last) + pipeline = TosaPipelineFP[input_t]( + mv2, + (input_tensor,), + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + ) + # Changing memory format leads to an unsupported as_strided_copy op being inserted into the graph, + # leading to a graph break. + pipeline.change_args( + "check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2} + ) + pipeline.run() + + @common.parametrize("per_channel_quantization", quant_test_data) def test_mv2_tosa_INT(per_channel_quantization): pipeline = TosaPipelineINT[input_t]( diff --git a/backends/arm/test/models/test_torch_functions.py b/backends/arm/test/models/test_torch_functions.py index 580438f6da8..de45dbe0356 100644 --- a/backends/arm/test/models/test_torch_functions.py +++ b/backends/arm/test/models/test_torch_functions.py @@ -101,7 +101,6 @@ def forward(self, *args): "Requires dynamic output shape.", "topk": "NotImplementedError: No registered serialization name for found", "sort": "NotImplementedError: No registered serialization name for found", - "norm": "An error occurred when running the 'KeepDimsFalseToSqueezePass' pass after the following passes:", }, ) def test_torch_fns_FP(test_data): diff --git a/backends/arm/test/passes/test_to_tosa_memory_format.py b/backends/arm/test/passes/test_to_tosa_memory_format.py index 1e9b8ffc63d..643a3bf5733 100644 --- a/backends/arm/test/passes/test_to_tosa_memory_format.py +++ b/backends/arm/test/passes/test_to_tosa_memory_format.py @@ -6,7 +6,10 @@ from typing import Tuple import torch -from executorch.backends.arm._passes import ToTosaMemoryFormatPass +from executorch.backends.arm._passes import ( + AnnotateOutputDimOrderPass, + ToTosaMemoryFormatPass, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -177,7 +180,10 @@ def test_to_tosa_memory_format_tosa_INT(module): ops_after_pass=module.ops_after_pass, ops_not_after_pass=module.ops_not_after_pass, pass_list=[RemoveGetItemPass], - passes_with_exported_program=[ToTosaMemoryFormatPass], + passes_with_exported_program=[ + AnnotateOutputDimOrderPass, + ToTosaMemoryFormatPass, + ], ) pipeline.pop_stage( "run_method_and_compare_outputs" diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 1b59b186a2e..3d002eff25e 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -13,11 +13,19 @@ from pathlib import Path +from types import NoneType from typing import Any, cast, Dict, List, Literal, Optional, Tuple import numpy as np import torch +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.constants import ( + NHWC_INVERSE_ORDER, + NHWC_ORDER, + NNHWC_INVERSE_ORDER, + NNHWC_ORDER, +) from executorch.backends.arm.ethosu import EthosUCompileSpec from executorch.backends.arm.test.conftest import is_option_enabled @@ -157,6 +165,36 @@ def get_output_quantization_params( return quant_params +def torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: + dtype = _torch_to_numpy_dtype_dict[tensor.dtype] + array = tensor.detach().numpy().astype(dtype) + dim_order = tensor.dim_order() + if dim_order == NHWC_ORDER: + a = array.transpose(NHWC_ORDER) + return a + elif dim_order == NNHWC_ORDER: + return array.transpose(NNHWC_ORDER) + else: + return array + + +def numpy_to_torch_tensor(array: np.ndarray, output_node: Node) -> torch.Tensor: + output_tensor = get_first_fake_tensor(output_node) + shape = output_tensor.shape + dim_order = output_tensor.dim_order() + if dim_order == NHWC_ORDER: + shape_with_dim_order = [shape[i] for i in NHWC_ORDER] + tensor = torch.from_numpy(array).reshape(shape_with_dim_order) + return tensor.permute(NHWC_INVERSE_ORDER).to(memory_format=torch.channels_last) + elif dim_order == NNHWC_ORDER: + shape_with_dim_order = [shape[i] for i in NNHWC_ORDER] + tensor = torch.from_numpy(array).reshape(shape_with_dim_order) + return tensor.permute(NNHWC_INVERSE_ORDER).to(memory_format=torch.channels_last) + else: + tensor = torch.from_numpy(array).reshape(shape) + return tensor + + class TosaReferenceModelDispatch(TorchFunctionMode): """A context manager for executing call_delegate nodes using the reference model""" @@ -168,7 +206,8 @@ def _tosa_dispatch(self, lowered_backend_module: LoweredBackendModule, inputs): tosa_buffer = lowered_backend_module.processed_bytes compile_spec = TosaCompileSpec.from_list(lowered_backend_module.compile_specs) - return run_tosa_graph(tosa_buffer, compile_spec.tosa_spec, inputs) + output_node = lowered_backend_module.original_module.graph.output_node() + return run_tosa_graph(tosa_buffer, compile_spec.tosa_spec, inputs, output_node) def __exit__(self, exc_type, exc_val, exc_tb): super().__exit__(exc_type, exc_val, exc_tb) @@ -190,6 +229,22 @@ def __torch_function__(self, func, types, args=..., kwargs=None): ) kwargs = kwargs or {} + + # This is a hack since Q/DQ ops does not handle channels last input correctly: the simplest and most robust + # workaround is to simply run them in channels first format and then convert back to channels last. + if func in ( + torch.ops.quantized_decomposed.quantize_per_tensor.out, + torch.ops.quantized_decomposed.dequantize_per_tensor.out, + torch.ops.quantized_decomposed.quantize_per_channel.out, + torch.ops.quantized_decomposed.dequantize_per_channel.out, + ): + + input_dim_order = args[0].dim_order() + if input_dim_order in (NHWC_ORDER, NNHWC_ORDER): + args = [args[0].to(memory_format=torch.contiguous_format), *args[1:]] + res = func(*args, **kwargs) + return res.to(memory_format=torch.channels_last) + return func(*args, **kwargs) @@ -244,14 +299,13 @@ def get_output_from_file( output_np = [] output_node = exported_program.graph_module.graph.output_node() for i, node in enumerate(output_node.args[0]): - output_shape = node.meta["val"].shape output_dtype = node.meta["val"].dtype tosa_ref_output = np.fromfile( os.path.join(intermediate_path, f"{output_base_name}-{i}.bin"), _torch_to_numpy_dtype_dict[output_dtype], ) - output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape)) + output_np.append(numpy_to_torch_tensor(tosa_ref_output, node)) return tuple(output_np) @@ -437,11 +491,14 @@ def prep_data_for_save( quant_param: Optional[QuantizationParams] = None, ): if isinstance(data, torch.Tensor): - data_np = np.array(data.detach(), order="C").astype( - _torch_to_numpy_dtype_dict[data.dtype] - ) + data_np = torch_tensor_to_numpy(data) + elif isinstance(data, (int, float, bool, NoneType)): + return np.array(data) else: - data_np = np.array(data) + raise RuntimeError( + f"Input dtype {type(data)} could not be converted to numpy array." + ) + if quant_param is not None: assert quant_param.node_name in input_name, ( f"The quantization params name '{quant_param.node_name}' does not " @@ -455,30 +512,8 @@ def prep_data_for_save( f"{quant_param.dtype}".replace("torch.", "") ) # Use string format of dtype to convert to numpy dtype ) - return data_np - - -def save_npy( - path: str, - data, - input_name: str, - quant_param: Optional[QuantizationParams] = None, -) -> str: - """Serializes and saves 'data' as a .npy file, possibly quantizing it before. - - Parameters: - path: the directory where to save the data. - data: the data to save. - input_name: the name of the file, without file-ending. - quant_param: the parameters to use for quantization. - Returns: - the full file path of the output. - """ - data_np = prep_data_for_save(data, input_name, quant_param) - file_path = os.path.join(path, input_name + ".npy") - np.save(file_path, data_np, allow_pickle=False) - return file_path + return data_np def save_bytes( @@ -691,9 +726,12 @@ def run_tosa_graph( graph: Any, tosa_version: TosaSpecification, inputs: list[torch.Tensor], + output_node: Node, ) -> list[torch.Tensor]: """Runs the TOSA reference model with inputs and returns the result.""" - inputs_np = [input.numpy() for input in inputs] + + # Convert tensors to numpy arrays with correct dim_order + inputs_np = [torch_tensor_to_numpy(input_tensor) for input_tensor in inputs] if isinstance(tosa_version, Tosa_1_00): import tosa_reference_model as reference_model @@ -715,7 +753,13 @@ def run_tosa_graph( status == reference_model.GraphStatus.TOSA_VALID ), "Non-valid TOSA given to reference model." - return [torch.from_numpy(output) for output in outputs_np] + # Convert output numpy arrays to tensors with same dim_order as the output nodes + result = [ + numpy_to_torch_tensor(output_array, node) + for output_array, node in zip(outputs_np, output_node.args[0]) + ] + + return result def get_target_board(compile_spec: ArmCompileSpec) -> str | None: diff --git a/docs/source/backends-arm-ethos-u.md b/docs/source/backends-arm-ethos-u.md index ae14cb9901f..a3268bb2b0a 100644 --- a/docs/source/backends-arm-ethos-u.md +++ b/docs/source/backends-arm-ethos-u.md @@ -193,5 +193,14 @@ Then build the arm executorch runtime using the script Finally, run the elf file on FVP using the script `executorch/backends/arm/scripts/run_fvp.sh --elf=executorch/mv2_arm_ethos_u55/cmake-out/arm_executor_runner --target=ethos-u55-128`. +## Memory formats + +Tensors of rank 4 and higher have two differing [memory format](https://pytorch.org/blog/tensor-memory-format-matters/) standards used. +Pytorch defaults to contiguous/ channels first/ NCHW memory formats, compared to TOSA which only supports channels last/NHWC memory format. +To support this, the backend inserts a transpose in the beginning if the incoming memory format is contiguous, and correspondingly a +transpose in the end if the outgoing memory format is contiguous. Note that this means that you may avoid transposing the data unneccessarily if the runtime integration and +full network is converted to use channels last. A word of caution must be given here however - changing memory format has been noted to have side effects such as +unsupported ops being inserted into the graph, and it is currently not widely tested, so the feature must so far be viewed as experimental. + ## See Also - [Arm Ethos-U Backend Tutorial](tutorial-arm.md) From cc7d8995baba75b56a3f9bff90766e68b1356ebb Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Tue, 16 Sep 2025 09:59:21 +0200 Subject: [PATCH 3/4] Fix upsteam review comments Signed-off-by: Adrian Lundell --- backends/arm/_passes/to_tosa_memory_format_pass.py | 6 +++--- backends/arm/test/targets.bzl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 9294b54314c..73cd2e2954e 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -253,14 +253,14 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): for input_node in inputs: input_dim_order = get_first_fake_tensor(input_node).dim_order() if input_dim_order in (NCHW_ORDER, NNCHW_ORDER): - ToTosaMemoryFormatPass.insert_output_transpose(input_node, graph_module) + self.insert_output_transpose(input_node, graph_module) # Transpose outputs if they are in (N)NCHW format outputs = output_node.args[0] output_dim_orders = output_node.meta.get("original_dim_orders") if output_dim_orders is None: raise RuntimeError( - f"{AnnotateDecomposedMatmulPass.__name__} is required to run at the beginning of the pass pipeline when using {ToTosaMemoryFormatPass.__name__}." + f"{AnnotateDecomposedMatmulPass.__name__} is required to run at the beginning of the pass pipeline when using {self.__name__}." ) for output_node_input, output_dim_order in zip(outputs, output_dim_orders): # type: ignore[arg-type] @@ -268,7 +268,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): NCHW_ORDER, NNCHW_ORDER, ): - ToTosaMemoryFormatPass.insert_input_transpose( + self.insert_input_transpose( output_node, output_node_input, graph_module ) diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index f240855cdf4..7634eed7a53 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -39,7 +39,7 @@ def define_arm_tests(): "misc/test_bn_relu_folding_qat.py", "misc/test_custom_partition.py", "misc/test_debug_hook.py", - "misc/test_dim_order_guards.py", + "misc/test_dim_order.py", "misc/test_outputs_order.py", ] From 843e600c38a7b3c44e2a1353f636c3c3a635ec8d Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Tue, 16 Sep 2025 11:44:08 +0200 Subject: [PATCH 4/4] Fix mypy linter error --- backends/arm/_passes/to_tosa_memory_format_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 73cd2e2954e..ac16cbaf8cb 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -260,7 +260,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): output_dim_orders = output_node.meta.get("original_dim_orders") if output_dim_orders is None: raise RuntimeError( - f"{AnnotateDecomposedMatmulPass.__name__} is required to run at the beginning of the pass pipeline when using {self.__name__}." + f"{AnnotateDecomposedMatmulPass.__name__} is required to run at the beginning of the pass pipeline when using {ToTosaMemoryFormatPass.__name__}." ) for output_node_input, output_dim_order in zip(outputs, output_dim_orders): # type: ignore[arg-type]