@@ -216,37 +216,42 @@ def post_process():
216216
217217
218218def smart_mask_updater (
219- ar_len , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
219+ _ , n_updates , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
220220):
221- # Update the KV cache input for the next inference when the position exceeds the autoregressive length.
222- if pos >= ar_len :
221+ # ar_len is unused in smart mask
222+ max_cache_len = k_caches [0 ].size (- 1 )
223+ if pos + n_updates < max_cache_len :
223224 for i , k_cache in enumerate (k_caches ):
224- k_cache [:, :, pos - ar_len ] = new_k_caches [i ][:, :, 0 ]
225+ k_cache [:, :, pos : pos + n_updates ] = new_k_caches [i ][:, :, : n_updates ]
225226
226227 for i , v_cache in enumerate (v_caches ):
227- v_cache [:, pos - ar_len , :] = new_v_caches [i ][:, 0 , :]
228- atten_mask [:, :, pos - ar_len ] = 0
228+ v_cache [:, pos : pos + n_updates , :] = new_v_caches [i ][:, :n_updates , :]
229+ atten_mask [:, :, pos : pos + n_updates ] = 0
230+ pos += n_updates
229231
230- pos += 1
231232 return (atten_mask , pos , k_caches , v_caches )
232233
233234
234235def shift_pointer_updater (
235- ar_len , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
236+ ar_len , n_updates , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
236237):
237- # Update the KV cache input for the next inference when the position exceeds the autoregressive length.
238- if pos >= ar_len :
238+ max_cache_len = k_caches [ 0 ]. size ( - 1 )
239+ if pos + n_updates < max_cache_len :
239240 k_caches = [
240- torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ][:, :, :1 ]], dim = - 1 )
241+ torch .cat (
242+ [k_cache [:, :, n_updates :], new_k_caches [i ][:, :, :n_updates ]], dim = - 1
243+ )
241244 for i , k_cache in enumerate (k_caches )
242245 ]
243246 v_caches = [
244- torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ][:, :1 , :]], dim = 1 )
247+ torch .cat (
248+ [v_cache [:, n_updates :, :], new_v_caches [i ][:, :n_updates , :]], dim = 1
249+ )
245250 for i , v_cache in enumerate (v_caches )
246251 ]
247- atten_mask [:, :, - pos - 1 ] = 0
252+ atten_mask [:, :, - pos - n_updates - ar_len : - pos - ar_len ] = 0
253+ pos += n_updates
248254
249- pos += 1
250255 return (atten_mask , pos , k_caches , v_caches )
251256
252257
@@ -266,69 +271,127 @@ def kv_inference(
266271 # TODO: change criteria & support batch inputs if necessary
267272 all_pos = torch .arange (0 , max_seq_len , 1 , dtype = torch .int32 ).unsqueeze (0 )
268273
269- token_list , result_logits = [], []
274+ prompt_token_list , generated_token_list , result_logits = [], [], []
270275
271276 if isinstance (prompt , str ):
272277 # Llama2 tokenizer has no special tokens
273278 if isinstance (tokenizer , (SentencePieceTokenizer , HuggingFaceTokenizer )):
274- token_list = tokenizer .encode (prompt , bos = True , eos = False )
279+ prompt_token_list = tokenizer .encode (prompt , bos = True , eos = False )
275280 elif isinstance (tokenizer , TiktokenTokenizer ):
276- token_list = tokenizer .encode (
281+ prompt_token_list = tokenizer .encode (
277282 prompt , bos = True , eos = False , allowed_special = "all"
278283 )
279284 else :
280285 raise RuntimeError ("Unknown tokenizer" )
281286 else :
282- token_list = prompt .flatten ().tolist ()
283- pos = len (token_list ) if len (token_list ) < ar_len else ar_len
287+ prompt_token_list = prompt .flatten ().tolist ()
284288 dtype = torch .int64 if use_i64_token else torch .int32
285289
286290 with torch .no_grad ():
287- while token_list [- 1 ] != tokenizer .eos_id and pos < max_seq_len :
288- tmp_token_list = torch .tensor (
289- token_list [pos - ar_len : pos ], dtype = dtype
290- ).reshape (1 , - 1 )
291- tmp_pos = all_pos [:, pos - ar_len : pos ]
292- tmp_atten_mask = atten_mask
293- if pos < ar_len :
294- tmp_token_list = torch .cat (
295- [
296- torch .zeros ((1 , ar_len - pos ), dtype = dtype ),
297- torch .tensor (token_list , dtype = dtype ).reshape (1 , - 1 ),
298- ],
299- dim = 1 ,
300- )
301- tmp_pos = torch .cat (
302- [
303- torch .zeros ((1 , ar_len - pos ), dtype = torch .int32 ),
304- all_pos [:, :pos ],
305- ],
306- dim = 1 ,
307- )
308- tmp_atten_mask = torch .cat (
309- [
310- torch .ones (1 , ar_len , max_seq_len - pos ) * - 255.0 ,
311- atten_mask [:, :, - pos :],
312- ],
313- dim = - 1 ,
314- )
291+ # Phase 1: Prefill the prompt in ar_len chunks.
292+ num_prompt_tokens = len (prompt_token_list )
293+ prompt_processed_so_far = (
294+ 0 # Tracks how many prompt tokens have been processed.
295+ )
296+ while prompt_processed_so_far < num_prompt_tokens :
297+ chunk_start_idx = prompt_processed_so_far
298+ # Take a chunk of prompt tokens, up to ar_len length.
299+ chunk_end_idx = min (num_prompt_tokens , prompt_processed_so_far + ar_len )
300+ actual_chunk_tokens = prompt_token_list [chunk_start_idx :chunk_end_idx ]
301+ num_tokens_in_chunk = len (actual_chunk_tokens )
302+
303+ # Prepare tmp_token_list (padded with zeros).
304+ tmp_token_list = torch .zeros ((1 , ar_len ), dtype = dtype )
305+ tmp_token_list [0 , :num_tokens_in_chunk ] = torch .tensor (
306+ actual_chunk_tokens , dtype = dtype
307+ )
308+
309+ # Prepare tmp_pos (padded with zeros).
310+ tmp_pos = torch .zeros ((1 , ar_len ), dtype = torch .int32 )
311+ tmp_pos [0 , :num_tokens_in_chunk ] = all_pos [
312+ 0 ,
313+ prompt_processed_so_far : prompt_processed_so_far + num_tokens_in_chunk ,
314+ ]
315315
316+ # Run inference.
316317 logits , new_k_caches , new_v_caches = module (
317318 tmp_token_list ,
318- tmp_atten_mask ,
319+ atten_mask ,
319320 tmp_pos ,
320321 * k_caches ,
321322 * v_caches ,
322323 )
323324 if collect_logits :
324- result_logits .append (logits )
325- atten_mask , pos , k_caches , v_caches = kv_updater (
326- ar_len , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
325+ result_logits .append (logits [:, :num_tokens_in_chunk ])
326+
327+ # Update the prompt_processed_so_far, KV cache and attention mask.
328+ atten_mask , prompt_processed_so_far , k_caches , v_caches = kv_updater (
329+ ar_len ,
330+ num_tokens_in_chunk ,
331+ atten_mask ,
332+ prompt_processed_so_far ,
333+ k_caches ,
334+ v_caches ,
335+ new_k_caches ,
336+ new_v_caches ,
337+ )
338+ # Append the last run logits to the result_logits list.
339+ generated_token_list .append (
340+ torch .argmax (logits [:, num_tokens_in_chunk - 1 ], dim = - 1 ).item ()
341+ )
342+
343+ # Phase 2: Generate tokens until the EOS token is generated or max_seq_len is reached.
344+ # When run on wikitext for ppl evaluation, this while-loop is not expected to run.
345+ token_generated_so_far = 0
346+ # After prompt processing, the new one token is generated
347+ pos = prompt_processed_so_far + 1
348+ while generated_token_list [- 1 ] != tokenizer .eos_id and pos < max_seq_len :
349+ num_generated_tokens = len (generated_token_list )
350+ chunk_start_idx = token_generated_so_far
351+ # Take a chunk of generated tokens, up to ar_len length.
352+ chunk_end_idx = min (num_generated_tokens , token_generated_so_far + ar_len )
353+ actual_chunk_tokens = generated_token_list [chunk_start_idx :chunk_end_idx ]
354+ num_tokens_in_chunk = len (actual_chunk_tokens )
355+
356+ # Prepare tmp_token_list (padded with zeros).
357+ tmp_token_list = torch .zeros ((1 , ar_len ), dtype = dtype )
358+ tmp_token_list [0 , :num_tokens_in_chunk ] = torch .tensor (
359+ actual_chunk_tokens , dtype = dtype
360+ )
361+
362+ # Prepare tmp_pos (padded with zeros).
363+ tmp_pos = torch .zeros ((1 , ar_len ), dtype = torch .int32 )
364+ tmp_pos [0 , :num_tokens_in_chunk ] = all_pos [
365+ 0 , pos : pos + num_tokens_in_chunk
366+ ]
367+
368+ logits , new_k_caches , new_v_caches = module (
369+ tmp_token_list ,
370+ atten_mask ,
371+ tmp_pos ,
372+ * k_caches ,
373+ * v_caches ,
327374 )
328- if pos > len (token_list ):
329- token_list .append (torch .argmax (logits [:, - 1 ], dim = - 1 ).item ())
375+ if collect_logits :
330376
331- logging .info (f"kv inference result:\n { tokenizer .decode (token_list )} " )
377+ result_logits .append (logits [:, :num_tokens_in_chunk ])
378+ atten_mask , pos , k_caches , v_caches = kv_updater (
379+ ar_len ,
380+ 1 ,
381+ atten_mask ,
382+ pos ,
383+ k_caches ,
384+ v_caches ,
385+ new_k_caches ,
386+ new_v_caches ,
387+ )
388+ generated_token_list .append (
389+ torch .argmax (logits [:, num_tokens_in_chunk - 1 ], dim = - 1 ).item ()
390+ )
391+ token_generated_so_far += 1
392+ logging .info (
393+ f"kv inference result:\n { tokenizer .decode (prompt_token_list + generated_token_list )} "
394+ )
332395 if collect_logits :
333396 result_logits = torch .cat (result_logits , dim = 1 )
334397 return result_logits
0 commit comments