@@ -219,37 +219,42 @@ def post_process():
219
219
220
220
221
221
def smart_mask_updater (
222
- ar_len , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
222
+ _ , n_updates , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
223
223
):
224
- # Update the KV cache input for the next inference when the position exceeds the autoregressive length.
225
- if pos >= ar_len :
224
+ # ar_len is unused in smart mask
225
+ max_cache_len = k_caches [0 ].size (- 1 )
226
+ if pos + n_updates <= max_cache_len :
226
227
for i , k_cache in enumerate (k_caches ):
227
- k_cache [:, :, pos - ar_len ] = new_k_caches [i ][:, :, 0 ]
228
+ k_cache [:, :, pos : pos + n_updates ] = new_k_caches [i ][:, :, : n_updates ]
228
229
229
230
for i , v_cache in enumerate (v_caches ):
230
- v_cache [:, pos - ar_len , :] = new_v_caches [i ][:, 0 , :]
231
- atten_mask [:, :, pos - ar_len ] = 0
231
+ v_cache [:, pos : pos + n_updates , :] = new_v_caches [i ][:, :n_updates , :]
232
+ atten_mask [:, :, pos : pos + n_updates ] = 0
233
+ pos += n_updates
232
234
233
- pos += 1
234
235
return (atten_mask , pos , k_caches , v_caches )
235
236
236
237
237
238
def shift_pointer_updater (
238
- ar_len , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
239
+ ar_len , n_updates , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
239
240
):
240
- # Update the KV cache input for the next inference when the position exceeds the autoregressive length.
241
- if pos >= ar_len :
241
+ max_cache_len = k_caches [ 0 ]. size ( - 1 )
242
+ if pos + n_updates <= max_cache_len :
242
243
k_caches = [
243
- torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ][:, :, :1 ]], dim = - 1 )
244
+ torch .cat (
245
+ [k_cache [:, :, n_updates :], new_k_caches [i ][:, :, :n_updates ]], dim = - 1
246
+ )
244
247
for i , k_cache in enumerate (k_caches )
245
248
]
246
249
v_caches = [
247
- torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ][:, :1 , :]], dim = 1 )
250
+ torch .cat (
251
+ [v_cache [:, n_updates :, :], new_v_caches [i ][:, :n_updates , :]], dim = 1
252
+ )
248
253
for i , v_cache in enumerate (v_caches )
249
254
]
250
- atten_mask [:, :, - pos - 1 ] = 0
255
+ atten_mask [:, :, - pos - n_updates - ar_len : - pos - ar_len ] = 0
256
+ pos += n_updates
251
257
252
- pos += 1
253
258
return (atten_mask , pos , k_caches , v_caches )
254
259
255
260
@@ -269,70 +274,121 @@ def kv_inference(
269
274
# TODO: change criteria & support batch inputs if necessary
270
275
all_pos = torch .arange (0 , max_seq_len , 1 , dtype = torch .int32 ).unsqueeze (0 )
271
276
272
- token_list , result_logits = [], []
277
+ prompt_token_list , total_token_list , result_logits = [], [], []
273
278
274
279
if isinstance (prompt , str ):
275
280
# Llama2 tokenizer has no special tokens
276
281
if isinstance (tokenizer , (SentencePieceTokenizer , HuggingFaceTokenizer )):
277
- token_list = tokenizer .encode (prompt , bos = True , eos = False )
282
+ prompt_token_list = tokenizer .encode (prompt , bos = True , eos = False )
278
283
elif isinstance (tokenizer , TiktokenTokenizer ):
279
- token_list = tokenizer .encode (
284
+ prompt_token_list = tokenizer .encode (
280
285
prompt , bos = True , eos = False , allowed_special = "all"
281
286
)
282
287
else :
283
288
raise RuntimeError ("Unknown tokenizer" )
284
289
else :
285
290
# pyre-ignore
286
- token_list = prompt .flatten ().tolist ()
287
- pos = len ( token_list ) if len ( token_list ) < ar_len else ar_len
291
+ prompt_token_list = prompt .flatten ().tolist ()
292
+ total_token_list = prompt_token_list
288
293
dtype = torch .int64 if use_i64_token else torch .int32
289
294
290
295
with torch .no_grad ():
291
- while token_list [- 1 ] != tokenizer .eos_id and pos < max_seq_len :
292
- tmp_token_list = torch .tensor (
293
- token_list [pos - ar_len : pos ], dtype = dtype
294
- ).reshape (1 , - 1 )
295
- tmp_pos = all_pos [:, pos - ar_len : pos ]
296
- tmp_atten_mask = atten_mask
297
- if pos < ar_len :
298
- tmp_token_list = torch .cat (
299
- [
300
- torch .zeros ((1 , ar_len - pos ), dtype = dtype ),
301
- torch .tensor (token_list , dtype = dtype ).reshape (1 , - 1 ),
302
- ],
303
- dim = 1 ,
304
- )
305
- tmp_pos = torch .cat (
306
- [
307
- torch .zeros ((1 , ar_len - pos ), dtype = torch .int32 ),
308
- all_pos [:, :pos ],
309
- ],
310
- dim = 1 ,
311
- )
312
- tmp_atten_mask = torch .cat (
313
- [
314
- torch .ones (1 , ar_len , max_seq_len - pos ) * - 255.0 ,
315
- atten_mask [:, :, - pos :],
316
- ],
317
- dim = - 1 ,
318
- )
296
+ # Phase 1: Prefill the prompt in ar_len chunks.
297
+ num_prompt_tokens = len (prompt_token_list )
298
+ pos = 0 # Tracks how many prompt tokens have been processed.
299
+ while pos < num_prompt_tokens :
300
+ chunk_start_idx = pos
301
+ # Take a chunk of prompt tokens, up to ar_len length.
302
+ chunk_end_idx = min (num_prompt_tokens , pos + ar_len )
303
+ actual_chunk_tokens = prompt_token_list [chunk_start_idx :chunk_end_idx ]
304
+ num_tokens_in_chunk = len (actual_chunk_tokens )
305
+
306
+ # Prepare tmp_token_list (padded with zeros).
307
+ tmp_token_list = torch .zeros ((1 , ar_len ), dtype = dtype )
308
+ tmp_token_list [0 , :num_tokens_in_chunk ] = torch .tensor (
309
+ actual_chunk_tokens , dtype = dtype
310
+ )
319
311
312
+ # Prepare tmp_pos (padded with zeros).
313
+ tmp_pos = torch .zeros ((1 , ar_len ), dtype = torch .int32 )
314
+ tmp_pos [0 , :num_tokens_in_chunk ] = all_pos [
315
+ 0 ,
316
+ pos : pos + num_tokens_in_chunk ,
317
+ ]
318
+
319
+ # Run inference.
320
320
logits , new_k_caches , new_v_caches = module (
321
321
tmp_token_list ,
322
- tmp_atten_mask ,
322
+ atten_mask ,
323
323
tmp_pos ,
324
324
* k_caches ,
325
325
* v_caches ,
326
326
)
327
327
if collect_logits :
328
- result_logits .append (logits )
328
+ result_logits .append (logits [:, :num_tokens_in_chunk ])
329
+
330
+ # Update the pos, KV cache and attention mask.
329
331
atten_mask , pos , k_caches , v_caches = kv_updater (
330
- ar_len , atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
332
+ ar_len ,
333
+ num_tokens_in_chunk ,
334
+ atten_mask ,
335
+ pos ,
336
+ k_caches ,
337
+ v_caches ,
338
+ new_k_caches ,
339
+ new_v_caches ,
340
+ )
341
+ # Append the last run logits to the total_token_list.
342
+ total_token_list .append (
343
+ torch .argmax (logits [:, num_tokens_in_chunk - 1 ], dim = - 1 ).item ()
344
+ )
345
+
346
+ # Phase 2: Generate tokens until the EOS token is generated or max_seq_len is reached.
347
+ # When run on wikitext for ppl evaluation, this while-loop is not expected to run.
348
+ max_cache_len = max_seq_len - ar_len
349
+ num_tokens = len (total_token_list )
350
+ while total_token_list [- 1 ] != tokenizer .eos_id and num_tokens < max_seq_len :
351
+ chunk_start_idx = min (pos , max_cache_len )
352
+ # Take a chunk of generated tokens, up to ar_len length.
353
+ chunk_end_idx = num_tokens
354
+ actual_chunk_tokens = total_token_list [chunk_start_idx :chunk_end_idx ]
355
+ num_tokens_in_chunk = len (actual_chunk_tokens )
356
+
357
+ # Prepare tmp_token_list (padded with zeros).
358
+ tmp_token_list = torch .zeros ((1 , ar_len ), dtype = dtype )
359
+ tmp_token_list [0 , :num_tokens_in_chunk ] = torch .tensor (
360
+ actual_chunk_tokens , dtype = dtype
361
+ )
362
+
363
+ # Prepare tmp_pos (padded with zeros).
364
+ tmp_pos = torch .zeros ((1 , ar_len ), dtype = torch .int32 )
365
+ tmp_pos [0 , :num_tokens_in_chunk ] = all_pos [0 , chunk_start_idx :chunk_end_idx ]
366
+
367
+ logits , new_k_caches , new_v_caches = module (
368
+ tmp_token_list ,
369
+ atten_mask ,
370
+ tmp_pos ,
371
+ * k_caches ,
372
+ * v_caches ,
331
373
)
332
- if pos > len ( token_list ) :
333
- token_list .append (torch . argmax ( logits [:, - 1 ], dim = - 1 ). item () )
374
+ if collect_logits :
375
+ result_logits .append (logits [:, : num_tokens_in_chunk ] )
334
376
335
- logging .info (f"kv inference result:\n { tokenizer .decode (token_list )} " )
377
+ atten_mask , pos , k_caches , v_caches = kv_updater (
378
+ ar_len ,
379
+ 1 ,
380
+ atten_mask ,
381
+ pos ,
382
+ k_caches ,
383
+ v_caches ,
384
+ new_k_caches ,
385
+ new_v_caches ,
386
+ )
387
+ total_token_list .append (
388
+ torch .argmax (logits [:, num_tokens_in_chunk - 1 ], dim = - 1 ).item ()
389
+ )
390
+ num_tokens = len (total_token_list )
391
+ logging .info (f"kv inference result:\n { tokenizer .decode (total_token_list )} " )
336
392
if collect_logits :
337
393
result_logits = torch .cat (result_logits , dim = 1 )
338
394
return result_logits
0 commit comments