Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions backends/arm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,9 @@ It is possible to control the compilation flow to aid in development and debug o
Configuration of the EthosUBackend export flow is controlled by CompileSpec information (essentially used as compilation flags) to determine which of these outputs is produced. In particular this allows for use of the tosa_reference_model to run intermediate output to check for correctness and quantization accuracy without a full loop via hardware implemntation.

As this is in active development see the EthosUBackend for accurate information on [compilation flags](https://github.com/pytorch/executorch/blob/29f6dc9353e90951ed3fae3c57ae416de0520067/backends/arm/arm_backend.py#L319-L324)

## Model specific and optional passes
The current TOSA version does not support int64. For LLMs for example LLama, often aten.emedding is the first operator and it requires int64 indicies.
In order to lower this to TOSA and int64->int32 cast need to be injected. This pass need to run very early in the lowering process and can be passed in to the to_edge_transform_and_lower() function call as an optional parameter. See example in: backends/arm/test/models/test_llama.py.
By doing this aten.embedding will be decomposed into to aten.index_select which can handle int32 indices.
Note that this additional step is only needed for pure float models. With quantization this is automatically handled during annotation before the export stage.
4 changes: 4 additions & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .convert_to_clamp import ConvertToClampPass # noqa
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
from .decompose_div_pass import DecomposeDivPass # noqa
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
from .decompose_gelu_pass import DecomposeGeluPass # noqa
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
Expand All @@ -46,6 +47,9 @@
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
from .insert_int64_input_cast_pass import ( # noqa # noqa
InsertCastForOpsWithInt64InputPass,
)
from .insert_rescales_pass import InsertRescalePass # noqa
from .insert_table_ops import InsertTableOpsPass # noqa
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
Expand Down
6 changes: 5 additions & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from executorch.backends.arm._passes import (
AnnotateChannelsLastDimOrder,
AnnotateDecomposedMatmulPass,
Expand All @@ -26,6 +25,7 @@
ConvertToClampPass,
DecomposeCosineSimilarityPass,
DecomposeDivPass,
DecomposeEmbeddingPass,
DecomposeGeluPass,
DecomposeGroupNormPass,
DecomposeLayerNormPass,
Expand All @@ -46,6 +46,7 @@
FuseConstantArgsPass,
FuseEqualPlaceholdersPass,
FuseQuantizedActivationPass,
InsertCastForOpsWithInt64InputPass,
InsertRescalePass,
InsertTableOpsPass,
MatchArgRanksPass,
Expand Down Expand Up @@ -139,6 +140,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(DecomposeSqrtPass())
self.add_pass(ConvertIntPowToMuls())
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
self.add_pass(DecomposeEmbeddingPass())
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(ConvertSplitToSlicePass())
Expand Down Expand Up @@ -211,6 +213,8 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
)

def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(InsertCastForOpsWithInt64InputPass())
self.add_pass(DecomposeEmbeddingPass())
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
self.add_pass(ScalarsToAttributePass())
Expand Down
120 changes: 120 additions & 0 deletions backends/arm/_passes/decompose_embedding_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import logging
from math import prod

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

from .arm_pass_utils import create_node, get_first_fake_tensor

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


class DecomposeEmbeddingPass(ExportPass):
"""
This pass decomposes embedding into index_select.

Example:
o = embedding(w, i)
Becomes:
i = view_copy(i) # flatten indices
o = index_select(w, i)
o = view_copy(o) # reshape back output
Note:
i = indices is expected to be int32 before this pass
"""

aten_ops = (torch.ops.aten.embedding.default,)
edge_ops = (exir_ops.edge.aten.embedding.default,)

def get_decomposition(self, op):
if op in self.aten_ops:
return (
torch.ops.aten.view_copy.default,
torch.ops.aten.index_select.default,
)

if op in self.edge_ops:
return (
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.index_select.default,
)
raise RuntimeError(
f"[{self.__class__.__name__}] Can't get decomposition for op {op}"
)

def call(self, graph_module):
graph = graph_module.graph
modified_graph = False

for node in graph.nodes:
if node.op != "call_function":
continue
if node.target not in self.aten_ops + self.edge_ops:
continue

args = node.args

weights = args[0]
indices = args[1]

weights_shape = get_first_fake_tensor(weights).shape
indices_shape = get_first_fake_tensor(indices).shape

output_shape = torch.Size(list(indices_shape) + [weights_shape[1]])
if output_shape != get_first_fake_tensor(node).shape:
raise RuntimeError(
f"[{self.__class__.__name__}] Unexpected output shape mismatch {output_shape} "
"!= {get_first_fake_tensor(node).shape}"
)

view_copy_op, index_select_op = self.get_decomposition(node.target)

with graph.inserting_before(node):
reshaped_indices = [prod(list(indices_shape))]
flattened_indices = create_node(
graph=graph,
op_target=view_copy_op,
args=(indices, reshaped_indices),
)
node.replace_input_with(indices, flattened_indices)

index_select = create_node(
graph=graph,
op_target=index_select_op,
args=(weights, 0, flattened_indices),
)
node.replace_all_uses_with(index_select)
graph.erase_node(node)

with graph.inserting_after(index_select):
restored_output = create_node(
graph,
view_copy_op,
)
restored_output.args = (
index_select,
output_shape,
)
original_users = [
user for user in index_select.users if user != restored_output
]
for user in original_users:
user.replace_input_with(index_select, restored_output)

modified_graph = True

if modified_graph:
graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, modified_graph)
94 changes: 94 additions & 0 deletions backends/arm/_passes/insert_int64_input_cast_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import logging

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

from .arm_pass_utils import create_node, get_first_fake_tensor

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


class InsertCastForOpsWithInt64InputPass(ExportPass):

aten_ops = (torch.ops.aten.embedding.default,)
edge_ops = (exir_ops.edge.aten.embedding.default,)

def get_decomposition(self, op):
if op in self.edge_ops:
return exir_ops.edge.aten._to_copy.default

if op in self.aten_ops:
return torch.ops.aten._to_copy.default

raise RuntimeError(
f"[{self.__class__.__name__}] Can't get decomposition for op {op}"
)

def _check_aten_embedding_within_int32(self, weights, indices, node: torch.fx.Node):
weights_shape = get_first_fake_tensor(weights).shape
vocab_size = weights_shape[0]

# Essentially output = weight[indices] which means 0 <= indices[i] < vocab_size
# So should be good if vocab size or number embeddings is below max int32
if vocab_size >= torch.iinfo(torch.int32).max:
logger.warning(
f"[{node.name}] has size ({vocab_size}) that exceeds int32 limit,"
"so aten.embedding will not be lowered to TOSA."
)
return False

return True

def call(self, graph_module):
graph = graph_module.graph
modified_graph = False

for node in list(graph.nodes):
if node.op != "call_function":
continue
if node.target not in self.aten_ops + self.edge_ops:
continue

args = node.args
weights = args[0]
indices = args[1]

valid_for_insert = False
if node.target in (
exir_ops.edge.aten.embedding.default,
torch.ops.aten.embedding.default,
):
valid_for_insert = self._check_aten_embedding_within_int32(
weights, indices, node
)

if valid_for_insert:
to_copy_op = self.get_decomposition(node.target)
with graph.inserting_before(node):
cast_before = create_node(
graph,
to_copy_op,
args=(indices,),
kwargs={
"dtype": torch.int32,
"memory_format": torch.preserve_format,
},
)
node.replace_input_with(indices, cast_before)

modified_graph = True

if modified_graph:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)
2 changes: 2 additions & 0 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

from . import ( # noqa
convolution_support,
embedding_support,
ethos_u55_support,
index_select_support,
minmax_support,
pool_2d_support,
reduce_sum_support,
Expand Down
47 changes: 47 additions & 0 deletions backends/arm/operator_support/embedding_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch

import torch.fx as fx
from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops


@register_tosa_support_check
class EmbeddingSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.embedding.default]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool: # type: ignore[override, misc]
# Note aten.embedding.default requires int64 indices and TOSA does not support it.
# Int32 indices here for aten.embedding.default is ok since it will be decomposed into ops that can handle it.
assert (
len(node.all_input_nodes) == 2
), "Number of inputs to aten.embedding is not 2"
indices_val = node.all_input_nodes[1].meta["val"]
indices_dtype = indices_val.dtype

if indices_dtype != torch.int32:
self.reporter.report_reject(
node,
f"Indices dtype {indices_val.dtype} is not supported in {node.target}.",
)
return False

return True
50 changes: 50 additions & 0 deletions backends/arm/operator_support/index_select_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.fx as fx
from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops


@register_tosa_support_check
class IndexSelectSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.index_select.default]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool: # type: ignore[override, misc]

weights_shape = node.all_input_nodes[0].meta["val"].shape
indices_val = node.all_input_nodes[1].meta["val"]
indices_dtype = indices_val.dtype

if indices_dtype != torch.int32:
self.reporter.report_reject(
node,
f"Indices dtype {indices_val.dtype} is not supported in {node.target}.",
)
return False

if not (
len(weights_shape) == 2
or (len(weights_shape) == 3 and weights_shape[0] == 1)
):
self.reporter.report_reject(
node, f"{node.target} with weights shape {weights_shape} not supported."
)
return False
return True
Loading
Loading