Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import executorch.backends.arm.tosa.dialect # noqa: unused
from executorch.backends.arm._passes import (
AddBiasPass,
AnnotateChannelsLastDimOrder,
Expand Down
62 changes: 62 additions & 0 deletions backends/arm/tosa/dialect/lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable

from executorch.exir.dialects._ops import _BACKEND_OP_LIB, ops as exir_ops
from torch.library import Library, register_fake
from torchgen.model import FunctionSchema

# create a torch library for the TOSA dialect
# This defines a library to include Backend Dialect Operators in Executorch
tosa_lib = Library("tosa", "DEF")


def register_tosa_dialect_op(op_schema, func) -> Callable:
if tosa_lib.ns not in _BACKEND_OP_LIB:
_BACKEND_OP_LIB.append(tosa_lib.ns)

if "::" in op_schema:
raise ValueError("The schema should not contain a namespace.")

# Parse the op_schema into a FunctionSchema
func_schema = FunctionSchema.parse(op_schema)
overload_name = func_schema.name.overload_name
if overload_name:
raise ValueError(
"The TOSA dialect does not support overload names in the op schema."
)

opname = func_schema.name.name.base
tosa_lib.define(op_schema)

overload_name = "default"
op_qualified_name = f"{tosa_lib.ns}::{opname}"

register_fake(op_qualified_name, func, lib=tosa_lib)

op = getattr(getattr(getattr(exir_ops.backend, tosa_lib.ns), opname), overload_name)

# For now, since the TOSA operators are only used for lowering and serialization in the backend
# the op doesn't need to be callable. This can be changed in the future if needed to support
# execution of TOSA ops directly.
def not_callable():
raise RuntimeError("TOSA dialect op is not callable")

op.__equvalent_callable__ = not_callable

return op


class TosaValueError(ValueError):
def __init__(self, message="A TOSA value error occurred", *args, **kwargs):
super().__init__(message, *args, **kwargs)
self.op = kwargs.get("op", None)

def __str__(self):
base_message = super().__str__()
if self.op is not None:
return f"{base_message} (TOSA op: {self.op})"
return base_message
68 changes: 68 additions & 0 deletions backends/arm/tosa/dialect/ops_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Callable, Iterable, List, ParamSpec, TypeVar

from executorch.backends.arm.tosa.dialect.lib import register_tosa_dialect_op

from executorch.backends.arm.tosa_specification import (
get_context_spec,
TosaSpecification,
)

P = ParamSpec("P")
R = TypeVar("R")

# The list of registered ops are not yet used, except for registration
_tosa_registered_ops: dict[TosaSpecification, list[Callable]] = {
TosaSpecification.create_from_string("TOSA-1.0+FP"): [],
TosaSpecification.create_from_string("TOSA-1.0+INT"): [],
}

# Mapping to ensure we only register a given function once.
_registered_tosa_ops_by_func: dict[Callable, Callable] = {}


def register_tosa_op(
op_schema: str, tosa_specs: Iterable[TosaSpecification]
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Decorator for registering a TOSA operation.

Parameters:
op_schema : A string that defines the operation schema.
tosa_specs : Iterable of TOSA specification strings,
e.g. ("TOSA-1.0+INT", "TOSA-1.0+FP").

The decorated function is registered with the given op_schema by calling
register_tosa_dialect_op(op_schema, func) only once per function. The resulting
callable is then inserted into _tosa_registered_ops for each spec.
"""

def decorator(func: Callable[P, R]) -> Callable[P, R]:
# Only call register_tosa_dialect_op if the function hasn't been registered yet.
if func not in _registered_tosa_ops_by_func:
op_callable = register_tosa_dialect_op(op_schema, func)
_registered_tosa_ops_by_func[func] = op_callable
else:
op_callable = _registered_tosa_ops_by_func[func]

# For each TOSA spec, ensure the operation is added only once.
for spec in tosa_specs:
if spec not in _tosa_registered_ops:
raise ValueError(f"TOSA spec {spec} not listed for registrations")
if op_callable not in _tosa_registered_ops[spec]:
_tosa_registered_ops[spec].append(op_callable)

# return the original function
return func

return decorator


def get_registered_tosa_ops() -> List[Callable]:
tosa_spec = get_context_spec()
return _tosa_registered_ops[tosa_spec]
Loading