Skip to content

Commit 3d948c3

Browse files
committed
Update on "[ET-VK][ez] Enable dynamic shape support when using push constants"
## Changes * Call `encode_execute()` upon resize in `VulkanBackend.cpp` * Minor update to `DispatchNode` to store push constant data array as a persistent member of the class ## Motivation Passing in tensor metadata (i.e. sizes, strides) via push constants is typically more performant than passing them via a UBO (uniform buffer object). However, currently dynamic shapes do not work when push constants are used as I realized that the tensor metadata contained in the push constants do not get updated. It appears that that `vkCmdPushConstants` sets the push constants when encoding the command buffer, however the push constants will not be updated if the command buffer is submitted for execution multiple times. Therefore, to update push constant values **the command buffer needs to be re-encoded**. ## Performance Impact This may add a small performance overhead (i.e. re-encoding the command buffer) when executing models with dynamic shapes. Models that do not trigger tensor resizing will not be impacted. However, I measured the impact on a llama 3.2 1B model and the impact of re-encoding a command buffer appears to be negligible. In any case, re-encoding the command buffer is a "necessary evil" when working with dynamic shapes, otherwise the tensor metadata seen by shaders may never get updated. Furthermore, re-encoding the command buffer can allow an opportunity to adjust global work group sizing to match current tensor sizes, which may have a huge performance impact when maximum tensor sizes far exceeds what tensor sizes will realistically be during inference (one instance of this is for transformer models when the max sequence length is very long). Differential Revision: [D75686051](https://our.internmc.facebook.com/intern/diff/D75686051/) [ghstack-poisoned]
2 parents 455574f + 3af2111 commit 3d948c3

File tree

74 files changed

+1903
-587
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+1903
-587
lines changed

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ jobs:
305305
# Install requirements
306306
${CONDA_RUN} sh install_requirements.sh
307307
${CONDA_RUN} sh backends/apple/coreml/scripts/install_requirements.sh
308-
${CONDA_RUN} python install_executorch.py --pybind coreml
308+
${CONDA_RUN} python install_executorch.py
309309
${CONDA_RUN} sh examples/models/llama/install_requirements.sh
310310
311311
# Test ANE llama

CMakePresets.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/third-party/ios-cmake/ios.toolchain.cmake",
1616
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/macos.cmake",
1717
"PLATFORM": "MAC_ARM64",
18-
"DEPLOYMENT_TARGET": "10.15"
18+
"DEPLOYMENT_TARGET": "12.0"
1919
},
2020
"condition": {
2121
"lhs": "${hostSystemName}",
@@ -77,7 +77,7 @@
7777
"inherits": ["common"],
7878
"cacheVariables": {
7979
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/pybind.cmake",
80-
"CMAKE_OSX_DEPLOYMENT_TARGET": "10.15"
80+
"CMAKE_OSX_DEPLOYMENT_TARGET": "12.0"
8181
},
8282
"condition": {
8383
"type": "inList",
@@ -93,7 +93,7 @@
9393
],
9494
"cacheVariables": {
9595
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/llm.cmake",
96-
"CMAKE_OSX_DEPLOYMENT_TARGET": "10.15"
96+
"CMAKE_OSX_DEPLOYMENT_TARGET": "12.0"
9797
},
9898
"condition": {
9999
"type": "inList",

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ let package = Package(
120120
name: "executorch",
121121
platforms: [
122122
.iOS(.v17),
123-
.macOS(.v10_15),
123+
.macOS(.v12),
124124
],
125125
products: packageProducts,
126126
targets: packageTargets + [

backends/apple/coreml/runtime/delegate/ETCoreMLDefaultModelExecutor.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ - (instancetype)initWithModel:(ETCoreMLModel *)model {
2727
eventLogger:(const executorchcoreml::ModelEventLogger* _Nullable __unused)eventLogger
2828
error:(NSError * __autoreleasing *)error {
2929
if (self.ignoreOutputBackings) {
30-
if (@available(macOS 11.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)) {
30+
if (@available(iOS 16.0, tvOS 16.0, watchOS 9.0, *)) {
3131
predictionOptions.outputBackings = @{};
3232
}
3333
}

backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ BOOL is_backed_by_same_buffer(MLMultiArray *array1, MLMultiArray *array2) {
9292
NSOrderedSet<NSString *> *output_names,
9393
NSError * __autoreleasing *error) {
9494
MLPredictionOptions *options = [MLPredictionOptions new];
95-
if (@available(macOS 11.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)) {
95+
if (@available(iOS 16.0, tvOS 16.0, watchOS 9.0, *)) {
9696
NSMutableDictionary<NSString *, id> *output_backings = [NSMutableDictionary dictionary];
9797
NSEnumerator<NSString *> *enumerator = [output_names objectEnumerator];
9898
for (MLMultiArray *output in outputs) {
@@ -687,7 +687,7 @@ - (void)addPrewarmedAsset:(ETCoreMLAsset *)asset {
687687
eventLogger:eventLogger
688688
error:&localError];
689689
// Try without output backings.
690-
if (@available(macOS 11.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)) {
690+
if (@available(iOS 16.0, tvOS 16.0, watchOS 9.0, *)) {
691691
if (!modelOutputs && predictionOptions.outputBackings.count > 0) {
692692
executor.ignoreOutputBackings = YES;
693693
localError = nil;

backends/apple/mps/setup.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ I 00:00:00.122615 executorch:mps_executor_runner.mm:501] Model verified successf
9191
### [Optional] Run the generated model directly using pybind
9292
1. Make sure `pybind` MPS support was installed:
9393
```bash
94-
./install_executorch.sh --pybind mps
94+
CMAKE_ARGS="-DEXECUTORCH_BUILD_MPS=ON" ./install_executorch.sh
9595
```
9696
2. Run the `mps_example` script to trace the model and run it directly from python:
9797
```bash

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2424
from .decompose_div_pass import DecomposeDivPass # noqa
2525
from .decompose_gelu_pass import DecomposeGeluPass # noqa
26+
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
2627
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
2728
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
2829
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
DecomposeCosineSimilarityPass,
2828
DecomposeDivPass,
2929
DecomposeGeluPass,
30+
DecomposeGroupNormPass,
3031
DecomposeLayerNormPass,
3132
DecomposeLeakyReLUPass,
3233
DecomposeLinearPass,
@@ -141,6 +142,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
141142
self.add_pass(ConvertMmToBmmPass())
142143
self.add_pass(DecomposeLinearPass())
143144
self.add_pass(DecomposeLeakyReLUPass())
145+
self.add_pass(DecomposeGroupNormPass())
144146
self.add_pass(DecomposeLayerNormPass())
145147
self.add_pass(DecomposeVarPass())
146148
self.add_pass(
@@ -208,6 +210,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
208210
self.add_pass(DecomposeScaledDotProductAttention())
209211
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
210212
self.add_pass(ScalarsToAttributePass())
213+
self.add_pass(DecomposeGroupNormPass())
211214
self.add_pass(DecomposeLayerNormPass())
212215
self.add_pass(DecomposeVarPass())
213216
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import operator
9+
10+
import torch
11+
from executorch.backends.arm._passes import ArmPass
12+
from executorch.backends.arm._passes.arm_pass_utils import create_node
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import PassResult
15+
16+
17+
def get_group_norm_decomposition(op) -> tuple:
18+
if op == exir_ops.edge.aten.native_group_norm.default:
19+
return (
20+
exir_ops.edge.aten.mean.dim,
21+
exir_ops.edge.aten.sub.Tensor,
22+
exir_ops.edge.aten.var.correction,
23+
exir_ops.edge.aten.full.default,
24+
exir_ops.edge.aten.add.Tensor,
25+
exir_ops.edge.aten.rsqrt.default,
26+
exir_ops.edge.aten.mul.Tensor,
27+
exir_ops.edge.aten.view_copy.default,
28+
)
29+
if op == torch.ops.aten.group_norm.default:
30+
return (
31+
torch.ops.aten.mean.dim,
32+
torch.ops.aten.sub.Tensor,
33+
torch.ops.aten.var.correction,
34+
torch.ops.aten.full.default,
35+
torch.ops.aten.add.Tensor,
36+
torch.ops.aten.rsqrt.default,
37+
torch.ops.aten.mul.Tensor,
38+
torch.ops.aten.view_copy.default,
39+
)
40+
raise RuntimeError(f"Can't get group_norm composition for op {op}")
41+
42+
43+
class DecomposeGroupNormPass(ArmPass):
44+
"""
45+
groupnorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
46+
Decompose groupnorm(x, weight, bias, N, C, HxW, group, eps) to a sequence of:
47+
mean = op_mean(x, dims) # E[x]
48+
var = op_var(x, dims) # Var[x]
49+
numerator = op_sub(x, mean) # (x - E[x])
50+
add = op_add(var, eps) # Var[x] + eps
51+
rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps)
52+
mul = op_mul(numerator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps))
53+
weigths = op_mul(mul, weigths) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
54+
bias = op_add(weigths, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
55+
where x can viewed with shape [N, group, C//group, HxW] dims=[C//group, HxW]
56+
57+
Source: https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html
58+
"""
59+
60+
def call(self, graph_module: torch.fx.GraphModule):
61+
modified = False
62+
for node in graph_module.graph.nodes:
63+
if node.op != "call_function" or node.target not in (
64+
exir_ops.edge.aten.native_group_norm.default,
65+
torch.ops.aten.group_norm.default,
66+
):
67+
continue
68+
69+
# epsilon default value
70+
eps = torch.finfo().eps
71+
weights = None
72+
bias = None
73+
args = node.args
74+
meta = node.meta
75+
if isinstance(meta["val"], tuple):
76+
shape = meta["val"][0].size()
77+
dtype = meta["val"][0].dtype
78+
else:
79+
shape = meta["val"].size()
80+
dtype = meta["val"].dtype
81+
match len(args):
82+
# MI profile always provides all the args: x, weight, bias, N, C, HxW, group, eps
83+
case 8:
84+
x, weights, bias, N, C, HxW, group, eps = args
85+
# BI profile: affine=[True|False], eps!=1e-5
86+
case 5:
87+
x, group, weights, bias, eps = args
88+
# BI profile: affine=True, eps=1e-5
89+
case 4:
90+
x, group, weights, bias = args
91+
# BI profile: affine=False, eps=1e=5
92+
case 2:
93+
x, group = args
94+
# Unsupported args
95+
case _:
96+
raise ValueError(
97+
f"Unsupported group_norm argument pattern with {len(args)} args"
98+
)
99+
N = shape[0]
100+
C = shape[1]
101+
HxW = 1
102+
for dim in shape[2:]:
103+
HxW *= dim
104+
channels_per_group = C // group
105+
grouped_shape = torch.Size([N, group, channels_per_group, HxW])
106+
dims = [2, 3]
107+
epsilon_reshaped_shape = torch.Size([1] * len(grouped_shape))
108+
weights_reshaped_shape = torch.Size([1, group, channels_per_group, 1])
109+
(
110+
mean_op,
111+
sub_op,
112+
var_op,
113+
full_op,
114+
add_op,
115+
rsqrt_op,
116+
mul_op,
117+
view_op,
118+
) = get_group_norm_decomposition(node.target)
119+
with graph_module.graph.inserting_before(node):
120+
keepdim = True
121+
x_reshaped = create_node(
122+
graph_module.graph,
123+
view_op,
124+
args=(x, grouped_shape),
125+
from_node=node,
126+
)
127+
mean = create_node(
128+
graph_module.graph, mean_op, args=(x_reshaped, dims, keepdim)
129+
)
130+
sub = create_node(graph_module.graph, sub_op, args=(x_reshaped, mean))
131+
var = create_node(
132+
graph_module.graph,
133+
var_op,
134+
args=(x_reshaped, dims),
135+
kwargs={"correction": 0, "keepdim": keepdim},
136+
from_node=node,
137+
)
138+
full = create_node(
139+
graph_module.graph,
140+
full_op,
141+
args=(epsilon_reshaped_shape, eps),
142+
kwargs={"dtype": dtype},
143+
from_node=node,
144+
)
145+
add0 = create_node(
146+
graph_module.graph, add_op, args=(var, full), from_node=node
147+
)
148+
rsqrt = create_node(
149+
graph_module.graph, rsqrt_op, args=(add0,), from_node=node
150+
)
151+
mul0 = create_node(
152+
graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node
153+
)
154+
if weights is not None:
155+
weights_reshaped = create_node(
156+
graph_module.graph,
157+
view_op,
158+
args=(weights, weights_reshaped_shape),
159+
from_node=node,
160+
)
161+
mul1 = create_node(
162+
graph_module.graph,
163+
mul_op,
164+
args=(
165+
mul0,
166+
weights_reshaped,
167+
),
168+
from_node=node,
169+
)
170+
else:
171+
mul1 = mul0
172+
if bias is not None:
173+
bias_reshaped_shape = weights_reshaped_shape
174+
bias_reshaped = create_node(
175+
graph_module.graph,
176+
view_op,
177+
args=(bias, bias_reshaped_shape),
178+
from_node=node,
179+
)
180+
output = create_node(
181+
graph_module.graph,
182+
add_op,
183+
args=(mul1, bias_reshaped),
184+
from_node=node,
185+
)
186+
else:
187+
output = mul1
188+
189+
output_reshaped = create_node(
190+
graph_module.graph,
191+
view_op,
192+
args=(output, shape),
193+
from_node=node,
194+
)
195+
196+
users = [user for user in node.users if node != user]
197+
node.replace_all_uses_with(output_reshaped)
198+
for user in users:
199+
if user.target == operator.getitem:
200+
user.replace_all_uses_with(output_reshaped)
201+
graph_module.graph.erase_node(node)
202+
graph_module.graph.eliminate_dead_code()
203+
modified = True
204+
if modified:
205+
graph_module.recompile()
206+
graph_module = super().call(graph_module).graph_module
207+
208+
return PassResult(graph_module, modified)

backends/arm/_passes/decompose_layernorm_pass.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -47,11 +46,12 @@ class DecomposeLayerNormPass(ArmPass):
4746
Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
4847
mean = op_mean(x, dims) # E[x]
4948
var = op_var(x, dims) # Var[x]
50-
denominator = op_sub(x, mean) # (x - E[x])
49+
numerator = op_sub(x, mean) # (x - E[x])
5150
add = op_add(var, eps) # Var[x] + eps
5251
rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps)
53-
mul = op_mul(denominator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
54-
bias = op_add(mul, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
52+
mul = op_mul(numerator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps))
53+
weigths = op_mul(mul, weigths) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
54+
bias = op_add(weigths, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
5555
5656
Source: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
5757
"""

0 commit comments

Comments
 (0)