Skip to content

Commit 00af94d

Browse files
committed
Update base for Update on "[Executorch][llm] Add ring buffer based kv cache and mask calculation to MHA"
Leveraging previous work now we allow MHA to have ring buffer cache. If ring buffer cache is used then we query the mask from kv cache and use that for sdpa instead of using precalculated mask. In this process we had to adjsut ring buffer implementation to allow keeping the context of full sliding window. See code for comment. Differential Revision: [D73891425](https://our.internmc.facebook.com/intern/diff/D73891425/) [ghstack-poisoned]
2 parents 59663be + 7e1f3e3 commit 00af94d

File tree

195 files changed

+9325
-8028
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

195 files changed

+9325
-8028
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2024-12-16
1+
2025-05-06

.github/workflows/apple.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ on:
55
branches:
66
- main
77
- release/*
8+
tags:
9+
- ciflow/trunk/*
810
pull_request:
911
paths:
1012
- .ci/scripts/setup-ios.sh

.github/workflows/build-presets.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,20 @@ on:
1111
concurrency:
1212
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
1313
cancel-in-progress: true
14+
15+
jobs:
16+
apple:
17+
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
18+
strategy:
19+
matrix:
20+
preset: [macos-arm64]
21+
with:
22+
job-name: build
23+
runner: macos-latest-xlarge
24+
python-version: 3.12
25+
submodules: recursive
26+
script: |
27+
set -eux
28+
${CONDA_RUN} ./install_requirements.sh > /dev/null
29+
${CONDA_RUN} cmake --preset ${{ matrix.preset }}
30+
${CONDA_RUN} cmake --build cmake-out --parallel

.github/workflows/pull.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,7 @@ jobs:
434434
output=$(ls -la cmake-out/test/size_test)
435435
arr=($output)
436436
size=${arr[4]}
437-
# threshold=48120 on devserver with gcc11.4
438-
# todo(lfq): update once binary size is below 50kb.
439-
threshold="47552"
437+
threshold="47560"
440438
if [[ "$size" -le "$threshold" ]]; then
441439
echo "Success $size <= $threshold"
442440
else

CMakeLists.txt

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,19 @@
4444

4545
cmake_minimum_required(VERSION 3.24)
4646
project(executorch)
47+
48+
# MARK: - Start EXECUTORCH_H12025_BUILD_MIGRATION --------------------------------------------------
49+
50+
include(${PROJECT_SOURCE_DIR}/tools/cmake/common/preset.cmake)
51+
52+
load_build_preset()
53+
include(${PROJECT_SOURCE_DIR}/tools/cmake/preset/default.cmake)
54+
55+
# Print all the configs that were called with announce_configured_options.
56+
print_configured_options()
57+
58+
# MARK: - End EXECUTORCH_H12025_BUILD_MIGRATION ----------------------------------------------------
59+
4760
include(tools/cmake/Utils.cmake)
4861
include(CMakeDependentOption)
4962

@@ -96,9 +109,6 @@ set(EXECUTORCH_PAL_DEFAULT
96109
"Which PAL default implementation to use: one of {posix, minimal}"
97110
)
98111

99-
option(EXECUTORCH_ENABLE_LOGGING "Build with ET_LOG_ENABLED"
100-
${_default_release_disabled_options}
101-
)
102112
if(NOT EXECUTORCH_ENABLE_LOGGING)
103113
# Avoid pulling in the logging strings, which can be large. Note that this
104114
# will set the compiler flag for all targets in this directory, and for all
@@ -170,8 +180,6 @@ option(EXECUTORCH_BUILD_ARM_BAREMETAL
170180
"Build the Arm Baremetal flow for Cortex-M and Ethos-U" OFF
171181
)
172182

173-
option(EXECUTORCH_BUILD_COREML "Build the Core ML backend" OFF)
174-
175183
option(EXECUTORCH_BUILD_KERNELS_CUSTOM "Build the custom kernels" OFF)
176184

177185
option(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT "Build the custom ops lib for AOT"

CMakePresets.json

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"version": 10,
3+
"cmakeMinimumRequired": {
4+
"major": 3,
5+
"minor": 31,
6+
"patch": 0
7+
},
8+
"$comment": "On-device AI across mobile, embedded and edge for PyTorch.",
9+
"configurePresets": [
10+
{
11+
"name": "common",
12+
"hidden": true,
13+
"binaryDir": "${sourceDir}/cmake-out",
14+
"generator": "Unix Makefiles"
15+
},
16+
{
17+
"name": "macos-arm64",
18+
"inherits": ["common"],
19+
"generator": "Xcode",
20+
"cacheVariables": {
21+
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/third-party/ios-cmake/ios.toolchain.cmake",
22+
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/macos-arm64.cmake",
23+
"PLATFORM": "MAC_ARM64",
24+
"DEPLOYMENT_TARGET": "10.15"
25+
},
26+
"condition": {
27+
"lhs": "${hostSystemName}",
28+
"type": "equals",
29+
"rhs": "Darwin"
30+
}
31+
}
32+
]
33+
}

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2020
from .convert_to_clamp import ConvertToClampPass # noqa
2121
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
22+
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2223
from .decompose_div_pass import DecomposeDivPass # noqa
2324
from .decompose_gelu_pass import DecomposeGeluPass # noqa
2425
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

76
# pyre-unsafe
87

98
import itertools
10-
9+
import operator
1110
from typing import List
1211

1312
import torch
@@ -22,7 +21,7 @@
2221

2322
class AnnotateDecomposedMatmulPass(ExportPass):
2423
"""
25-
torch.matmul can be decomposed in many ways, for instance:
24+
torch.matmul and it's equivalent operator @ can be decomposed in many ways, for instance:
2625
dq -> matmul -> q can become
2726
dq -> repeat -> view -> bmm -> view -> dq which makes quantization folding
2827
difficult. This helper function find all matmul partitions and annotate its
@@ -50,6 +49,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
5049
graph_module.graph,
5150
[
5251
torch.matmul,
52+
operator.matmul,
5353
],
5454
None,
5555
)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ConvertSqueezesToViewPass,
2525
ConvertToClampPass,
2626
DecomposeBatchNormPass,
27+
DecomposeCosineSimilarityPass,
2728
DecomposeDivPass,
2829
DecomposeGeluPass,
2930
DecomposeLayerNormPass,
@@ -205,6 +206,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
205206
self.add_pass(DecomposeVarPass())
206207
self.add_pass(DecomposeMeanDimPass())
207208
self.add_pass(DecomposeNotEqualPass())
209+
self.add_pass(DecomposeCosineSimilarityPass())
208210
self.add_pass(DecomposeDivPass())
209211
self.add_pass(DecomposeLeakyReLUPass())
210212
self.add_pass(DecomposeSqrtPass())
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.pass_base import ExportPass
8+
9+
torch_cosine_similarity = (torch.ops.aten.cosine_similarity.default,)
10+
11+
12+
class DecomposeCosineSimilarityPass(ExportPass):
13+
"""
14+
Decomposition of aten.cosine_similarity:
15+
16+
dot = sum(mul(x1, x2), dims, keepdim=False)
17+
norm = pow( sum(mul(x, x), dims, keepdim=False), 0.5 )
18+
eps = full( (), eps_scalar )
19+
n1c = max(norm1, eps)
20+
n2c = max(norm2, eps)
21+
denom = mul(n1c, n2c)
22+
out = div(dot, denom)
23+
"""
24+
25+
def call_operator(self, op, args, kwargs, meta):
26+
if op not in torch_cosine_similarity:
27+
return super().call_operator(op, args, kwargs, meta)
28+
29+
x1, x2 = args[0], args[1]
30+
dim = kwargs.get("dim", 1)
31+
eps = kwargs.get("eps", 1e-8)
32+
dims = [dim] if isinstance(dim, int) else list(dim)
33+
34+
# 1) dot
35+
prod = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x2), {}, meta)
36+
dot = super().call_operator(
37+
torch.ops.aten.sum.dim_IntList, (prod, dims, False), {}, meta
38+
)
39+
40+
# 2a) norm1 = pow(sum(x1*x1), 0.5)
41+
x1_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x1), {}, meta)
42+
s1 = super().call_operator(
43+
torch.ops.aten.sum.dim_IntList, (x1_sq, dims, False), {}, meta
44+
)
45+
norm1 = super().call_operator(
46+
torch.ops.aten.pow.Tensor_Scalar, (s1, 0.5), {}, meta
47+
)
48+
49+
# 2b) norm2 = pow(sum(x2*x2), 0.5)
50+
x2_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x2, x2), {}, meta)
51+
s2 = super().call_operator(
52+
torch.ops.aten.sum.dim_IntList, (x2_sq, dims, False), {}, meta
53+
)
54+
norm2 = super().call_operator(
55+
torch.ops.aten.pow.Tensor_Scalar, (s2, 0.5), {}, meta
56+
)
57+
58+
# 3) eps scalar - we need to broadcast ourselves as TOSA dont do this for scalar
59+
eps_t = super().call_operator(
60+
torch.ops.aten.full_like.default, (norm1, eps), {}, meta
61+
)
62+
63+
# 4) clamp to avoid zero division
64+
n1c = super().call_operator(
65+
torch.ops.aten.maximum.default, (norm1, eps_t), {}, meta
66+
)
67+
n2c = super().call_operator(
68+
torch.ops.aten.maximum.default, (norm2, eps_t), {}, meta
69+
)
70+
71+
# 5) denom and divide
72+
denom = super().call_operator(torch.ops.aten.mul.Tensor, (n1c, n2c), {}, meta)
73+
out = super().call_operator(torch.ops.aten.div.Tensor, (dot, denom), {}, meta)
74+
75+
return out

0 commit comments

Comments
 (0)