Skip to content

Commit ff5fe0f

Browse files
author
ssjia
committed
Update base for Update on "[ET-VK] Implementation of to_dim_order_copy"
Title says it all! Previously, to_dim_order_copy was handled by removing the op. However, this is not possible if the op is modifying the dtype of the original tensor, so these instances of the op would be skipped by the partitioner. This diff adds an implementation dtype conversion, which allows to_dim_order_copy to be lowered. Differential Revision: [D86340341](https://our.internmc.facebook.com/intern/diff/D86340341/) [ghstack-poisoned]
2 parents 1f826ea + d361573 commit ff5fe0f

File tree

151 files changed

+3155
-2712
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

151 files changed

+3155
-2712
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
e6f766c7d750d40603eee3f66c5915bac606b3ea
1+
556fc09a9f67f24ca5591ec049c5d0c347c5f62a

.ci/scripts/test_qnn_static_llm.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ elif [[ "${TASK_NAME}" == "stories_260k_bc" ]]; then
8181
fi
8282

8383
elif [[ "${TASK_NAME}" == "smollm2_135m" ]]; then
84-
$PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_static_smollm2 --model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./static_smollm2 --enable_x86_64
84+
$PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_static_llm_model --model_name smollm2_135m --model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./static_smollm2 --enable_x86_64
8585
exit_code1=$?
8686
if [ $exit_code1 -ne 0 ]; then
8787
exit 1

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ jobs:
347347
elif [[ ${{ matrix.os}} == "zephyr-preset" ]]; then
348348
setup_script_args="--target-toolchain zephyr"
349349
toolchain_prefix=arm-zephyr-eabi-
350-
threshold="135240" # 132 KiB
350+
threshold="135656" # 132 KiB
351351
toolchain_cmake=examples/zephyr/x86_64-linux-arm-zephyr-eabi-gcc.cmake
352352
else
353353
echo "Fail unsupport OS selection ${{ matrix.os }}"

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from .convert_int64_output_ops_to_int32 import ConvertInt64OutputOpsToInt32Pass # noqa
2222
from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa
2323
from .convert_minmax_pass import ConvertMinMaxPass # noqa
24+
from .convert_permute_singleton_to_view_pass import ( # noqa
25+
ConvertPermuteSingletonToViewPass,
26+
)
2427
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
2528
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2629
from .convert_to_clamp import ConvertToClampPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ConvertIntPowToMuls,
2828
ConvertMinMaxPass,
2929
ConvertMmToBmmPass,
30+
ConvertPermuteSingletonToViewPass,
3031
ConvertSplitToSlicePass,
3132
ConvertSqueezesToViewPass,
3233
ConvertToClampPass,
@@ -234,6 +235,7 @@ def _tosa_pipeline(
234235
self.add_pass(CastToInt32Pass())
235236
self.add_pass(BroadcastArgsPass())
236237

238+
self.add_pass(ConvertPermuteSingletonToViewPass())
237239
self.add_pass(FuseViewCopyTransform())
238240
self.add_pass(FuseConstantArgsPass(exported_program))
239241
self.add_pass(DecomposeConv2dWithInt16ActivationPass())

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121

2222
def calculate_multiples(args):
23+
"""Returns expand args converted to repeat args, and whether the expand changes the rank"""
2324
input_node_or_tensor = args[0]
2425

2526
if isinstance(input_node_or_tensor, torch.fx.node.Node):
@@ -45,7 +46,7 @@ def calculate_multiples(args):
4546
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
4647
for i in range(expanded_rank)
4748
]
48-
return multiples
49+
return multiples, expanded_rank != len(input_shape)
4950

5051

5152
class ConvertExpandCopyToRepeatPass(ArmPass):
@@ -62,9 +63,9 @@ def call_operator(self, op, args, kwargs, meta):
6263
if op != self.expand_copy:
6364
return super().call_operator(op, args, kwargs, meta)
6465

65-
multiples = calculate_multiples(args)
66+
multiples, changes_rank = calculate_multiples(args)
6667

67-
if all((x == 1 for x in multiples)):
68+
if all((x == 1 for x in multiples)) and not changes_rank:
6869
# All dimensions/repetitions occur only once. Remove node
6970
# altogether since it's in practice just a copy.
7071
logger.warning("Found redundant expand node (no-op). Removing it.")
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Sequence, Set, Tuple, Type
8+
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
12+
from torch._ops import OpOverload
13+
14+
15+
_PERMUTE_TARGETS: Tuple[OpOverload, ...] = (
16+
exir_ops.edge.aten.permute.default,
17+
exir_ops.edge.aten.permute_copy.default,
18+
)
19+
20+
21+
class ConvertPermuteSingletonToViewPass(ExportPass):
22+
"""Replace permutations that only move singleton axes with a reshape.
23+
24+
Examples:
25+
x = rand(1,1,1,4)
26+
y = permute(x, (0,3,1,2))
27+
28+
becomes:
29+
x = rand(1,1,1,4)
30+
y = view_copy(x, (1,4,1,1))
31+
"""
32+
33+
_passes_required_after: Set[Type[ExportPass]] = set()
34+
35+
def call_operator(self, op, args, kwargs, meta):
36+
if op not in _PERMUTE_TARGETS:
37+
return super().call_operator(op, args, kwargs, meta)
38+
39+
input_tensor = args[0].data
40+
permutation = args[1]
41+
if not is_singleton_permutation(input_tensor.shape, permutation):
42+
return super().call_operator(op, args, kwargs, meta)
43+
44+
output_shape = meta["val"].shape
45+
view_args = (args[0], output_shape)
46+
return super().call_operator(
47+
exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta
48+
)
49+
50+
51+
def is_singleton_permutation(shape: Sequence[int], permutation: Sequence[int]) -> bool:
52+
"""
53+
Treat as a view only when non-singleton axes keep their order; singleton
54+
axes may move freely since they carry no data volume.
55+
"""
56+
rank = len(shape)
57+
normalized_perm = [d % rank for d in permutation]
58+
59+
non_singleton_axes = [i for i, size in enumerate(shape) if size != 1]
60+
permuted_non_singleton_axes = [axis for axis in normalized_perm if shape[axis] != 1]
61+
62+
return permuted_non_singleton_axes == non_singleton_axes

backends/arm/_passes/decompose_embedding_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from .arm_pass_utils import create_node, get_first_fake_tensor
1818

1919
logger = logging.getLogger(__name__)
20-
logger.setLevel(logging.WARNING)
2120

2221

2322
class DecomposeEmbeddingPass(ArmPass):

backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class InsertInt32CastsAfterInt64PlaceholdersPass(ArmPass):
3636
# Key: op overload; Value: zero-based indices of positional args that must be i64.
3737
I64_INPUT_ARG_POSITIONS = {
3838
torch.ops.aten.one_hot.default: (0,),
39+
torch.ops.aten.index_copy_.default: (2,),
40+
torch.ops.aten.index_copy.default: (2,),
3941
}
4042

4143
def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule):

backends/arm/ethosu/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _compile_tosa_flatbuffer(
6363
binary = vela_compile(
6464
tosa_flatbuffer,
6565
compile_flags,
66-
verbose=logger.getEffectiveLevel() == logging.INFO,
66+
verbose=logger.getEffectiveLevel() <= logging.INFO,
6767
intermediate_path=compile_spec.get_intermediate_path(),
6868
)
6969
return binary

0 commit comments

Comments
 (0)