Skip to content

Commit 75044ad

Browse files
committed
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code sharing"
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
2 parents 5eb26e8 + d1b33cb commit 75044ad

File tree

165 files changed

+3734
-2006
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

165 files changed

+3734
-2006
lines changed

.ci/docker/requirements-ci.txt

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
mpmath==1.3.0
2-
numpy==1.21.3; python_version == '3.10'
3-
numpy==1.23.2; python_version == '3.11'
4-
numpy; python_version >= '3.12'
2+
numpy==2.0.0; python_version >= '3.10'
53
PyYAML==6.0.1
64
ruamel.yaml==0.17.32
75
sympy==1.12
86
timm==0.6.13
97
tomli==2.0.1
108
torchsr==1.0.4
11-
transformers==4.38.0
9+
transformers==4.47.1
1210
zstd==1.5.5.1
13-
pandas==2.0.3; python_version == '3.10'
14-
pandas; python_version >= '3.11'
11+
pandas==2.2.2; python_version >= '3.10'
1512
pytest==7.2.0
1613
pytest-cov==4.1.0
1714
expecttest==0.1.6
@@ -24,7 +21,7 @@ sphinx-gallery==0.14.0
2421
breathe==4.34.0
2522
exhale==0.2.3
2623
docutils==0.16
27-
matplotlib==3.7.2
24+
matplotlib==3.9.4
2825
# PyTorch Theme
2926
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
3027
myst-parser==0.18.1

.ci/scripts/build-qnn-sdk.sh

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/bin/bash
22
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
# All rights reserved.
45
#
56
# This source code is licensed under the BSD-style license found in the
@@ -11,10 +12,16 @@ set -o xtrace
1112
build_qnn_backend() {
1213
echo "Start building qnn backend."
1314
export ANDROID_NDK_ROOT=/opt/ndk
14-
export QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728
15+
export QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029
1516
export EXECUTORCH_ROOT="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../.." && pwd)"
1617

17-
bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number 2 --release
18+
# Workaround to avoid issues around missing flatccrt library (depending on the
19+
# number of jobs used), see issue #7300:
20+
# Build twice (second time with `--no_clean`) to make sure libflatccrt.a is
21+
# available.
22+
# TODO: Remove this workaround once the underlying issue is fixed.
23+
bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number 2 --release || \
24+
bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number 2 --release --no_clean
1825
}
1926

2027
set_up_aot() {

.ci/scripts/setup-qnn-deps.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ install_qnn() {
1616
QNN_INSTALLATION_DIR=/tmp/qnn
1717
mkdir -p "${QNN_INSTALLATION_DIR}"
1818

19-
curl -Lo /tmp/v2.25.0.24.07.28.zip "https://softwarecenter.qualcomm.com/api/download/software/qualcomm_neural_processing_sdk/v2.25.0.240728.zip"
19+
curl -Lo /tmp/v2.28.0.24.10.29.zip "https://softwarecenter.qualcomm.com/api/download/software/qualcomm_neural_processing_sdk/v2.28.0.241029.zip"
2020
echo "Finishing downloading qnn sdk."
21-
unzip -qo /tmp/v2.25.0.24.07.28.zip -d /tmp
21+
unzip -qo /tmp/v2.28.0.24.10.29.zip -d /tmp
2222
echo "Finishing unzip qnn sdk."
2323

2424

.ci/scripts/test_llama.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ echo "COREML option ${COREML}"
121121
if [[ "${MODE}" =~ .*qnn.* ]]; then
122122
QNN=ON
123123
export EXECUTORCH_ROOT="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/.." && pwd)"
124-
export QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728
124+
export QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029
125125
export LD_LIBRARY_PATH="${QNN_SDK_ROOT}/lib/x86_64-linux-clang"
126126
export PYTHONPATH=".."
127127
cp schema/program.fbs exir/_serialize/program.fbs

.github/pytorch-probot.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# The schema is from https://github.com/pytorch/pytorch/blob/main/.github/pytorch-probot.yml
2+
tracking_issue: 7679
23
ciflow_push_tags:
34
- ciflow/android
45
- ciflow/apple

.github/workflows/android-perf.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ jobs:
260260
--output_name="${OUT_ET_MODEL_NAME}.pte"
261261
ls -lh "${OUT_ET_MODEL_NAME}.pte"
262262
elif [[ ${{ matrix.config }} == "llama3_qnn_htp" ]]; then
263-
export QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728
263+
export QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029
264264
export LD_LIBRARY_PATH=$QNN_SDK_ROOT/lib/x86_64-linux-clang/
265265
export PYTHONPATH=$(pwd)/..
266266
@@ -347,7 +347,7 @@ jobs:
347347
PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh
348348
349349
export ANDROID_ABIS="arm64-v8a"
350-
PYTHON_EXECUTABLE=python EXECUTORCH_BUILD_QNN=ON QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728 bash build/build_android_llm_demo.sh ${ARTIFACTS_DIR_NAME}
350+
PYTHON_EXECUTABLE=python EXECUTORCH_BUILD_QNN=ON QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029 bash build/build_android_llm_demo.sh ${ARTIFACTS_DIR_NAME}
351351
352352
# Let's see how expensive this job is, we might want to tone it down by running it periodically
353353
benchmark-on-device:

backends/apple/coreml/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,14 @@ class Model(torch.nn.Module):
9393
source_model = Model()
9494
example_inputs = (torch.randn((1, 3, 256, 256)), )
9595

96-
pre_autograd_aten_dialect = export_for_training(model, example_inputs).module()
96+
pre_autograd_aten_dialect = export_for_training(source_model, example_inputs).module()
9797

9898
quantization_config = LinearQuantizerConfig.from_dict(
9999
{
100100
"global_config": {
101101
"quantization_scheme": QuantizationScheme.symmetric,
102-
"activation_dtype": torch.uint8,
103-
"weight_dtype": torch.int8,
102+
"activation_dtype": torch.quint8,
103+
"weight_dtype": torch.qint8,
104104
"weight_per_channel": True,
105105
}
106106
}

backends/apple/coreml/scripts/install_requirements.sh

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,7 @@ cmake --build "$COREMLTOOLS_DIR_PATH/build" --parallel
4747

4848
echo "${green}ExecuTorch: Installing coremltools."
4949
pip install "$COREMLTOOLS_DIR_PATH"
50-
# CoreMLTools have started supporting numpy 2.0,
51-
# but ExecuTorch example model test env is still using older transformers,
52-
# so for now we will need to downgrade numpy to 1.x
53-
# TODO: Remove this numpy downgrade once later transformers starts to be used
54-
pip install numpy==1.26.4
50+
5551
STATUS=$?
5652
if [ $STATUS -ne 0 ]; then
5753
echo "${red}ExecuTorch: Failed to install coremltools."

backends/arm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ backends/arm/test/setup_testing.sh
119119
The you can run the tests with
120120

121121
```
122-
pytest -c /dev/null -v -n auto backends/arm/test --arm_quantize_io --arm_run_corstoneFVP
122+
pytest -c /dev/null -v -n auto backends/arm/test --arm_run_corstoneFVP
123123
```
124124

125125
### Code coverage

backends/arm/_passes/arm_pass_manager.py

Lines changed: 66 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
# pyre-unsafe
99

10-
import torch
1110
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
1211
AnnotateChannelsLastDimOrder,
1312
)
@@ -28,6 +27,7 @@
2827
)
2928
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
3029
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
30+
from executorch.backends.arm._passes.decompose_select import DecomposeSelectPass
3131
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
3232
DecomposeSoftmaxesPass,
3333
)
@@ -46,7 +46,7 @@
4646
)
4747
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
4848
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
49-
ConvertMeanDimToAveragePool,
49+
ConvertMeanDimToAveragePoolPass,
5050
)
5151
from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass
5252
from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
@@ -60,92 +60,98 @@
6060
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
6161
UnsqueezeScalarPlaceholdersPass,
6262
)
63+
from executorch.backends.arm.tosa_specification import TosaSpecification
6364
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
6465
from executorch.exir import ExportedProgram
65-
from executorch.exir.backend.compile_spec_schema import CompileSpec
66-
from executorch.exir.dialects._ops import ops as exir_ops
6766
from executorch.exir.pass_manager import PassManager
67+
from torch.fx import GraphModule
6868

6969

7070
class ArmPassManager(PassManager):
7171

72-
def _transform(self, graph_module: torch.fx.GraphModule):
72+
def __init__(self, tosa_spec: TosaSpecification) -> None:
73+
self.tosa_spec = tosa_spec
74+
super().__init__()
75+
76+
def _transform(self, graph_module: GraphModule):
7377
return self(graph_module).graph_module
7478

75-
def transform_to_backend_pipeline(
76-
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
77-
):
78-
"""Apply passes before transforming program to backend"""
79+
def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
7980
self.add_pass(FuseQuantizedActivationPass())
81+
self.add_pass(RemoveGetItemPass())
82+
self.add_pass(ConvertSplitToSlicePass())
83+
self.add_pass(ConvertMmToBmmPass())
8084
self.add_pass(DecomposeLinearPass())
85+
self.add_pass(ConvertMeanDimToAveragePoolPass())
86+
87+
self.add_pass(AnnotateDecomposedMatmulPass())
88+
self.add_pass(QuantizeFullArgument())
89+
self.add_pass(FoldAndAnnotateQParamsPass())
90+
self.add_pass(RetraceFoldedDtypesPass())
91+
self.add_pass(InsertTableOpsPass(exported_program))
92+
93+
self.add_pass(RemoveClonePass())
94+
self.add_pass(SizeAdjustConv2DPass())
95+
self.add_pass(ConvertExpandCopyToRepeatPass())
96+
self.add_pass(UnsqueezeBeforeRepeatPass())
97+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
98+
self.add_pass(CastInt64ToInt32Pass(exported_program))
99+
self.add_pass(MatchArgRanksPass(exported_program))
100+
self.add_pass(KeepDimsFalseToSqueezePass())
101+
self.add_pass(Conv1dUnsqueezePass(exported_program))
102+
self.add_pass(DecomposeSelectPass())
103+
104+
self.add_pass(AnnotateChannelsLastDimOrder())
105+
106+
return self._transform(exported_program.graph_module)
107+
108+
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
109+
110+
self.add_pass(FuseQuantizedActivationPass())
81111
self.add_pass(RemoveGetItemPass())
112+
self.add_pass(ConvertSplitToSlicePass())
113+
self.add_pass(ConvertMmToBmmPass())
114+
self.add_pass(DecomposeLinearPass())
82115
self.add_pass(DecomposeLayerNormPass())
83116
self.add_pass(DecomposeVarPass())
84-
self.add_pass(ConvertMeanDimToAveragePool())
85117
self.add_pass(DecomposeMeanDimPass())
86-
self.add_pass(ConvertSplitToSlicePass())
87-
self.add_pass(ConvertMmToBmmPass())
88-
# TODO MLETORCH-558
118+
self.add_pass(ConvertMeanDimToAveragePoolPass())
119+
self.add_pass(DecomposeDivPass())
120+
self.add_pass(DecomposeSoftmaxesPass())
121+
89122
self.add_pass(AnnotateDecomposedMatmulPass())
90123
self.add_pass(QuantizeFullArgument())
91-
self.add_pass(
92-
FoldAndAnnotateQParamsPass(
93-
[
94-
exir_ops.edge.aten.minimum.default,
95-
exir_ops.edge.aten.maximum.default,
96-
exir_ops.edge.aten.add.Tensor,
97-
exir_ops.edge.aten.avg_pool2d.default,
98-
exir_ops.edge.aten.bmm.default,
99-
exir_ops.edge.aten.cat.default,
100-
exir_ops.edge.aten.convolution.default,
101-
exir_ops.edge.aten.clone.default,
102-
exir_ops.edge.aten.exp.default,
103-
exir_ops.edge.aten.expand_copy.default,
104-
exir_ops.edge.aten.full.default,
105-
exir_ops.edge.aten.hardtanh.default,
106-
exir_ops.edge.aten.log.default,
107-
exir_ops.edge.aten.max_pool2d.default,
108-
exir_ops.edge.aten.mul.Tensor,
109-
exir_ops.edge.aten.permute_copy.default,
110-
exir_ops.edge.aten.reciprocal.default,
111-
exir_ops.edge.aten.relu.default,
112-
exir_ops.edge.aten.repeat.default,
113-
exir_ops.edge.aten.rsqrt.default,
114-
exir_ops.edge.aten.select_copy.int,
115-
exir_ops.edge.aten.sigmoid.default,
116-
exir_ops.edge.aten.slice_copy.Tensor,
117-
exir_ops.edge.aten.squeeze_copy.dims,
118-
exir_ops.edge.aten.sub.Tensor,
119-
exir_ops.edge.aten.sum.dim_IntList,
120-
exir_ops.edge.aten.tanh.default,
121-
exir_ops.edge.aten.unsqueeze_copy.default,
122-
exir_ops.edge.aten.upsample_nearest2d.vec,
123-
exir_ops.edge.aten.view_copy.default,
124-
]
125-
)
126-
)
124+
self.add_pass(FoldAndAnnotateQParamsPass())
127125
self.add_pass(RetraceFoldedDtypesPass())
128126
self.add_pass(InsertTableOpsPass(exported_program))
127+
128+
self.add_pass(RemoveClonePass())
129+
self.add_pass(SizeAdjustConv2DPass())
129130
self.add_pass(ConvertExpandCopyToRepeatPass())
130131
self.add_pass(UnsqueezeBeforeRepeatPass())
131-
self.add_pass(CastInt64ToInt32Pass(exported_program))
132132
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
133-
self.add_pass(SizeAdjustConv2DPass())
134-
self.add_pass(RemoveClonePass())
133+
self.add_pass(CastInt64ToInt32Pass(exported_program))
135134
self.add_pass(MatchArgRanksPass(exported_program))
136-
self.add_pass(DecomposeDivPass())
137135
self.add_pass(KeepDimsFalseToSqueezePass())
138136
self.add_pass(Conv1dUnsqueezePass(exported_program))
139-
self.add_pass(DecomposeSoftmaxesPass())
140-
for spec in compile_spec:
141-
if spec.key == "permute_memory_format":
142-
memory_format = spec.value.decode()
143-
if memory_format == "nhwc":
144-
self.add_pass(AnnotateChannelsLastDimOrder())
137+
self.add_pass(DecomposeSelectPass())
138+
139+
self.add_pass(AnnotateChannelsLastDimOrder())
145140

146141
return self._transform(exported_program.graph_module)
147142

148-
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
143+
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
144+
"""Apply passes before transforming program to backend"""
145+
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+BI"):
146+
return self._tosa_080_BI_pipeline(exported_program)
147+
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+MI"):
148+
return self._tosa_080_MI_pipeline(exported_program)
149+
else:
150+
raise NotImplementedError(
151+
f"No pass pipeline implemented for {self.tosa_spec=}"
152+
)
153+
154+
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
149155
self.add_pass(ScalarsToAttributePass())
150156
self.add_pass(DecomposeLayerNormPass())
151157
self.add_pass(DecomposeVarPass())

0 commit comments

Comments
 (0)