Skip to content

Commit 27d553a

Browse files
committed
Update base for Update on "[ET-VK][ez] Explicitly skip marking output nodes that are mutable buffers"
## Changes * Move the logic skipping output nodes that are mutable buffers from runtime to AOT ## Context A `fx.Graph` may return nodes that are mutable buffers: ``` class GraphModule(torch.nn.Module): def forward(self, p_wrapped_module_wq_weight: "f32[2048, 2048]", p_wrapped_module_wk_weight: "f32[512, 2048]", p_wrapped_module_wv_weight: "f32[512, 2048]", p_wrapped_module_wo_weight: "f32[2048, 2048]", b_wrapped_module_kv_cache_k_cache: "f32[1, 2048, 8, 64]", b_wrapped_module_kv_cache_v_cache: "f32[1, 2048, 8, 64]", x: "f32[1, s27, 2048]", freqs_cos: "f32[s27, 32]", freqs_sin: "f32[s27, 32]", input_pos: "i64[1]"): sym_size: "Sym(s27)" = torch.ops.aten.sym_size.int(x, 1) ... # b_wrapped_module_kv_cache_*_cache are mutable buffers # getitem_2 and getitem_3 are derived from mutable buffers, hence they are # themselves mutable buffers auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.llama.update_cache.default, value = getitem_1, cache = b_wrapped_module_kv_cache_k_cache, start_pos = _local_scalar_dense_1); getitem_1 = b_wrapped_module_kv_cache_k_cache = None getitem_2: "f32[1, 2048, 8, 64]" = auto_functionalized[1]; auto_functionalized = None auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops.llama.update_cache.default, value = aten_view_copy_default_8, cache = b_wrapped_module_kv_cache_v_cache, start_pos = _local_scalar_dense_1); aten_view_copy_default_8 = b_wrapped_module_kv_cache_v_cache = _local_scalar_dense_1 = None getitem_3: "f32[1, 2048, 8, 64]" = auto_functionalized_1[1]; auto_functionalized_1 = None ... aten_permute_copy_default_3: "f32[2048, 2048]" = executorch_exir_dialects_edge__ops_aten_permute_copy_default(p_wrapped_module_wo_weight, [1, 0]); p_wrapped_module_wo_weight = None aten_view_copy_default_10: "f32[s27, 2048]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_view_copy_default_9, [sym_size, 2048]); aten_view_copy_default_9 = None aten_mm_default_3: "f32[s27, 2048]" = executorch_exir_dialects_edge__ops_aten_mm_default(aten_view_copy_default_10, aten_permute_copy_default_3); aten_view_copy_default_10 = aten_permute_copy_default_3 = None aten_view_copy_default_11: "f32[1, s27, 2048]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_mm_default_3, [1, sym_size, 2048]); aten_mm_default_3 = sym_size = None # getitem_2 and getitem_3 are returned as outputs, presumably to prevent the # update_cache calls from being removed due to dead code elimination return (getitem_2, getitem_3, aten_view_copy_default_11, None) ``` In the graph signature of the `ExportedProgram` these show up as `BUFFER_MUTATION` outputs ``` Graph signature: # inputs p_wrapped_module_wq_weight: PARAMETER target='wrapped_module.wq.weight' p_wrapped_module_wk_weight: PARAMETER target='wrapped_module.wk.weight' p_wrapped_module_wv_weight: PARAMETER target='wrapped_module.wv.weight' p_wrapped_module_wo_weight: PARAMETER target='wrapped_module.wo.weight' b_wrapped_module_kv_cache_k_cache: BUFFER target='wrapped_module.kv_cache.k_cache' persistent=True b_wrapped_module_kv_cache_v_cache: BUFFER target='wrapped_module.kv_cache.v_cache' persistent=True x: USER_INPUT freqs_cos: USER_INPUT freqs_sin: USER_INPUT input_pos: USER_INPUT # outputs getitem_2: BUFFER_MUTATION target='wrapped_module.kv_cache.k_cache' getitem_3: BUFFER_MUTATION target='wrapped_module.kv_cache.v_cache' aten_view_copy_default_11: USER_OUTPUT : USER_OUTPUT ``` Although these nodes are technically returned by the `fx.Graph`, `BUFFER_MUTATION` outputs are not included in the delegate call schema. Since the Vulkan delegate serialization uses the output node to mark which values are returned as outputs, this could result in a mismatch betwen the outputs of the Vulkan delegate and the outputs expected by the ExecuTorch runtime. ## Motivation Previously, this mismatch was addressed in the runtime, by skipping the processing of non-tensor outputs. However, this solution does not account for the fact that in some models, paramters of the model may be returned as outputs. In this case, those parameter outputs would be skipped but the ExecuTorch runtime would still expect to receive them as outputs. To solve the problem properly, this diff changes the serialization logic to check if an output node is a mutable buffer, and skip marking it as an output if so. In the runtime, all output nodes are processed instead of only processing tensor outputs. Differential Revision: [D77281491](https://our.internmc.facebook.com/intern/diff/D77281491/) [ghstack-poisoned]
2 parents 910cc4e + ecb85ce commit 27d553a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+1914
-442
lines changed

.ci/scripts/test_model.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,15 @@ test_model() {
102102
bash examples/models/llama/install_requirements.sh
103103
# Test export_llm script: python3 -m extension.llm.export.export_llm.
104104
# Use Llama random checkpoint with Qwen 2.5 1.5b model configuration.
105-
"${PYTHON_EXECUTABLE}" -m extension.llm.export.export_llm base.model_class="${MODEL_NAME}" base.params=examples/models/qwen2_5/1_5b_config.json
105+
"${PYTHON_EXECUTABLE}" -m extension.llm.export.export_llm base.model_class="${MODEL_NAME}" base.params=examples/models/qwen2_5/config/1_5b_config.json
106106
rm "./${MODEL_NAME}.pte"
107107
return # Skip running with portable executor runnner since portable doesn't support Qwen's biased linears.
108108
fi
109109
if [[ "${MODEL_NAME}" == "phi_4_mini" ]]; then
110110
# Install requirements for export_llama
111111
bash examples/models/llama/install_requirements.sh
112112
# Test export_llm script: python3 -m extension.llm.export.export_llm.
113-
"${PYTHON_EXECUTABLE}" -m extension.llm.export.export_llm base.model_class="${MODEL_NAME}" base.params=examples/models/phi_4_mini/config.json
113+
"${PYTHON_EXECUTABLE}" -m extension.llm.export.export_llm base.model_class="${MODEL_NAME}" base.params=examples/models/phi_4_mini/config/config.json
114114
run_portable_executor_runner
115115
rm "./${MODEL_NAME}.pte"
116116
return

.github/workflows/android-perf.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ jobs:
317317
DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "." --files "tokenizer.json")
318318
python -m extension.llm.export.export_llm \
319319
base.model_class=qwen3_0_6b \
320-
base.params=examples/models/qwen3/0_6b_config.json \
320+
base.params=examples/models/qwen3/config/0_6b_config.json \
321321
model.use_kv_cache=true \
322322
model.use_sdpa_with_kv_cache=true \
323323
model.dtype_override=fp32 \

.github/workflows/apple-perf.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ jobs:
322322
DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "." --files "tokenizer.json")
323323
${CONDA_RUN} python -m extension.llm.export.export_llm \
324324
base.model_class=qwen3_0_6b \
325-
base.params=examples/models/qwen3/0_6b_config.json \
325+
base.params=examples/models/qwen3/config/0_6b_config.json \
326326
model.use_kv_cache=true \
327327
model.use_sdpa_with_kv_cache=true \
328328
model.dtype_override=fp32 \

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2424
from .convert_to_clamp import ConvertToClampPass # noqa
2525
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
26+
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa
2627
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2728
from .decompose_div_pass import DecomposeDivPass # noqa
2829
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
ConvertSqueezesToViewPass,
2727
ConvertToClampPass,
2828
DecomposeAvgPool2d,
29+
DecomposeBatchNormNoStatsPass,
2930
DecomposeCosineSimilarityPass,
3031
DecomposeDivPass,
3132
DecomposeEmbeddingPass,
@@ -164,6 +165,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
164165
self.add_pass(DecomposeLeakyReLUPass())
165166
self.add_pass(DecomposeGroupNormPass())
166167
self.add_pass(DecomposeLayerNormPass())
168+
self.add_pass(DecomposeBatchNormNoStatsPass())
167169
self.add_pass(DecomposeVarPass())
168170
self.add_pass(
169171
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import operator
9+
10+
import torch
11+
from executorch.backends.arm._passes import ArmPass
12+
from executorch.backends.arm._passes.arm_pass_utils import create_node
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import PassResult
15+
16+
17+
class DecomposeBatchNormNoStatsPass(ArmPass):
18+
"""
19+
Decompose BatchNorm2d(track_running_stats=False) (aten._native_batch_norm_legit_no_training)
20+
into a sequence of elementwise operations:
21+
22+
# let input = x, rm = running_mean, rv = running_var, eps: float
23+
rm_view = view(rm, weights_shape)
24+
rv_view = view(rv, weights_shape)
25+
centered = sub(x, rm_view)
26+
eps_full = full(eps_shape, eps)
27+
var_eps = add(rv_view, eps_full)
28+
inv_sqrt = rsqrt(var_eps)
29+
normed = mul(centered, inv_sqrt)
30+
weighted = mul(normed, w_view)
31+
biased = add(weighted, b_view)
32+
33+
Source: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html
34+
"""
35+
36+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
37+
bn_ops = (
38+
exir_ops.edge.aten._native_batch_norm_legit.no_stats,
39+
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
40+
torch.ops.aten._native_batch_norm_legit_no_training.default,
41+
torch.ops.aten.batch_norm.default,
42+
torch.ops.aten.native_batch_norm.default,
43+
)
44+
45+
for node in graph_module.graph.nodes:
46+
if node.op != "call_function" or node.target not in bn_ops:
47+
continue
48+
49+
if node.target in (
50+
torch.ops.aten.batch_norm.default,
51+
torch.ops.aten.native_batch_norm.default,
52+
):
53+
# signature: (input, weight, bias, mean, var, training, momentum, eps, cudnn_enabled)
54+
# pos‐arg 5 is training
55+
training = node.kwargs.get("training", False)
56+
if len(node.args) > 5:
57+
training = node.args[5]
58+
if training:
59+
# skip training‐mode batchnorm
60+
continue
61+
62+
# Extract args
63+
args = node.args
64+
meta = node.meta
65+
66+
# Default eps
67+
eps: float = torch.finfo().eps
68+
# weight and bias may be None
69+
x = args[0]
70+
weight = args[1] if len(args) > 1 else None
71+
bias = args[2] if len(args) > 2 else None
72+
running_mean = args[3]
73+
running_var = args[4]
74+
if len(args) > 6:
75+
eps = args[6]
76+
77+
# Determine shapes
78+
val = meta.get("val")
79+
ref_tensor = val[0] if isinstance(val, tuple) else val
80+
shape = tuple(ref_tensor.size())
81+
dtype = ref_tensor.dtype
82+
rank = len(shape)
83+
84+
# channel dimension is 1 for BatchNorm2d
85+
channel_axis = 1
86+
weights_shape = [1] * rank
87+
weights_shape[channel_axis] = shape[channel_axis]
88+
num_features = shape[channel_axis]
89+
90+
# Ops to use
91+
sub_op = exir_ops.edge.aten.sub.Tensor
92+
view_op = exir_ops.edge.aten.view_copy.default
93+
full_op = exir_ops.edge.aten.full.default
94+
add_op = exir_ops.edge.aten.add.Tensor
95+
rsqrt_op = exir_ops.edge.aten.rsqrt.default
96+
mul_op = exir_ops.edge.aten.mul.Tensor
97+
98+
# Begin decomposition
99+
with graph_module.graph.inserting_before(node):
100+
# reshape running stats
101+
rm_view = create_node(
102+
graph_module.graph,
103+
view_op,
104+
args=(running_mean, weights_shape),
105+
from_node=node,
106+
)
107+
rv_view = create_node(
108+
graph_module.graph,
109+
view_op,
110+
args=(running_var, weights_shape),
111+
from_node=node,
112+
)
113+
# center input
114+
centered = create_node(
115+
graph_module.graph,
116+
sub_op,
117+
args=(x, rm_view),
118+
from_node=node,
119+
)
120+
# epsilon tensor
121+
eps_shape = [1] * rank
122+
eps_full = create_node(
123+
graph_module.graph,
124+
full_op,
125+
args=(eps_shape, eps),
126+
kwargs={"dtype": dtype},
127+
from_node=node,
128+
)
129+
# var + eps
130+
var_eps = create_node(
131+
graph_module.graph,
132+
add_op,
133+
args=(rv_view, eps_full),
134+
from_node=node,
135+
)
136+
# inverse sqrt
137+
inv_sqrt = create_node(
138+
graph_module.graph,
139+
rsqrt_op,
140+
args=(var_eps,),
141+
from_node=node,
142+
)
143+
# normalized
144+
normed = create_node(
145+
graph_module.graph,
146+
mul_op,
147+
args=(centered, inv_sqrt),
148+
from_node=node,
149+
)
150+
151+
# weight
152+
if weight is None:
153+
one = create_node(
154+
graph_module.graph,
155+
full_op,
156+
args=([num_features], 1),
157+
kwargs={"dtype": dtype},
158+
from_node=node,
159+
)
160+
w_view = create_node(
161+
graph_module.graph,
162+
view_op,
163+
args=(one, weights_shape),
164+
from_node=node,
165+
)
166+
else:
167+
w_view = create_node(
168+
graph_module.graph,
169+
view_op,
170+
args=(weight, weights_shape),
171+
from_node=node,
172+
)
173+
weighted = create_node(
174+
graph_module.graph,
175+
mul_op,
176+
args=(normed, w_view),
177+
from_node=node,
178+
)
179+
180+
# bias
181+
if bias is None:
182+
zero = create_node(
183+
graph_module.graph,
184+
full_op,
185+
args=([num_features], 0),
186+
kwargs={"dtype": dtype},
187+
from_node=node,
188+
)
189+
b_view = create_node(
190+
graph_module.graph,
191+
view_op,
192+
args=(zero, weights_shape),
193+
from_node=node,
194+
)
195+
else:
196+
b_view = create_node(
197+
graph_module.graph,
198+
view_op,
199+
args=(bias, weights_shape),
200+
from_node=node,
201+
)
202+
final_out = create_node(
203+
graph_module.graph,
204+
add_op,
205+
args=(weighted, b_view),
206+
from_node=node,
207+
)
208+
209+
users = [u for u in node.users if u is not node]
210+
node.replace_all_uses_with(final_out)
211+
for u in users:
212+
if u.target == operator.getitem:
213+
u.replace_all_uses_with(final_out)
214+
graph_module.graph.erase_node(node)
215+
graph_module.graph.eliminate_dead_code()
216+
217+
graph_module.recompile()
218+
new_gm = super().call(graph_module).graph_module
219+
return PassResult(new_gm, True)

backends/arm/test/ops/test_batch_norm.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ class BatchNorm2dNoStats(torch.nn.Module):
224224
Decomposes into _native_batch_norm_legit.no_stats
225225
"""
226226

227+
aten_ops = ["torch.ops.aten.batch_norm.default"]
228+
227229
def __init__(
228230
self,
229231
num_features: int,
@@ -250,29 +252,60 @@ def forward(self, x):
250252
return self.batch_norm_2d(x)
251253

252254

253-
@pytest.mark.skip(
254-
reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats."
255-
)
256-
def test_native_batch_norm_legit_no_stats_tosa_MI():
257-
pass
255+
@common.parametrize("test_data", test_data_suite)
256+
def test_native_batch_norm_legit_no_stats_tosa_MI(test_data: Tuple):
257+
test_data, model_params = test_data()
258+
pipeline = TosaPipelineMI[input_t1](
259+
BatchNorm2dNoStats(*model_params),
260+
(test_data,),
261+
aten_op=BatchNorm2dNoStats.aten_ops,
262+
)
263+
pipeline.run()
258264

259265

260266
@pytest.mark.skip(
261267
reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats."
262268
)
263-
def test_native_batch_norm_legit_no_stats_tosa_BI():
264-
pass
269+
def test_native_batch_norm_legit_no_stats_tosa_BI(test_data: Tuple):
270+
test_data, model_params = test_data()
271+
pipeline = TosaPipelineBI[input_t1](
272+
BatchNorm2dNoStats(*model_params),
273+
(test_data,),
274+
aten_op=BatchNorm2dNoStats.aten_ops,
275+
qtol=1,
276+
)
277+
pipeline.run()
265278

266279

267280
@pytest.mark.skip(
268281
reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats."
269282
)
270-
def test_native_batch_norm_legit_no_stats_u55_BI():
271-
pass
283+
@common.parametrize("test_data", test_data_suite)
284+
@common.XfailIfNoCorstone300
285+
def test_native_batch_norm_legit_no_stats_u55_BI(test_data: Tuple):
286+
test_data, model_params = test_data()
287+
pipeline = EthosU55PipelineBI[input_t1](
288+
BatchNorm2dNoStats(*model_params),
289+
(test_data,),
290+
aten_op=BatchNorm2dNoStats.aten_ops,
291+
run_on_fvp=True,
292+
qtol=1,
293+
)
294+
pipeline.run()
272295

273296

274297
@pytest.mark.skip(
275298
reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats."
276299
)
277-
def test_native_batch_norm_legit_no_stats_u85_BI():
278-
pass
300+
@common.parametrize("test_data", test_data_suite)
301+
@common.XfailIfNoCorstone320
302+
def test_native_batch_norm_legit_no_stats_u85_BI(test_data: Tuple):
303+
test_data, model_params = test_data()
304+
pipeline = EthosU85PipelineBI[input_t1](
305+
BatchNorm2dNoStats(*model_params),
306+
(test_data,),
307+
aten_op=BatchNorm2dNoStats.aten_ops,
308+
run_on_fvp=False,
309+
qtol=1,
310+
)
311+
pipeline.run()

0 commit comments

Comments
 (0)