Skip to content

Commit 64ab3f5

Browse files
committed
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code sharing"
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
2 parents 6f79856 + 0937059 commit 64ab3f5

File tree

109 files changed

+1609
-3169
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+1609
-3169
lines changed

.ci/scripts/utils.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@ retry () {
1717
}
1818

1919
clean_executorch_install_folders() {
20-
./install_requirements.sh --clean
20+
./install_executorch.sh --clean
2121
}
2222

2323
install_executorch() {
2424
which pip
2525
# Install executorch, this assumes that Executorch is checked out in the
2626
# current directory.
2727
if [[ "${1:-}" == "use-pt-pinned-commit" ]]; then
28-
./install_requirements.sh --pybind xnnpack --use-pt-pinned-commit
28+
./install_executorch.sh --pybind xnnpack --use-pt-pinned-commit
2929
else
30-
./install_requirements.sh --pybind xnnpack
30+
./install_executorch.sh --pybind xnnpack
3131
fi
3232
# Just print out the list of packages for debugging
3333
pip list

.github/workflows/apple.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ on:
99
paths:
1010
- .ci/scripts/setup-ios.sh
1111
- .github/workflows/apple.yml
12-
- install_requirements.sh
12+
- install_executorch.sh
1313
- backends/apple/**
1414
- build/build_apple_frameworks.sh
1515
- build/build_apple_llm_demo.sh

.github/workflows/pull.yml

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ jobs:
200200
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"
201201
202202
# install pybind
203-
bash install_requirements.sh --pybind xnnpack
203+
bash install_executorch.sh --pybind xnnpack
204204
205205
# install Llava requirements
206206
bash examples/models/llama/install_requirements.sh
@@ -333,6 +333,9 @@ jobs:
333333

334334
unittest-arm:
335335
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
336+
permissions:
337+
id-token: write
338+
contents: read
336339
with:
337340
runner: linux.2xlarge
338341
docker-image: executorch-ubuntu-22.04-arm-sdk
@@ -395,6 +398,25 @@ jobs:
395398
# Test llama2
396399
PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh -model stories110M -build_tool "${BUILD_TOOL}" -mode "${MODE}" -dtype "${DTYPE}" -pt2e_quantize "${PT2E_QUANTIZE}"
397400
401+
test-qnn-models-linux:
402+
name: test-qnn-models-linux
403+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
404+
strategy:
405+
fail-fast: false
406+
with:
407+
runner: linux.2xlarge
408+
docker-image: executorch-ubuntu-22.04-qnn-sdk
409+
submodules: 'true'
410+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
411+
timeout: 180
412+
script: |
413+
# The generic Linux job chooses to use base env, not the one setup by the image
414+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
415+
conda activate "${CONDA_ENV}"
416+
417+
# placeholder for running test_qnn_delegate.py, can use matrix such that we can trigger different jobs, refers to test-llama-runner-qnn-linux
418+
# reminder: make sure each job runs fast
419+
398420
test-phi-3-mini-runner-linux:
399421
name: test-phi-3-mini-runner-linux
400422
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
@@ -414,7 +436,7 @@ jobs:
414436
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"
415437
416438
# install pybind
417-
bash install_requirements.sh --pybind xnnpack
439+
bash install_executorch.sh --pybind xnnpack
418440
419441
# install phi-3-mini requirements
420442
bash examples/models/phi-3-mini/install_requirements.sh
@@ -441,7 +463,7 @@ jobs:
441463
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"
442464
443465
# install pybind
444-
bash install_requirements.sh --pybind xnnpack
466+
bash install_executorch.sh --pybind xnnpack
445467
446468
# install llama requirements
447469
bash examples/models/llama/install_requirements.sh
@@ -468,7 +490,7 @@ jobs:
468490
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"
469491
470492
# install pybind
471-
bash install_requirements.sh --pybind xnnpack
493+
bash install_executorch.sh --pybind xnnpack
472494
473495
# install llama requirements
474496
bash examples/models/llama/install_requirements.sh
@@ -495,7 +517,7 @@ jobs:
495517
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"
496518
497519
# install pybind
498-
bash install_requirements.sh --pybind xnnpack
520+
bash install_executorch.sh --pybind xnnpack
499521
500522
# install llama requirements
501523
bash examples/models/llama/install_requirements.sh

.github/workflows/trunk.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ jobs:
132132
test-arm-backend-delegation:
133133
name: test-arm-backend-delegation
134134
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
135+
permissions:
136+
id-token: write
137+
contents: read
135138
with:
136139
runner: linux.2xlarge
137140
docker-image: executorch-ubuntu-22.04-arm-sdk
@@ -159,6 +162,9 @@ jobs:
159162
test-arm-reference-delegation:
160163
name: test-arm-reference-delegation
161164
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
165+
permissions:
166+
id-token: write
167+
contents: read
162168
with:
163169
runner: linux.2xlarge
164170
docker-image: executorch-ubuntu-22.04-arm-sdk

backends/apple/mps/setup.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ I 00:00:00.122615 executorch:mps_executor_runner.mm:501] Model verified successf
9797
### [Optional] Run the generated model directly using pybind
9898
1. Make sure `pybind` MPS support was installed:
9999
```bash
100-
./install_requirements.sh --pybind mps
100+
./install_executorch.sh --pybind mps
101101
```
102102
2. Run the `mps_example` script to trace the model and run it directly from python:
103103
```bash

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -15,7 +15,7 @@
1515
get_node_arg,
1616
insert_q_dq_pair,
1717
)
18-
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op
18+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
1919
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
2020
from executorch.exir.dialects._ops import ops as exir_ops
2121
from executorch.exir.pass_base import ExportPass, PassResult
@@ -43,9 +43,6 @@ def _transpose_impl(*args, **kwargs):
4343
return args[0]
4444

4545

46-
register_passable_op(torch.ops.passthrough_to_tosa._transpose)
47-
48-
4946
class AnnotateChannelsLastDimOrder(ExportPass):
5047
"""
5148
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order

backends/arm/_passes/fuse_quantized_activation_pass.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ def _is_fuseable_quantized_activation(self, node: Node):
1919
is_fuseable = min_val == 0
2020

2121
is_quantized = len(node.users) == 1 and next(iter(node.users)).target == q_op
22-
if is_quantized:
22+
if is_fuseable and is_quantized:
2323
quant_node = next(iter(node.users))
2424
zp = quant_node.args[2]
2525
qmin = quant_node.args[3]
26-
27-
return is_fuseable and is_quantized and zp == qmin
26+
return zp == qmin
27+
else:
28+
return False
2829

2930
def _is_fuseable_input(self, node: Node):
3031
return (

backends/arm/operators/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
op_bmm,
1414
op_cat,
1515
op_conv2d,
16-
op_dequant,
1716
op_exp,
1817
op_full,
1918
op_get_item,
@@ -24,7 +23,6 @@
2423
op_min,
2524
op_mul,
2625
op_permute,
27-
op_quant,
2826
op_reciprocal,
2927
op_relu,
3028
op_repeat,

backends/arm/operators/op_dequant.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

backends/arm/operators/op_hardtanh.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -19,7 +19,6 @@
1919
)
2020
from executorch.backends.arm.tosa_mapping import TosaArg
2121

22-
from executorch.backends.arm.tosa_quant_utils import quantize_value
2322
from serializer.tosa_serializer import TosaOp
2423

2524

@@ -44,8 +43,8 @@ def define_node(
4443
input_qparams = get_input_qparams(node) # pyre-ignore[16]
4544
qargs = input_qparams[0]
4645
# Convert to quantized representation
47-
clamp_min_qs = quantize_value(inputs[1].number, qargs)
48-
clamp_max_qs = quantize_value(inputs[2].number, qargs)
46+
clamp_min_qs = qargs.quantize_value(inputs[1].number).item()
47+
clamp_max_qs = qargs.quantize_value(inputs[2].number).item()
4948
# Set fp values to 0.0 since they are not used
5049
clamp_min_fp = 0.0
5150
clamp_max_fp = 0.0

0 commit comments

Comments
 (0)