Skip to content

Commit 6e8cff5

Browse files
committed
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code sharing"
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
2 parents 84ef14b + f6a87ee commit 6e8cff5

File tree

6 files changed

+9
-7
lines changed

6 files changed

+9
-7
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,6 @@ def forward(
309309
seqlen,
310310
mask: torch.Tensor,
311311
) -> torch.Tensor:
312-
# TODO(kimishpatel): Move this slicing logic to Attention block so that
313-
# SDPA does not have to take input_pos as arg
314312
if self.enable_dynamic_shape:
315313
start_pos = input_pos[-1].item()
316314
torch._check_is_size(start_pos)

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
)
5050

5151
# For now supporting int8 only
52-
self.use_custom_update_cache_op = True
52+
self.use_custom_update_cache_op = use_custom_update_cache_op
5353
self.quantized_cache_dtype = torch.int8
5454
self.cache_fp_type = torch.float32
5555
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)

examples/models/llama/tests/test_simple_sdpa.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import copy
87
import unittest
98

109
import torch

extension/llm/export/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ runtime.python_library(
1212
name = "export_lib",
1313
srcs = [
1414
"builder.py",
15+
"export_passes.py",
1516
"partitioner_lib.py",
1617
"quantizer_lib.py",
1718
],

extension/llm/export/builder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
3636

37-
from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes
37+
from executorch.extension.llm.export.export_passes import RemoveRedundantPermutes
3838
from executorch.extension.llm.tokenizer.utils import get_tokenizer
3939
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
4040
from torch.ao.quantization.quantizer import Quantizer
@@ -113,7 +113,7 @@ def __init__(
113113
self.calibration_seq_length = calibration_seq_length
114114
self.calibration_data = calibration_data
115115
self.tokenizer_path = tokenizer_path
116-
self.canonical_passes = [RemoveRedundantTransposes()]
116+
self.canonical_passes = [RemoveRedundantPermutes()]
117117

118118
def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
119119
"""
@@ -227,6 +227,10 @@ def export(self) -> "LLMEdgeManager":
227227
return self
228228

229229
def run_canonical_optimizations(self):
230+
"""
231+
Run canonical optimizations (at the moment removing redundant permutes) on the model.
232+
"""
233+
assert self.pre_autograd_graph_module is not None, "Please run export() first"
230234
for pass_instance in self.canonical_passes:
231235
logging.info(f"Running canonical pass: {pass_instance.__class__.__name__}")
232236
res = pass_instance(self.pre_autograd_graph_module)

extension/llm/export/export_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def _normalize_dims(tensor: FakeTensor, dim_0: int, dim_1: int):
1919
return dim_0, dim_1
2020

2121

22-
class RemoveRedundantTransposes(ExportPass):
22+
class RemoveRedundantPermutes(ExportPass):
2323
"""
2424
This pass removes redundant transpose nodes in the graph.
2525
It checks if the next node is also a transpose node and if the two transpose nodes undo each other.

0 commit comments

Comments
 (0)