Skip to content

Commit 618ade4

Browse files
authored
Merge branch 'main' into export-D82995994
2 parents 8e2b8ed + 16ced4e commit 618ade4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1089
-136
lines changed

.ci/scripts/test_llama.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ cmake_install_executorch_libraries() {
159159
-DCMAKE_INSTALL_PREFIX=cmake-out \
160160
-DCMAKE_BUILD_TYPE="$CMAKE_BUILD_TYPE" \
161161
-DEXECUTORCH_BUILD_QNN="$QNN" \
162+
-DEXECUTORCH_ENABLE_LOGGING=ON \
162163
-DQNN_SDK_ROOT="$QNN_SDK_ROOT"
163164
cmake --build cmake-out -j9 --target install --config "$CMAKE_BUILD_TYPE"
164165
}

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
from .decompose_glu_pass import DecomposeGluPass # noqa
4747
from .decompose_grouped_conv import DecomposeGroupedConv # noqa
4848
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
49+
from .decompose_int16_activation_conv2d_pass import ( # noqa
50+
DecomposeConv2dWithInt16ActivationPass,
51+
)
4952
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
5053
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
5154
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa

backends/arm/_passes/add_bias_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from executorch.backends.arm._passes import ArmPass
1010
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
11+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1112
from executorch.backends.transforms.utils import create_constant_placeholder
1213

1314
from executorch.exir.dialects._ops import ops as exir_ops
@@ -59,6 +60,10 @@ def call(self, graph_module):
5960
persistent_buffer=True,
6061
name=f"{node.name}_bias",
6162
)
63+
if node.args[0].meta["val"].dtype == torch.int16:
64+
bias_node.meta[TosaSpecialDtype.meta_key()] = (
65+
TosaSpecialDtype.INT48
66+
)
6267
node.update_arg(2, bias_node)
6368

6469
if modified:

backends/arm/_passes/arm_pass_manager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
DecomposeAtanPass,
4343
DecomposeAvgPool2d,
4444
DecomposeBatchNormNoStatsPass,
45+
DecomposeConv2dWithInt16ActivationPass,
4546
DecomposeCoshPass,
4647
DecomposeCosineSimilarityPass,
4748
DecomposeCumsumPass,
@@ -183,6 +184,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
183184
self.add_pass(ComputeConstantOpsAOT(exported_program))
184185

185186
self.add_pass(DecomposeGroupedConv())
187+
186188
self.add_pass(ConvertExpandCopyToRepeatPass())
187189
self.add_pass(UnsqueezeBeforeRepeatPass())
188190
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
@@ -196,9 +198,14 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
196198

197199
self.add_pass(FuseViewCopyTransform())
198200
self.add_pass(FuseConstantArgsPass(exported_program))
201+
self.add_pass(InsertTableOpsPass(exported_program))
202+
# If we have a conv2d with int16 activation split up into a convolution
203+
# and an addition, to work-around the lack of support for int48 in torch
204+
# needs to happen before AddBiasPass, but after the table ops are inserted
205+
# to be able to validate that conv2d has right dtype arguments.
206+
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
199207
self.add_pass(AddBiasPass(exported_program))
200208

201-
self.add_pass(InsertTableOpsPass(exported_program))
202209
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
203210
self.add_pass(ToTosaMemoryFormatPass(exported_program))
204211
self.add_pass(RemoveNoopPass())
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import cast
9+
10+
import torch
11+
from executorch.backends.arm._passes.quant_args import QuantArgs
12+
13+
from executorch.backends.arm.tosa.specification import get_context_spec, Tosa_1_00
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass
16+
17+
18+
class DecomposeConv2dWithInt16ActivationPass(ExportPass):
19+
"""
20+
This pass decomposes a convolution with input dtype int16 and bias
21+
into a convolution without bias followed by an addition of the bias
22+
since the TOSA op requires the bias to be int48 which is hard to represent
23+
in torch. Instead rescale the int48 output to int16 and add the bias in int16.
24+
"""
25+
26+
def call_operator(self, op, args, kwargs, meta):
27+
if op != exir_ops.edge.aten.convolution.default:
28+
return super().call_operator(op, args, kwargs, meta)
29+
30+
tosa_spec = get_context_spec()
31+
if not tosa_spec.support_integer():
32+
return super().call_operator(op, args, kwargs, meta)
33+
34+
# return if no bias
35+
if args[2] is None:
36+
return super().call_operator(op, args, kwargs, meta)
37+
38+
if args[0].data.dtype == torch.int8:
39+
return super().call_operator(op, args, kwargs, meta)
40+
elif args[0].data.dtype == torch.int16:
41+
if isinstance(tosa_spec, Tosa_1_00) and not tosa_spec.support_extension(
42+
"int16"
43+
):
44+
raise ValueError(
45+
"int16 activation for convolution requires TOSA int16 extension"
46+
)
47+
else:
48+
raise NotImplementedError(
49+
"Decomposition to conv+add only implemented for activation of int16 type"
50+
)
51+
52+
# convolution with bias and activation is int16
53+
# The bias is assumed to be quantized with the same quantization parameters as
54+
# as the output of the convolution
55+
bias = args[2]
56+
assert (
57+
meta.data["output_qparams"][0].dtype == bias.data.dtype
58+
), "Bias needs to have same type as quantized output type"
59+
no_bias_args = list(args)
60+
no_bias_args[2] = None
61+
# split up to convolution + bias
62+
convolution = super().call_operator(op, tuple(no_bias_args), kwargs, meta)
63+
64+
# create a copy of the meta without the qparams, to be used with the new nodes
65+
new_meta = meta.copy()
66+
new_meta.data.pop("output_qparams", None)
67+
new_meta.data.pop("input_qparams", None)
68+
69+
# reshape the tensor to the same rank as the convolution output to add the bias to the channels
70+
channel_bias = super().call_operator(
71+
exir_ops.edge.aten.view_copy.default,
72+
(bias, [1, len(bias.data), 1, 1]),
73+
{},
74+
new_meta,
75+
)
76+
77+
output_dtype = meta.data["output_qparams"][0].dtype
78+
79+
if output_dtype == torch.int16:
80+
# The conv will get the output int48 scaled to int32 in serialization step.
81+
# To be able to add the bias we need to first scale (cast?) the output to int32.
82+
# The resulting i32 sum will then need to be scaled back to the output dtype.
83+
84+
# calculate common rescale factor from convolution output and bias quantization
85+
output_qparams = cast(QuantArgs, meta.data["output_qparams"][0])
86+
conv_output_scale = output_qparams.scale
87+
bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2])
88+
bias_scale = bias_qparams.scale
89+
90+
common_scale = max(bias_scale, conv_output_scale)
91+
92+
# calculate how we can rescale bias and conv to a common scale and maximize the output range
93+
bias_rescale_factor = bias_scale / common_scale
94+
conv_rescale_factor = conv_output_scale / common_scale
95+
96+
# Either of conv output or bias now covers the full int16 range and the other one a smaller range.
97+
# Since we are upscaling to int32 we have 16 additional bits to work with to maximize the output range.
98+
# Worst case here is that both bias and conv output covers the full int16 range so we leave one bit
99+
# and then one for the sign bit.
100+
bits_left_to_shift = 14
101+
102+
# update rescale factors
103+
bias_rescale_factor *= 1 << bits_left_to_shift
104+
conv_rescale_factor *= 1 << bits_left_to_shift
105+
106+
conv_output = super().call_operator(
107+
exir_ops.backend.tosa.RESCALE.default,
108+
(convolution, torch.int32, conv_rescale_factor, 0, 0),
109+
{},
110+
new_meta,
111+
)
112+
113+
bias_rescaled = super().call_operator(
114+
exir_ops.backend.tosa.RESCALE.default,
115+
(channel_bias, torch.int32, bias_rescale_factor, 0, 0),
116+
{},
117+
new_meta,
118+
)
119+
120+
add = super().call_operator(
121+
exir_ops.edge.aten.add.Tensor,
122+
(conv_output, bias_rescaled),
123+
{},
124+
new_meta,
125+
)
126+
127+
res_rescale = super().call_operator(
128+
exir_ops.backend.tosa.RESCALE.default,
129+
(
130+
add,
131+
output_dtype,
132+
(common_scale / (conv_output_scale * (1 << bits_left_to_shift))),
133+
0,
134+
0,
135+
),
136+
{},
137+
new_meta,
138+
)
139+
140+
else:
141+
raise NotImplementedError(
142+
f"Decomposition to conv+add only implemented for activation of int16 type, not for {output_dtype}"
143+
)
144+
145+
return res_rescale

backends/arm/_passes/fuse_equal_placeholders_pass.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from typing import Set, Type
99

1010
import torch
11+
1112
from executorch.backends.arm._passes.arm_pass_utils import (
1213
get_constant_placeholder_kind,
1314
get_param_tensor,
1415
is_param_node,
1516
)
17+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1618
from executorch.backends.transforms.utils import (
1719
create_constant_placeholder,
1820
delete_constant_placeholder,
@@ -47,9 +49,14 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4749
continue
4850
# Create a lightweight fingerprint: dtype + shape + SHA1 of raw bytes
4951
# Ensure tensor is on CPU and contiguous
52+
53+
# ensure we don't merge any special case int48_t tensors with int32_t tensors
54+
# since int48_t tensors needs to be instantiated separately.
55+
is_int48 = node.meta.get(TosaSpecialDtype.meta_key(), None)
5056
t_cpu = tensor.detach().cpu().contiguous()
5157
data_bytes = t_cpu.numpy().tobytes()
5258
key = (
59+
is_int48,
5360
str(t_cpu.dtype),
5461
tuple(t_cpu.shape),
5562
hashlib.sha1(data_bytes).hexdigest(),

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@
1919
slice_copy_support,
2020
to_dim_order_copy_support,
2121
tosa_supported_operators,
22+
where_support,
2223
)

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ class EthosU55NotSupported(OperatorSupportBase):
128128
exir_ops.edge.aten.bitwise_and.Scalar,
129129
exir_ops.edge.aten.bitwise_or.Scalar,
130130
exir_ops.edge.aten.bitwise_xor.Scalar,
131-
exir_ops.edge.aten.bitwise_not,
131+
exir_ops.edge.aten.bitwise_not.default,
132132
exir_ops.edge.aten.logical_and.default,
133133
exir_ops.edge.aten.logical_or.default,
134134
exir_ops.edge.aten.logical_xor.default,

backends/arm/operator_support/index_tensor_support.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Provide TOSA support checks for ``aten.index.Tensor``.
6+
7+
Reject unsupported patterns such as high-rank index tensors, front-positioned
8+
slice/ellipsis/None markers, and cases that exceed ``int32`` element limits.
9+
10+
"""
511

612
import math
713

@@ -18,7 +24,8 @@
1824

1925
@register_tosa_support_check
2026
class IndexTensorSupported(SupportedTOSAOperatorCheck):
21-
"""
27+
"""Prevent partitioning of unsupported ``index.Tensor`` usages.
28+
2229
This support check is intended to prevent the partitioning of
2330
currently unsupported usages of the index.Tensor operator.
2431
@@ -95,6 +102,7 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck):
95102
t[1:3, torch.arange(5), 2:3, torch.arange(3).reshape(3,1)]
96103
are also possible and can result in some unintuitive behaviors
97104
where batching and indexing are mixed together.
105+
98106
"""
99107

100108
targets = [exir_ops.edge.aten.index.Tensor]
@@ -107,6 +115,14 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck):
107115
def is_node_tosa_supported(
108116
self, node: fx.Node, tosa_spec: TosaSpecification
109117
) -> bool: # type: ignore[override, misc]
118+
"""Return True if ``aten.index.Tensor`` usage fits supported patterns.
119+
120+
Enforces the following constraints:
121+
- No ``None`` (unsqueeze), slice, or ellipsis before an indexing tensor.
122+
- Indexing tensors have rank <= 3.
123+
- The value tensor element count fits in ``int32``.
124+
125+
"""
110126
indices = node.args[1]
111127
for index in indices: # type: ignore[union-attr]
112128
# Usage 2 guard

backends/arm/operator_support/tosa_profile_supported_op_lists.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@
104104
exir_ops.edge.aten.squeeze_copy.dims,
105105
exir_ops.edge.aten.pow.Tensor_Scalar,
106106
exir_ops.edge.aten.pow.Tensor_Tensor,
107-
exir_ops.edge.aten.where.self,
108107
operator.getitem,
109108
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
110109
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
@@ -136,6 +135,7 @@
136135
exir_ops.edge.aten.logit.default,
137136
exir_ops.edge.aten.acos.default,
138137
exir_ops.edge.aten.elu.default,
138+
exir_ops.edge.aten.bitwise_not.default,
139139
}
140140

141141

@@ -220,7 +220,6 @@
220220
exir_ops.edge.aten.squeeze_copy.dims,
221221
exir_ops.edge.aten.pow.Tensor_Scalar,
222222
exir_ops.edge.aten.pow.Tensor_Tensor,
223-
exir_ops.edge.aten.where.self,
224223
operator.getitem,
225224
exir_ops.edge.aten.constant_pad_nd.default,
226225
exir_ops.edge.aten.amax.default,

0 commit comments

Comments
 (0)