@@ -216,37 +216,42 @@ def post_process():
216216
217217
218218def smart_mask_updater (
219- ar_len , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
219+ _ , n_updates , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
220220):
221- # Update the KV cache input for the next inference when the position exceeds the autoregressive length.
222- if pos >= ar_len :
221+ # ar_len is unused in smart mask
222+ max_cache_len = k_caches [0 ].size (- 1 )
223+ if pos + n_updates <= max_cache_len :
223224 for i , k_cache in enumerate (k_caches ):
224- k_cache [:, :, pos - ar_len ] = new_k_caches [i ][:, :, 0 ]
225+ k_cache [:, :, pos : pos + n_updates ] = new_k_caches [i ][:, :, : n_updates ]
225226
226227 for i , v_cache in enumerate (v_caches ):
227- v_cache [:, pos - ar_len , :] = new_v_caches [i ][:, 0 , :]
228- atten_mask [:, :, pos - ar_len ] = 0
228+ v_cache [:, pos : pos + n_updates , :] = new_v_caches [i ][:, :n_updates , :]
229+ atten_mask [:, :, pos : pos + n_updates ] = 0
230+ pos += n_updates
229231
230- pos += 1
231232 return (atten_mask , pos , k_caches , v_caches )
232233
233234
234235def shift_pointer_updater (
235- ar_len , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
236+ ar_len , n_updates , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
236237):
237- # Update the KV cache input for the next inference when the position exceeds the autoregressive length.
238- if pos >= ar_len :
238+ max_cache_len = k_caches [ 0 ]. size ( - 1 )
239+ if pos + n_updates <= max_cache_len :
239240 k_caches = [
240- torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ][:, :, :1 ]], dim = - 1 )
241+ torch .cat (
242+ [k_cache [:, :, n_updates :], new_k_caches [i ][:, :, :n_updates ]], dim = - 1
243+ )
241244 for i , k_cache in enumerate (k_caches )
242245 ]
243246 v_caches = [
244- torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ][:, :1 , :]], dim = 1 )
247+ torch .cat (
248+ [v_cache [:, n_updates :, :], new_v_caches [i ][:, :n_updates , :]], dim = 1
249+ )
245250 for i , v_cache in enumerate (v_caches )
246251 ]
247- atten_mask [:, :, - pos - 1 ] = 0
252+ atten_mask [:, :, - pos - n_updates - ar_len : - pos - ar_len ] = 0
253+ pos += n_updates
248254
249- pos += 1
250255 return (atten_mask , pos , k_caches , v_caches )
251256
252257
@@ -266,69 +271,120 @@ def kv_inference(
266271 # TODO: change criteria & support batch inputs if necessary
267272 all_pos = torch .arange (0 , max_seq_len , 1 , dtype = torch .int32 ).unsqueeze (0 )
268273
269- token_list , result_logits = [], []
274+ prompt_token_list , total_token_list , result_logits = [], [], []
270275
271276 if isinstance (prompt , str ):
272277 # Llama2 tokenizer has no special tokens
273278 if isinstance (tokenizer , (SentencePieceTokenizer , HuggingFaceTokenizer )):
274- token_list = tokenizer .encode (prompt , bos = True , eos = False )
279+ prompt_token_list = tokenizer .encode (prompt , bos = True , eos = False )
275280 elif isinstance (tokenizer , TiktokenTokenizer ):
276- token_list = tokenizer .encode (
281+ prompt_token_list = tokenizer .encode (
277282 prompt , bos = True , eos = False , allowed_special = "all"
278283 )
279284 else :
280285 raise RuntimeError ("Unknown tokenizer" )
281286 else :
282- token_list = prompt .flatten ().tolist ()
283- pos = len ( token_list ) if len ( token_list ) < ar_len else ar_len
287+ prompt_token_list = prompt .flatten ().tolist ()
288+ total_token_list = prompt_token_list
284289 dtype = torch .int64 if use_i64_token else torch .int32
285290
286291 with torch .no_grad ():
287- while token_list [- 1 ] != tokenizer .eos_id and pos < max_seq_len :
288- tmp_token_list = torch .tensor (
289- token_list [pos - ar_len : pos ], dtype = dtype
290- ).reshape (1 , - 1 )
291- tmp_pos = all_pos [:, pos - ar_len : pos ]
292- tmp_atten_mask = atten_mask
293- if pos < ar_len :
294- tmp_token_list = torch .cat (
295- [
296- torch .zeros ((1 , ar_len - pos ), dtype = dtype ),
297- torch .tensor (token_list , dtype = dtype ).reshape (1 , - 1 ),
298- ],
299- dim = 1 ,
300- )
301- tmp_pos = torch .cat (
302- [
303- torch .zeros ((1 , ar_len - pos ), dtype = torch .int32 ),
304- all_pos [:, :pos ],
305- ],
306- dim = 1 ,
307- )
308- tmp_atten_mask = torch .cat (
309- [
310- torch .ones (1 , ar_len , max_seq_len - pos ) * - 255.0 ,
311- atten_mask [:, :, - pos :],
312- ],
313- dim = - 1 ,
314- )
292+ # Phase 1: Prefill the prompt in ar_len chunks.
293+ num_prompt_tokens = len (prompt_token_list )
294+ pos = 0 # Tracks how many prompt tokens have been processed.
295+ while pos < num_prompt_tokens :
296+ chunk_start_idx = pos
297+ # Take a chunk of prompt tokens, up to ar_len length.
298+ chunk_end_idx = min (num_prompt_tokens , pos + ar_len )
299+ actual_chunk_tokens = prompt_token_list [chunk_start_idx :chunk_end_idx ]
300+ num_tokens_in_chunk = len (actual_chunk_tokens )
301+
302+ # Prepare tmp_token_list (padded with zeros).
303+ tmp_token_list = torch .zeros ((1 , ar_len ), dtype = dtype )
304+ tmp_token_list [0 , :num_tokens_in_chunk ] = torch .tensor (
305+ actual_chunk_tokens , dtype = dtype
306+ )
315307
308+ # Prepare tmp_pos (padded with zeros).
309+ tmp_pos = torch .zeros ((1 , ar_len ), dtype = torch .int32 )
310+ tmp_pos [0 , :num_tokens_in_chunk ] = all_pos [
311+ 0 ,
312+ pos : pos + num_tokens_in_chunk ,
313+ ]
314+
315+ # Run inference.
316316 logits , new_k_caches , new_v_caches = module (
317317 tmp_token_list ,
318- tmp_atten_mask ,
318+ atten_mask ,
319319 tmp_pos ,
320320 * k_caches ,
321321 * v_caches ,
322322 )
323323 if collect_logits :
324- result_logits .append (logits )
324+ result_logits .append (logits [:, :num_tokens_in_chunk ])
325+
326+ # Update the pos, KV cache and attention mask.
325327 atten_mask , pos , k_caches , v_caches = kv_updater (
326- ar_len , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
328+ ar_len ,
329+ num_tokens_in_chunk ,
330+ atten_mask ,
331+ pos ,
332+ k_caches ,
333+ v_caches ,
334+ new_k_caches ,
335+ new_v_caches ,
336+ )
337+ # Append the last run logits to the total_token_list.
338+ total_token_list .append (
339+ torch .argmax (logits [:, num_tokens_in_chunk - 1 ], dim = - 1 ).item ()
340+ )
341+
342+ # Phase 2: Generate tokens until the EOS token is generated or max_seq_len is reached.
343+ # When run on wikitext for ppl evaluation, this while-loop is not expected to run.
344+ max_cache_len = max_seq_len - ar_len
345+ num_tokens = len (total_token_list )
346+ while total_token_list [- 1 ] != tokenizer .eos_id and num_tokens < max_seq_len :
347+ chunk_start_idx = min (pos , max_cache_len )
348+ # Take a chunk of generated tokens, up to ar_len length.
349+ chunk_end_idx = num_tokens
350+ actual_chunk_tokens = total_token_list [chunk_start_idx :chunk_end_idx ]
351+ num_tokens_in_chunk = len (actual_chunk_tokens )
352+
353+ # Prepare tmp_token_list (padded with zeros).
354+ tmp_token_list = torch .zeros ((1 , ar_len ), dtype = dtype )
355+ tmp_token_list [0 , :num_tokens_in_chunk ] = torch .tensor (
356+ actual_chunk_tokens , dtype = dtype
357+ )
358+
359+ # Prepare tmp_pos (padded with zeros).
360+ tmp_pos = torch .zeros ((1 , ar_len ), dtype = torch .int32 )
361+ tmp_pos [0 , :num_tokens_in_chunk ] = all_pos [0 , chunk_start_idx :chunk_end_idx ]
362+
363+ logits , new_k_caches , new_v_caches = module (
364+ tmp_token_list ,
365+ atten_mask ,
366+ tmp_pos ,
367+ * k_caches ,
368+ * v_caches ,
327369 )
328- if pos > len ( token_list ) :
329- token_list .append (torch . argmax ( logits [:, - 1 ], dim = - 1 ). item () )
370+ if collect_logits :
371+ result_logits .append (logits [:, : num_tokens_in_chunk ] )
330372
331- logging .info (f"kv inference result:\n { tokenizer .decode (token_list )} " )
373+ atten_mask , pos , k_caches , v_caches = kv_updater (
374+ ar_len ,
375+ 1 ,
376+ atten_mask ,
377+ pos ,
378+ k_caches ,
379+ v_caches ,
380+ new_k_caches ,
381+ new_v_caches ,
382+ )
383+ total_token_list .append (
384+ torch .argmax (logits [:, num_tokens_in_chunk - 1 ], dim = - 1 ).item ()
385+ )
386+ num_tokens = len (total_token_list )
387+ logging .info (f"kv inference result:\n { tokenizer .decode (total_token_list )} " )
332388 if collect_logits :
333389 result_logits = torch .cat (result_logits , dim = 1 )
334390 return result_logits
0 commit comments