Skip to content

Commit f142f4a

Browse files
talumbaucopybara-github
authored andcommitted
KV Cache updates use dynamic_update_slice
PiperOrigin-RevId: 705660857
1 parent 5a93316 commit f142f4a

File tree

4 files changed

+33
-30
lines changed

4 files changed

+33
-30
lines changed

ai_edge_torch/generative/layers/attention.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,7 @@ def forward(
241241
q, k = _embed_rope(q, k, n_elem, rope)
242242

243243
if kv_cache is not None:
244-
kv_cache = kv_utils.update(
245-
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
246-
)
244+
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
247245
k, v = kv_cache.k_cache, kv_cache.v_cache
248246

249247
y = self.sdpa_func(
@@ -379,9 +377,7 @@ def forward(
379377
q, k = _embed_rope(q, k, n_elem, rope)
380378

381379
if kv_cache is not None:
382-
kv_cache = kv_utils.update(
383-
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
384-
)
380+
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
385381
k, v = kv_cache.k_cache, kv_cache.v_cache
386382
if mask is None:
387383
mask = torch.zeros(

ai_edge_torch/generative/layers/kv_cache.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def update(
146146
input_pos: torch.Tensor,
147147
k_slice: torch.Tensor,
148148
v_slice: torch.Tensor,
149-
enable_hlfb: bool = True,
149+
use_dus: bool = True,
150150
) -> KVCacheEntry:
151151
"""Out of place update of Cache buffer.
152152
@@ -155,17 +155,14 @@ def update(
155155
input_pos (torch.Tensor): The update slice positions.
156156
k_slice (torch.Tensor): The K slice to be updated in the new cache.
157157
v_slice (torch.Tensor): The V slice to be updated in the new cache.
158-
enable_hlfb (bool, optional): Whether the op is annotated for export with
159-
High Level Function Boundary. Defaults to True.
160158
161159
Returns:
162160
KVCacheEntry: The updated KVCache entry based on the passed inputs.
163161
"""
164-
# Don't enable HLFB for kv cache op for now, since it won't work with LLM
165-
# inference engine. Remove this part once we ship a new LLM inference engine.
166-
enable_hlfb=False
167-
update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
168-
return update_func(cache, input_pos, k_slice, v_slice)
162+
# Turn dynamic_update_slice updates off for now.
163+
use_dus=False
164+
update_kv_cache = _update_kv_impl if use_dus else _update_kv_base_impl
165+
return update_kv_cache(cache, input_pos, k_slice, v_slice)
169166

170167

171168
def _update_kv_base_impl(
@@ -181,18 +178,28 @@ def _update_kv_base_impl(
181178
return updated_cache
182179

183180

184-
def _update_kv_hlfb_impl(
181+
def _get_slice_indices(positions: torch.Tensor) -> torch.Tensor:
182+
"""Dynamic Update Slice updates are a variadic sequence of 0-rank tensors."""
183+
184+
zero = torch.zeros([]).int()
185+
positions = positions.int()[0].reshape([])
186+
return [zero, positions, zero, zero]
187+
188+
189+
def _update_kv_impl(
185190
cache: KVCacheEntry,
186191
input_pos: torch.Tensor,
187192
k_slice: torch.Tensor,
188193
v_slice: torch.Tensor,
189194
) -> KVCacheEntry:
190-
"""Update the cache buffer with High Level Function Boundary annotation."""
191-
builder = hlfb.StableHLOCompositeBuilder(name="odml.update_external_kv_cache")
192-
k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
193-
cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
194-
)
195-
k = k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
196-
v = v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
197-
k, v = builder.mark_outputs(k, v)
198-
return KVCacheEntry(k, v)
195+
"""Update the cache buffer for K and V caches."""
196+
# NB: Here assume that input_pos == range(input_pos[0], len(input_pos))
197+
198+
k_slice_indices = _get_slice_indices(input_pos)
199+
v_slice_indices = _get_slice_indices(input_pos)
200+
201+
k = dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
202+
v = dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)
203+
204+
updated_cache = KVCacheEntry(k, v)
205+
return updated_cache

ai_edge_torch/generative/test/test_kv_cache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,18 @@ def test_cache_udpate(self):
7171
[0, 0, 5, 5, 0, 0, 0, 0],
7272
)
7373
# multi-slice update
74-
input_pos = torch.tensor([0, 3])
74+
input_pos = torch.tensor([0, 1])
7575
k_slice = v_slice = torch.full(
7676
(1, 2, NUM_QG, HEAD_DIM), 7, dtype=torch.float
7777
)
7878
updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
7979
self.assertEqual(
8080
updated_entry.k_cache.numpy().flatten().tolist(),
81-
[7, 7, 0, 0, 0, 0, 7, 7],
81+
[7, 7, 7, 7, 0, 0, 0, 0],
8282
)
8383
self.assertEqual(
8484
updated_entry.v_cache.numpy().flatten().tolist(),
85-
[7, 7, 0, 0, 0, 0, 7, 7],
85+
[7, 7, 7, 7, 0, 0, 0, 0],
8686
)
8787

8888
def test_serialization(self):

ai_edge_torch/generative/test/test_model_conversion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def test_toy_model_with_kv_cache_with_hlfb(self):
100100
ai_edge_config.Config.use_torch_xla,
101101
reason="tests with custom ops are not supported on oss",
102102
)
103-
def test_toy_model_has_ekv_op(self):
104-
"""Tests that the model has the external kv cache op."""
103+
def test_toy_model_has_dus_op(self):
104+
"""Tests that the model has the dynamic update slice op."""
105105
_, edge_model, _ = self._get_params(enable_hlfb=True)
106106
interpreter_ = interpreter.InterpreterWithCustomOps(
107107
custom_op_registerers=["GenAIOpsRegisterer"],
@@ -111,7 +111,7 @@ def test_toy_model_has_ekv_op(self):
111111

112112
# pylint: disable=protected-access
113113
op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
114-
self.assertIn("odml.update_external_kv_cache", op_names)
114+
self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
115115

116116
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
117117
# prefill

0 commit comments

Comments
 (0)