Skip to content

Commit 2ffaef2

Browse files
authored
Merge branch 'main' into export-D81519735
2 parents 20e4ab8 + 86e61bf commit 2ffaef2

File tree

139 files changed

+4631
-2358
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

139 files changed

+4631
-2358
lines changed

.github/workflows/pull.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,13 @@ jobs:
971971
./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear
972972
./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row
973973
974+
# "Classic" Operator tests
975+
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_op.sh --build
976+
# TODO(ssjia): figure out how to run custom op tests in CI. Currently, they are
977+
# failing due to to the libstdc++.so.6 installed with conda not supporting
978+
# GLIBCXX_3.4.30. These tests are still run in Meta internal CI.
979+
# ./cmake-out/backends/vulkan/test/op_tests/vulkan_sdpa_test
980+
974981
# Run e2e testing for selected operators. More operators will be tested via this
975982
# route in the future.
976983
python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*pt2e*"

backends/apple/coreml/recipes/coreml_recipe_provider.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def create_recipe(
6969
recipe_type, activation_dtype=torch.float32, **kwargs
7070
)
7171
elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL:
72+
self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao")
7273
return self._build_torchao_quantized_recipe(
7374
recipe_type,
7475
weight_dtype=torch.int4,
@@ -77,6 +78,7 @@ def create_recipe(
7778
)
7879
elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP:
7980
group_size = kwargs.pop("group_size", 32)
81+
self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao")
8082
return self._build_torchao_quantized_recipe(
8183
recipe_type,
8284
weight_dtype=torch.int4,
@@ -85,11 +87,14 @@ def create_recipe(
8587
**kwargs,
8688
)
8789
elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL:
90+
self._validate_and_set_deployment_target(kwargs, ct.target.iOS16, "torchao")
8891
return self._build_torchao_quantized_recipe(
8992
recipe_type, weight_dtype=torch.int8, is_per_channel=True, **kwargs
9093
)
9194
elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP:
9295
group_size = kwargs.pop("group_size", 32)
96+
# override minimum_deployment_target to ios18 for torchao (GH issue #13122)
97+
self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao")
9398
return self._build_torchao_quantized_recipe(
9499
recipe_type,
95100
weight_dtype=torch.int8,
@@ -312,8 +317,6 @@ def _build_torchao_quantized_recipe(
312317
ao_quantization_configs=[config],
313318
)
314319

315-
# override minimum_deployment_target to ios18 for torchao (GH issue #13122)
316-
self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao")
317320
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
318321

319322
return ExportRecipe(

backends/apple/coreml/test/test_coreml_quantizer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616

1717
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
18-
from torch.export import export_for_training
18+
from torch.export import export
1919
from torchao.quantization.pt2e.quantize_pt2e import (
2020
convert_pt2e,
2121
prepare_pt2e,
@@ -32,9 +32,7 @@ def quantize_and_compare(
3232
) -> None:
3333
assert quantization_type in {"PTQ", "QAT"}
3434

35-
pre_autograd_aten_dialect = export_for_training(
36-
model, example_inputs, strict=True
37-
).module()
35+
pre_autograd_aten_dialect = export(model, example_inputs, strict=True).module()
3836

3937
quantization_config = LinearQuantizerConfig.from_dict(
4038
{

backends/apple/coreml/test/test_coreml_recipes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def test_minimum_deployment_target_validation(self):
501501
(CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, ct.target.iOS18, {}),
502502
(
503503
CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL,
504-
ct.target.iOS18,
504+
ct.target.iOS16,
505505
{},
506506
),
507507
(CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP, ct.target.iOS18, {}),

backends/apple/mps/test/test_mps_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def lower_module_and_test_output(
206206

207207
expected_output = model(*sample_inputs)
208208

209-
model = torch.export.export_for_training(
209+
model = torch.export.export(
210210
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
211211
).module()
212212

backends/arm/README.md

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,6 @@ The current TOSA version does not support int64. However, int64 is commonly used
206206
- For quantized models, these transformations will be automatically handled during annotation before the export stage.
207207

208208
List of model specific and optional passes:
209-
- InsertCastForOpsWithInt64InputPass
210-
- Functionality:
211-
- For LLMs such as LLama, some opeartors like aten.embedding have int64 input. In order to lower these operators to TOSA, this pass will insert a casting node that converts the input from int64 to int32.
212-
- Supported Ops:
213-
- aten.embedding.default, aten.slice_copy.Tensor
214-
- Example usage:
215-
- backends/arm/test/models/test_llama.py
216-
217209
- ConvertInt64ConstOpsToInt32Pass
218210
- Functionalities:
219211
- Rewrites constant-producing ops that output int64 to instead output int32, when values are within int32 bounds.
@@ -244,3 +236,16 @@ List of model specific and optional passes:
244236
- Example usage:
245237
- (Functionality 1) backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py
246238
- (Functionality 2) backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
239+
240+
- InsertInt32CastsAfterInt64PlaceholdersPass
241+
- Functionalities:
242+
- Inserts an int64 -> int32 cast immediately after each int64 placeholder (graph input).
243+
- Redirects all uses of each int64 placeholder to its int32 cast output.
244+
- Inserts local int32 -> int64 casts at call sites where an operator requires int64 inputs, e.g. `torch.nn.functional.one_hot`
245+
- Pass ordering:
246+
- When used with `ConvertInt64ConstOpsToInt32Pass` and `ConvertInt64OutputOpsToInt32Pass`, run this pass last.
247+
- Rationale: Those passes may cause retracing to re-infer some int64 placeholders as int32. Running this pass last casts only inputs that remain int64, minimizing inserted casts.
248+
- Example usage:
249+
- backends/arm/test/models/test_llama.py
250+
- backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
251+
- backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@
7575
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
7676
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
7777
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
78-
from .insert_int64_input_cast_pass import ( # noqa # noqa
79-
InsertCastForOpsWithInt64InputPass,
78+
from .insert_int32_casts_after_int64_placeholders import ( # noqa
79+
InsertInt32CastsAfterInt64PlaceholdersPass,
8080
)
8181
from .insert_rescales_pass import InsertRescalePass # noqa
8282
from .insert_table_ops import InsertTableOpsPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
FuseConstantArgsPass,
7777
FuseEqualPlaceholdersPass,
7878
FuseQuantizedActivationPass,
79-
InsertCastForOpsWithInt64InputPass,
79+
InsertInt32CastsAfterInt64PlaceholdersPass,
8080
InsertRescalePass,
8181
InsertTableOpsPass,
8282
MatchArgDtypePass,
@@ -277,7 +277,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
277277
) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph
278278
self.add_pass(ConvertInt64ConstOpsToInt32Pass())
279279
self.add_pass(ConvertInt64OutputOpsToInt32Pass())
280-
self.add_pass(InsertCastForOpsWithInt64InputPass())
280+
self.add_pass(InsertInt32CastsAfterInt64PlaceholdersPass())
281281
self.add_pass(DecomposeEmbeddingPass())
282282
self.add_pass(DecomposeScaledDotProductAttention())
283283
self.add_pass(DecomposeRoundPass())
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import logging
10+
11+
import torch
12+
from executorch.backends.arm._passes.arm_pass_utils import create_node
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import EdgeOpOverload, ExportPass, PassResult
15+
from torch._subclasses.fake_tensor import FakeTensor
16+
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class InsertInt32CastsAfterInt64PlaceholdersPass(ExportPass):
22+
"""
23+
Insert an int64->int32 cast after each int64 placeholder.
24+
25+
Note: Overflow checks are not applied in this pass. It is the user's responsibility to ensure that values fit within
26+
the int32 range.
27+
"""
28+
29+
# Ops that require i64 inputs → positions of args to upcast.
30+
# Key: op overload; Value: zero-based indices of positional args that must be i64.
31+
I64_INPUT_ARG_POSITIONS = {
32+
torch.ops.aten.one_hot.default: (0,),
33+
}
34+
35+
def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule):
36+
"""
37+
If an operator requires int64 inputs but dtype propagation (via call_operator)
38+
produced int32, insert a local int32→int64 cast at the call site to satisfy
39+
PyTorch's operator input validation.
40+
"""
41+
modified = False
42+
graph = graph_module.graph
43+
for node in graph.nodes:
44+
if node.op != "call_function":
45+
continue
46+
if node.target not in self.I64_INPUT_ARG_POSITIONS:
47+
continue
48+
49+
with graph.inserting_before(node):
50+
arg_positions = self.I64_INPUT_ARG_POSITIONS.get(node.target)
51+
args_list = list(node.args)
52+
for pos in arg_positions: # type: ignore[union-attr]
53+
input_arg = args_list[pos]
54+
to_copy_op = self._get_decomposition(graph)
55+
cast_node = graph_module.graph.create_node(
56+
"call_function",
57+
to_copy_op,
58+
(input_arg,),
59+
{"dtype": torch.int64},
60+
)
61+
cast_node.meta["val"] = node.meta["val"].to(torch.int64)
62+
args_list[pos] = cast_node
63+
node.args = tuple(args_list)
64+
modified = True
65+
return modified
66+
67+
def _graph_uses_edge_ops(self, graph: torch.fx.Graph) -> bool:
68+
for n in graph.nodes:
69+
if n.op == "call_function":
70+
if isinstance(n.target, EdgeOpOverload):
71+
return True
72+
return False
73+
74+
def _get_decomposition(self, graph: torch.fx.Graph):
75+
if self._graph_uses_edge_ops(graph):
76+
return exir_ops.edge.dim_order_ops._to_dim_order_copy.default
77+
else:
78+
return torch.ops.dim_order_ops._to_dim_order_copy.default
79+
80+
def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool:
81+
return isinstance(node_val, FakeTensor) and node_val.dtype == dtype
82+
83+
def _insert_placeholder_i64_to_i32_casts(self, graph_module: torch.fx.GraphModule):
84+
modified = False
85+
graph = graph_module.graph
86+
for node in graph.nodes:
87+
if node.op != "placeholder":
88+
continue
89+
node_val = node.meta["val"]
90+
if not self._is_tensor_of_dtype(node_val, torch.int64):
91+
continue
92+
93+
to_copy_op = self._get_decomposition(graph)
94+
with graph.inserting_after(node):
95+
cast_after = create_node(
96+
graph,
97+
to_copy_op,
98+
args=(node,),
99+
kwargs={
100+
"dtype": torch.int32,
101+
},
102+
)
103+
users = [user for user in node.users if user != cast_after]
104+
for user in users:
105+
user.replace_input_with(node, cast_after)
106+
logger.warning(
107+
f"Inserting a casting node {cast_after.name} after {node.name} to cast int64 placeholder"
108+
f" to int32 for {node.name} defined in {node.meta.get('stack_trace','[no stack trace found]')}"
109+
)
110+
modified = True
111+
return modified
112+
113+
def call(self, graph_module: torch.fx.GraphModule):
114+
modified = False
115+
modified |= self._insert_placeholder_i64_to_i32_casts(graph_module)
116+
modified |= self._insert_callsite_i32_to_i64_casts(graph_module)
117+
118+
if modified:
119+
graph_module.graph.eliminate_dead_code()
120+
graph_module.recompile()
121+
graph_module = super().call(graph_module).graph_module
122+
return PassResult(graph_module, modified)

backends/arm/_passes/insert_int64_input_cast_pass.py

Lines changed: 0 additions & 109 deletions
This file was deleted.

0 commit comments

Comments
 (0)