diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 261ee045790..dd9f8b7bed0 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -166,12 +166,22 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul return self._transform(exported_program.graph_module) + def _tosa_1_0_int_quantized_pipeline(self, exported_program: ExportedProgram): + return self._tosa_080_BI_pipeline(exported_program) + + def _tosa_1_0_fp_pipeline(self, exported_program: ExportedProgram): + return self._tosa_080_MI_pipeline(exported_program) + def transform_to_backend_pipeline(self, exported_program: ExportedProgram): """Apply passes before transforming program to backend""" if self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+BI"): return self._tosa_080_BI_pipeline(exported_program) elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+MI"): return self._tosa_080_MI_pipeline(exported_program) + elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"): + return self._tosa_1_0_fp_pipeline(exported_program) + elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"): + return self._tosa_1_0_int_quantized_pipeline(exported_program) else: raise NotImplementedError( f"No pass pipeline implemented for {self.tosa_spec=}" diff --git a/backends/arm/operator_support/convolution_support.py b/backends/arm/operator_support/convolution_support.py index 9e13babe23a..75899eb7425 100644 --- a/backends/arm/operator_support/convolution_support.py +++ b/backends/arm/operator_support/convolution_support.py @@ -22,6 +22,8 @@ class ConvolutionSupported(SupportedTOSAOperatorCheck): tosa_specs = [ TosaSpecification.create_from_string("TOSA-0.80+BI"), TosaSpecification.create_from_string("TOSA-0.80+MI"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): diff --git a/backends/arm/operator_support/minmax_support.py b/backends/arm/operator_support/minmax_support.py index bdff368a5ce..86b949082eb 100644 --- a/backends/arm/operator_support/minmax_support.py +++ b/backends/arm/operator_support/minmax_support.py @@ -22,6 +22,7 @@ class MinMaxSupported(SupportedTOSAOperatorCheck): # TODO : "MLETORCH-718 : Quantization of indices in arm_quantizer" tosa_specs = [ TosaSpecification.create_from_string("TOSA-0.80+MI"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): diff --git a/backends/arm/operator_support/pool_2d_support.py b/backends/arm/operator_support/pool_2d_support.py index 8291ede8ad9..750fab2730d 100644 --- a/backends/arm/operator_support/pool_2d_support.py +++ b/backends/arm/operator_support/pool_2d_support.py @@ -41,6 +41,8 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck): tosa_specs = [ TosaSpecification.create_from_string("TOSA-0.80+BI"), TosaSpecification.create_from_string("TOSA-0.80+MI"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): @@ -94,6 +96,8 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck): tosa_specs = [ TosaSpecification.create_from_string("TOSA-0.80+BI"), TosaSpecification.create_from_string("TOSA-0.80+MI"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): diff --git a/backends/arm/operator_support/reduce_sum_support.py b/backends/arm/operator_support/reduce_sum_support.py index 37a71d7264c..a50bcbceab7 100644 --- a/backends/arm/operator_support/reduce_sum_support.py +++ b/backends/arm/operator_support/reduce_sum_support.py @@ -21,6 +21,8 @@ class SumSupported(SupportedTOSAOperatorCheck): tosa_specs = [ TosaSpecification.create_from_string("TOSA-0.80+BI"), TosaSpecification.create_from_string("TOSA-0.80+MI"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): diff --git a/backends/arm/operator_support/right_shift_support.py b/backends/arm/operator_support/right_shift_support.py index 212b877a59f..49976b2346f 100644 --- a/backends/arm/operator_support/right_shift_support.py +++ b/backends/arm/operator_support/right_shift_support.py @@ -29,6 +29,8 @@ class RightShiftSupported(SupportedTOSAOperatorCheck): tosa_specs = [ TosaSpecification.create_from_string("TOSA-0.80+BI"), TosaSpecification.create_from_string("TOSA-0.80+MI"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): diff --git a/backends/arm/operator_support/slice_copy_support.py b/backends/arm/operator_support/slice_copy_support.py index 477b3bfae4a..ea18c408149 100644 --- a/backends/arm/operator_support/slice_copy_support.py +++ b/backends/arm/operator_support/slice_copy_support.py @@ -25,6 +25,8 @@ class SliceCopySupported(SupportedTOSAOperatorCheck): tosa_specs = [ TosaSpecification.create_from_string("TOSA-0.80+BI"), TosaSpecification.create_from_string("TOSA-0.80+MI"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: # type: ignore[override, misc] diff --git a/backends/arm/operator_support/to_copy_support.py b/backends/arm/operator_support/to_copy_support.py index 7926b3dc053..aa0be8cfcd0 100644 --- a/backends/arm/operator_support/to_copy_support.py +++ b/backends/arm/operator_support/to_copy_support.py @@ -30,6 +30,8 @@ class ToCopySupported(SupportedTOSAOperatorCheck): tosa_specs = [ TosaSpecification.create_from_string("TOSA-0.80+BI"), TosaSpecification.create_from_string("TOSA-0.80+MI"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] SupportedTypeDict = dict[torch.dtype, list[torch.dtype]] diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index f84bde7fadc..0f500bd759e 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -66,6 +66,8 @@ def is_node_tosa_supported( _tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = { TosaSpecification.create_from_string("TOSA-0.80+BI"): [], TosaSpecification.create_from_string("TOSA-0.80+MI"): [], + TosaSpecification.create_from_string("TOSA-1.0+INT"): [], + TosaSpecification.create_from_string("TOSA-1.0+FP"): [], } diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index 72fb58f582c..5056c5f7f54 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -5,11 +5,10 @@ # pyre-unsafe -from typing import Dict, List +from typing import Any, Dict, List import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.export import ExportedProgram @@ -25,11 +24,18 @@ class NodeVisitor: # a specific TOSA version. # When all node_visitors has been refactored to target a specific # version, this list should be removed. - tosa_specs = [ + tosa_specs_1_00 = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + tosa_specs_0_80 = [ TosaSpecification.create_from_string("TOSA-0.80+BI"), TosaSpecification.create_from_string("TOSA-0.80+MI"), ] + tosa_specs = tosa_specs_0_80 + tosa_specs_1_00 + def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification): self._exported_program = exported_program self.tosa_spec = tosa_spec @@ -37,7 +43,7 @@ def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecificati def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: @@ -48,6 +54,8 @@ def define_node( _node_visitor_dicts: Dict[TosaSpecification, Dict] = { TosaSpecification.create_from_string("TOSA-0.80+BI"): {}, TosaSpecification.create_from_string("TOSA-0.80+MI"): {}, + TosaSpecification.create_from_string("TOSA-1.0+INT"): {}, + TosaSpecification.create_from_string("TOSA-1.0+FP"): {}, } diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 930c834f232..6692b75c892 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -5,15 +5,18 @@ # # pyre-unsafe -from typing import cast, Dict +from typing import Any, cast, Dict import numpy as np import torch import torch.fx -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import NodeVisitor from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.arm.tosa_specification import ( + Tosa_0_80, + Tosa_1_00, + TosaSpecification, +) from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape from torch._export.utils import ( get_buffer, @@ -28,7 +31,7 @@ def process_call_function( node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, node_visitors: Dict[str, NodeVisitor], tosa_spec: TosaSpecification, ): @@ -63,7 +66,7 @@ def process_call_function( def process_inputs( node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, tosa_spec: TosaSpecification, ): """Serialize an input node""" @@ -81,6 +84,14 @@ def process_inputs( f"Failed processing input placeholder: {node.name}. " "Is the original torch function supported?" ) from e + + if isinstance(tosa_spec, Tosa_0_80): + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + elif isinstance(tosa_spec, Tosa_1_00): + import serializer.tosa_serializer as ts + else: + raise ValueError(f"Unsupported TOSA spec: {tosa_spec}") + input_shape = tosa_arg.shape input_dim_order = tosa_arg.dim_order tensor = ts.TosaSerializerTensor( @@ -95,7 +106,7 @@ def process_inputs( def process_inputs_to_parameters( node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, edge_program: ExportedProgram, tosa_spec: TosaSpecification, ): @@ -124,7 +135,7 @@ def process_inputs_to_parameters( def process_inputs_to_buffers( node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, edge_program: ExportedProgram, ): """Serialize quantized weights""" @@ -152,7 +163,7 @@ def process_inputs_to_buffers( def process_inputs_to_lifted_tensor_constants( node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, edge_program: ExportedProgram, ): try: @@ -172,7 +183,7 @@ def process_inputs_to_lifted_tensor_constants( def process_placeholder( node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, edge_program: ExportedProgram, tosa_spec: TosaSpecification, ): @@ -198,7 +209,7 @@ def process_placeholder( def process_output( node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, ): for output in cast(tuple[torch.fx.Node, ...], node.args[0]): tosa_graph.addOutputTensor( diff --git a/backends/arm/scripts/install_reference_model.sh b/backends/arm/scripts/install_reference_model.sh index 796a1ed418e..0141b195a0d 100755 --- a/backends/arm/scripts/install_reference_model.sh +++ b/backends/arm/scripts/install_reference_model.sh @@ -13,7 +13,7 @@ tosa_reference_model_url="https://git.gitlab.arm.com/tosa/tosa-reference-model.g tosa_reference_model_0_80_branch="v0.80" tosa_reference_model_0_80_rev="70ed0b40fa831387e36abdb4f7fb9670a3464f5a" tosa_serialization_lib_0_80_rev="v0.80.1" -tosa_reference_model_1_0_rev="v1.0" +tosa_reference_model_1_0_rev="f9b4ceb850964be03a39e965ad7a0546dc6c57ae" script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) @@ -47,6 +47,9 @@ function setup_tosa_reference_model() { # Vela's flatbuffer requirement is expected to loosen, then remove this. MLETORCH-565 CMAKE_POLICY_VERSION_MINIMUM=3.5 pip install . --no-dependencies flatbuffers popd + + # Install the 1.0 branch from upstream + CMAKE_POLICY_VERSION_MINIMUM=3.5 BUILD_PYBIND=1 pip install "tosa-tools@git+${tosa_reference_model_url}@${tosa_reference_model_1_0_rev}" ml_dtypes==0.5.1 --no-dependencies flatbuffers } setup_tosa_reference_model $1 diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 16f2e85e1d2..4481a9c7cc2 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -13,27 +13,24 @@ from pathlib import Path -from typing import cast, Dict, List, Literal, Optional, Tuple +from typing import Any, cast, Dict, List, Literal, Optional, Tuple import numpy as np import torch -try: - import tosa_tools.v0_80.tosa_reference_model as tosa_reference_model -except ImportError: - tosa_reference_model = None from executorch.backends.arm.arm_backend import get_tosa_spec, is_tosa - from executorch.backends.arm.test.conftest import is_option_enabled -from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.arm.tosa_specification import ( + Tosa_0_80, + Tosa_1_00, + TosaSpecification, +) from executorch.exir import ExecutorchProgramManager, ExportedProgram from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.lowered_backend_module import LoweredBackendModule -from packaging.version import Version from torch.fx.node import Node from torch.overrides import TorchFunctionMode -from tosa_tools.v0_80.tosa import TosaGraph logger = logging.getLogger(__name__) @@ -566,7 +563,7 @@ def arm_executor_runner_exists(target_board): def run_tosa_graph( - graph: TosaGraph, + graph: Any, tosa_version: TosaSpecification, inputs: list[torch.Tensor], ) -> list[torch.Tensor]: @@ -574,25 +571,38 @@ def run_tosa_graph( inputs_np = [input.numpy() for input in inputs] transpose_data_format(inputs_np, to="NHWC") - tosa_release = tosa_version.version - - if tosa_release > Version("0.80"): - logger.warning("The reference model is only tested for TOSA v0.80") - - # tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training. - tosa_profile = 1 if tosa_version.support_float() else 0 - debug_mode = "ALL" if logger.level <= logging.DEBUG else None - outputs_np, status = tosa_reference_model.run( - graph, - inputs_np, - verbosity=_tosa_refmodel_loglevel(logger.level), - tosa_profile=tosa_profile, - initialize_variable_tensor_from_numpy=1, # True - debug_mode=debug_mode, - ) + if isinstance(tosa_version, Tosa_0_80): + import tosa_tools.v0_80.tosa_reference_model as reference_model + + # tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training. + tosa_profile = 1 if tosa_version.support_float() else 0 + debug_mode = "ALL" if logger.level <= logging.DEBUG else None + outputs_np, status = reference_model.run( + graph, + inputs_np, + verbosity=_tosa_refmodel_loglevel(logger.level), + tosa_profile=tosa_profile, + initialize_variable_tensor_from_numpy=True, + debug_mode=debug_mode, + ) + elif isinstance(tosa_version, Tosa_1_00): + import tosa_reference_model as reference_model + + debug_mode = "ALL" if logger.level <= logging.DEBUG else None + outputs_np, status = reference_model.run( + graph, + inputs_np, + verbosity=_tosa_refmodel_loglevel(logger.level), + initialize_variable_tensor_from_numpy=True, + debug_mode=debug_mode, + ) + else: + raise ValueError( + f"Unknown TOSA specification: {tosa_version}. No refererence model available to run for this specification version" + ) assert ( - status == tosa_reference_model.GraphStatus.TOSA_VALID + status == reference_model.GraphStatus.TOSA_VALID ), "Non-valid TOSA given to reference model." transpose_data_format(outputs_np, to="NCHW") diff --git a/backends/arm/tosa_backend.py b/backends/arm/tosa_backend.py index adb4fba1fc8..ad16b0d84df 100644 --- a/backends/arm/tosa_backend.py +++ b/backends/arm/tosa_backend.py @@ -13,7 +13,7 @@ import logging from typing import cast, final, List -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore +import executorch.backends.arm.tosa_specification as tosa_specification from executorch.backends.arm.arm_backend import get_tosa_spec from executorch.backends.arm.operators.node_visitor import get_node_visitors @@ -88,7 +88,22 @@ def preprocess( # noqa: C901 # Converted output for this subgraph, serializer needs path early as it emits # const data directly. Path created and data written only in debug builds. + if isinstance(tosa_spec, tosa_specification.Tosa_0_80): + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + elif isinstance(tosa_spec, tosa_specification.Tosa_1_00): + import serializer.tosa_serializer as ts # type: ignore + else: + raise RuntimeError( + f"Unknown TOSA version {tosa_spec}, no pip package installed to handle serialization to that version." + ) + tosa_graph = ts.TosaSerializer(artifact_path) + + assert ( + tosa_spec.version.major == ts.TOSA_VERSION_MAJOR + and tosa_spec.version.minor == ts.TOSA_VERSION_MINOR + ), f"TOSA serializer version ({ts.TOSA_VERSION_MAJOR}.{ts.TOSA_VERSION_MINOR}) doesn't match specification {tosa_spec}" + graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline( # type: ignore exported_program=edge_program ) diff --git a/backends/arm/tosa_specification.py b/backends/arm/tosa_specification.py index 94c307d440c..640361e059c 100644 --- a/backends/arm/tosa_specification.py +++ b/backends/arm/tosa_specification.py @@ -142,7 +142,7 @@ class Tosa_1_00(TosaSpecification): available_profiles = ["INT", "FP"] valid_extensions = { - "INT": ["int16", "int4", "var", "cf"], + "INT": ["int16", "int4", "var", "cf", "u55"], "FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"], }