Skip to content

Commit 69a6a48

Browse files
authored
Merge branch 'main' into HDCharles-patch-1
2 parents c38c4e3 + 8e2737c commit 69a6a48

File tree

28 files changed

+1187
-236
lines changed

28 files changed

+1187
-236
lines changed

.github/scripts/label_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222

2323
LABEL_ERR_MSG_TITLE = "This PR needs a `release notes:` label"
2424
LABEL_ERR_MSG = f"""# {LABEL_ERR_MSG_TITLE}
25-
If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with `release notes:`.
26-
27-
If not, please add the `release notes: none` label.
25+
If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with `release notes:`. This helps us keep track and include your important work in the next release notes.
2826
2927
To add a label, you can comment to pytorchbot, for example
3028
`@pytorchbot label "release notes: none"`

.github/scripts/trymerge.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,7 @@
5959
patterns_to_regex,
6060
retries_decorator,
6161
)
62-
from label_utils import (
63-
gh_add_labels,
64-
gh_remove_label,
65-
has_required_labels,
66-
LABEL_ERR_MSG,
67-
)
62+
from label_utils import gh_add_labels, gh_remove_label
6863
from trymerge_explainer import get_revert_message, TryMergeExplainer
6964

7065
# labels
@@ -2116,9 +2111,6 @@ def merge(
21162111
# Check for approvals
21172112
find_matching_merge_rule(pr, repo, skip_mandatory_checks=True)
21182113

2119-
if not has_required_labels(pr):
2120-
raise RuntimeError(LABEL_ERR_MSG.lstrip(" #"))
2121-
21222114
if ignore_current:
21232115
checks = pr.get_checkrun_conclusions()
21242116
_, failing, _ = categorize_checks(

.github/workflows/check-labels.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ jobs:
5151
PR_NUM: ${{ github.event.number || github.event.inputs.pr_number }}
5252
run: |
5353
set -ex
54-
python3 .github/scripts/check_labels.py --exit-non-zero "${PR_NUM}"
54+
python3 .github/scripts/check_labels.py "${PR_NUM}"

backends/apple/coreml/runtime/delegate/ETCoreMLDefaultModelExecutor.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ - (instancetype)initWithModel:(ETCoreMLModel *)model {
2727
eventLogger:(const executorchcoreml::ModelEventLogger* _Nullable __unused)eventLogger
2828
error:(NSError * __autoreleasing *)error {
2929
if (self.ignoreOutputBackings) {
30-
if (@available(macOS 11.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)) {
30+
if (@available(iOS 16.0, tvOS 16.0, watchOS 9.0, *)) {
3131
predictionOptions.outputBackings = @{};
3232
}
3333
}

backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ BOOL is_backed_by_same_buffer(MLMultiArray *array1, MLMultiArray *array2) {
9292
NSOrderedSet<NSString *> *output_names,
9393
NSError * __autoreleasing *error) {
9494
MLPredictionOptions *options = [MLPredictionOptions new];
95-
if (@available(macOS 11.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)) {
95+
if (@available(iOS 16.0, tvOS 16.0, watchOS 9.0, *)) {
9696
NSMutableDictionary<NSString *, id> *output_backings = [NSMutableDictionary dictionary];
9797
NSEnumerator<NSString *> *enumerator = [output_names objectEnumerator];
9898
for (MLMultiArray *output in outputs) {
@@ -687,7 +687,7 @@ - (void)addPrewarmedAsset:(ETCoreMLAsset *)asset {
687687
eventLogger:eventLogger
688688
error:&localError];
689689
// Try without output backings.
690-
if (@available(macOS 11.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)) {
690+
if (@available(iOS 16.0, tvOS 16.0, watchOS 9.0, *)) {
691691
if (!modelOutputs && predictionOptions.outputBackings.count > 0) {
692692
executor.ignoreOutputBackings = YES;
693693
localError = nil;

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2424
from .decompose_div_pass import DecomposeDivPass # noqa
2525
from .decompose_gelu_pass import DecomposeGeluPass # noqa
26+
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
2627
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
2728
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
2829
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
DecomposeCosineSimilarityPass,
2828
DecomposeDivPass,
2929
DecomposeGeluPass,
30+
DecomposeGroupNormPass,
3031
DecomposeLayerNormPass,
3132
DecomposeLeakyReLUPass,
3233
DecomposeLinearPass,
@@ -141,6 +142,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
141142
self.add_pass(ConvertMmToBmmPass())
142143
self.add_pass(DecomposeLinearPass())
143144
self.add_pass(DecomposeLeakyReLUPass())
145+
self.add_pass(DecomposeGroupNormPass())
144146
self.add_pass(DecomposeLayerNormPass())
145147
self.add_pass(DecomposeVarPass())
146148
self.add_pass(
@@ -208,6 +210,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
208210
self.add_pass(DecomposeScaledDotProductAttention())
209211
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
210212
self.add_pass(ScalarsToAttributePass())
213+
self.add_pass(DecomposeGroupNormPass())
211214
self.add_pass(DecomposeLayerNormPass())
212215
self.add_pass(DecomposeVarPass())
213216
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import operator
9+
10+
import torch
11+
from executorch.backends.arm._passes import ArmPass
12+
from executorch.backends.arm._passes.arm_pass_utils import create_node
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import PassResult
15+
16+
17+
def get_group_norm_decomposition(op) -> tuple:
18+
if op == exir_ops.edge.aten.native_group_norm.default:
19+
return (
20+
exir_ops.edge.aten.mean.dim,
21+
exir_ops.edge.aten.sub.Tensor,
22+
exir_ops.edge.aten.var.correction,
23+
exir_ops.edge.aten.full.default,
24+
exir_ops.edge.aten.add.Tensor,
25+
exir_ops.edge.aten.rsqrt.default,
26+
exir_ops.edge.aten.mul.Tensor,
27+
exir_ops.edge.aten.view_copy.default,
28+
)
29+
if op == torch.ops.aten.group_norm.default:
30+
return (
31+
torch.ops.aten.mean.dim,
32+
torch.ops.aten.sub.Tensor,
33+
torch.ops.aten.var.correction,
34+
torch.ops.aten.full.default,
35+
torch.ops.aten.add.Tensor,
36+
torch.ops.aten.rsqrt.default,
37+
torch.ops.aten.mul.Tensor,
38+
torch.ops.aten.view_copy.default,
39+
)
40+
raise RuntimeError(f"Can't get group_norm composition for op {op}")
41+
42+
43+
class DecomposeGroupNormPass(ArmPass):
44+
"""
45+
groupnorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
46+
Decompose groupnorm(x, weight, bias, N, C, HxW, group, eps) to a sequence of:
47+
mean = op_mean(x, dims) # E[x]
48+
var = op_var(x, dims) # Var[x]
49+
numerator = op_sub(x, mean) # (x - E[x])
50+
add = op_add(var, eps) # Var[x] + eps
51+
rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps)
52+
mul = op_mul(numerator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps))
53+
weigths = op_mul(mul, weigths) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
54+
bias = op_add(weigths, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
55+
where x can viewed with shape [N, group, C//group, HxW] dims=[C//group, HxW]
56+
57+
Source: https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html
58+
"""
59+
60+
def call(self, graph_module: torch.fx.GraphModule):
61+
modified = False
62+
for node in graph_module.graph.nodes:
63+
if node.op != "call_function" or node.target not in (
64+
exir_ops.edge.aten.native_group_norm.default,
65+
torch.ops.aten.group_norm.default,
66+
):
67+
continue
68+
69+
# epsilon default value
70+
eps = torch.finfo().eps
71+
weights = None
72+
bias = None
73+
args = node.args
74+
meta = node.meta
75+
if isinstance(meta["val"], tuple):
76+
shape = meta["val"][0].size()
77+
dtype = meta["val"][0].dtype
78+
else:
79+
shape = meta["val"].size()
80+
dtype = meta["val"].dtype
81+
match len(args):
82+
# MI profile always provides all the args: x, weight, bias, N, C, HxW, group, eps
83+
case 8:
84+
x, weights, bias, N, C, HxW, group, eps = args
85+
# BI profile: affine=[True|False], eps!=1e-5
86+
case 5:
87+
x, group, weights, bias, eps = args
88+
# BI profile: affine=True, eps=1e-5
89+
case 4:
90+
x, group, weights, bias = args
91+
# BI profile: affine=False, eps=1e=5
92+
case 2:
93+
x, group = args
94+
# Unsupported args
95+
case _:
96+
raise ValueError(
97+
f"Unsupported group_norm argument pattern with {len(args)} args"
98+
)
99+
N = shape[0]
100+
C = shape[1]
101+
HxW = 1
102+
for dim in shape[2:]:
103+
HxW *= dim
104+
channels_per_group = C // group
105+
grouped_shape = torch.Size([N, group, channels_per_group, HxW])
106+
dims = [2, 3]
107+
epsilon_reshaped_shape = torch.Size([1] * len(grouped_shape))
108+
weights_reshaped_shape = torch.Size([1, group, channels_per_group, 1])
109+
(
110+
mean_op,
111+
sub_op,
112+
var_op,
113+
full_op,
114+
add_op,
115+
rsqrt_op,
116+
mul_op,
117+
view_op,
118+
) = get_group_norm_decomposition(node.target)
119+
with graph_module.graph.inserting_before(node):
120+
keepdim = True
121+
x_reshaped = create_node(
122+
graph_module.graph,
123+
view_op,
124+
args=(x, grouped_shape),
125+
from_node=node,
126+
)
127+
mean = create_node(
128+
graph_module.graph, mean_op, args=(x_reshaped, dims, keepdim)
129+
)
130+
sub = create_node(graph_module.graph, sub_op, args=(x_reshaped, mean))
131+
var = create_node(
132+
graph_module.graph,
133+
var_op,
134+
args=(x_reshaped, dims),
135+
kwargs={"correction": 0, "keepdim": keepdim},
136+
from_node=node,
137+
)
138+
full = create_node(
139+
graph_module.graph,
140+
full_op,
141+
args=(epsilon_reshaped_shape, eps),
142+
kwargs={"dtype": dtype},
143+
from_node=node,
144+
)
145+
add0 = create_node(
146+
graph_module.graph, add_op, args=(var, full), from_node=node
147+
)
148+
rsqrt = create_node(
149+
graph_module.graph, rsqrt_op, args=(add0,), from_node=node
150+
)
151+
mul0 = create_node(
152+
graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node
153+
)
154+
if weights is not None:
155+
weights_reshaped = create_node(
156+
graph_module.graph,
157+
view_op,
158+
args=(weights, weights_reshaped_shape),
159+
from_node=node,
160+
)
161+
mul1 = create_node(
162+
graph_module.graph,
163+
mul_op,
164+
args=(
165+
mul0,
166+
weights_reshaped,
167+
),
168+
from_node=node,
169+
)
170+
else:
171+
mul1 = mul0
172+
if bias is not None:
173+
bias_reshaped_shape = weights_reshaped_shape
174+
bias_reshaped = create_node(
175+
graph_module.graph,
176+
view_op,
177+
args=(bias, bias_reshaped_shape),
178+
from_node=node,
179+
)
180+
output = create_node(
181+
graph_module.graph,
182+
add_op,
183+
args=(mul1, bias_reshaped),
184+
from_node=node,
185+
)
186+
else:
187+
output = mul1
188+
189+
output_reshaped = create_node(
190+
graph_module.graph,
191+
view_op,
192+
args=(output, shape),
193+
from_node=node,
194+
)
195+
196+
users = [user for user in node.users if node != user]
197+
node.replace_all_uses_with(output_reshaped)
198+
for user in users:
199+
if user.target == operator.getitem:
200+
user.replace_all_uses_with(output_reshaped)
201+
graph_module.graph.erase_node(node)
202+
graph_module.graph.eliminate_dead_code()
203+
modified = True
204+
if modified:
205+
graph_module.recompile()
206+
graph_module = super().call(graph_module).graph_module
207+
208+
return PassResult(graph_module, modified)

backends/arm/_passes/decompose_layernorm_pass.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -47,11 +46,12 @@ class DecomposeLayerNormPass(ArmPass):
4746
Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
4847
mean = op_mean(x, dims) # E[x]
4948
var = op_var(x, dims) # Var[x]
50-
denominator = op_sub(x, mean) # (x - E[x])
49+
numerator = op_sub(x, mean) # (x - E[x])
5150
add = op_add(var, eps) # Var[x] + eps
5251
rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps)
53-
mul = op_mul(denominator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
54-
bias = op_add(mul, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
52+
mul = op_mul(numerator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps))
53+
weigths = op_mul(mul, weigths) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
54+
bias = op_add(weigths, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
5555
5656
Source: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
5757
"""

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def is_node_supported(
198198
exir_ops.edge.aten.div.Scalar,
199199
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
200200
exir_ops.edge.aten.native_layer_norm.default,
201+
exir_ops.edge.aten.native_group_norm.default,
201202
exir_ops.edge.aten.sigmoid.default,
202203
exir_ops.edge.aten.mean.dim,
203204
exir_ops.edge.aten.mm.default,
@@ -264,6 +265,7 @@ def is_node_supported(
264265
exir_ops.edge.aten.div.Tensor: None,
265266
exir_ops.edge.aten._native_batch_norm_legit_no_training.default: "BatchNorm2D with track_running_stats==True not immediately following a convolution is not supported for quantized TOSA backends.",
266267
exir_ops.edge.aten.native_layer_norm.default: None,
268+
exir_ops.edge.aten.native_group_norm.default: None,
267269
exir_ops.edge.aten._softmax.default: None,
268270
exir_ops.edge.aten._log_softmax.default: None,
269271
exir_ops.edge.aten.var.correction: None,

0 commit comments

Comments
 (0)