Skip to content

Commit 011af55

Browse files
committed
Update base for Update on "Module support for multiple ptd files"
Support multiple PTD files in Module. This change updates the following private variables in Module: ``` std::string data_path --> std::unordered_set<std::string> data_files_ std::unique_ptr<DataLoader> data_map_loader --> std::vectror<std::unique_ptr<DataLoader>> data_map_loaders_ std::unique_ptr<NamedDataMap> data_map --> std::vector<std::unique_ptr<NamedDataMap> named_data_maps_ ``` And introduces a new private variable. When we have multiple NamedDataMaps, they need to be merged into one, for use in method, etc. This is not implemented yet. ``` std::unique_ptr<NamedDataMap> merged_data_map_ ``` The process of using a PTD file is: ``` std::string file --> wrapped in DataLoader --> wrapped in NamedDataMap. ``` At each stage we can have multiple. This diff also introduces a new Module constructor that takes in `std::unordered_set<std::string> named_data_map_paths_` Differential Revision: [D82059808](https://our.internmc.facebook.com/intern/diff/D82059808/) [ghstack-poisoned]
2 parents 0b78412 + 598ba46 commit 011af55

File tree

142 files changed

+4598
-2272
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

142 files changed

+4598
-2272
lines changed

.github/workflows/pull.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,13 @@ jobs:
971971
./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear
972972
./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row
973973
974+
# "Classic" Operator tests
975+
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_op.sh --build
976+
# TODO(ssjia): figure out how to run custom op tests in CI. Currently, they are
977+
# failing due to to the libstdc++.so.6 installed with conda not supporting
978+
# GLIBCXX_3.4.30. These tests are still run in Meta internal CI.
979+
# ./cmake-out/backends/vulkan/test/op_tests/vulkan_sdpa_test
980+
974981
# Run e2e testing for selected operators. More operators will be tested via this
975982
# route in the future.
976983
python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*pt2e*"

backends/apple/coreml/recipes/coreml_recipe_provider.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def create_recipe(
6969
recipe_type, activation_dtype=torch.float32, **kwargs
7070
)
7171
elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL:
72+
self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao")
7273
return self._build_torchao_quantized_recipe(
7374
recipe_type,
7475
weight_dtype=torch.int4,
@@ -77,6 +78,7 @@ def create_recipe(
7778
)
7879
elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP:
7980
group_size = kwargs.pop("group_size", 32)
81+
self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao")
8082
return self._build_torchao_quantized_recipe(
8183
recipe_type,
8284
weight_dtype=torch.int4,
@@ -85,11 +87,14 @@ def create_recipe(
8587
**kwargs,
8688
)
8789
elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL:
90+
self._validate_and_set_deployment_target(kwargs, ct.target.iOS16, "torchao")
8891
return self._build_torchao_quantized_recipe(
8992
recipe_type, weight_dtype=torch.int8, is_per_channel=True, **kwargs
9093
)
9194
elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP:
9295
group_size = kwargs.pop("group_size", 32)
96+
# override minimum_deployment_target to ios18 for torchao (GH issue #13122)
97+
self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao")
9398
return self._build_torchao_quantized_recipe(
9499
recipe_type,
95100
weight_dtype=torch.int8,
@@ -312,8 +317,6 @@ def _build_torchao_quantized_recipe(
312317
ao_quantization_configs=[config],
313318
)
314319

315-
# override minimum_deployment_target to ios18 for torchao (GH issue #13122)
316-
self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao")
317320
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
318321

319322
return ExportRecipe(

backends/apple/coreml/test/test_coreml_quantizer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616

1717
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
18-
from torch.export import export_for_training
18+
from torch.export import export
1919
from torchao.quantization.pt2e.quantize_pt2e import (
2020
convert_pt2e,
2121
prepare_pt2e,
@@ -32,9 +32,7 @@ def quantize_and_compare(
3232
) -> None:
3333
assert quantization_type in {"PTQ", "QAT"}
3434

35-
pre_autograd_aten_dialect = export_for_training(
36-
model, example_inputs, strict=True
37-
).module()
35+
pre_autograd_aten_dialect = export(model, example_inputs, strict=True).module()
3836

3937
quantization_config = LinearQuantizerConfig.from_dict(
4038
{

backends/apple/coreml/test/test_coreml_recipes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def test_minimum_deployment_target_validation(self):
501501
(CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, ct.target.iOS18, {}),
502502
(
503503
CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL,
504-
ct.target.iOS18,
504+
ct.target.iOS16,
505505
{},
506506
),
507507
(CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP, ct.target.iOS18, {}),

backends/apple/mps/test/test_mps_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def lower_module_and_test_output(
206206

207207
expected_output = model(*sample_inputs)
208208

209-
model = torch.export.export_for_training(
209+
model = torch.export.export(
210210
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
211211
).module()
212212

backends/arm/README.md

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,6 @@ The current TOSA version does not support int64. However, int64 is commonly used
206206
- For quantized models, these transformations will be automatically handled during annotation before the export stage.
207207

208208
List of model specific and optional passes:
209-
- InsertCastForOpsWithInt64InputPass
210-
- Functionality:
211-
- For LLMs such as LLama, some opeartors like aten.embedding have int64 input. In order to lower these operators to TOSA, this pass will insert a casting node that converts the input from int64 to int32.
212-
- Supported Ops:
213-
- aten.embedding.default, aten.slice_copy.Tensor
214-
- Example usage:
215-
- backends/arm/test/models/test_llama.py
216-
217209
- ConvertInt64ConstOpsToInt32Pass
218210
- Functionalities:
219211
- Rewrites constant-producing ops that output int64 to instead output int32, when values are within int32 bounds.
@@ -244,3 +236,16 @@ List of model specific and optional passes:
244236
- Example usage:
245237
- (Functionality 1) backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py
246238
- (Functionality 2) backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
239+
240+
- InsertInt32CastsAfterInt64PlaceholdersPass
241+
- Functionalities:
242+
- Inserts an int64 -> int32 cast immediately after each int64 placeholder (graph input).
243+
- Redirects all uses of each int64 placeholder to its int32 cast output.
244+
- Inserts local int32 -> int64 casts at call sites where an operator requires int64 inputs, e.g. `torch.nn.functional.one_hot`
245+
- Pass ordering:
246+
- When used with `ConvertInt64ConstOpsToInt32Pass` and `ConvertInt64OutputOpsToInt32Pass`, run this pass last.
247+
- Rationale: Those passes may cause retracing to re-infer some int64 placeholders as int32. Running this pass last casts only inputs that remain int64, minimizing inserted casts.
248+
- Example usage:
249+
- backends/arm/test/models/test_llama.py
250+
- backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
251+
- backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@
7575
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
7676
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
7777
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
78-
from .insert_int64_input_cast_pass import ( # noqa # noqa
79-
InsertCastForOpsWithInt64InputPass,
78+
from .insert_int32_casts_after_int64_placeholders import ( # noqa
79+
InsertInt32CastsAfterInt64PlaceholdersPass,
8080
)
8181
from .insert_rescales_pass import InsertRescalePass # noqa
8282
from .insert_table_ops import InsertTableOpsPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
FuseConstantArgsPass,
7777
FuseEqualPlaceholdersPass,
7878
FuseQuantizedActivationPass,
79-
InsertCastForOpsWithInt64InputPass,
79+
InsertInt32CastsAfterInt64PlaceholdersPass,
8080
InsertRescalePass,
8181
InsertTableOpsPass,
8282
MatchArgDtypePass,
@@ -277,7 +277,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
277277
) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph
278278
self.add_pass(ConvertInt64ConstOpsToInt32Pass())
279279
self.add_pass(ConvertInt64OutputOpsToInt32Pass())
280-
self.add_pass(InsertCastForOpsWithInt64InputPass())
280+
self.add_pass(InsertInt32CastsAfterInt64PlaceholdersPass())
281281
self.add_pass(DecomposeEmbeddingPass())
282282
self.add_pass(DecomposeScaledDotProductAttention())
283283
self.add_pass(DecomposeRoundPass())
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import logging
10+
11+
import torch
12+
from executorch.backends.arm._passes.arm_pass_utils import create_node
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import EdgeOpOverload, ExportPass, PassResult
15+
from torch._subclasses.fake_tensor import FakeTensor
16+
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class InsertInt32CastsAfterInt64PlaceholdersPass(ExportPass):
22+
"""
23+
Insert an int64->int32 cast after each int64 placeholder.
24+
25+
Note: Overflow checks are not applied in this pass. It is the user's responsibility to ensure that values fit within
26+
the int32 range.
27+
"""
28+
29+
# Ops that require i64 inputs → positions of args to upcast.
30+
# Key: op overload; Value: zero-based indices of positional args that must be i64.
31+
I64_INPUT_ARG_POSITIONS = {
32+
torch.ops.aten.one_hot.default: (0,),
33+
}
34+
35+
def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule):
36+
"""
37+
If an operator requires int64 inputs but dtype propagation (via call_operator)
38+
produced int32, insert a local int32→int64 cast at the call site to satisfy
39+
PyTorch's operator input validation.
40+
"""
41+
modified = False
42+
graph = graph_module.graph
43+
for node in graph.nodes:
44+
if node.op != "call_function":
45+
continue
46+
if node.target not in self.I64_INPUT_ARG_POSITIONS:
47+
continue
48+
49+
with graph.inserting_before(node):
50+
arg_positions = self.I64_INPUT_ARG_POSITIONS.get(node.target)
51+
args_list = list(node.args)
52+
for pos in arg_positions: # type: ignore[union-attr]
53+
input_arg = args_list[pos]
54+
to_copy_op = self._get_decomposition(graph)
55+
cast_node = graph_module.graph.create_node(
56+
"call_function",
57+
to_copy_op,
58+
(input_arg,),
59+
{"dtype": torch.int64},
60+
)
61+
cast_node.meta["val"] = node.meta["val"].to(torch.int64)
62+
args_list[pos] = cast_node
63+
node.args = tuple(args_list)
64+
modified = True
65+
return modified
66+
67+
def _graph_uses_edge_ops(self, graph: torch.fx.Graph) -> bool:
68+
for n in graph.nodes:
69+
if n.op == "call_function":
70+
if isinstance(n.target, EdgeOpOverload):
71+
return True
72+
return False
73+
74+
def _get_decomposition(self, graph: torch.fx.Graph):
75+
if self._graph_uses_edge_ops(graph):
76+
return exir_ops.edge.dim_order_ops._to_dim_order_copy.default
77+
else:
78+
return torch.ops.dim_order_ops._to_dim_order_copy.default
79+
80+
def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool:
81+
return isinstance(node_val, FakeTensor) and node_val.dtype == dtype
82+
83+
def _insert_placeholder_i64_to_i32_casts(self, graph_module: torch.fx.GraphModule):
84+
modified = False
85+
graph = graph_module.graph
86+
for node in graph.nodes:
87+
if node.op != "placeholder":
88+
continue
89+
node_val = node.meta["val"]
90+
if not self._is_tensor_of_dtype(node_val, torch.int64):
91+
continue
92+
93+
to_copy_op = self._get_decomposition(graph)
94+
with graph.inserting_after(node):
95+
cast_after = create_node(
96+
graph,
97+
to_copy_op,
98+
args=(node,),
99+
kwargs={
100+
"dtype": torch.int32,
101+
},
102+
)
103+
users = [user for user in node.users if user != cast_after]
104+
for user in users:
105+
user.replace_input_with(node, cast_after)
106+
logger.warning(
107+
f"Inserting a casting node {cast_after.name} after {node.name} to cast int64 placeholder"
108+
f" to int32 for {node.name} defined in {node.meta.get('stack_trace','[no stack trace found]')}"
109+
)
110+
modified = True
111+
return modified
112+
113+
def call(self, graph_module: torch.fx.GraphModule):
114+
modified = False
115+
modified |= self._insert_placeholder_i64_to_i32_casts(graph_module)
116+
modified |= self._insert_callsite_i32_to_i64_casts(graph_module)
117+
118+
if modified:
119+
graph_module.graph.eliminate_dead_code()
120+
graph_module.recompile()
121+
graph_module = super().call(graph_module).graph_module
122+
return PassResult(graph_module, modified)

backends/arm/_passes/insert_int64_input_cast_pass.py

Lines changed: 0 additions & 109 deletions
This file was deleted.

0 commit comments

Comments
 (0)