@@ -219,37 +219,42 @@ def post_process():
219219
220220
221221def smart_mask_updater (
222- ar_len , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
222+ _ , n_updates , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
223223):
224- # Update the KV cache input for the next inference when the position exceeds the autoregressive length.
225- if pos >= ar_len :
224+ # ar_len is unused in smart mask
225+ max_cache_len = k_caches [0 ].size (- 1 )
226+ if pos + n_updates <= max_cache_len :
226227 for i , k_cache in enumerate (k_caches ):
227- k_cache [:, :, pos - ar_len ] = new_k_caches [i ][:, :, 0 ]
228+ k_cache [:, :, pos : pos + n_updates ] = new_k_caches [i ][:, :, : n_updates ]
228229
229230 for i , v_cache in enumerate (v_caches ):
230- v_cache [:, pos - ar_len , :] = new_v_caches [i ][:, 0 , :]
231- atten_mask [:, :, pos - ar_len ] = 0
231+ v_cache [:, pos : pos + n_updates , :] = new_v_caches [i ][:, :n_updates , :]
232+ atten_mask [:, :, pos : pos + n_updates ] = 0
233+ pos += n_updates
232234
233- pos += 1
234235 return (atten_mask , pos , k_caches , v_caches )
235236
236237
237238def shift_pointer_updater (
238- ar_len , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
239+ ar_len , n_updates , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
239240):
240- # Update the KV cache input for the next inference when the position exceeds the autoregressive length.
241- if pos >= ar_len :
241+ max_cache_len = k_caches [ 0 ]. size ( - 1 )
242+ if pos + n_updates <= max_cache_len :
242243 k_caches = [
243- torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ][:, :, :1 ]], dim = - 1 )
244+ torch .cat (
245+ [k_cache [:, :, n_updates :], new_k_caches [i ][:, :, :n_updates ]], dim = - 1
246+ )
244247 for i , k_cache in enumerate (k_caches )
245248 ]
246249 v_caches = [
247- torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ][:, :1 , :]], dim = 1 )
250+ torch .cat (
251+ [v_cache [:, n_updates :, :], new_v_caches [i ][:, :n_updates , :]], dim = 1
252+ )
248253 for i , v_cache in enumerate (v_caches )
249254 ]
250- atten_mask [:, :, - pos - 1 ] = 0
255+ atten_mask [:, :, - pos - n_updates - ar_len : - pos - ar_len ] = 0
256+ pos += n_updates
251257
252- pos += 1
253258 return (atten_mask , pos , k_caches , v_caches )
254259
255260
@@ -269,70 +274,121 @@ def kv_inference(
269274 # TODO: change criteria & support batch inputs if necessary
270275 all_pos = torch .arange (0 , max_seq_len , 1 , dtype = torch .int32 ).unsqueeze (0 )
271276
272- token_list , result_logits = [], []
277+ prompt_token_list , total_token_list , result_logits = [], [], []
273278
274279 if isinstance (prompt , str ):
275280 # Llama2 tokenizer has no special tokens
276281 if isinstance (tokenizer , (SentencePieceTokenizer , HuggingFaceTokenizer )):
277- token_list = tokenizer .encode (prompt , bos = True , eos = False )
282+ prompt_token_list = tokenizer .encode (prompt , bos = True , eos = False )
278283 elif isinstance (tokenizer , TiktokenTokenizer ):
279- token_list = tokenizer .encode (
284+ prompt_token_list = tokenizer .encode (
280285 prompt , bos = True , eos = False , allowed_special = "all"
281286 )
282287 else :
283288 raise RuntimeError ("Unknown tokenizer" )
284289 else :
285290 # pyre-ignore
286- token_list = prompt .flatten ().tolist ()
287- pos = len ( token_list ) if len ( token_list ) < ar_len else ar_len
291+ prompt_token_list = prompt .flatten ().tolist ()
292+ total_token_list = prompt_token_list
288293 dtype = torch .int64 if use_i64_token else torch .int32
289294
290295 with torch .no_grad ():
291- while token_list [- 1 ] != tokenizer .eos_id and pos < max_seq_len :
292- tmp_token_list = torch .tensor (
293- token_list [pos - ar_len : pos ], dtype = dtype
294- ).reshape (1 , - 1 )
295- tmp_pos = all_pos [:, pos - ar_len : pos ]
296- tmp_atten_mask = atten_mask
297- if pos < ar_len :
298- tmp_token_list = torch .cat (
299- [
300- torch .zeros ((1 , ar_len - pos ), dtype = dtype ),
301- torch .tensor (token_list , dtype = dtype ).reshape (1 , - 1 ),
302- ],
303- dim = 1 ,
304- )
305- tmp_pos = torch .cat (
306- [
307- torch .zeros ((1 , ar_len - pos ), dtype = torch .int32 ),
308- all_pos [:, :pos ],
309- ],
310- dim = 1 ,
311- )
312- tmp_atten_mask = torch .cat (
313- [
314- torch .ones (1 , ar_len , max_seq_len - pos ) * - 255.0 ,
315- atten_mask [:, :, - pos :],
316- ],
317- dim = - 1 ,
318- )
296+ # Phase 1: Prefill the prompt in ar_len chunks.
297+ num_prompt_tokens = len (prompt_token_list )
298+ pos = 0 # Tracks how many prompt tokens have been processed.
299+ while pos < num_prompt_tokens :
300+ chunk_start_idx = pos
301+ # Take a chunk of prompt tokens, up to ar_len length.
302+ chunk_end_idx = min (num_prompt_tokens , pos + ar_len )
303+ actual_chunk_tokens = prompt_token_list [chunk_start_idx :chunk_end_idx ]
304+ num_tokens_in_chunk = len (actual_chunk_tokens )
305+
306+ # Prepare tmp_token_list (padded with zeros).
307+ tmp_token_list = torch .zeros ((1 , ar_len ), dtype = dtype )
308+ tmp_token_list [0 , :num_tokens_in_chunk ] = torch .tensor (
309+ actual_chunk_tokens , dtype = dtype
310+ )
319311
312+ # Prepare tmp_pos (padded with zeros).
313+ tmp_pos = torch .zeros ((1 , ar_len ), dtype = torch .int32 )
314+ tmp_pos [0 , :num_tokens_in_chunk ] = all_pos [
315+ 0 ,
316+ pos : pos + num_tokens_in_chunk ,
317+ ]
318+
319+ # Run inference.
320320 logits , new_k_caches , new_v_caches = module (
321321 tmp_token_list ,
322- tmp_atten_mask ,
322+ atten_mask ,
323323 tmp_pos ,
324324 * k_caches ,
325325 * v_caches ,
326326 )
327327 if collect_logits :
328- result_logits .append (logits )
328+ result_logits .append (logits [:, :num_tokens_in_chunk ])
329+
330+ # Update the pos, KV cache and attention mask.
329331 atten_mask , pos , k_caches , v_caches = kv_updater (
330- ar_len , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
332+ ar_len ,
333+ num_tokens_in_chunk ,
334+ atten_mask ,
335+ pos ,
336+ k_caches ,
337+ v_caches ,
338+ new_k_caches ,
339+ new_v_caches ,
340+ )
341+ # Append the last run logits to the total_token_list.
342+ total_token_list .append (
343+ torch .argmax (logits [:, num_tokens_in_chunk - 1 ], dim = - 1 ).item ()
344+ )
345+
346+ # Phase 2: Generate tokens until the EOS token is generated or max_seq_len is reached.
347+ # When run on wikitext for ppl evaluation, this while-loop is not expected to run.
348+ max_cache_len = max_seq_len - ar_len
349+ num_tokens = len (total_token_list )
350+ while total_token_list [- 1 ] != tokenizer .eos_id and num_tokens < max_seq_len :
351+ chunk_start_idx = min (pos , max_cache_len )
352+ # Take a chunk of generated tokens, up to ar_len length.
353+ chunk_end_idx = num_tokens
354+ actual_chunk_tokens = total_token_list [chunk_start_idx :chunk_end_idx ]
355+ num_tokens_in_chunk = len (actual_chunk_tokens )
356+
357+ # Prepare tmp_token_list (padded with zeros).
358+ tmp_token_list = torch .zeros ((1 , ar_len ), dtype = dtype )
359+ tmp_token_list [0 , :num_tokens_in_chunk ] = torch .tensor (
360+ actual_chunk_tokens , dtype = dtype
361+ )
362+
363+ # Prepare tmp_pos (padded with zeros).
364+ tmp_pos = torch .zeros ((1 , ar_len ), dtype = torch .int32 )
365+ tmp_pos [0 , :num_tokens_in_chunk ] = all_pos [0 , chunk_start_idx :chunk_end_idx ]
366+
367+ logits , new_k_caches , new_v_caches = module (
368+ tmp_token_list ,
369+ atten_mask ,
370+ tmp_pos ,
371+ * k_caches ,
372+ * v_caches ,
331373 )
332- if pos > len ( token_list ) :
333- token_list .append (torch . argmax ( logits [:, - 1 ], dim = - 1 ). item () )
374+ if collect_logits :
375+ result_logits .append (logits [:, : num_tokens_in_chunk ] )
334376
335- logging .info (f"kv inference result:\n { tokenizer .decode (token_list )} " )
377+ atten_mask , pos , k_caches , v_caches = kv_updater (
378+ ar_len ,
379+ 1 ,
380+ atten_mask ,
381+ pos ,
382+ k_caches ,
383+ v_caches ,
384+ new_k_caches ,
385+ new_v_caches ,
386+ )
387+ total_token_list .append (
388+ torch .argmax (logits [:, num_tokens_in_chunk - 1 ], dim = - 1 ).item ()
389+ )
390+ num_tokens = len (total_token_list )
391+ logging .info (f"kv inference result:\n { tokenizer .decode (total_token_list )} " )
336392 if collect_logits :
337393 result_logits = torch .cat (result_logits , dim = 1 )
338394 return result_logits
0 commit comments