Skip to content

Commit 028526f

Browse files
committed
Update on "Reuse GELU implementation from PyTorch core"
kernels/optimized doesn't need to support embedded systems, so it can just take a header-only dep on PyTorch. Note that, because we will pick up Sleef internally and ignore it externally thanks to ATen vec, this PR gets to enable optimized GELU in OSS. Testing: CI to make sure this doesn't break mobile build modes; happy to take advice on anything not currently covered that might break. Differential Revision: [D66335522](https://our.internmc.facebook.com/intern/diff/D66335522/) [ghstack-poisoned]
2 parents 011ee70 + c18e4eb commit 028526f

File tree

95 files changed

+2577
-1275
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+2577
-1275
lines changed

.github/pytorch-probot.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# The schema is from https://github.com/pytorch/pytorch/blob/main/.github/pytorch-probot.yml
2+
tracking_issue: 7679
23
ciflow_push_tags:
34
- ciflow/android
45
- ciflow/apple

.github/workflows/android-perf.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ jobs:
260260
--output_name="${OUT_ET_MODEL_NAME}.pte"
261261
ls -lh "${OUT_ET_MODEL_NAME}.pte"
262262
elif [[ ${{ matrix.config }} == "llama3_qnn_htp" ]]; then
263-
export QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728
263+
export QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029
264264
export LD_LIBRARY_PATH=$QNN_SDK_ROOT/lib/x86_64-linux-clang/
265265
export PYTHONPATH=$(pwd)/..
266266
@@ -347,7 +347,7 @@ jobs:
347347
PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh
348348
349349
export ANDROID_ABIS="arm64-v8a"
350-
PYTHON_EXECUTABLE=python EXECUTORCH_BUILD_QNN=ON QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728 bash build/build_android_llm_demo.sh ${ARTIFACTS_DIR_NAME}
350+
PYTHON_EXECUTABLE=python EXECUTORCH_BUILD_QNN=ON QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029 bash build/build_android_llm_demo.sh ${ARTIFACTS_DIR_NAME}
351351
352352
# Let's see how expensive this job is, we might want to tone it down by running it periodically
353353
benchmark-on-device:

.github/workflows/pull.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,9 @@ jobs:
332332

333333
unittest-arm:
334334
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
335+
permissions:
336+
id-token: write
337+
contents: read
335338
with:
336339
runner: linux.2xlarge
337340
docker-image: executorch-ubuntu-22.04-arm-sdk
@@ -394,6 +397,25 @@ jobs:
394397
# Test llama2
395398
PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh -model stories110M -build_tool "${BUILD_TOOL}" -mode "${MODE}" -dtype "${DTYPE}" -pt2e_quantize "${PT2E_QUANTIZE}"
396399
400+
test-qnn-models-linux:
401+
name: test-qnn-models-linux
402+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
403+
strategy:
404+
fail-fast: false
405+
with:
406+
runner: linux.2xlarge
407+
docker-image: executorch-ubuntu-22.04-qnn-sdk
408+
submodules: 'true'
409+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
410+
timeout: 180
411+
script: |
412+
# The generic Linux job chooses to use base env, not the one setup by the image
413+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
414+
conda activate "${CONDA_ENV}"
415+
416+
# placeholder for running test_qnn_delegate.py, can use matrix such that we can trigger different jobs, refers to test-llama-runner-qnn-linux
417+
# reminder: make sure each job runs fast
418+
397419
test-phi-3-mini-runner-linux:
398420
name: test-phi-3-mini-runner-linux
399421
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main

.github/workflows/trunk.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ jobs:
132132
test-arm-backend-delegation:
133133
name: test-arm-backend-delegation
134134
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
135+
permissions:
136+
id-token: write
137+
contents: read
135138
with:
136139
runner: linux.2xlarge
137140
docker-image: executorch-ubuntu-22.04-arm-sdk
@@ -159,6 +162,9 @@ jobs:
159162
test-arm-reference-delegation:
160163
name: test-arm-reference-delegation
161164
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
165+
permissions:
166+
id-token: write
167+
contents: read
162168
with:
163169
runner: linux.2xlarge
164170
docker-image: executorch-ubuntu-22.04-arm-sdk

.lintrunner.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ include_patterns = [
298298
'build/**/*.py',
299299
'codegen/**/*.py',
300300
# 'devtools/**/*.py',
301+
'devtools/visualization/**/*.py',
301302
'docs/**/*.py',
302303
# 'examples/**/*.py',
303304
# 'exir/**/*.py',

backends/apple/coreml/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,14 @@ class Model(torch.nn.Module):
9393
source_model = Model()
9494
example_inputs = (torch.randn((1, 3, 256, 256)), )
9595

96-
pre_autograd_aten_dialect = export_for_training(model, example_inputs).module()
96+
pre_autograd_aten_dialect = export_for_training(source_model, example_inputs).module()
9797

9898
quantization_config = LinearQuantizerConfig.from_dict(
9999
{
100100
"global_config": {
101101
"quantization_scheme": QuantizationScheme.symmetric,
102-
"activation_dtype": torch.uint8,
103-
"weight_dtype": torch.int8,
102+
"activation_dtype": torch.quint8,
103+
"weight_dtype": torch.qint8,
104104
"weight_per_channel": True,
105105
}
106106
}

backends/arm/_passes/arm_pass_manager.py

Lines changed: 63 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
# pyre-unsafe
99

10-
import torch
1110
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
1211
AnnotateChannelsLastDimOrder,
1312
)
@@ -47,7 +46,7 @@
4746
)
4847
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
4948
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
50-
ConvertMeanDimToAveragePool,
49+
ConvertMeanDimToAveragePoolPass,
5150
)
5251
from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass
5352
from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
@@ -61,86 +60,98 @@
6160
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
6261
UnsqueezeScalarPlaceholdersPass,
6362
)
63+
from executorch.backends.arm.tosa_specification import TosaSpecification
6464
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
6565
from executorch.exir import ExportedProgram
66-
from executorch.exir.dialects._ops import ops as exir_ops
6766
from executorch.exir.pass_manager import PassManager
67+
from torch.fx import GraphModule
6868

6969

7070
class ArmPassManager(PassManager):
7171

72-
def _transform(self, graph_module: torch.fx.GraphModule):
72+
def __init__(self, tosa_spec: TosaSpecification) -> None:
73+
self.tosa_spec = tosa_spec
74+
super().__init__()
75+
76+
def _transform(self, graph_module: GraphModule):
7377
return self(graph_module).graph_module
7478

75-
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
76-
"""Apply passes before transforming program to backend"""
79+
def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
7780
self.add_pass(FuseQuantizedActivationPass())
81+
self.add_pass(RemoveGetItemPass())
82+
self.add_pass(ConvertSplitToSlicePass())
83+
self.add_pass(ConvertMmToBmmPass())
7884
self.add_pass(DecomposeLinearPass())
85+
self.add_pass(ConvertMeanDimToAveragePoolPass())
86+
87+
self.add_pass(AnnotateDecomposedMatmulPass())
88+
self.add_pass(QuantizeFullArgument())
89+
self.add_pass(FoldAndAnnotateQParamsPass())
90+
self.add_pass(RetraceFoldedDtypesPass())
91+
self.add_pass(InsertTableOpsPass(exported_program))
92+
93+
self.add_pass(RemoveClonePass())
94+
self.add_pass(SizeAdjustConv2DPass())
95+
self.add_pass(ConvertExpandCopyToRepeatPass())
96+
self.add_pass(UnsqueezeBeforeRepeatPass())
97+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
98+
self.add_pass(CastInt64ToInt32Pass(exported_program))
99+
self.add_pass(MatchArgRanksPass(exported_program))
100+
self.add_pass(KeepDimsFalseToSqueezePass())
101+
self.add_pass(Conv1dUnsqueezePass(exported_program))
102+
self.add_pass(DecomposeSelectPass())
103+
104+
self.add_pass(AnnotateChannelsLastDimOrder())
105+
106+
return self._transform(exported_program.graph_module)
107+
108+
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
109+
110+
self.add_pass(FuseQuantizedActivationPass())
79111
self.add_pass(RemoveGetItemPass())
112+
self.add_pass(ConvertSplitToSlicePass())
113+
self.add_pass(ConvertMmToBmmPass())
114+
self.add_pass(DecomposeLinearPass())
80115
self.add_pass(DecomposeLayerNormPass())
81116
self.add_pass(DecomposeVarPass())
82-
self.add_pass(ConvertMeanDimToAveragePool())
83117
self.add_pass(DecomposeMeanDimPass())
84-
self.add_pass(ConvertSplitToSlicePass())
85-
self.add_pass(ConvertMmToBmmPass())
86-
# TODO MLETORCH-558
118+
self.add_pass(ConvertMeanDimToAveragePoolPass())
119+
self.add_pass(DecomposeDivPass())
120+
self.add_pass(DecomposeSoftmaxesPass())
121+
87122
self.add_pass(AnnotateDecomposedMatmulPass())
88123
self.add_pass(QuantizeFullArgument())
89-
self.add_pass(
90-
FoldAndAnnotateQParamsPass(
91-
[
92-
exir_ops.edge.aten.minimum.default,
93-
exir_ops.edge.aten.maximum.default,
94-
exir_ops.edge.aten.add.Tensor,
95-
exir_ops.edge.aten.avg_pool2d.default,
96-
exir_ops.edge.aten.bmm.default,
97-
exir_ops.edge.aten.cat.default,
98-
exir_ops.edge.aten.convolution.default,
99-
exir_ops.edge.aten.clone.default,
100-
exir_ops.edge.aten.exp.default,
101-
exir_ops.edge.aten.expand_copy.default,
102-
exir_ops.edge.aten.full.default,
103-
exir_ops.edge.aten.hardtanh.default,
104-
exir_ops.edge.aten.log.default,
105-
exir_ops.edge.aten.max_pool2d.default,
106-
exir_ops.edge.aten.mul.Tensor,
107-
exir_ops.edge.aten.permute_copy.default,
108-
exir_ops.edge.aten.reciprocal.default,
109-
exir_ops.edge.aten.relu.default,
110-
exir_ops.edge.aten.repeat.default,
111-
exir_ops.edge.aten.rsqrt.default,
112-
exir_ops.edge.aten.select_copy.int,
113-
exir_ops.edge.aten.sigmoid.default,
114-
exir_ops.edge.aten.slice_copy.Tensor,
115-
exir_ops.edge.aten.squeeze_copy.dims,
116-
exir_ops.edge.aten.sub.Tensor,
117-
exir_ops.edge.aten.sum.dim_IntList,
118-
exir_ops.edge.aten.tanh.default,
119-
exir_ops.edge.aten.unsqueeze_copy.default,
120-
exir_ops.edge.aten.upsample_nearest2d.vec,
121-
exir_ops.edge.aten.view_copy.default,
122-
]
123-
)
124-
)
124+
self.add_pass(FoldAndAnnotateQParamsPass())
125125
self.add_pass(RetraceFoldedDtypesPass())
126126
self.add_pass(InsertTableOpsPass(exported_program))
127+
128+
self.add_pass(RemoveClonePass())
129+
self.add_pass(SizeAdjustConv2DPass())
127130
self.add_pass(ConvertExpandCopyToRepeatPass())
128131
self.add_pass(UnsqueezeBeforeRepeatPass())
129-
self.add_pass(CastInt64ToInt32Pass(exported_program))
130132
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
131-
self.add_pass(SizeAdjustConv2DPass())
132-
self.add_pass(RemoveClonePass())
133+
self.add_pass(CastInt64ToInt32Pass(exported_program))
133134
self.add_pass(MatchArgRanksPass(exported_program))
134-
self.add_pass(DecomposeDivPass())
135135
self.add_pass(KeepDimsFalseToSqueezePass())
136136
self.add_pass(Conv1dUnsqueezePass(exported_program))
137-
self.add_pass(DecomposeSoftmaxesPass())
138137
self.add_pass(DecomposeSelectPass())
138+
139139
self.add_pass(AnnotateChannelsLastDimOrder())
140140

141141
return self._transform(exported_program.graph_module)
142142

143-
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
143+
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
144+
"""Apply passes before transforming program to backend"""
145+
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+BI"):
146+
return self._tosa_080_BI_pipeline(exported_program)
147+
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+MI"):
148+
return self._tosa_080_MI_pipeline(exported_program)
149+
else:
150+
raise NotImplementedError(
151+
f"No pass pipeline implemented for {self.tosa_spec=}"
152+
)
153+
154+
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
144155
self.add_pass(ScalarsToAttributePass())
145156
self.add_pass(DecomposeLayerNormPass())
146157
self.add_pass(DecomposeVarPass())

backends/arm/_passes/cast_int64_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -17,6 +17,10 @@
1717

1818

1919
class CastInt64ToInt32Pass(ExportPass):
20+
"""
21+
Cast int64 buffers to int32 if the int64 data is in int32 range.
22+
"""
23+
2024
def __init__(self, exported_program: torch.export.ExportedProgram):
2125
super(CastInt64ToInt32Pass, self).__init__()
2226
self.exported_program = exported_program

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
88

9-
from typing import cast, Dict, Iterable, Set, Tuple
9+
from typing import cast, Dict, Set, Tuple
1010

1111
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1212

@@ -55,7 +55,7 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
5555
class FoldAndAnnotateQParamsPass(ExportPass):
5656
"""
5757
A pass that walks the graph and removes any DQ and Q nodes before and after the target
58-
node in the supplied list of operators.
58+
node.
5959
The quantization parameters from the DQ/Q nodes are stored as meta values to be
6060
accessible for later lowering and serialization passes.
6161
The assumption is that the quantization annotatation adds DQ nodes for all tensor
@@ -82,9 +82,8 @@ class FoldAndAnnotateQParamsPass(ExportPass):
8282
8383
"""
8484

85-
def __init__(self, targeted_ops: Iterable[EdgeOpOverload]) -> None:
85+
def __init__(self) -> None:
8686
super().__init__()
87-
self.targeted_ops = targeted_ops
8887

8988
def fold_and_annotate_arg(
9089
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
@@ -131,7 +130,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
131130
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
132131
for n in graph_module.graph.nodes:
133132
n = cast(Node, n)
134-
if n.op != "call_function" or n.target not in self.targeted_ops:
133+
if n.op != "call_function":
135134
continue
136135

137136
# Make sure we haven't already set qparams meta information on the node
@@ -180,7 +179,7 @@ class QuantizeFullArgument(ExportPass):
180179

181180
def call(self, graph_module: GraphModule) -> PassResult:
182181
modified = False
183-
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
182+
# Loop over the graph nodes and find full.default nodes.
184183
for n in graph_module.graph.nodes:
185184
n = cast(Node, n)
186185
if n.target != exir_ops.edge.aten.full.default:

backends/arm/_passes/fuse_quantized_activation_pass.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ def _is_fuseable_quantized_activation(self, node: Node):
1919
is_fuseable = min_val == 0
2020

2121
is_quantized = len(node.users) == 1 and next(iter(node.users)).target == q_op
22-
if is_quantized:
22+
if is_fuseable and is_quantized:
2323
quant_node = next(iter(node.users))
2424
zp = quant_node.args[2]
2525
qmin = quant_node.args[3]
26-
27-
return is_fuseable and is_quantized and zp == qmin
26+
return zp == qmin
27+
else:
28+
return False
2829

2930
def _is_fuseable_input(self, node: Node):
3031
return (

0 commit comments

Comments
 (0)