From a5ec1b25d0cdefd81ad313a8a00b2847d1810903 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Wed, 2 Apr 2025 12:35:23 +0200 Subject: [PATCH] Arm backend: Introduce TOSA backend dialect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a TOSA backend dialect so it's possible to convert Edge IR to specific TOSA operators in the ARM backends. This is done to enable control of types and additional arguments available for TOSA operators in comparison with the Edge IR. The operators are registered into the exir.backend.tosa namespace and is only traceable, not executable since it's only used in the lowering step to a TOSA serialization format. Signed-off-by: Per Åstrand Change-Id: I79f1fd4dbc00465f329df5b0f31d0788d729e0a9 --- backends/arm/_passes/arm_pass_manager.py | 2 + backends/arm/tosa/dialect/lib.py | 62 +++++++++++++++++ backends/arm/tosa/dialect/ops_registration.py | 68 +++++++++++++++++++ 3 files changed, 132 insertions(+) create mode 100644 backends/arm/tosa/dialect/lib.py create mode 100644 backends/arm/tosa/dialect/ops_registration.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index f4a8af27ff8..278e26c09ea 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -6,6 +6,8 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe + +import executorch.backends.arm.tosa.dialect # noqa: unused from executorch.backends.arm._passes import ( AddBiasPass, AnnotateChannelsLastDimOrder, diff --git a/backends/arm/tosa/dialect/lib.py b/backends/arm/tosa/dialect/lib.py new file mode 100644 index 00000000000..3c965418c72 --- /dev/null +++ b/backends/arm/tosa/dialect/lib.py @@ -0,0 +1,62 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +from executorch.exir.dialects._ops import _BACKEND_OP_LIB, ops as exir_ops +from torch.library import Library, register_fake +from torchgen.model import FunctionSchema + +# create a torch library for the TOSA dialect +# This defines a library to include Backend Dialect Operators in Executorch +tosa_lib = Library("tosa", "DEF") + + +def register_tosa_dialect_op(op_schema, func) -> Callable: + if tosa_lib.ns not in _BACKEND_OP_LIB: + _BACKEND_OP_LIB.append(tosa_lib.ns) + + if "::" in op_schema: + raise ValueError("The schema should not contain a namespace.") + + # Parse the op_schema into a FunctionSchema + func_schema = FunctionSchema.parse(op_schema) + overload_name = func_schema.name.overload_name + if overload_name: + raise ValueError( + "The TOSA dialect does not support overload names in the op schema." + ) + + opname = func_schema.name.name.base + tosa_lib.define(op_schema) + + overload_name = "default" + op_qualified_name = f"{tosa_lib.ns}::{opname}" + + register_fake(op_qualified_name, func, lib=tosa_lib) + + op = getattr(getattr(getattr(exir_ops.backend, tosa_lib.ns), opname), overload_name) + + # For now, since the TOSA operators are only used for lowering and serialization in the backend + # the op doesn't need to be callable. This can be changed in the future if needed to support + # execution of TOSA ops directly. + def not_callable(): + raise RuntimeError("TOSA dialect op is not callable") + + op.__equvalent_callable__ = not_callable + + return op + + +class TosaValueError(ValueError): + def __init__(self, message="A TOSA value error occurred", *args, **kwargs): + super().__init__(message, *args, **kwargs) + self.op = kwargs.get("op", None) + + def __str__(self): + base_message = super().__str__() + if self.op is not None: + return f"{base_message} (TOSA op: {self.op})" + return base_message diff --git a/backends/arm/tosa/dialect/ops_registration.py b/backends/arm/tosa/dialect/ops_registration.py new file mode 100644 index 00000000000..865eca6b21b --- /dev/null +++ b/backends/arm/tosa/dialect/ops_registration.py @@ -0,0 +1,68 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Callable, Iterable, List, ParamSpec, TypeVar + +from executorch.backends.arm.tosa.dialect.lib import register_tosa_dialect_op + +from executorch.backends.arm.tosa_specification import ( + get_context_spec, + TosaSpecification, +) + +P = ParamSpec("P") +R = TypeVar("R") + +# The list of registered ops are not yet used, except for registration +_tosa_registered_ops: dict[TosaSpecification, list[Callable]] = { + TosaSpecification.create_from_string("TOSA-1.0+FP"): [], + TosaSpecification.create_from_string("TOSA-1.0+INT"): [], +} + +# Mapping to ensure we only register a given function once. +_registered_tosa_ops_by_func: dict[Callable, Callable] = {} + + +def register_tosa_op( + op_schema: str, tosa_specs: Iterable[TosaSpecification] +) -> Callable[[Callable[P, R]], Callable[P, R]]: + """ + Decorator for registering a TOSA operation. + + Parameters: + op_schema : A string that defines the operation schema. + tosa_specs : Iterable of TOSA specification strings, + e.g. ("TOSA-1.0+INT", "TOSA-1.0+FP"). + + The decorated function is registered with the given op_schema by calling + register_tosa_dialect_op(op_schema, func) only once per function. The resulting + callable is then inserted into _tosa_registered_ops for each spec. + """ + + def decorator(func: Callable[P, R]) -> Callable[P, R]: + # Only call register_tosa_dialect_op if the function hasn't been registered yet. + if func not in _registered_tosa_ops_by_func: + op_callable = register_tosa_dialect_op(op_schema, func) + _registered_tosa_ops_by_func[func] = op_callable + else: + op_callable = _registered_tosa_ops_by_func[func] + + # For each TOSA spec, ensure the operation is added only once. + for spec in tosa_specs: + if spec not in _tosa_registered_ops: + raise ValueError(f"TOSA spec {spec} not listed for registrations") + if op_callable not in _tosa_registered_ops[spec]: + _tosa_registered_ops[spec].append(op_callable) + + # return the original function + return func + + return decorator + + +def get_registered_tosa_ops() -> List[Callable]: + tosa_spec = get_context_spec() + return _tosa_registered_ops[tosa_spec]