Skip to content

Commit 16526f3

Browse files
committed
Pass token indices to KV cache
1 parent dc66754 commit 16526f3

File tree

4 files changed

+93
-16
lines changed

4 files changed

+93
-16
lines changed

litgpt/kvcache/base.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,12 @@ def max_tokens_forward(self) -> int:
158158
"""
159159
raise NotImplementedError()
160160

161-
def forward(self, key: torch.Tensor, value: torch.Tensor) -> KeysAndValues:
161+
def forward(
162+
self,
163+
key: torch.Tensor,
164+
value: torch.Tensor,
165+
token_idx: torch.Tensor,
166+
) -> KeysAndValues:
162167
"""
163168
Accepts key and value tensors for `1 <= num <= max_tokens_forward`
164169
new token positions. These are written into the cache. If the cache
@@ -176,6 +181,7 @@ def forward(self, key: torch.Tensor, value: torch.Tensor) -> KeysAndValues:
176181
key: New keys, `(eff_batch_size, n_query_groups, num, head_size)`,
177182
where `1 <= num <= max_tokens_forward`
178183
value: New values, `(eff_batch_size, n_query_groups, num, head_size)`
184+
token_idx: Token indices of input sequence, `(eff_batch_size, num)`
179185
180186
Returns:
181187
key_cached, value_cached, `(eff_batch_size, n_query_groups, T,
@@ -203,6 +209,7 @@ def update(self, *args, **kwargs):
203209
Args:
204210
*args: Depends on subclass
205211
**kwargs: Depends on subclass
212+
206213
"""
207214
raise NotImplementedError()
208215

@@ -228,7 +235,12 @@ def max_prefill_length(self) -> Optional[int]:
228235
"""
229236
raise NotImplementedError()
230237

231-
def prefill(self, key: torch.Tensor, value: torch.Tensor):
238+
def prefill(
239+
self,
240+
key: torch.Tensor,
241+
value: torch.Tensor,
242+
token_idx: torch.Tensor,
243+
):
232244
"""
233245
Starts a generation loop by passing key and value tensors coming from
234246
a prefill with embeddings coming from the prompts. The length must be
@@ -239,6 +251,8 @@ def prefill(self, key: torch.Tensor, value: torch.Tensor):
239251
Args:
240252
key: Prefill keys, `(eff_batch_size, n_query_groups, T, head_size)`
241253
value: Prefill values, `(eff_batch_size, n_query_groups, T, head_size)`
254+
token_idx: Token indices of input sequence, `(eff_batch_size, T)`
255+
242256
"""
243257
raise NotImplementedError()
244258

@@ -271,6 +285,7 @@ def size_estimate(self) -> Tuple[int, Dict[str, int]]:
271285
272286
Returns:
273287
num_bits_total, bits_by_part (unit is bit)
288+
274289
"""
275290
raise NotImplementedError()
276291

@@ -287,6 +302,7 @@ def size_estimate_apriori(cls, params: KVCacheParams, **kwargs) -> Tuple[int, Di
287302
288303
Returns:
289304
num_bits_total, bits_by_part (unit is bit)
305+
290306
"""
291307
raise NotImplementedError()
292308

@@ -326,6 +342,7 @@ class DenseKVCache(KVCache):
326342
327343
Note: If the cache is full, :meth:`forward` raises an exception. The cache
328344
buffers are allocated up front and are not enlarged later on.
345+
329346
"""
330347
def __init__(
331348
self,
@@ -344,7 +361,6 @@ def __init__(
344361
dtype: Data type for buffers
345362
max_sequence_length: Cache length. If not given, we use
346363
`config.block_size`
347-
max_tokens_forward: See parent class
348364
head_size: Size of final dimension of buffers. Defaults to head
349365
size of model
350366
@@ -380,7 +396,12 @@ def max_prefill_length(self) -> Optional[int]:
380396
def current_length(self) -> int:
381397
return self.next_position
382398

383-
def forward(self, key: torch.Tensor, value: torch.Tensor) -> KeysAndValues:
399+
def forward(
400+
self,
401+
key: torch.Tensor,
402+
value: torch.Tensor,
403+
token_idx: torch.Tensor,
404+
) -> KeysAndValues:
384405
if self.next_position is None:
385406
raise IndexError("Cache needs to be initialized with 'prefill' before being used")
386407
num = key.shape[2]
@@ -416,7 +437,12 @@ def forward(self, key: torch.Tensor, value: torch.Tensor) -> KeysAndValues:
416437
def update(self, *args, **kwargs):
417438
pass
418439

419-
def prefill(self, key: torch.Tensor, value: torch.Tensor):
440+
def prefill(
441+
self,
442+
key: torch.Tensor,
443+
value: torch.Tensor,
444+
token_idx: torch.Tensor,
445+
):
420446
if key.dim() != 4:
421447
raise ValueError("key must have 4 dimensions")
422448
init_length = key.shape[2]
@@ -517,7 +543,12 @@ def max_tokens_forward(self) -> int:
517543
def max_prefill_length(self) -> Optional[int]:
518544
return None
519545

520-
def forward(self, key: torch.Tensor, value: torch.Tensor) -> KeysAndValues:
546+
def forward(
547+
self,
548+
key: torch.Tensor,
549+
value: torch.Tensor,
550+
token_idx: torch.Tensor,
551+
) -> KeysAndValues:
521552
if self.next_position is None:
522553
raise IndexError("Cache needs to be initialized with 'prefill' before being used")
523554
if key.ndim != 4:
@@ -563,7 +594,12 @@ def forward(self, key: torch.Tensor, value: torch.Tensor) -> KeysAndValues:
563594
def update(self, *args, **kwargs):
564595
pass
565596

566-
def prefill(self, key: torch.Tensor, value: torch.Tensor):
597+
def prefill(
598+
self,
599+
key: torch.Tensor,
600+
value: torch.Tensor,
601+
token_idx: torch.Tensor,
602+
):
567603
if key.dim() != 4:
568604
raise ValueError("key must have 4 dimensions")
569605
init_length = key.shape[2]

litgpt/model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def forward(
309309
attn = block.attn
310310
if attn.kv_cache.batch_size < eff_batch_size:
311311
raise ValueError(f"Batch size {eff_batch_size} is too large for KV cache layer {l_ix} (batch size {attn.kv_cache.batch_size}). Use 'assign_kv_caches' or `set_kv_cache'")
312-
x = block(x, cos, sin, input_pos, self.mask_cache)
312+
x = block(x, cos, sin, idx, input_pos, self.mask_cache)
313313

314314
x = self.transformer.ln_f(x)
315315
clamp_head = partial(
@@ -428,6 +428,7 @@ def forward(
428428
x: torch.Tensor,
429429
cos: torch.Tensor,
430430
sin: torch.Tensor,
431+
token_idx: torch.Tensor,
431432
input_pos: Optional[int] = None,
432433
mask_cache: Optional[torch.Tensor] = None,
433434
) -> torch.Tensor:
@@ -457,6 +458,7 @@ def forward(
457458
x_normed,
458459
cos=cos,
459460
sin=sin,
461+
token_idx=token_idx,
460462
input_pos=input_pos,
461463
mask_cache=mask_cache,
462464
)
@@ -511,6 +513,7 @@ def forward(
511513
x: torch.Tensor,
512514
cos: torch.Tensor,
513515
sin: torch.Tensor,
516+
token_idx: torch.Tensor,
514517
input_pos: Optional[int] = None,
515518
mask_cache: Optional[torch.Tensor] = None,
516519
) -> torch.Tensor:
@@ -596,12 +599,12 @@ def forward(
596599
# Instead of asking for the key and value tensors as such,
597600
# `k_and_v` allows access to them. Since they are never needed at
598601
# the same time, this can save memory.
599-
k_and_v = self.kv_cache(k, v)
602+
k_and_v = self.kv_cache(k, v, token_idx=token_idx)
600603
# k, v: (B, nh_k, cache_length, hs)
601604
else:
602605
if for_prefill:
603606
# Prefill KV cache
604-
self.kv_cache.prefill(key=k, value=v)
607+
self.kv_cache.prefill(key=k, value=v, token_idx=token_idx)
605608
# In this case, `k_and_v` can vend both keys and values at the same
606609
# time.
607610
k_and_v = DefaultKeysAndValues(k, v)

tests/kvcache/test_base.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def test_most_recent():
1515
seed = 31415927
1616
random.seed(seed)
1717
torch.random.manual_seed(seed)
18+
vocab_size = 128
1819

1920
params = KVCacheParams(
2021
batch_size=3,
@@ -34,9 +35,22 @@ def test_most_recent():
3435
num_prefill = max_prefill_length
3536

3637
keys, values = random_keys_values(params, num=num_insert)
37-
kv_cache.prefill(keys[:, :, :num_prefill, :], values[:, :, :num_prefill, :])
38+
token_idx = torch.randint(
39+
low=0,
40+
high=vocab_size,
41+
size=(params.batch_size, num_insert),
42+
)
43+
kv_cache.prefill(
44+
key=keys[:, :, :num_prefill, :],
45+
value=values[:, :, :num_prefill, :],
46+
token_idx=token_idx[:, :num_prefill],
47+
)
3848
for pos in range(num_prefill, num_insert):
39-
kv_cache(keys[:, :, pos:(pos + 1), :], values[:, :, pos:(pos + 1), :])
49+
kv_cache(
50+
keys[:, :, pos:(pos + 1), :],
51+
values[:, :, pos:(pos + 1), :],
52+
token_idx=token_idx[:, pos:(pos + 1)],
53+
)
4054
if kv_cache.update_requires_attn_weights():
4155
attn_weights = random_attn_weights(params, num=kv_cache.current_length)
4256
kv_cache.update(attn_weights=attn_weights)

tests/kvcache/test_generic.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_store_retrieve(name):
1818
seed = 31415927
1919
random.seed(seed)
2020
torch.random.manual_seed(seed)
21+
vocab_size = 128
2122

2223
params = KVCacheParams(
2324
batch_size=3,
@@ -40,11 +41,22 @@ def test_store_retrieve(name):
4041
num_prefill = max_prefill_length
4142

4243
keys, values = random_keys_values(params, num=num_insert)
43-
kv_cache.prefill(keys[:, :, :num_prefill, :], values[:, :, :num_prefill, :])
44+
token_idx = torch.randint(
45+
low=0,
46+
high=vocab_size,
47+
size=(params.batch_size, num_insert),
48+
)
49+
kv_cache.prefill(
50+
key=keys[:, :, :num_prefill, :],
51+
value=values[:, :, :num_prefill, :],
52+
token_idx=token_idx[:, :num_prefill],
53+
)
4454
keys_and_values = None
4555
for pos in range(num_prefill, num_insert):
4656
keys_and_values = kv_cache(
47-
keys[:, :, pos:(pos + 1), :], values[:, :, pos:(pos + 1), :]
57+
keys[:, :, pos:(pos + 1), :],
58+
values[:, :, pos:(pos + 1), :],
59+
token_idx=token_idx[:, pos:(pos + 1)],
4860
)
4961
if kv_cache.update_requires_attn_weights():
5062
attn_weights = random_attn_weights(params, num=kv_cache.current_length)
@@ -80,6 +92,7 @@ def test_prefill(name):
8092
seed = 31415927
8193
random.seed(seed)
8294
torch.random.manual_seed(seed)
95+
vocab_size = 128
8396
num_compares = 3
8497

8598
params = KVCacheParams(
@@ -95,18 +108,29 @@ def test_prefill(name):
95108
kv_cache = create_kv_cache(name, params)
96109

97110
keys, values = random_keys_values(params, num=cache_length)
111+
token_idx = torch.randint(
112+
low=0,
113+
high=vocab_size,
114+
size=(params.batch_size, cache_length),
115+
)
98116
keys_cached = []
99117
values_cached = []
100118
max_prefill_length = kv_cache.max_prefill_length
101119
for _ in range(num_compares):
102120
num_prefill = random.randint(cache_length // 8, cache_length)
103121
if max_prefill_length is not None and num_prefill > max_prefill_length:
104122
num_prefill = max_prefill_length
105-
kv_cache.prefill(keys[:, :, :num_prefill, :], values[:, :, :num_prefill, :])
123+
kv_cache.prefill(
124+
key=keys[:, :, :num_prefill, :],
125+
value=values[:, :, :num_prefill, :],
126+
token_idx=token_idx[:, :num_prefill],
127+
)
106128
keys_and_values = None
107129
for pos in range(num_prefill, cache_length):
108130
keys_and_values = kv_cache(
109-
keys[:, :, pos:(pos + 1), :], values[:, :, pos:(pos + 1), :]
131+
keys[:, :, pos:(pos + 1), :],
132+
values[:, :, pos:(pos + 1), :],
133+
token_idx=token_idx[:, pos:(pos + 1)],
110134
)
111135
if kv_cache.update_requires_attn_weights():
112136
attn_weights = random_attn_weights(params, num=kv_cache.current_length)

0 commit comments

Comments
 (0)