Skip to content

Commit d80bffa

Browse files
committed
Update base for Update on "Arm backend: Add INT16 support to rescale operation"
Add INT16 support for RequantizeNode rescale operations in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, sigmoid, tanh, slice, view/transpose, cat, and FCNode operations, extending int16 support to RequantizeNode rescale operations. Changes: - Add INT16 dtype validation support in op_rescale.py - Enable rescale operations for 16A8W quantization configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. RequantizeNode rescale operations are essential for proper quantization scaling in the 16A8W pipeline. Differential Revision: [D80513725](https://our.internmc.facebook.com/intern/diff/D80513725/) cc digantdesai freddan80 per zingo oscarandersson8218 [ghstack-poisoned]
2 parents c843848 + 4f414d7 commit d80bffa

File tree

58 files changed

+2314
-1757
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+2314
-1757
lines changed

.github/workflows/pull.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,13 @@ jobs:
971971
./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear
972972
./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row
973973
974+
# "Classic" Operator tests
975+
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_op.sh --build
976+
# TODO(ssjia): figure out how to run custom op tests in CI. Currently, they are
977+
# failing due to to the libstdc++.so.6 installed with conda not supporting
978+
# GLIBCXX_3.4.30. These tests are still run in Meta internal CI.
979+
# ./cmake-out/backends/vulkan/test/op_tests/vulkan_sdpa_test
980+
974981
# Run e2e testing for selected operators. More operators will be tested via this
975982
# route in the future.
976983
python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*pt2e*"

backends/arm/README.md

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,6 @@ The current TOSA version does not support int64. However, int64 is commonly used
206206
- For quantized models, these transformations will be automatically handled during annotation before the export stage.
207207

208208
List of model specific and optional passes:
209-
- InsertCastForOpsWithInt64InputPass
210-
- Functionality:
211-
- For LLMs such as LLama, some opeartors like aten.embedding have int64 input. In order to lower these operators to TOSA, this pass will insert a casting node that converts the input from int64 to int32.
212-
- Supported Ops:
213-
- aten.embedding.default, aten.slice_copy.Tensor
214-
- Example usage:
215-
- backends/arm/test/models/test_llama.py
216-
217209
- ConvertInt64ConstOpsToInt32Pass
218210
- Functionalities:
219211
- Rewrites constant-producing ops that output int64 to instead output int32, when values are within int32 bounds.
@@ -244,3 +236,16 @@ List of model specific and optional passes:
244236
- Example usage:
245237
- (Functionality 1) backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py
246238
- (Functionality 2) backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
239+
240+
- InsertInt32CastsAfterInt64PlaceholdersPass
241+
- Functionalities:
242+
- Inserts an int64 -> int32 cast immediately after each int64 placeholder (graph input).
243+
- Redirects all uses of each int64 placeholder to its int32 cast output.
244+
- Inserts local int32 -> int64 casts at call sites where an operator requires int64 inputs, e.g. `torch.nn.functional.one_hot`
245+
- Pass ordering:
246+
- When used with `ConvertInt64ConstOpsToInt32Pass` and `ConvertInt64OutputOpsToInt32Pass`, run this pass last.
247+
- Rationale: Those passes may cause retracing to re-infer some int64 placeholders as int32. Running this pass last casts only inputs that remain int64, minimizing inserted casts.
248+
- Example usage:
249+
- backends/arm/test/models/test_llama.py
250+
- backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
251+
- backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@
7575
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
7676
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
7777
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
78-
from .insert_int64_input_cast_pass import ( # noqa # noqa
79-
InsertCastForOpsWithInt64InputPass,
78+
from .insert_int32_casts_after_int64_placeholders import ( # noqa
79+
InsertInt32CastsAfterInt64PlaceholdersPass,
8080
)
8181
from .insert_rescales_pass import InsertRescalePass # noqa
8282
from .insert_table_ops import InsertTableOpsPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
FuseConstantArgsPass,
7777
FuseEqualPlaceholdersPass,
7878
FuseQuantizedActivationPass,
79-
InsertCastForOpsWithInt64InputPass,
79+
InsertInt32CastsAfterInt64PlaceholdersPass,
8080
InsertRescalePass,
8181
InsertTableOpsPass,
8282
MatchArgDtypePass,
@@ -277,7 +277,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
277277
) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph
278278
self.add_pass(ConvertInt64ConstOpsToInt32Pass())
279279
self.add_pass(ConvertInt64OutputOpsToInt32Pass())
280-
self.add_pass(InsertCastForOpsWithInt64InputPass())
280+
self.add_pass(InsertInt32CastsAfterInt64PlaceholdersPass())
281281
self.add_pass(DecomposeEmbeddingPass())
282282
self.add_pass(DecomposeScaledDotProductAttention())
283283
self.add_pass(DecomposeRoundPass())
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import logging
10+
11+
import torch
12+
from executorch.backends.arm._passes.arm_pass_utils import create_node
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import EdgeOpOverload, ExportPass, PassResult
15+
from torch._subclasses.fake_tensor import FakeTensor
16+
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class InsertInt32CastsAfterInt64PlaceholdersPass(ExportPass):
22+
"""
23+
Insert an int64->int32 cast after each int64 placeholder.
24+
25+
Note: Overflow checks are not applied in this pass. It is the user's responsibility to ensure that values fit within
26+
the int32 range.
27+
"""
28+
29+
# Ops that require i64 inputs → positions of args to upcast.
30+
# Key: op overload; Value: zero-based indices of positional args that must be i64.
31+
I64_INPUT_ARG_POSITIONS = {
32+
torch.ops.aten.one_hot.default: (0,),
33+
}
34+
35+
def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule):
36+
"""
37+
If an operator requires int64 inputs but dtype propagation (via call_operator)
38+
produced int32, insert a local int32→int64 cast at the call site to satisfy
39+
PyTorch's operator input validation.
40+
"""
41+
modified = False
42+
graph = graph_module.graph
43+
for node in graph.nodes:
44+
if node.op != "call_function":
45+
continue
46+
if node.target not in self.I64_INPUT_ARG_POSITIONS:
47+
continue
48+
49+
with graph.inserting_before(node):
50+
arg_positions = self.I64_INPUT_ARG_POSITIONS.get(node.target)
51+
args_list = list(node.args)
52+
for pos in arg_positions: # type: ignore[union-attr]
53+
input_arg = args_list[pos]
54+
to_copy_op = self._get_decomposition(graph)
55+
cast_node = graph_module.graph.create_node(
56+
"call_function",
57+
to_copy_op,
58+
(input_arg,),
59+
{"dtype": torch.int64},
60+
)
61+
cast_node.meta["val"] = node.meta["val"].to(torch.int64)
62+
args_list[pos] = cast_node
63+
node.args = tuple(args_list)
64+
modified = True
65+
return modified
66+
67+
def _graph_uses_edge_ops(self, graph: torch.fx.Graph) -> bool:
68+
for n in graph.nodes:
69+
if n.op == "call_function":
70+
if isinstance(n.target, EdgeOpOverload):
71+
return True
72+
return False
73+
74+
def _get_decomposition(self, graph: torch.fx.Graph):
75+
if self._graph_uses_edge_ops(graph):
76+
return exir_ops.edge.dim_order_ops._to_dim_order_copy.default
77+
else:
78+
return torch.ops.dim_order_ops._to_dim_order_copy.default
79+
80+
def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool:
81+
return isinstance(node_val, FakeTensor) and node_val.dtype == dtype
82+
83+
def _insert_placeholder_i64_to_i32_casts(self, graph_module: torch.fx.GraphModule):
84+
modified = False
85+
graph = graph_module.graph
86+
for node in graph.nodes:
87+
if node.op != "placeholder":
88+
continue
89+
node_val = node.meta["val"]
90+
if not self._is_tensor_of_dtype(node_val, torch.int64):
91+
continue
92+
93+
to_copy_op = self._get_decomposition(graph)
94+
with graph.inserting_after(node):
95+
cast_after = create_node(
96+
graph,
97+
to_copy_op,
98+
args=(node,),
99+
kwargs={
100+
"dtype": torch.int32,
101+
},
102+
)
103+
users = [user for user in node.users if user != cast_after]
104+
for user in users:
105+
user.replace_input_with(node, cast_after)
106+
logger.warning(
107+
f"Inserting a casting node {cast_after.name} after {node.name} to cast int64 placeholder"
108+
f" to int32 for {node.name} defined in {node.meta.get('stack_trace','[no stack trace found]')}"
109+
)
110+
modified = True
111+
return modified
112+
113+
def call(self, graph_module: torch.fx.GraphModule):
114+
modified = False
115+
modified |= self._insert_placeholder_i64_to_i32_casts(graph_module)
116+
modified |= self._insert_callsite_i32_to_i64_casts(graph_module)
117+
118+
if modified:
119+
graph_module.graph.eliminate_dead_code()
120+
graph_module.recompile()
121+
graph_module = super().call(graph_module).graph_module
122+
return PassResult(graph_module, modified)

backends/arm/_passes/insert_int64_input_cast_pass.py

Lines changed: 0 additions & 109 deletions
This file was deleted.

backends/arm/operators/op_abs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def define_node(
7373
abs_output = output
7474

7575
# Do the INT32 Abs
76-
tosa_graph.addOperator(
76+
self._serialize_operator(
77+
node,
78+
tosa_graph,
7779
ts.TosaOp.Op().ABS,
7880
[
7981
rescaled_inputs[0].name,

backends/arm/operators/op_sum.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def define_node(
6767
dtype=ts.DType.INT32,
6868
)
6969

70-
tosa_graph.addOperator(
70+
self._serialize_operator(
71+
node,
72+
tosa_graph,
7173
ts.TosaOp.Op().REDUCE_SUM,
7274
[rescaled_inputs[0].name],
7375
[intermediate.name],
@@ -111,7 +113,9 @@ def define_node(
111113
attr = ts.TosaSerializerAttribute()
112114
attr.ReduceSumAttribute(tensor.dim_order.index(dim))
113115

114-
tosa_graph.addOperator(
116+
self._serialize_operator(
117+
node,
118+
tosa_graph,
115119
ts.TosaOp.Op().REDUCE_SUM,
116120
[tensor.name],
117121
[output.name],

backends/arm/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ python_library(
4444
"//executorch/backends/arm:ethosu_partitioner",
4545
"//executorch/backends/arm/quantizer:lib",
4646
"//executorch/backends/arm/tosa:mapping",
47+
"//executorch/backends/arm:vgf_partitioner",
4748
"//executorch/devtools/backend_debug:delegation_info",
4849
"//executorch/exir/backend:operator_support",
4950
"fbsource//third-party/pypi/tabulate:tabulate",

backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from executorch.backends.arm._passes import (
1212
ConvertInt64ConstOpsToInt32Pass,
1313
ConvertInt64OutputOpsToInt32Pass,
14-
InsertCastForOpsWithInt64InputPass,
14+
InsertInt32CastsAfterInt64PlaceholdersPass,
1515
)
1616

1717
from executorch.backends.arm.test import common
@@ -33,10 +33,9 @@ class TestCLIPTextModelWithProjection(unittest.TestCase):
3333
# for that is some assert ops are removed by passes in the
3434
# .to_executorch step, i.e. after Arm partitioner.
3535
ops_after_partitioner = {
36-
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3,
37-
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
3836
"executorch_exir_dialects_edge__ops_aten_argmax_default": 1,
39-
"torch.ops.higher_order.executorch_call_delegate": 1,
37+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
38+
"torch.ops.higher_order.executorch_call_delegate": 2,
4039
}
4140

4241
def _prepare_inputs(
@@ -71,9 +70,9 @@ def test_CLIPTextModelWithProjection_tosa_FP(self):
7170
example_inputs=text_encoder_model_inputs,
7271
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
7372
transform_passes=[
74-
InsertCastForOpsWithInt64InputPass(),
7573
ConvertInt64ConstOpsToInt32Pass(),
7674
ConvertInt64OutputOpsToInt32Pass(),
75+
InsertInt32CastsAfterInt64PlaceholdersPass(),
7776
],
7877
)
7978
.export()

0 commit comments

Comments
 (0)