Skip to content

Commit b6a4eb5

Browse files
committed
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code sharing"
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
2 parents c0bf723 + 928d08a commit b6a4eb5

File tree

4 files changed

+9
-9
lines changed

4 files changed

+9
-9
lines changed

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def update(self, input_pos, k_val, v_val):
9898
However the storage is [B, S, H, D] so we incur transpose in, transpose out
9999
This shall be removed by subsequent post-export graph pass
100100
"""
101-
k_val = k_val.transpose(1, 2).contiguous()
102-
v_val = v_val.transpose(1, 2).contiguous()
101+
k_val = k_val.transpose(1, 2)
102+
v_val = v_val.transpose(1, 2)
103103
# quantize current k_val and store it in the cache
104104
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)
105105

@@ -249,8 +249,8 @@ def update(
249249
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
250250
) -> Tuple[torch.Tensor, torch.Tensor]:
251251
# input_pos: [S], k_val: [B, H, S, D]
252-
k_val = k_val.transpose(1, 2).contiguous()
253-
v_val = v_val.transpose(1, 2).contiguous()
252+
k_val = k_val.transpose(1, 2)
253+
v_val = v_val.transpose(1, 2)
254254
start_pos = input_pos[0].item()
255255
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
256256
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)

examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def _init_cache(self):
4747
)
4848

4949
def _init_kv(self):
50-
kv_shape = (1, self.seq_len, self.n_kv_heads, self.head_dim)
51-
q_shape = (1, self.seq_len, self.n_heads, self.head_dim)
50+
kv_shape = (1, self.n_kv_heads, self.seq_len, self.head_dim)
51+
q_shape = (1, self.n_heads, self.seq_len, self.head_dim)
5252
q = torch.rand(q_shape, dtype=self.dtype)
5353
k = torch.rand(kv_shape, dtype=self.dtype)
5454
v = torch.rand(kv_shape, dtype=self.dtype)

extension/aten_util/make_aten_functor_from_et_functor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ struct type_convert<
106106
torch::executor::Tensor>>>
107107
final {
108108
explicit type_convert(ATensor value)
109-
: value_(value),
109+
: value_(value.contiguous()),
110110
converted_(from_blob(
111111
value_.mutable_data_ptr(),
112112
{value_.sizes().begin(), value_.sizes().end()},
@@ -117,7 +117,7 @@ struct type_convert<
117117
}
118118

119119
private:
120-
ATensor value_;
120+
typename remove_const_ref<ATensor>::type value_;
121121
TensorPtr converted_;
122122
};
123123

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ at::Tensor custom_sdpa_aten(
121121
const bool is_causal,
122122
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
123123
const std::optional<double> scale) {
124-
auto output = at::empty_like(q);
124+
auto output = at::empty(q.sizes());
125125
WRAP_TO_ATEN(custom_sdpa_out_no_context, 8)
126126
(q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output);
127127
return output;

0 commit comments

Comments
 (0)