Skip to content

Commit 7198aa0

Browse files
authored
Merge branch 'main' into initial_cond
2 parents c709e67 + 4ea9ddf commit 7198aa0

File tree

116 files changed

+3221
-1596
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

116 files changed

+3221
-1596
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
e6f766c7d750d40603eee3f66c5915bac606b3ea
1+
556fc09a9f67f24ca5591ec049c5d0c347c5f62a

.ci/scripts/test_qnn_static_llm.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ elif [[ "${TASK_NAME}" == "stories_260k_bc" ]]; then
8181
fi
8282

8383
elif [[ "${TASK_NAME}" == "smollm2_135m" ]]; then
84-
$PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_static_smollm2 --model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./static_smollm2 --enable_x86_64
84+
$PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_static_llm_model --model_name smollm2_135m --model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./static_smollm2 --enable_x86_64
8585
exit_code1=$?
8686
if [ $exit_code1 -ne 0 ]; then
8787
exit 1

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ jobs:
347347
elif [[ ${{ matrix.os}} == "zephyr-preset" ]]; then
348348
setup_script_args="--target-toolchain zephyr"
349349
toolchain_prefix=arm-zephyr-eabi-
350-
threshold="135240" # 132 KiB
350+
threshold="135656" # 132 KiB
351351
toolchain_cmake=examples/zephyr/x86_64-linux-arm-zephyr-eabi-gcc.cmake
352352
else
353353
echo "Fail unsupport OS selection ${{ matrix.os }}"

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121

2222
def calculate_multiples(args):
23+
"""Returns expand args converted to repeat args, and whether the expand changes the rank"""
2324
input_node_or_tensor = args[0]
2425

2526
if isinstance(input_node_or_tensor, torch.fx.node.Node):
@@ -45,7 +46,7 @@ def calculate_multiples(args):
4546
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
4647
for i in range(expanded_rank)
4748
]
48-
return multiples
49+
return multiples, expanded_rank != len(input_shape)
4950

5051

5152
class ConvertExpandCopyToRepeatPass(ArmPass):
@@ -62,9 +63,9 @@ def call_operator(self, op, args, kwargs, meta):
6263
if op != self.expand_copy:
6364
return super().call_operator(op, args, kwargs, meta)
6465

65-
multiples = calculate_multiples(args)
66+
multiples, changes_rank = calculate_multiples(args)
6667

67-
if all((x == 1 for x in multiples)):
68+
if all((x == 1 for x in multiples)) and not changes_rank:
6869
# All dimensions/repetitions occur only once. Remove node
6970
# altogether since it's in practice just a copy.
7071
logger.warning("Found redundant expand node (no-op). Removing it.")

backends/arm/_passes/decompose_embedding_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from .arm_pass_utils import create_node, get_first_fake_tensor
1818

1919
logger = logging.getLogger(__name__)
20-
logger.setLevel(logging.WARNING)
2120

2221

2322
class DecomposeEmbeddingPass(ArmPass):

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass
1414
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
1515
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
16+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1617
from executorch.exir.backend.utils import WhyNoPartitionReporter
1718
from executorch.exir.dialects._ops import ops as exir_ops
1819
from executorch.exir.pass_base import ExportPass
@@ -50,6 +51,15 @@ def get_view(op):
5051
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
5152

5253

54+
def get_quantization(op):
55+
"""Returns quant and dequant op of same type (per_channel/ tensor) as op if op is a dequant node, None otherwise."""
56+
if op in DQ_OPS:
57+
# Input of op can be placeholder, can't use that to get quant node directly.
58+
quant_type_index = DQ_OPS.index(op)
59+
return Q_OPS[quant_type_index], op
60+
return None
61+
62+
5363
class DecomposeMeanDimPass(ArmPass):
5464
"""
5565
Decomposes a meandim into avg_pool and/or sum + mul (1/N) depending on which dims the mean is taken for:
@@ -121,6 +131,7 @@ def call_operator(self, op, args, kwargs, meta):
121131
dims_to_reduce = [dim - 1 for dim in dims_to_reduce]
122132

123133
x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
134+
x = self._maybe_insert_q_dq_after(x, meta)
124135

125136
# Reduce (h,w) dims by avg pool if possible
126137
x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta)
@@ -133,7 +144,7 @@ def call_operator(self, op, args, kwargs, meta):
133144
dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce]
134145

135146
x = super().call_operator(view_op, (x, temp_shape), {}, meta, True)
136-
147+
x = self._maybe_insert_q_dq_after(x, meta)
137148
# Reduce remaining dims by sum
138149
x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype)
139150

@@ -156,6 +167,45 @@ def _reduce_by_sum(self, op, input_node, dims, meta, dtype):
156167
full = super().call_operator(
157168
full_op, ([1] * len(output_shape), 1 / N), {"dtype": dtype}, meta, True
158169
)
170+
if (quant_ops := get_quantization(input_node.node.target)) is not None:
171+
# Insert Q and DQ nodes after full op.
172+
# Since the value of full is known, we can compute quant params such that dq(q_max_value)
173+
q_op, dq_op = quant_ops
174+
qmax = input_node.node.args[4]
175+
full_quant_args = (
176+
1 / (N * qmax), # Scale to map qmax to 1/N
177+
0, # Zero point
178+
*input_node.node.args[3:],
179+
)
180+
q_args = (full, *full_quant_args)
181+
full = super().call_operator(
182+
q_op,
183+
q_args,
184+
kwargs={},
185+
meta=meta,
186+
updated=True,
187+
)
188+
dq_args = (full, *full_quant_args)
189+
full = super().call_operator(
190+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
191+
)
192+
193+
# Insert Q and DQ nodes after sum op.
194+
# Scale needs to be adjusted with N, since it was computed on data after the division with N.
195+
sum_quant_args = (input_node.node.args[1] * N, *input_node.node.args[2:])
196+
q_args = (sum, *sum_quant_args)
197+
sum = super().call_operator(
198+
q_op,
199+
q_args,
200+
kwargs={},
201+
meta=meta,
202+
updated=True,
203+
)
204+
dq_args = (sum, *sum_quant_args)
205+
sum = super().call_operator(
206+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
207+
)
208+
159209
return super().call_operator(mul_op, (sum, full), {}, meta, True)
160210

161211
def _reduce_by_average_pool(self, op, input_node, dims, meta):
@@ -190,10 +240,38 @@ def _reduce_by_average_pool(self, op, input_node, dims, meta):
190240
)
191241

192242
if is_supported:
243+
out = super().call_operator(avgpool_op, args, {}, meta, True)
244+
out = self._maybe_insert_q_dq_after(out, meta)
193245
return (
194-
super().call_operator(avgpool_op, args, {}, meta, True),
246+
out,
195247
dims_to_reduce_by_sum,
196248
)
197249

198250
else:
199251
return input_node, dims
252+
253+
def _maybe_insert_q_dq_after(self, op, meta):
254+
"""If the input node of op is a dequant node, insert a q-dq pair after op with identical quantization parameters."""
255+
256+
if len(op.node.all_input_nodes) > 1:
257+
raise ValueError(
258+
f"Expected one input to {op.node}, got inputs {op.node.all_input_nodes}"
259+
)
260+
input_node = op.node.all_input_nodes[0]
261+
if (quant_ops := get_quantization(input_node.target)) is not None:
262+
q_op, dq_op = quant_ops
263+
quant_args = list(input_node.args[1:])
264+
q_args = (op, *quant_args)
265+
out = super().call_operator(
266+
q_op,
267+
q_args,
268+
kwargs={},
269+
meta=meta,
270+
updated=True,
271+
)
272+
dq_args = (out, *quant_args)
273+
return super().call_operator(
274+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
275+
)
276+
else:
277+
return op

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def resolve_arg(arg):
6565
if isinstance(arg, torch.fx.Node) and arg in input_nodes:
6666
idx = input_nodes.index(arg)
6767
t = get_param_tensor(self.exported_program, arg)
68-
if qparams:
68+
# Check if qparams exist for this arg
69+
if qparams and idx in qparams.keys():
6970
t = qparams[idx].dequantize_value(t)
7071
return t
7172
if isinstance(arg, tuple):

backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class InsertInt32CastsAfterInt64PlaceholdersPass(ArmPass):
3636
# Key: op overload; Value: zero-based indices of positional args that must be i64.
3737
I64_INPUT_ARG_POSITIONS = {
3838
torch.ops.aten.one_hot.default: (0,),
39+
torch.ops.aten.index_copy_.default: (2,),
40+
torch.ops.aten.index_copy.default: (2,),
3941
}
4042

4143
def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule):

backends/arm/ethosu/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _compile_tosa_flatbuffer(
6363
binary = vela_compile(
6464
tosa_flatbuffer,
6565
compile_flags,
66-
verbose=logger.getEffectiveLevel() == logging.INFO,
66+
verbose=logger.getEffectiveLevel() <= logging.INFO,
6767
intermediate_path=compile_spec.get_intermediate_path(),
6868
)
6969
return binary

backends/arm/operator_support/minmax_support.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Declare operator support for min/max along a dimension in TOSA.
6+
7+
Provide support checks ensuring that argmax/argmin indices are not consumed,
8+
restricting to float profiles until index quantization is supported.
9+
10+
"""
511

612
import torch.fx as fx
713
from executorch.backends.arm.operator_support.tosa_supported_operators import (
@@ -14,6 +20,8 @@
1420

1521
@register_tosa_support_check
1622
class MinMaxSupported(SupportedTOSAOperatorCheck):
23+
"""Provide TOSA support check for ``aten.max.dim`` and ``aten.min.dim``."""
24+
1725
targets = [
1826
exir_ops.edge.aten.max.dim,
1927
exir_ops.edge.aten.min.dim,
@@ -24,7 +32,16 @@ class MinMaxSupported(SupportedTOSAOperatorCheck):
2432
TosaSpecification.create_from_string("TOSA-1.0+FP"),
2533
]
2634

27-
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
35+
def is_node_tosa_supported(
36+
self, node: fx.Node, tosa_spec: TosaSpecification
37+
) -> bool:
38+
"""Return True if the node is supported by TOSA.
39+
40+
Allow max/min when the argmax/argmin output is unused or dropped (i.e.,
41+
only the value is consumed). Disallow cases where arg indices are
42+
further used.
43+
44+
"""
2845
if node.target in [exir_ops.edge.aten.max.dim, exir_ops.edge.aten.min.dim]:
2946
no_argmax = len(node.users) == 1
3047
no_argmax_users = (len(node.users) == 2) and (

0 commit comments

Comments
 (0)