Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions backends/arm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,6 @@ The current TOSA version does not support int64. However, int64 is commonly used
- For quantized models, these transformations will be automatically handled during annotation before the export stage.

List of model specific and optional passes:
- InsertCastForOpsWithInt64InputPass
- Functionality:
- For LLMs such as LLama, some opeartors like aten.embedding have int64 input. In order to lower these operators to TOSA, this pass will insert a casting node that converts the input from int64 to int32.
- Supported Ops:
- aten.embedding.default, aten.slice_copy.Tensor
- Example usage:
- backends/arm/test/models/test_llama.py

- ConvertInt64ConstOpsToInt32Pass
- Functionalities:
- Rewrites constant-producing ops that output int64 to instead output int32, when values are within int32 bounds.
Expand Down Expand Up @@ -244,3 +236,16 @@ List of model specific and optional passes:
- Example usage:
- (Functionality 1) backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py
- (Functionality 2) backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py

- InsertInt32CastsAfterInt64PlaceholdersPass
- Functionalities:
- Inserts an int64 -> int32 cast immediately after each int64 placeholder (graph input).
- Redirects all uses of each int64 placeholder to its int32 cast output.
- Inserts local int32 -> int64 casts at call sites where an operator requires int64 inputs, e.g. `torch.nn.functional.one_hot`
- Pass ordering:
- When used with `ConvertInt64ConstOpsToInt32Pass` and `ConvertInt64OutputOpsToInt32Pass`, run this pass last.
- Rationale: Those passes may cause retracing to re-infer some int64 placeholders as int32. Running this pass last casts only inputs that remain int64, minimizing inserted casts.
- Example usage:
- backends/arm/test/models/test_llama.py
- backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
- backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py
4 changes: 2 additions & 2 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
from .insert_int64_input_cast_pass import ( # noqa # noqa
InsertCastForOpsWithInt64InputPass,
from .insert_int32_casts_after_int64_placeholders import ( # noqa
InsertInt32CastsAfterInt64PlaceholdersPass,
)
from .insert_rescales_pass import InsertRescalePass # noqa
from .insert_table_ops import InsertTableOpsPass # noqa
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
FuseConstantArgsPass,
FuseEqualPlaceholdersPass,
FuseQuantizedActivationPass,
InsertCastForOpsWithInt64InputPass,
InsertInt32CastsAfterInt64PlaceholdersPass,
InsertRescalePass,
InsertTableOpsPass,
MatchArgDtypePass,
Expand Down Expand Up @@ -277,7 +277,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph
self.add_pass(ConvertInt64ConstOpsToInt32Pass())
self.add_pass(ConvertInt64OutputOpsToInt32Pass())
self.add_pass(InsertCastForOpsWithInt64InputPass())
self.add_pass(InsertInt32CastsAfterInt64PlaceholdersPass())
self.add_pass(DecomposeEmbeddingPass())
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeRoundPass())
Expand Down
122 changes: 122 additions & 0 deletions backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import logging

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import EdgeOpOverload, ExportPass, PassResult
from torch._subclasses.fake_tensor import FakeTensor


logger = logging.getLogger(__name__)


class InsertInt32CastsAfterInt64PlaceholdersPass(ExportPass):
"""
Insert an int64->int32 cast after each int64 placeholder.

Note: Overflow checks are not applied in this pass. It is the user's responsibility to ensure that values fit within
the int32 range.
"""

# Ops that require i64 inputs → positions of args to upcast.
# Key: op overload; Value: zero-based indices of positional args that must be i64.
I64_INPUT_ARG_POSITIONS = {
torch.ops.aten.one_hot.default: (0,),
}

def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule):
"""
If an operator requires int64 inputs but dtype propagation (via call_operator)
produced int32, insert a local int32→int64 cast at the call site to satisfy
PyTorch's operator input validation.
"""
modified = False
graph = graph_module.graph
for node in graph.nodes:
if node.op != "call_function":
continue
if node.target not in self.I64_INPUT_ARG_POSITIONS:
continue

with graph.inserting_before(node):
arg_positions = self.I64_INPUT_ARG_POSITIONS.get(node.target)
args_list = list(node.args)
for pos in arg_positions: # type: ignore[union-attr]
input_arg = args_list[pos]
to_copy_op = self._get_decomposition(graph)
cast_node = graph_module.graph.create_node(
"call_function",
to_copy_op,
(input_arg,),
{"dtype": torch.int64},
)
cast_node.meta["val"] = node.meta["val"].to(torch.int64)
args_list[pos] = cast_node
node.args = tuple(args_list)
modified = True
return modified

def _graph_uses_edge_ops(self, graph: torch.fx.Graph) -> bool:
for n in graph.nodes:
if n.op == "call_function":
if isinstance(n.target, EdgeOpOverload):
return True
return False

def _get_decomposition(self, graph: torch.fx.Graph):
if self._graph_uses_edge_ops(graph):
return exir_ops.edge.dim_order_ops._to_dim_order_copy.default
else:
return torch.ops.dim_order_ops._to_dim_order_copy.default

def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool:
return isinstance(node_val, FakeTensor) and node_val.dtype == dtype

def _insert_placeholder_i64_to_i32_casts(self, graph_module: torch.fx.GraphModule):
modified = False
graph = graph_module.graph
for node in graph.nodes:
if node.op != "placeholder":
continue
node_val = node.meta["val"]
if not self._is_tensor_of_dtype(node_val, torch.int64):
continue

to_copy_op = self._get_decomposition(graph)
with graph.inserting_after(node):
cast_after = create_node(
graph,
to_copy_op,
args=(node,),
kwargs={
"dtype": torch.int32,
},
)
users = [user for user in node.users if user != cast_after]
for user in users:
user.replace_input_with(node, cast_after)
logger.warning(
f"Inserting a casting node {cast_after.name} after {node.name} to cast int64 placeholder"
f" to int32 for {node.name} defined in {node.meta.get('stack_trace','[no stack trace found]')}"
)
modified = True
return modified

def call(self, graph_module: torch.fx.GraphModule):
modified = False
modified |= self._insert_placeholder_i64_to_i32_casts(graph_module)
modified |= self._insert_callsite_i32_to_i64_casts(graph_module)

if modified:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, modified)
109 changes: 0 additions & 109 deletions backends/arm/_passes/insert_int64_input_cast_pass.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from executorch.backends.arm._passes import (
ConvertInt64ConstOpsToInt32Pass,
ConvertInt64OutputOpsToInt32Pass,
InsertCastForOpsWithInt64InputPass,
InsertInt32CastsAfterInt64PlaceholdersPass,
)

from executorch.backends.arm.test import common
Expand All @@ -33,10 +33,9 @@ class TestCLIPTextModelWithProjection(unittest.TestCase):
# for that is some assert ops are removed by passes in the
# .to_executorch step, i.e. after Arm partitioner.
ops_after_partitioner = {
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3,
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
"executorch_exir_dialects_edge__ops_aten_argmax_default": 1,
"torch.ops.higher_order.executorch_call_delegate": 1,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
"torch.ops.higher_order.executorch_call_delegate": 2,
}

def _prepare_inputs(
Expand Down Expand Up @@ -71,9 +70,9 @@ def test_CLIPTextModelWithProjection_tosa_FP(self):
example_inputs=text_encoder_model_inputs,
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
transform_passes=[
InsertCastForOpsWithInt64InputPass(),
ConvertInt64ConstOpsToInt32Pass(),
ConvertInt64OutputOpsToInt32Pass(),
InsertInt32CastsAfterInt64PlaceholdersPass(),
],
)
.export()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,22 @@ class TestSD3Transformer2DModel(unittest.TestCase):
SD3Transformer2DModel is the transformer model used by Stable Diffusion 3.5 Medium
"""

# Adjust nbr below as we increase op support. Note: most of the delegates
# calls are directly consecutive to each other in the .pte. The reason
# for that is some assert ops are removed by passes in the
# .to_executorch step, i.e. after Arm partitioner.
ops_after_partitioner = {
# Adjust nbr below as we increase op support.
ops_after_partitioner_FP = {
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
"executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 1,
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
"torch.ops.higher_order.executorch_call_delegate": 1,
}

ops_after_partitioner_INT = {
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
"torch.ops.higher_order.executorch_call_delegate": 2,
}

def _prepare_inputs(
self,
batch_size=2,
Expand Down Expand Up @@ -102,7 +106,7 @@ def test_SD3Transformer2DModel_tosa_FP(self):
)
.export()
.to_edge_transform_and_lower()
.check_count(self.ops_after_partitioner)
.check_count(self.ops_after_partitioner_FP)
.to_executorch()
.run_method_and_compare_outputs(
inputs=sd35_transformer2D_model_inputs,
Expand All @@ -125,7 +129,7 @@ def test_SD3Transformer2DModel_tosa_INT(self):
.quantize()
.export()
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_count(self.ops_after_partitioner_INT)
.to_executorch()
.run_method_and_compare_outputs(
inputs=sd35_transformer2D_model_inputs,
Expand Down
Loading
Loading