Skip to content

Commit ce8c770

Browse files
committed
Update on "Reuse GELU implementation from PyTorch core"
kernels/optimized doesn't need to support embedded systems, so it can just take a header-only dep on PyTorch. Note that, because we will pick up Sleef internally and ignore it externally thanks to ATen vec, this PR gets to enable optimized GELU in OSS. Testing: CI to make sure this doesn't break mobile build modes; happy to take advice on anything not currently covered that might break. Differential Revision: [D66335522](https://our.internmc.facebook.com/intern/diff/D66335522/) [ghstack-poisoned]
2 parents 732b825 + 283798c commit ce8c770

File tree

137 files changed

+2511
-1177
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

137 files changed

+2511
-1177
lines changed

.ci/docker/requirements-ci.txt

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
mpmath==1.3.0
2-
numpy==1.21.3; python_version == '3.10'
3-
numpy==1.23.2; python_version == '3.11'
4-
numpy; python_version >= '3.12'
2+
numpy==2.0.0; python_version >= '3.10'
53
PyYAML==6.0.1
64
ruamel.yaml==0.17.32
75
sympy==1.12
86
timm==0.6.13
97
tomli==2.0.1
108
torchsr==1.0.4
11-
transformers==4.38.0
9+
transformers==4.47.1
1210
zstd==1.5.5.1
13-
pandas==2.0.3; python_version == '3.10'
14-
pandas; python_version >= '3.11'
11+
pandas==2.2.2; python_version >= '3.10'
1512
pytest==7.2.0
1613
pytest-cov==4.1.0
1714
expecttest==0.1.6
@@ -24,7 +21,7 @@ sphinx-gallery==0.14.0
2421
breathe==4.34.0
2522
exhale==0.2.3
2623
docutils==0.16
27-
matplotlib==3.7.2
24+
matplotlib==3.9.4
2825
# PyTorch Theme
2926
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
3027
myst-parser==0.18.1

.ci/scripts/build-qnn-sdk.sh

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/bin/bash
22
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
# All rights reserved.
45
#
56
# This source code is licensed under the BSD-style license found in the
@@ -11,10 +12,16 @@ set -o xtrace
1112
build_qnn_backend() {
1213
echo "Start building qnn backend."
1314
export ANDROID_NDK_ROOT=/opt/ndk
14-
export QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728
15+
export QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029
1516
export EXECUTORCH_ROOT="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../.." && pwd)"
1617

17-
bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number 2 --release
18+
# Workaround to avoid issues around missing flatccrt library (depending on the
19+
# number of jobs used), see issue #7300:
20+
# Build twice (second time with `--no_clean`) to make sure libflatccrt.a is
21+
# available.
22+
# TODO: Remove this workaround once the underlying issue is fixed.
23+
bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number 2 --release || \
24+
bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number 2 --release --no_clean
1825
}
1926

2027
set_up_aot() {

.ci/scripts/setup-qnn-deps.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ install_qnn() {
1616
QNN_INSTALLATION_DIR=/tmp/qnn
1717
mkdir -p "${QNN_INSTALLATION_DIR}"
1818

19-
curl -Lo /tmp/v2.25.0.24.07.28.zip "https://softwarecenter.qualcomm.com/api/download/software/qualcomm_neural_processing_sdk/v2.25.0.240728.zip"
19+
curl -Lo /tmp/v2.28.0.24.10.29.zip "https://softwarecenter.qualcomm.com/api/download/software/qualcomm_neural_processing_sdk/v2.28.0.241029.zip"
2020
echo "Finishing downloading qnn sdk."
21-
unzip -qo /tmp/v2.25.0.24.07.28.zip -d /tmp
21+
unzip -qo /tmp/v2.28.0.24.10.29.zip -d /tmp
2222
echo "Finishing unzip qnn sdk."
2323

2424

.ci/scripts/test_llama.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ echo "COREML option ${COREML}"
121121
if [[ "${MODE}" =~ .*qnn.* ]]; then
122122
QNN=ON
123123
export EXECUTORCH_ROOT="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/.." && pwd)"
124-
export QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728
124+
export QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029
125125
export LD_LIBRARY_PATH="${QNN_SDK_ROOT}/lib/x86_64-linux-clang"
126126
export PYTHONPATH=".."
127127
cp schema/program.fbs exir/_serialize/program.fbs

backends/apple/coreml/scripts/install_requirements.sh

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,7 @@ cmake --build "$COREMLTOOLS_DIR_PATH/build" --parallel
4747

4848
echo "${green}ExecuTorch: Installing coremltools."
4949
pip install "$COREMLTOOLS_DIR_PATH"
50-
# CoreMLTools have started supporting numpy 2.0,
51-
# but ExecuTorch example model test env is still using older transformers,
52-
# so for now we will need to downgrade numpy to 1.x
53-
# TODO: Remove this numpy downgrade once later transformers starts to be used
54-
pip install numpy==1.26.4
50+
5551
STATUS=$?
5652
if [ $STATUS -ne 0 ]; then
5753
echo "${red}ExecuTorch: Failed to install coremltools."

backends/arm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ backends/arm/test/setup_testing.sh
119119
The you can run the tests with
120120

121121
```
122-
pytest -c /dev/null -v -n auto backends/arm/test --arm_quantize_io --arm_run_corstoneFVP
122+
pytest -c /dev/null -v -n auto backends/arm/test --arm_run_corstoneFVP
123123
```
124124

125125
### Code coverage

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
3030
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
31+
from executorch.backends.arm._passes.decompose_select import DecomposeSelectPass
3132
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
3233
DecomposeSoftmaxesPass,
3334
)
@@ -62,7 +63,6 @@
6263
)
6364
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
6465
from executorch.exir import ExportedProgram
65-
from executorch.exir.backend.compile_spec_schema import CompileSpec
6666
from executorch.exir.dialects._ops import ops as exir_ops
6767
from executorch.exir.pass_manager import PassManager
6868

@@ -72,9 +72,7 @@ class ArmPassManager(PassManager):
7272
def _transform(self, graph_module: torch.fx.GraphModule):
7373
return self(graph_module).graph_module
7474

75-
def transform_to_backend_pipeline(
76-
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
77-
):
75+
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
7876
"""Apply passes before transforming program to backend"""
7977
self.add_pass(FuseQuantizedActivationPass())
8078
self.add_pass(DecomposeLinearPass())
@@ -137,11 +135,8 @@ def transform_to_backend_pipeline(
137135
self.add_pass(KeepDimsFalseToSqueezePass())
138136
self.add_pass(Conv1dUnsqueezePass(exported_program))
139137
self.add_pass(DecomposeSoftmaxesPass())
140-
for spec in compile_spec:
141-
if spec.key == "permute_memory_format":
142-
memory_format = spec.value.decode()
143-
if memory_format == "nhwc":
144-
self.add_pass(AnnotateChannelsLastDimOrder())
138+
self.add_pass(DecomposeSelectPass())
139+
self.add_pass(AnnotateChannelsLastDimOrder())
145140

146141
return self._transform(exported_program.graph_module)
147142

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass_utils import create_node
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
14+
15+
class DecomposeSelectPass(ExportPass):
16+
"""
17+
This pass decomposes select into slice + squeeze to ensure that Aten and TOSA outputs has the same rank (input rank -1)
18+
"""
19+
20+
def call(self, graph_module: torch.fx.GraphModule):
21+
for node in graph_module.graph.nodes:
22+
23+
if node.op != "call_function":
24+
continue
25+
26+
if node.target in (
27+
exir_ops.edge.aten.select.int,
28+
exir_ops.edge.aten.select_copy.int,
29+
):
30+
slice_op = exir_ops.edge.aten.slice_copy.Tensor
31+
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
32+
else:
33+
continue
34+
35+
input_node, dim, index = node.args
36+
37+
rank = len(input_node.meta["val"].size())
38+
dim = dim % rank if dim < 0 else dim
39+
index = index % rank if index < 0 else index
40+
dim_list = list(range(rank))
41+
42+
with graph_module.graph.inserting_before(node):
43+
slice_node = create_node(
44+
graph_module.graph, slice_op, (input_node, dim, index, index + 1)
45+
)
46+
squeeze_node = create_node(
47+
graph_module.graph, squeeze_op, (slice_node, dim_list)
48+
)
49+
50+
node.replace_all_uses_with(squeeze_node)
51+
graph_module.graph.erase_node(node)
52+
53+
graph_module.graph.eliminate_dead_code()
54+
graph_module.recompile()
55+
graph_module = super().call(graph_module).graph_module
56+
return PassResult(graph_module, True)

backends/arm/_passes/tag_io_quant_pass.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

backends/arm/arm_backend.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -49,8 +49,6 @@ def __init__(self):
4949
self.compiler_flags = []
5050
self.output_format = None
5151
self.path_for_intermediates = None
52-
# TODO MLETORCH-265 Remove permute_nhwc flag
53-
self.permute_nhwc = False
5452
self.quantize_io = False
5553
self.tosa_version = None
5654
self.input_order = None
@@ -118,16 +116,6 @@ def dump_intermediate_artifacts_to(
118116
self.path_for_intermediates = output_path
119117
return self
120118

121-
def set_permute_memory_format(
122-
self, set_nhwc_permutation: bool = True
123-
) -> "ArmCompileSpecBuilder":
124-
"""
125-
Permute to channel last in compiler and runtime. Compilation and
126-
runtime will convert rank 4 inputs to channel last for each sub-graph.
127-
"""
128-
self.permute_nhwc = set_nhwc_permutation
129-
return self
130-
131119
def set_quantize_io(self, quantize_io: bool = False) -> "ArmCompileSpecBuilder":
132120
"""
133121
Quantization of inputs and dequantization of outputs for cases where
@@ -170,11 +158,6 @@ def build(self) -> List[CompileSpec]:
170158
CompileSpec("debug_artifact_path", self.path_for_intermediates.encode())
171159
)
172160

173-
if self.permute_nhwc:
174-
self.compile_spec.append(
175-
CompileSpec("permute_memory_format", "nhwc".encode())
176-
)
177-
178161
if self.input_order:
179162
self.compile_spec.append(
180163
CompileSpec(
@@ -188,20 +171,27 @@ def build(self) -> List[CompileSpec]:
188171
return self.compile_spec
189172

190173

191-
def is_permute_memory(compile_spec: List[CompileSpec]) -> bool:
192-
for spec in compile_spec:
193-
if spec.key == "permute_memory_format":
194-
return spec.value.decode() == "nhwc"
195-
return False
196-
197-
198174
def is_tosa(compile_spec: List[CompileSpec]) -> bool:
199175
for spec in compile_spec:
200176
if spec.key == "output_format":
201177
return spec.value.decode() == "tosa"
202178
return False
203179

204180

181+
def is_quantize_io(compile_specs: List[CompileSpec]) -> bool:
182+
for spec in compile_specs:
183+
if spec.key == "quantize_io" and spec.value.decode() == "True":
184+
return True
185+
return False
186+
187+
188+
def get_tosa_version(compile_spec: List[CompileSpec]) -> TosaSpecification:
189+
for spec in compile_spec:
190+
if spec.key == "tosa_version":
191+
return TosaSpecification.create_from_string(spec.value.decode())
192+
raise RuntimeError("Could not find TOSA version in CompileSpec")
193+
194+
205195
def get_intermediate_path(compile_spec: List[CompileSpec]) -> Optional[str]:
206196
for spec in compile_spec:
207197
if spec.key == "debug_artifact_path":
@@ -264,7 +254,7 @@ def preprocess( # noqa: C901
264254
# const data directly. Path created and data written only in debug builds.
265255
tosa_graph = ts.TosaSerializer(artifact_path)
266256
graph_module = ArmPassManager().transform_to_backend_pipeline(
267-
exported_program=edge_program, compile_spec=compile_spec
257+
exported_program=edge_program
268258
)
269259

270260
node_visitors = get_node_visitors(edge_program, tosa_spec)

0 commit comments

Comments
 (0)