Skip to content

Commit 59663be

Browse files
committed
Update base for Update on "[Executorch][llm] Add ring buffer based kv cache and mask calculation to MHA"
Leveraging previous work now we allow MHA to have ring buffer cache. If ring buffer cache is used then we query the mask from kv cache and use that for sdpa instead of using precalculated mask. In this process we had to adjsut ring buffer implementation to allow keeping the context of full sliding window. See code for comment. Differential Revision: [D73891425](https://our.internmc.facebook.com/intern/diff/D73891425/) [ghstack-poisoned]
2 parents 51ae252 + 1ae8c2c commit 59663be

File tree

171 files changed

+2243
-675
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

171 files changed

+2243
-675
lines changed

.github/workflows/_link_check.yml

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,51 @@ on:
77

88
jobs:
99
lint-urls:
10+
if: ${{ github.event_name != 'pull_request' || !contains(github.event.pull_request.labels.*.name, 'skip-url-lint') }}
1011
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
1112
with:
1213
runner: linux.2xlarge
1314
docker-image: executorch-ubuntu-22.04-linter
14-
submodules: 'none'
15+
submodules: false
1516
fetch-depth: 0
1617
ref: ${{ inputs.ref }}
17-
timeout: 90
18+
timeout: 120
1819
script: |
1920
./scripts/lint_urls.sh $(
20-
[ "${{ github.event_name }}" = "pull_request" ] \
21-
&& git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} \
22-
|| [ "${{ github.event_name }}" = "push" ] \
23-
&& git diff --name-only ${{ github.event.before }} ${{ github.sha }}
24-
)
21+
{ [ "${{ github.event_name }}" = "pull_request" ] \
22+
&& git diff --name-only "${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }}"; } \
23+
|| \
24+
{ [ "${{ github.event_name }}" = "push" ] \
25+
&& git diff --name-only "${{ github.event.before }}...${{ github.sha }}"; }
26+
) || {
27+
echo
28+
echo "URL lint failed."
29+
echo "If this is a transient outage, you can bypass it by adding the \`skip-url-lint\` label to your PR."
30+
echo "Or add \`@lint-ignore\` somewhere on the same line as the URL you want to skip checking."
31+
exit 1
32+
}
2533
2634
lint-xrefs:
35+
if: ${{ github.event_name != 'pull_request' || !contains(github.event.pull_request.labels.*.name, 'skip-xref-lint') }}
2736
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
2837
with:
2938
runner: linux.2xlarge
3039
docker-image: executorch-ubuntu-22.04-linter
31-
submodules: 'none'
40+
submodules: false
3241
fetch-depth: 0
3342
ref: ${{ inputs.ref }}
34-
timeout: 90
43+
timeout: 60
3544
script: |
3645
./scripts/lint_xrefs.sh $(
37-
[ "${{ github.event_name }}" = "pull_request" ] \
38-
&& git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} \
39-
|| [ "${{ github.event_name }}" = "push" ] \
40-
&& git diff --name-only ${{ github.event.before }} ${{ github.sha }}
41-
)
46+
{ [ "${{ github.event_name }}" = "pull_request" ] \
47+
&& git diff --name-only "${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }}"; } \
48+
|| \
49+
{ [ "${{ github.event_name }}" = "push" ] \
50+
&& git diff --name-only "${{ github.event.before }}...${{ github.sha }}"; }
51+
) || {
52+
echo
53+
echo "Xref lint failed."
54+
echo "If this is a transient outage, you can bypass it by adding the \`skip-xref-lint\` label to your PR."
55+
echo "Or add \`@lint-ignore\` somewhere on the same line as the reference you want to skip checking."
56+
exit 1
57+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
name: Build Presets
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches:
7+
- main
8+
- release/*
9+
workflow_dispatch:
10+
11+
concurrency:
12+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
13+
cancel-in-progress: true

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ endif()
608608
# any backends.
609609
#
610610
add_library(executorch ${_executorch__srcs})
611-
target_link_libraries(executorch PUBLIC executorch_core)
611+
target_link_libraries(executorch PRIVATE executorch_core)
612612
target_include_directories(executorch PUBLIC ${_common_include_directories})
613613
target_compile_definitions(executorch PUBLIC C10_USING_CUSTOM_GENERATED_MACROS)
614614
target_compile_options(executorch PUBLIC ${_common_compile_options})

backends/arm/_passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ python_library(
1111
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
1212
"//executorch/exir:lib",
1313
"//executorch/backends/transforms:utils",
14+
"//executorch/backends/transforms:decompose_sdpa",
1415
],
1516
)

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,5 @@
5757
from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa
5858
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
5959
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
60+
from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip
6061
from .arm_pass_manager import ArmPassManager # noqa # usort: skip

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
7070
if quantized_input:
7171
matmul_args = matmul_node.all_input_nodes
7272
for node in matmul_args:
73+
# Find the dq-node connected to this mm/bmm arg
7374
input_node = self._match_partition_to_node(
7475
node, partition.input_nodes
7576
)
76-
77-
# Remove partition input dq-node
78-
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
79-
graph_module.graph.erase_node(input_node)
8077
input_node_qargs = QuantArgs.from_operator(
8178
input_node.target, input_node.args
8279
)
83-
80+
# Insert new dq-node just before the mm/bmm with input_node's qparams
8481
with graph_module.graph.inserting_before(matmul_node):
8582
# Create new dq-node before matmul
8683
dq_node = create_node(
@@ -90,6 +87,13 @@ def call(self, graph_module: GraphModule) -> PassResult:
9087
dq_node.args = (node, *input_node_qargs)
9188
matmul_node.replace_input_with(node, dq_node)
9289

90+
for partition_input in partition.input_nodes:
91+
# Remove partition input dq-node
92+
partition_input.replace_all_uses_with(
93+
partition_input.all_input_nodes[0]
94+
)
95+
graph_module.graph.erase_node(partition_input)
96+
9397
partition_output = list(partition.output_nodes[0].users)[0]
9498
quantized_output = partition_output.target == q_op
9599
if quantized_output:

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
MatchWhereSelfDtypePass,
5050
QuantizeOperatorArguments,
5151
RemoveClonePass,
52+
ReplaceInfValues,
5253
ReplaceScalarWithTensorArgPassTOSABI,
5354
ReplaceScalarWithTensorArgPassTOSAMI,
5455
RetraceFoldedDtypesPass,
@@ -216,4 +217,5 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
216217
self.add_pass(DecomposeSoftmaxPass())
217218

218219
self.add_pass(ConvertMinMaxPass())
220+
self.add_pass(ReplaceInfValues())
219221
return self._transform(graph_module)

backends/arm/_passes/convert_split_to_slice.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
2-
# All rights reserved.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

76
# pyre-unsafe
87

98
import torch.fx
10-
from executorch.backends.arm._passes.arm_pass_utils import create_node
11-
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
9+
from executorch.backends.arm._passes.arm_pass_utils import (
10+
create_node,
11+
get_first_fake_tensor,
12+
)
1213
from executorch.exir.dialects._ops import ops as exir_ops
1314
from executorch.exir.pass_base import ExportPass, PassResult
1415

@@ -34,7 +35,7 @@ def call(self, graph_module: torch.fx.GraphModule):
3435
split_node = node
3536
input_node = split_node.all_input_nodes[0]
3637
output_nodes = split_node.users.copy()
37-
_, shape, _ = extract_tensor_meta(input_node.meta)
38+
shape = get_first_fake_tensor(input_node).shape
3839
rank = len(shape)
3940
split_lengths = split_node.args[1]
4041
dim = split_node.args[2] if len(split_node.args) > 2 else 0
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# Copyright 2025 Arm Limited and/or its affiliates.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This pass is based on backends/qualcomm/_passes/replace_inf_values.py
8+
# with some modification to replaced inf values.
9+
10+
import torch
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
14+
class ReplaceInfValues(ExportPass):
15+
"""
16+
Due to limitation in Quantizer, we need to change inf/-inf to more quantizable values.
17+
"""
18+
19+
def __init__(self):
20+
super(ReplaceInfValues, self).__init__()
21+
22+
def call(self, graph_module: torch.fx.GraphModule):
23+
modified = False
24+
for buf_name, tensor in graph_module.named_buffers():
25+
if tensor.is_floating_point():
26+
modified = True
27+
# 255 here is mainly for attention_mask in Llama for reasonable quant scale
28+
tensor[tensor == float("inf")] = 255
29+
tensor[tensor == float("-inf")] = -255
30+
setattr(graph_module, buf_name, tensor)
31+
32+
for node in graph_module.graph.nodes:
33+
arg_list = list(node.args)
34+
for index, arg in enumerate(arg_list):
35+
if arg == float("-inf"):
36+
modified = True
37+
arg_list[index] = -255
38+
elif arg == float("inf"):
39+
modified = True
40+
arg_list[index] = +255
41+
node.args = tuple(arg_list)
42+
43+
if modified:
44+
graph_module.recompile()
45+
return PassResult(graph_module, modified)

backends/arm/operator_support/slice_copy_support.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
SupportedTOSAOperatorCheck,
1313
)
1414
from executorch.backends.arm.tosa_specification import TosaSpecification
15-
from executorch.backends.arm.tosa_utils import getNodeArgs
1615
from executorch.exir.dialects._ops import ops as exir_ops
1716

1817
logger = logging.getLogger(__name__)
@@ -33,8 +32,8 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) ->
3332
if tosa_spec not in self.tosa_specs:
3433
return False
3534

36-
inputs = getNodeArgs(node)
37-
if len(inputs) == 5 and (step := inputs[4].number) != 1:
35+
args = node.args
36+
if len(args) == 5 and (step := args[4]) != 1:
3837
logging.warning(f"{node.target} with step size of {step} not supported.")
3938
return False
4039
return True

0 commit comments

Comments
 (0)