Skip to content

Commit a58a23c

Browse files
Merge branch 'pytorch:main' into pt/py313
2 parents 9f1a282 + 488d761 commit a58a23c

File tree

128 files changed

+5706
-830
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

128 files changed

+5706
-830
lines changed

backends/arm/_passes/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from .decompose_round_pass import DecomposeRoundPass # noqa
6868
from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa
6969
from .decompose_select import DecomposeSelectPass # noqa
70+
from .decompose_select_scatter_pass import DecomposeSelectScatterPass # noqa
7071
from .decompose_sign_pass import DecomposeSignPass # noqa
7172
from .decompose_silu_pass import DecomposeSiluPass # noqa
7273
from .decompose_sinh_pass import DecomposeSinhPass # noqa
@@ -116,5 +117,7 @@
116117
from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa
117118
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
118119
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
119-
from .replace_inf_values_pass import ReplaceInfValuesPass # noqa # usort: skip
120+
from .replace_inf_and_limit_values_pass import ( # noqa # usort: skip
121+
ReplaceInfAndLimitValuesPass,
122+
)
120123
from .arm_pass_manager import ArmPassManager # noqa # usort: skip

backends/arm/_passes/arm_pass_manager.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
DecomposeRoundPass,
7171
DecomposeScaledDotProductAttentionPass,
7272
DecomposeSelectPass,
73+
DecomposeSelectScatterPass,
7374
DecomposeSignPass,
7475
DecomposeSiluPass,
7576
DecomposeSinhPass,
@@ -98,7 +99,7 @@
9899
RemoveGetItemPass,
99100
RemoveGraphAssertsPass,
100101
RemoveNoopPass,
101-
ReplaceInfValuesPass,
102+
ReplaceInfAndLimitValuesPass,
102103
ReplaceScalarWithTensorByProfilePass,
103104
RewriteConv2dPass,
104105
RewriteMatmulPass,
@@ -174,18 +175,14 @@ def _tosa_pipeline(
174175
self.add_passes(
175176
[
176177
FuseQuantizedActivationPass(),
177-
RemoveGetItemPass(),
178178
ConvertToClampPass(),
179179
DecomposeInt32ClampPass(),
180180
DecomposeGroupNormPass(),
181181
DecomposeLayerNormPass(),
182-
DecomposeBatchNormNoStatsPass(),
183182
DecomposeVarPass(),
184183
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
185184
AnnotateDecomposedMatmulPass(),
186185
ConvertELUParamsPass(),
187-
ConvertSplitToSlicePass(),
188-
QuantizeClampArgumentsPass(),
189186
]
190187
)
191188

@@ -207,6 +204,10 @@ def _tosa_pipeline(
207204
# Node transformation passes (post q/dq folding)
208205
self.add_passes(
209206
[
207+
ConvertSplitToSlicePass(),
208+
QuantizeClampArgumentsPass(),
209+
RemoveGetItemPass(),
210+
DecomposeBatchNormNoStatsPass(),
210211
DecomposeLogitPass(),
211212
DecomposeMaskedFillPass(),
212213
DecomposeRoundPass(),
@@ -243,7 +244,6 @@ def _tosa_pipeline(
243244
# passes. Ticket: MLETORCH-1540
244245
DecomposeNotEqualPass(),
245246
MatchArgRanksPass(exported_program),
246-
FuseConstantArgsPass(exported_program),
247247
]
248248
)
249249

@@ -265,6 +265,7 @@ def _tosa_pipeline(
265265
DecomposeAvgPool2dPass(),
266266
DecorateFp32toInt32CastingPass(),
267267
ComputeConstantOpsAOTPass(exported_program),
268+
FuseConstantArgsPass(exported_program),
268269
ConvertExpandCopyToRepeatPass(),
269270
UnsqueezeBeforeRepeatPass(),
270271
DecomposeCumsumPass(exported_program),
@@ -330,6 +331,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
330331
# Transformation passes (pre scalar -> tensor)
331332
self.add_passes(
332333
[
334+
DecomposeSelectScatterPass(),
333335
ConvertInt64ConstOpsToInt32Pass(),
334336
ConvertInt64OutputOpsToInt32Pass(),
335337
InsertInt32CastsAfterInt64PlaceholdersPass(),
@@ -383,7 +385,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
383385
# Postprocessing passes
384386
self.add_passes(
385387
[
386-
ReplaceInfValuesPass(),
388+
ReplaceInfAndLimitValuesPass(),
387389
DecomposeMaskedFillPass() if not self.tosa_spec.is_U55_subset else None,
388390
]
389391
)

backends/arm/_passes/convert_split_to_slice.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,48 @@ def call(self, graph_module: torch.fx.GraphModule):
8585
graph,
8686
self.slice,
8787
(input_node, dim, starts[index], ends[index]),
88+
from_node=node,
89+
)
90+
slice_node.meta = _copy_user_node_qparams(
91+
split_node, output_node, index
8892
)
89-
slice_node.meta = split_node.meta.copy()
90-
slice_node.meta["val"] = slice_node.meta["val"][index]
9193
output_node.replace_all_uses_with(slice_node)
9294
graph.eliminate_dead_code()
9395
graph_module.recompile()
9496
graph_module = super().call(graph_module).graph_module
9597
return PassResult(graph_module, True)
98+
99+
100+
def _copy_user_node_qparams(
101+
split_node: torch.fx.Node, output_node: torch.fx.Node, index: int
102+
) -> dict:
103+
"""
104+
Construct metadata for the slice node that will replace the split output.
105+
106+
Note that output quantization parameters are copied from the user nodes
107+
of the split node. The split node itself does not have output quantization
108+
parameters.
109+
110+
Args:
111+
split_node: The split node being replaced.
112+
output_node: The getitem node that is user of the split node.
113+
index: The index of the output being processed.
114+
Returns:
115+
Updated metadata dictionary for the slice node.
116+
"""
117+
118+
def _select_index(value):
119+
if isinstance(value, (list, tuple)):
120+
return value[index]
121+
return value
122+
123+
meta = split_node.meta.copy()
124+
if "val" in meta:
125+
meta["val"] = _select_index(meta["val"])
126+
if "tensor_meta" in meta:
127+
meta["tensor_meta"] = _select_index(meta["tensor_meta"])
128+
if "input_qparams" in meta:
129+
meta["input_qparams"] = dict(meta["input_qparams"])
130+
if "output_qparams" in meta:
131+
meta["output_qparams"] = dict(output_node.meta["output_qparams"])
132+
return meta
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
10+
from executorch.backends.arm._passes import ArmPass
11+
from executorch.backends.arm._passes.convert_int64_const_ops_to_int32 import (
12+
ConvertInt64ConstOpsToInt32Pass,
13+
)
14+
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15+
ReplaceScalarWithTensorByProfilePass,
16+
)
17+
from executorch.exir.dialects._ops import ops as exir_ops
18+
from executorch.exir.pass_base import ExportPass
19+
20+
edge_scatter_ops = (exir_ops.edge.aten.select_scatter.default,)
21+
aten_scatter_ops = (torch.ops.aten.select_scatter.default,)
22+
23+
24+
def get_select_scatter_decomposition(op) -> tuple:
25+
if op in edge_scatter_ops:
26+
return (
27+
exir_ops.edge.aten.arange.start_step,
28+
exir_ops.edge.aten.eq.Scalar,
29+
exir_ops.edge.aten.where.self,
30+
exir_ops.edge.aten.expand_copy.default,
31+
exir_ops.edge.aten.unsqueeze_copy.default,
32+
exir_ops.edge.aten.view_copy.default,
33+
)
34+
if op in aten_scatter_ops:
35+
return (
36+
torch.ops.aten.arange.start_step,
37+
torch.ops.aten.eq.Scalar,
38+
torch.ops.aten.where.self,
39+
torch.ops.aten.expand_copy.default,
40+
torch.ops.aten.unsqueeze_copy.default,
41+
torch.ops.aten.view_copy.default,
42+
)
43+
44+
raise RuntimeError(f"Can't get select_scatter decomposition for op {op}")
45+
46+
47+
class DecomposeSelectScatterPass(ArmPass):
48+
"""select_scatter is decomposed into other ops during export, however this is only
49+
suppported for the fp profile and for the int profile we need to decompose it here.
50+
51+
The decomposition is as follows:
52+
- Build a boolean mask the size of x
53+
eq(view(arange(0, dim_size), mask_shape), index)
54+
- Broadcast source to x
55+
expand(unsqueeze(source, dim), shape)
56+
- Route the updated slice while keeping the untouched lanes
57+
where(mask, expanded_source, x)
58+
59+
This reflects the decomposition for the fp profile implemented in torch._refs
60+
"""
61+
62+
_passes_required_after: Set[Type[ExportPass]] = {
63+
ReplaceScalarWithTensorByProfilePass,
64+
ConvertInt64ConstOpsToInt32Pass,
65+
}
66+
67+
def call_operator(self, op, args, kwargs, meta):
68+
if op not in (edge_scatter_ops + aten_scatter_ops):
69+
return super().call_operator(op, args, kwargs, meta, updated=False)
70+
71+
(
72+
arange_op,
73+
eq_op,
74+
where_op,
75+
expand_op,
76+
unsqueeze_op,
77+
view_op,
78+
) = get_select_scatter_decomposition(op)
79+
80+
input_tensor = args[0]
81+
src_tensor = args[1]
82+
dim = int(args[2])
83+
index = int(args[3])
84+
85+
shape = input_tensor.data.size()
86+
rank = len(shape)
87+
dim = dim % rank if dim < 0 else dim
88+
dim_size = shape[dim]
89+
if index < 0:
90+
index = index + dim_size
91+
92+
mask_shape = [1] * rank
93+
mask_shape[dim] = -1
94+
95+
arange_node = super().call_operator(
96+
arange_op,
97+
(0, dim_size, 1),
98+
{},
99+
meta,
100+
updated=False,
101+
)
102+
103+
view_node = super().call_operator(
104+
view_op,
105+
(arange_node, mask_shape),
106+
{},
107+
meta,
108+
updated=False,
109+
)
110+
111+
mask_node = super().call_operator(
112+
eq_op,
113+
(view_node, index),
114+
{},
115+
meta,
116+
updated=False,
117+
)
118+
119+
unsqueeze_node = super().call_operator(
120+
unsqueeze_op,
121+
(src_tensor, dim),
122+
{},
123+
meta,
124+
updated=False,
125+
)
126+
127+
expand_node = super().call_operator(
128+
expand_op,
129+
(unsqueeze_node, shape),
130+
{},
131+
meta,
132+
updated=False,
133+
)
134+
135+
where_node = super().call_operator(
136+
where_op,
137+
(mask_node, expand_node, input_tensor),
138+
{},
139+
meta,
140+
updated=True,
141+
)
142+
143+
return where_node

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ class QuantizeClampArgumentsPass(ArmPass):
334334
- Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator.
335335
"""
336336

337-
_passes_required_after: Set[Type[ExportPass]] = {FoldAndAnnotateQParamsPass}
337+
_passes_required_after: Set[Type[ExportPass]] = set()
338338

339339
def call(self, graph_module: GraphModule) -> PassResult:
340340
modified = False
@@ -346,12 +346,15 @@ def call(self, graph_module: GraphModule) -> PassResult:
346346
}:
347347
continue
348348

349-
# Make sure we have a quantized operator
350-
user = list(n.users)[0]
351-
if user.target not in Q_OPS:
349+
try:
350+
output_qparams = get_output_qparams(n)
351+
except ValueError:
352+
continue
353+
if len(output_qparams) == 0:
352354
continue
353355

354-
qargs = QuantArgs.from_operator(user.target, user.args)
356+
# Qparams are stored per user index; use the first entry.
357+
qargs = next(iter(output_qparams.values()))
355358

356359
if n.target == exir_ops.edge.aten.clamp.default:
357360
# Quantize the min and max arguments of clamp, if they are not None
@@ -368,4 +371,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
368371

369372
modified = True
370373

374+
if modified:
375+
# Retrace to refresh fake tensor metadata after updating clamp min/max.
376+
graph_module = super().call(graph_module).graph_module
377+
graph_module.recompile()
378+
371379
return PassResult(graph_module, modified)

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,10 @@ def f(node_name_pre_computed):
178178
return node_name_pre_computed
179179
"""
180180

181-
_passes_required_after: Set[Type[ExportPass]] = {FuseEqualPlaceholdersPass}
181+
_passes_required_after: Set[Type[ExportPass]] = {
182+
FuseEqualPlaceholdersPass,
183+
FuseConstantArgsPass,
184+
}
182185

183186
targeted_ops = [
184187
exir_ops.edge.aten.full.default,

backends/arm/_passes/replace_inf_values_pass.py renamed to backends/arm/_passes/replace_inf_and_limit_values_pass.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
from executorch.exir.pass_base import ExportPass, PassResult
1515

1616

17-
class ReplaceInfValuesPass(ArmPass):
17+
class ReplaceInfAndLimitValuesPass(ArmPass):
1818
"""
19-
Due to limitation in Quantizer, we need to change inf/-inf to more quantizable values.
19+
Rewrites +inf/-inf and floating-point limit values (e.g., torch.finfo(...).min/max)
20+
to quantization-friendly values (±255 by default), improving quantizer stability
21+
(notably for attention mask paths).
2022
"""
2123

2224
_passes_required_after: Set[Type[ExportPass]] = set()
@@ -34,12 +36,12 @@ def call(self, graph_module: torch.fx.GraphModule):
3436
for node in graph_module.graph.nodes:
3537
arg_list = list(node.args)
3638
for index, arg in enumerate(arg_list):
37-
if arg == float("-inf"):
39+
if arg == float("-inf") or arg == torch.finfo(torch.float32).min:
3840
modified = True
39-
arg_list[index] = -255
40-
elif arg == float("inf"):
41+
arg_list[index] = -255.0
42+
elif arg == float("inf") or arg == torch.finfo(torch.float32).max:
4143
modified = True
42-
arg_list[index] = +255
44+
arg_list[index] = +255.0
4345
node.args = tuple(arg_list)
4446

4547
if modified:

0 commit comments

Comments
 (0)