@@ -242,7 +242,8 @@ def __init__(
242242 config : ModelArgs ,
243243 input_len : int ,
244244 cache_lens : Union [int , List [int ]],
245- dtype = torch .float32 ,
245+ batch_size : int = 1 ,
246+ dtype : torch .dtype = torch .float32 ,
246247 style : str = "shift_pointer" ,
247248 mask_val : float = float ("-inf" ),
248249 ):
@@ -266,15 +267,21 @@ def __init__(
266267 if split_mha :
267268 self .k_caches = {
268269 StaticKVCache .calculate_cache_key (layer_id , head_id ): torch .zeros (
269- 1 , cache_lens [layer_id ], none_throws (config .head_dim ), dtype = dtype
270+ batch_size ,
271+ cache_lens [layer_id ],
272+ none_throws (config .head_dim ),
273+ dtype = dtype ,
270274 )
271275 for layer_id in range (config .n_layers )
272276 for head_id in range (none_throws (config .n_kv_heads ))
273277 if cache_lens [layer_id ] > 0
274278 }
275279 self .v_caches = {
276280 StaticKVCache .calculate_cache_key (layer_id , head_id ): torch .zeros (
277- 1 , cache_lens [layer_id ], none_throws (config .head_dim ), dtype = dtype
281+ batch_size ,
282+ cache_lens [layer_id ],
283+ none_throws (config .head_dim ),
284+ dtype = dtype ,
278285 )
279286 for layer_id in range (config .n_layers )
280287 for head_id in range (none_throws (config .n_kv_heads ))
@@ -283,7 +290,7 @@ def __init__(
283290 else :
284291 self .k_caches = {
285292 StaticKVCache .calculate_cache_key (layer_id , 0 ): torch .zeros (
286- 1 ,
293+ batch_size ,
287294 none_throws (config .n_kv_heads ),
288295 cache_lens [layer_id ],
289296 none_throws (config .head_dim ),
@@ -293,7 +300,7 @@ def __init__(
293300 }
294301 self .v_caches = {
295302 StaticKVCache .calculate_cache_key (layer_id , 0 ): torch .zeros (
296- 1 ,
303+ batch_size ,
297304 none_throws (config .n_kv_heads ),
298305 cache_lens [layer_id ],
299306 none_throws (config .head_dim ),
@@ -323,7 +330,7 @@ def reset(self):
323330 def prefill (
324331 self ,
325332 model : Callable [..., Any ],
326- tokens : List [int ],
333+ tokens : Union [ List [int ], torch . Tensor ],
327334 ) -> torch .Tensor :
328335 if self .cache_full :
329336 raise RuntimeError ("KV cache is full." )
@@ -336,18 +343,21 @@ def prefill(
336343 )
337344 )
338345
346+ if isinstance (tokens , list ):
347+ tokens = torch .tensor ([tokens ], dtype = torch .int32 )
348+
339349 logits = None
340350 all_logits = None
341- for i in range (0 , len ( tokens ), self .input_len ):
342- logits = self ._run_once (model , tokens [i : i + self .input_len ])[0 ]
351+ for i in range (0 , tokens . size ( 1 ), self .input_len ):
352+ logits = self ._run_once (model , tokens [:, i : i + self .input_len ])[0 ]
343353 if self .config .generate_full_logits :
344354 if all_logits is None :
345355 all_logits = logits
346356 else :
347357 all_logits = torch .cat ([all_logits , logits ], dim = 1 )
348358
349359 if self .config .generate_full_logits :
350- return all_logits [:, : len ( tokens ), :]
360+ return all_logits [:, : tokens . size ( 1 ), :]
351361
352362 return logits
353363
@@ -510,15 +520,16 @@ def lookahead_decode( # noqa: C901
510520 def _run_once (
511521 self ,
512522 model : Callable [..., Any ],
513- tokens : List [int ],
523+ tokens : Union [ List [int ], torch . Tensor ],
514524 non_padded_len : Optional [int ] = None ,
515525 freqs_cos_override : Optional [torch .Tensor ] = None ,
516526 freqs_sin_override : Optional [torch .Tensor ] = None ,
517527 ):
518- n_tokens = len (tokens )
528+ if isinstance (tokens , list ):
529+ tokens = torch .tensor ([tokens ], dtype = torch .int32 )
530+ n_tokens = tokens .size (1 )
519531 if n_tokens < self .input_len :
520- tokens += [0 ] * (self .input_len - n_tokens )
521- tokens = torch .tensor ([tokens ], dtype = torch .int32 ) # pyre-ignore[9]
532+ tokens = F .pad (tokens , (0 , self .input_len - n_tokens ))
522533 if freqs_cos_override is None :
523534 freqs_cos_override = self .freqs_cos [self .pos : self .pos + self .input_len ]
524535 if freqs_sin_override is None :
0 commit comments