diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index dc31fb12cde..37cd1256154 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -6,6 +6,8 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe + +import executorch.backends.arm.tosa.dialect # noqa: unused from executorch.backends.arm._passes import ( AddBiasPass, AnnotateChannelsLastDimOrder, diff --git a/backends/arm/tosa/dialect/lib.py b/backends/arm/tosa/dialect/lib.py new file mode 100644 index 00000000000..3c965418c72 --- /dev/null +++ b/backends/arm/tosa/dialect/lib.py @@ -0,0 +1,62 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +from executorch.exir.dialects._ops import _BACKEND_OP_LIB, ops as exir_ops +from torch.library import Library, register_fake +from torchgen.model import FunctionSchema + +# create a torch library for the TOSA dialect +# This defines a library to include Backend Dialect Operators in Executorch +tosa_lib = Library("tosa", "DEF") + + +def register_tosa_dialect_op(op_schema, func) -> Callable: + if tosa_lib.ns not in _BACKEND_OP_LIB: + _BACKEND_OP_LIB.append(tosa_lib.ns) + + if "::" in op_schema: + raise ValueError("The schema should not contain a namespace.") + + # Parse the op_schema into a FunctionSchema + func_schema = FunctionSchema.parse(op_schema) + overload_name = func_schema.name.overload_name + if overload_name: + raise ValueError( + "The TOSA dialect does not support overload names in the op schema." + ) + + opname = func_schema.name.name.base + tosa_lib.define(op_schema) + + overload_name = "default" + op_qualified_name = f"{tosa_lib.ns}::{opname}" + + register_fake(op_qualified_name, func, lib=tosa_lib) + + op = getattr(getattr(getattr(exir_ops.backend, tosa_lib.ns), opname), overload_name) + + # For now, since the TOSA operators are only used for lowering and serialization in the backend + # the op doesn't need to be callable. This can be changed in the future if needed to support + # execution of TOSA ops directly. + def not_callable(): + raise RuntimeError("TOSA dialect op is not callable") + + op.__equvalent_callable__ = not_callable + + return op + + +class TosaValueError(ValueError): + def __init__(self, message="A TOSA value error occurred", *args, **kwargs): + super().__init__(message, *args, **kwargs) + self.op = kwargs.get("op", None) + + def __str__(self): + base_message = super().__str__() + if self.op is not None: + return f"{base_message} (TOSA op: {self.op})" + return base_message diff --git a/backends/arm/tosa/dialect/ops_registration.py b/backends/arm/tosa/dialect/ops_registration.py new file mode 100644 index 00000000000..865eca6b21b --- /dev/null +++ b/backends/arm/tosa/dialect/ops_registration.py @@ -0,0 +1,68 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Callable, Iterable, List, ParamSpec, TypeVar + +from executorch.backends.arm.tosa.dialect.lib import register_tosa_dialect_op + +from executorch.backends.arm.tosa_specification import ( + get_context_spec, + TosaSpecification, +) + +P = ParamSpec("P") +R = TypeVar("R") + +# The list of registered ops are not yet used, except for registration +_tosa_registered_ops: dict[TosaSpecification, list[Callable]] = { + TosaSpecification.create_from_string("TOSA-1.0+FP"): [], + TosaSpecification.create_from_string("TOSA-1.0+INT"): [], +} + +# Mapping to ensure we only register a given function once. +_registered_tosa_ops_by_func: dict[Callable, Callable] = {} + + +def register_tosa_op( + op_schema: str, tosa_specs: Iterable[TosaSpecification] +) -> Callable[[Callable[P, R]], Callable[P, R]]: + """ + Decorator for registering a TOSA operation. + + Parameters: + op_schema : A string that defines the operation schema. + tosa_specs : Iterable of TOSA specification strings, + e.g. ("TOSA-1.0+INT", "TOSA-1.0+FP"). + + The decorated function is registered with the given op_schema by calling + register_tosa_dialect_op(op_schema, func) only once per function. The resulting + callable is then inserted into _tosa_registered_ops for each spec. + """ + + def decorator(func: Callable[P, R]) -> Callable[P, R]: + # Only call register_tosa_dialect_op if the function hasn't been registered yet. + if func not in _registered_tosa_ops_by_func: + op_callable = register_tosa_dialect_op(op_schema, func) + _registered_tosa_ops_by_func[func] = op_callable + else: + op_callable = _registered_tosa_ops_by_func[func] + + # For each TOSA spec, ensure the operation is added only once. + for spec in tosa_specs: + if spec not in _tosa_registered_ops: + raise ValueError(f"TOSA spec {spec} not listed for registrations") + if op_callable not in _tosa_registered_ops[spec]: + _tosa_registered_ops[spec].append(op_callable) + + # return the original function + return func + + return decorator + + +def get_registered_tosa_ops() -> List[Callable]: + tosa_spec = get_context_spec() + return _tosa_registered_ops[tosa_spec]