@@ -242,7 +242,8 @@ def __init__(
242
242
config : ModelArgs ,
243
243
input_len : int ,
244
244
cache_lens : Union [int , List [int ]],
245
- dtype = torch .float32 ,
245
+ batch_size : int = 1 ,
246
+ dtype : torch .dtype = torch .float32 ,
246
247
style : str = "shift_pointer" ,
247
248
mask_val : float = float ("-inf" ),
248
249
):
@@ -266,15 +267,21 @@ def __init__(
266
267
if split_mha :
267
268
self .k_caches = {
268
269
StaticKVCache .calculate_cache_key (layer_id , head_id ): torch .zeros (
269
- 1 , cache_lens [layer_id ], none_throws (config .head_dim ), dtype = dtype
270
+ batch_size ,
271
+ cache_lens [layer_id ],
272
+ none_throws (config .head_dim ),
273
+ dtype = dtype ,
270
274
)
271
275
for layer_id in range (config .n_layers )
272
276
for head_id in range (none_throws (config .n_kv_heads ))
273
277
if cache_lens [layer_id ] > 0
274
278
}
275
279
self .v_caches = {
276
280
StaticKVCache .calculate_cache_key (layer_id , head_id ): torch .zeros (
277
- 1 , cache_lens [layer_id ], none_throws (config .head_dim ), dtype = dtype
281
+ batch_size ,
282
+ cache_lens [layer_id ],
283
+ none_throws (config .head_dim ),
284
+ dtype = dtype ,
278
285
)
279
286
for layer_id in range (config .n_layers )
280
287
for head_id in range (none_throws (config .n_kv_heads ))
@@ -283,7 +290,7 @@ def __init__(
283
290
else :
284
291
self .k_caches = {
285
292
StaticKVCache .calculate_cache_key (layer_id , 0 ): torch .zeros (
286
- 1 ,
293
+ batch_size ,
287
294
none_throws (config .n_kv_heads ),
288
295
cache_lens [layer_id ],
289
296
none_throws (config .head_dim ),
@@ -293,7 +300,7 @@ def __init__(
293
300
}
294
301
self .v_caches = {
295
302
StaticKVCache .calculate_cache_key (layer_id , 0 ): torch .zeros (
296
- 1 ,
303
+ batch_size ,
297
304
none_throws (config .n_kv_heads ),
298
305
cache_lens [layer_id ],
299
306
none_throws (config .head_dim ),
@@ -323,7 +330,7 @@ def reset(self):
323
330
def prefill (
324
331
self ,
325
332
model : Callable [..., Any ],
326
- tokens : List [int ],
333
+ tokens : Union [ List [int ], torch . Tensor ],
327
334
) -> torch .Tensor :
328
335
if self .cache_full :
329
336
raise RuntimeError ("KV cache is full." )
@@ -336,18 +343,21 @@ def prefill(
336
343
)
337
344
)
338
345
346
+ if isinstance (tokens , list ):
347
+ tokens = torch .tensor ([tokens ], dtype = torch .int32 )
348
+
339
349
logits = None
340
350
all_logits = None
341
- for i in range (0 , len ( tokens ), self .input_len ):
342
- logits = self ._run_once (model , tokens [i : i + self .input_len ])[0 ]
351
+ for i in range (0 , tokens . size ( 1 ), self .input_len ):
352
+ logits = self ._run_once (model , tokens [:, i : i + self .input_len ])[0 ]
343
353
if self .config .generate_full_logits :
344
354
if all_logits is None :
345
355
all_logits = logits
346
356
else :
347
357
all_logits = torch .cat ([all_logits , logits ], dim = 1 )
348
358
349
359
if self .config .generate_full_logits :
350
- return all_logits [:, : len ( tokens ), :]
360
+ return all_logits [:, : tokens . size ( 1 ), :]
351
361
352
362
return logits
353
363
@@ -510,15 +520,16 @@ def lookahead_decode( # noqa: C901
510
520
def _run_once (
511
521
self ,
512
522
model : Callable [..., Any ],
513
- tokens : List [int ],
523
+ tokens : Union [ List [int ], torch . Tensor ],
514
524
non_padded_len : Optional [int ] = None ,
515
525
freqs_cos_override : Optional [torch .Tensor ] = None ,
516
526
freqs_sin_override : Optional [torch .Tensor ] = None ,
517
527
):
518
- n_tokens = len (tokens )
528
+ if isinstance (tokens , list ):
529
+ tokens = torch .tensor ([tokens ], dtype = torch .int32 )
530
+ n_tokens = tokens .size (1 )
519
531
if n_tokens < self .input_len :
520
- tokens += [0 ] * (self .input_len - n_tokens )
521
- tokens = torch .tensor ([tokens ], dtype = torch .int32 ) # pyre-ignore[9]
532
+ tokens = F .pad (tokens , (0 , self .input_len - n_tokens ))
522
533
if freqs_cos_override is None :
523
534
freqs_cos_override = self .freqs_cos [self .pos : self .pos + self .input_len ]
524
535
if freqs_sin_override is None :
0 commit comments