Skip to content

Commit ed78ae3

Browse files
committed
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code sharing"
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
2 parents 6e8cff5 + 13c7da9 commit ed78ae3

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,13 +286,15 @@ def update(
286286
class SDPA(nn.Module):
287287
def __init__(
288288
self,
289+
kv_cache: KVCache,
289290
dim: int,
290291
head_dim: int,
291292
n_rep: int,
292293
max_seq_len: int,
293294
enable_dynamic_shape: bool,
294295
):
295296
super().__init__()
297+
self.kv_cache = kv_cache
296298
self.dim = dim
297299
self.head_dim = head_dim
298300
self.n_rep = n_rep
@@ -371,6 +373,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
371373
args.enable_dynamic_shape,
372374
)
373375
self.SDPA = SDPA(
376+
kv_cache=self.kv_cache,
374377
dim=self.n_local_heads * self.head_dim,
375378
head_dim=self.head_dim,
376379
n_rep=self.n_rep,
@@ -403,7 +406,7 @@ def forward(
403406

404407
if self.use_kv_cache:
405408
assert input_pos is not None
406-
k, v = self.kv_cache.update(input_pos, k, v)
409+
k, v = self.SDPA.kv_cache.update(input_pos, k, v)
407410
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
408411
return self.wo(output)
409412

examples/models/llama/source_transformation/sdpa.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
class SDPACustom(torch.nn.Module):
2020
def __init__(
2121
self,
22+
kv_cache: KVCache,
2223
dim: int,
2324
):
2425
super().__init__()
26+
self.kv_cache = kv_cache
2527
self.dim = dim
2628

2729
def forward(
@@ -63,7 +65,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
6365
setattr(
6466
module,
6567
name,
66-
SDPACustom(child.dim),
68+
SDPACustom(child.kv_cache, child.dim),
6769
)
6870
else:
6971
_replace_sdpa_with_custom_op(child)
@@ -79,11 +81,13 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
7981
class SDPASimple(torch.nn.Module):
8082
def __init__(
8183
self,
84+
kv_cache: KVCache,
8285
dim: int,
8386
head_dim: int,
8487
n_rep: int,
8588
):
8689
super().__init__()
90+
self.kv_cache = kv_cache
8791
self.dim = dim
8892
self.head_dim = head_dim
8993
self.n_rep = n_rep
@@ -131,10 +135,12 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
131135
class SDPAFlex(torch.nn.Module):
132136
def __init__(
133137
self,
138+
kv_cache: KVCache,
134139
dim: int,
135140
n_rep: int,
136141
):
137142
super().__init__()
143+
self.kv_cache = kv_cache
138144
self.dim = dim
139145
self.n_rep = n_rep
140146

@@ -171,7 +177,7 @@ def replace_sdpa_with_simple_sdpa(module: torch.nn.Module):
171177
setattr(
172178
module,
173179
name,
174-
SDPASimple(child.dim, child.head_dim, child.n_rep),
180+
SDPASimple(child.kv_cache, child.dim, child.head_dim, child.n_rep),
175181
)
176182
else:
177183
replace_sdpa_with_simple_sdpa(child)
@@ -184,7 +190,7 @@ def replace_sdpa_with_flex_sdpa(module: torch.nn.Module):
184190
setattr(
185191
module,
186192
name,
187-
SDPAFlex(child.dim, child.n_rep),
193+
SDPAFlex(child.kv_cache, child.dim, child.n_rep),
188194
)
189195
else:
190196
replace_sdpa_with_flex_sdpa(child)
@@ -216,11 +222,13 @@ class SDPACoreML(torch.nn.Module):
216222

217223
def __init__(
218224
self,
225+
kv_cache: KVCache,
219226
dim: int,
220227
head_dim: int,
221228
n_rep: int,
222229
):
223230
super().__init__()
231+
self.kv_cache = kv_cache
224232
self.dim = dim
225233
self.head_dim = head_dim
226234
self.n_rep = n_rep
@@ -252,7 +260,7 @@ def replace_sdpa_with_coreml_sdpa(module: torch.nn.Module):
252260
setattr(
253261
module,
254262
name,
255-
SDPACoreML(child.dim, child.head_dim, child.n_rep),
263+
SDPACoreML(child.kv_cache, child.dim, child.head_dim, child.n_rep),
256264
)
257265
else:
258266
replace_sdpa_with_coreml_sdpa(child)

0 commit comments

Comments
 (0)