Skip to content

Commit 84ef14b

Browse files
committed
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code sharing"
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
2 parents 3468f0c + f87afa4 commit 84ef14b

File tree

4 files changed

+12
-11
lines changed

4 files changed

+12
-11
lines changed

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,11 @@ def replace_kv_cache_with_quantized_kv_cache(module):
207207
setattr(
208208
module,
209209
name,
210-
QuantizedKVCache.from_float(child, QuantizedCacheType.AffineAsymmetric),
211-
use_custom_update_cache_op=True,
210+
QuantizedKVCache.from_float(
211+
child,
212+
QuantizedCacheType.AffineAsymmetric,
213+
use_custom_update_cache_op=True,
214+
),
212215
)
213216
else:
214217
replace_kv_cache_with_quantized_kv_cache(child)

examples/models/llama/source_transformation/sdpa.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,11 @@
99
# Example script for exporting Llama2 to flatbuffer
1010

1111
import math
12-
from typing import Tuple, Union
12+
from typing import Tuple
1313

1414
import torch
1515

1616
from executorch.examples.models.llama.llama_transformer import KVCache, SDPA
17-
from executorch.examples.models.llama.source_transformation.quantized_kv_cache import (
18-
QuantizedKVCache,
19-
)
2017

2118

2219
class SDPACustom(torch.nn.Module):

examples/models/llama/tests/test_simple_sdpa.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def test_simple_sdpa(self):
3232
enable_dynamic_shape=False,
3333
)
3434
sdpa = SDPA(
35-
kv_cache=copy.deepcopy(kv_cache),
3635
dim=dim,
3736
head_dim=head_dim,
3837
n_rep=n_rep,
@@ -44,6 +43,11 @@ def test_simple_sdpa(self):
4443
key = torch.randn(1, 1, n_local_heads, head_dim)
4544
value = torch.randn(1, 1, n_local_heads, head_dim)
4645
mask = torch.randn(max_seq_length, max_seq_length)
46+
query = query.transpose(1, 2)
47+
key = key.transpose(1, 2)
48+
value = value.transpose(1, 2)
49+
key, value = kv_cache.update(input_pos, key, value)
50+
4751
sdpa_output = sdpa(
4852
input_pos,
4953
query,
@@ -54,9 +58,7 @@ def test_simple_sdpa(self):
5458
mask=mask,
5559
)
5660

57-
simple_sdpa = SDPASimple(
58-
kv_cache=copy.deepcopy(kv_cache), dim=dim, head_dim=head_dim, n_rep=n_rep
59-
)
61+
simple_sdpa = SDPASimple(dim=dim, head_dim=head_dim, n_rep=n_rep)
6062
simple_sdpa_output = simple_sdpa(
6163
input_pos, query, key, value, bsz=bsz, seqlen=seqlen, mask=mask
6264
)

extension/llm/export/test_export_passes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import unittest
32

43
import torch

0 commit comments

Comments
 (0)