Skip to content

Commit 049b31b

Browse files
committed
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code sharing"
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
2 parents 305350d + 5eb26e8 commit 049b31b

File tree

2 files changed

+22
-27
lines changed

2 files changed

+22
-27
lines changed

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,12 @@ def update(self, input_pos, k_val, v_val):
118118
v_zero_points, self.v_cache_zero_points, start_pos
119119
)
120120
else:
121-
self.k_cache[:, :, input_pos] = quantized_k_val
122-
self.k_cache_scales[:, :, input_pos] = k_scales
123-
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
124-
self.v_cache[:, :, input_pos] = quantized_v_val
125-
self.v_cache_scales[:, :, input_pos] = v_scales
126-
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
121+
self.k_cache[:, input_pos] = quantized_k_val
122+
self.k_cache_scales[:, input_pos] = k_scales
123+
self.k_cache_zero_points[:, input_pos] = k_zero_points
124+
self.v_cache[:, input_pos] = quantized_v_val
125+
self.v_cache_scales[:, input_pos] = v_scales
126+
self.v_cache_zero_points[:, input_pos] = v_zero_points
127127

128128
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
129129
self.k_cache,
@@ -149,8 +149,8 @@ def update(self, input_pos, k_val, v_val):
149149
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
150150
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
151151
else:
152-
k_out[:, :, input_pos] = k_val
153-
v_out[:, :, input_pos] = v_val
152+
k_out[:, input_pos] = k_val
153+
v_out[:, input_pos] = v_val
154154

155155
return k_out.transpose(1, 2), v_out.transpose(1, 2)
156156

examples/models/llama/source_transformation/test_quantized_kv_cache.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,12 @@ def _init_cache(self):
2424
self.max_seq_len,
2525
self.n_kv_heads,
2626
self.head_dim,
27-
self.transpose_kv_cache,
2827
self.enable_dynamic_shape,
2928
dtype=self.dtype,
3029
)
3130

3231
def _init_kv(self):
33-
if self.transpose_kv_cache:
34-
shape = (1, self.n_kv_heads, self.seq_len, self.head_dim)
35-
else:
36-
shape = (1, self.seq_len, self.n_kv_heads, self.head_dim)
32+
shape = (1, self.n_kv_heads, self.seq_len, self.head_dim)
3733
k = torch.rand(shape, dtype=self.dtype)
3834
v = torch.rand(shape, dtype=self.dtype)
3935
return k, v
@@ -45,11 +41,11 @@ def setUp(self):
4541
self.n_kv_heads = 8
4642
self.head_dim = 17
4743
self.enable_dynamic_shape = False
48-
self.transpose_kv_cache = False
4944
self.dtype = torch.float32
5045

51-
def _test_simple_update_fetch(self, is_transposed=False, is_dynamic_shape=False):
52-
self.transpose_kv_cache = is_transposed
46+
def _test_simple_update_fetch(
47+
self, is_dynamic_shape=False, use_custom_update_cache_op=False
48+
):
5349
self.enable_dynamic_shape = is_dynamic_shape
5450
input_pos = torch.tensor([0, 1, 2])
5551
self.seq_len = input_pos.size(0)
@@ -64,10 +60,7 @@ def _test_simple_update_fetch(self, is_transposed=False, is_dynamic_shape=False)
6460
)
6561

6662
def index(t, input_pos):
67-
if self.transpose_kv_cache:
68-
return t[:, :, input_pos, :]
69-
else:
70-
return t[:, input_pos, :, :]
63+
return t[:, :, input_pos, :]
7164

7265
sliced_k_cache = index(updated_k_cache, input_pos)
7366
sliced_v_cache = index(updated_v_cache, input_pos)
@@ -115,14 +108,16 @@ def index(t, input_pos):
115108
atol=1e-02,
116109
)
117110

118-
def test_simple_update_fetch_not_transposed(self):
111+
def test_simple_update_fetch(self):
119112
self._test_simple_update_fetch()
120113

121-
def test_simple_update_fetch_not_transposed_dynamic_shape(self):
122-
self._test_simple_update_fetch(is_dynamic_shape=True)
114+
def test_simple_update_fetch_use_custom_op(self):
115+
self._test_simple_update_fetch(use_custom_update_cache_op=True)
123116

124-
def test_simple_update_fetch_transposed(self):
125-
self._test_simple_update_fetch(is_transposed=True)
117+
def test_simple_update_fetch_dynamic_shape(self):
118+
self._test_simple_update_fetch(is_dynamic_shape=True)
126119

127-
def test_simple_update_fetch_transposed_dynamic_shape(self):
128-
self._test_simple_update_fetch(is_transposed=True, is_dynamic_shape=True)
120+
def test_simple_update_fetch_dynamic_shape_use_custom_op(self):
121+
self._test_simple_update_fetch(
122+
is_dynamic_shape=True, use_custom_update_cache_op=True
123+
)

0 commit comments

Comments
 (0)