Skip to content

Commit 6becf5d

Browse files
author
ssjia
committed
Update base for Update on "[ET-VK] buffer implementation of rotary positional embeddings"
Title says it all! Differential Revision: [D86340338](https://our.internmc.facebook.com/intern/diff/D86340338/) [ghstack-poisoned]
2 parents 31b4610 + d361573 commit 6becf5d

File tree

155 files changed

+3192
-2724
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

155 files changed

+3192
-2724
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
e6f766c7d750d40603eee3f66c5915bac606b3ea
1+
556fc09a9f67f24ca5591ec049c5d0c347c5f62a

.ci/scripts/test_qnn_static_llm.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ elif [[ "${TASK_NAME}" == "stories_260k_bc" ]]; then
8181
fi
8282

8383
elif [[ "${TASK_NAME}" == "smollm2_135m" ]]; then
84-
$PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_static_smollm2 --model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./static_smollm2 --enable_x86_64
84+
$PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_static_llm_model --model_name smollm2_135m --model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./static_smollm2 --enable_x86_64
8585
exit_code1=$?
8686
if [ $exit_code1 -ne 0 ]; then
8787
exit 1

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ jobs:
347347
elif [[ ${{ matrix.os}} == "zephyr-preset" ]]; then
348348
setup_script_args="--target-toolchain zephyr"
349349
toolchain_prefix=arm-zephyr-eabi-
350-
threshold="135240" # 132 KiB
350+
threshold="135656" # 132 KiB
351351
toolchain_cmake=examples/zephyr/x86_64-linux-arm-zephyr-eabi-gcc.cmake
352352
else
353353
echo "Fail unsupport OS selection ${{ matrix.os }}"

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from .convert_int64_output_ops_to_int32 import ConvertInt64OutputOpsToInt32Pass # noqa
2222
from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa
2323
from .convert_minmax_pass import ConvertMinMaxPass # noqa
24+
from .convert_permute_singleton_to_view_pass import ( # noqa
25+
ConvertPermuteSingletonToViewPass,
26+
)
2427
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
2528
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2629
from .convert_to_clamp import ConvertToClampPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ConvertIntPowToMuls,
2828
ConvertMinMaxPass,
2929
ConvertMmToBmmPass,
30+
ConvertPermuteSingletonToViewPass,
3031
ConvertSplitToSlicePass,
3132
ConvertSqueezesToViewPass,
3233
ConvertToClampPass,
@@ -234,6 +235,7 @@ def _tosa_pipeline(
234235
self.add_pass(CastToInt32Pass())
235236
self.add_pass(BroadcastArgsPass())
236237

238+
self.add_pass(ConvertPermuteSingletonToViewPass())
237239
self.add_pass(FuseViewCopyTransform())
238240
self.add_pass(FuseConstantArgsPass(exported_program))
239241
self.add_pass(DecomposeConv2dWithInt16ActivationPass())

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121

2222
def calculate_multiples(args):
23+
"""Returns expand args converted to repeat args, and whether the expand changes the rank"""
2324
input_node_or_tensor = args[0]
2425

2526
if isinstance(input_node_or_tensor, torch.fx.node.Node):
@@ -45,7 +46,7 @@ def calculate_multiples(args):
4546
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
4647
for i in range(expanded_rank)
4748
]
48-
return multiples
49+
return multiples, expanded_rank != len(input_shape)
4950

5051

5152
class ConvertExpandCopyToRepeatPass(ArmPass):
@@ -62,9 +63,9 @@ def call_operator(self, op, args, kwargs, meta):
6263
if op != self.expand_copy:
6364
return super().call_operator(op, args, kwargs, meta)
6465

65-
multiples = calculate_multiples(args)
66+
multiples, changes_rank = calculate_multiples(args)
6667

67-
if all((x == 1 for x in multiples)):
68+
if all((x == 1 for x in multiples)) and not changes_rank:
6869
# All dimensions/repetitions occur only once. Remove node
6970
# altogether since it's in practice just a copy.
7071
logger.warning("Found redundant expand node (no-op). Removing it.")
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Sequence, Set, Tuple, Type
8+
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
12+
from torch._ops import OpOverload
13+
14+
15+
_PERMUTE_TARGETS: Tuple[OpOverload, ...] = (
16+
exir_ops.edge.aten.permute.default,
17+
exir_ops.edge.aten.permute_copy.default,
18+
)
19+
20+
21+
class ConvertPermuteSingletonToViewPass(ExportPass):
22+
"""Replace permutations that only move singleton axes with a reshape.
23+
24+
Examples:
25+
x = rand(1,1,1,4)
26+
y = permute(x, (0,3,1,2))
27+
28+
becomes:
29+
x = rand(1,1,1,4)
30+
y = view_copy(x, (1,4,1,1))
31+
"""
32+
33+
_passes_required_after: Set[Type[ExportPass]] = set()
34+
35+
def call_operator(self, op, args, kwargs, meta):
36+
if op not in _PERMUTE_TARGETS:
37+
return super().call_operator(op, args, kwargs, meta)
38+
39+
input_tensor = args[0].data
40+
permutation = args[1]
41+
if not is_singleton_permutation(input_tensor.shape, permutation):
42+
return super().call_operator(op, args, kwargs, meta)
43+
44+
output_shape = meta["val"].shape
45+
view_args = (args[0], output_shape)
46+
return super().call_operator(
47+
exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta
48+
)
49+
50+
51+
def is_singleton_permutation(shape: Sequence[int], permutation: Sequence[int]) -> bool:
52+
"""
53+
Treat as a view only when non-singleton axes keep their order; singleton
54+
axes may move freely since they carry no data volume.
55+
"""
56+
rank = len(shape)
57+
normalized_perm = [d % rank for d in permutation]
58+
59+
non_singleton_axes = [i for i, size in enumerate(shape) if size != 1]
60+
permuted_non_singleton_axes = [axis for axis in normalized_perm if shape[axis] != 1]
61+
62+
return permuted_non_singleton_axes == non_singleton_axes

backends/arm/_passes/decompose_embedding_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from .arm_pass_utils import create_node, get_first_fake_tensor
1818

1919
logger = logging.getLogger(__name__)
20-
logger.setLevel(logging.WARNING)
2120

2221

2322
class DecomposeEmbeddingPass(ArmPass):

backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class InsertInt32CastsAfterInt64PlaceholdersPass(ArmPass):
3636
# Key: op overload; Value: zero-based indices of positional args that must be i64.
3737
I64_INPUT_ARG_POSITIONS = {
3838
torch.ops.aten.one_hot.default: (0,),
39+
torch.ops.aten.index_copy_.default: (2,),
40+
torch.ops.aten.index_copy.default: (2,),
3941
}
4042

4143
def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule):

backends/arm/ethosu/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _compile_tosa_flatbuffer(
6363
binary = vela_compile(
6464
tosa_flatbuffer,
6565
compile_flags,
66-
verbose=logger.getEffectiveLevel() == logging.INFO,
66+
verbose=logger.getEffectiveLevel() <= logging.INFO,
6767
intermediate_path=compile_spec.get_intermediate_path(),
6868
)
6969
return binary

0 commit comments

Comments
 (0)