Skip to content

Commit 7b41972

Browse files
committed
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code sharing"
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
2 parents a10e400 + 1f8c183 commit 7b41972

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,14 +249,14 @@ def update(
249249
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
250250
) -> Tuple[torch.Tensor, torch.Tensor]:
251251
# input_pos: [S], k_val: [B, H, S, D]
252-
k_val = k_val.transpose(1, 2).contiguous()
253-
v_val = v_val.transpose(1, 2).contiguous()
252+
k_val = k_val.transpose(1, 2)
253+
v_val = v_val.transpose(1, 2)
254254
start_pos = input_pos[0].item()
255255
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
256256
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
257257
return (
258-
self.k_cache.transpose(1, 2).contiguous(),
259-
self.v_cache.transpose(1, 2).contiguous(),
258+
self.k_cache.transpose(1, 2),
259+
self.v_cache.transpose(1, 2),
260260
)
261261

262262

0 commit comments

Comments
 (0)