Skip to content

Commit f0e16d1

Browse files
committed
Update calibration flow to enhance the speed of wikitext calibration
1 parent 7a1a1d3 commit f0e16d1

File tree

1 file changed

+117
-54
lines changed

1 file changed

+117
-54
lines changed

examples/qualcomm/oss_scripts/llama/decoder_utils.py

Lines changed: 117 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -216,37 +216,42 @@ def post_process():
216216

217217

218218
def smart_mask_updater(
219-
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
219+
_, n_updates, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
220220
):
221-
# Update the KV cache input for the next inference when the position exceeds the autoregressive length.
222-
if pos >= ar_len:
221+
# ar_len is unused in smart mask
222+
max_cache_len = k_caches[0].size(-1)
223+
if pos + n_updates < max_cache_len:
223224
for i, k_cache in enumerate(k_caches):
224-
k_cache[:, :, pos - ar_len] = new_k_caches[i][:, :, 0]
225+
k_cache[:, :, pos : pos + n_updates] = new_k_caches[i][:, :, :n_updates]
225226

226227
for i, v_cache in enumerate(v_caches):
227-
v_cache[:, pos - ar_len, :] = new_v_caches[i][:, 0, :]
228-
atten_mask[:, :, pos - ar_len] = 0
228+
v_cache[:, pos : pos + n_updates, :] = new_v_caches[i][:, :n_updates, :]
229+
atten_mask[:, :, pos : pos + n_updates] = 0
230+
pos += n_updates
229231

230-
pos += 1
231232
return (atten_mask, pos, k_caches, v_caches)
232233

233234

234235
def shift_pointer_updater(
235-
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
236+
ar_len, n_updates, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
236237
):
237-
# Update the KV cache input for the next inference when the position exceeds the autoregressive length.
238-
if pos >= ar_len:
238+
max_cache_len = k_caches[0].size(-1)
239+
if pos + n_updates < max_cache_len:
239240
k_caches = [
240-
torch.cat([k_cache[:, :, 1:], new_k_caches[i][:, :, :1]], dim=-1)
241+
torch.cat(
242+
[k_cache[:, :, n_updates:], new_k_caches[i][:, :, :n_updates]], dim=-1
243+
)
241244
for i, k_cache in enumerate(k_caches)
242245
]
243246
v_caches = [
244-
torch.cat([v_cache[:, 1:, :], new_v_caches[i][:, :1, :]], dim=1)
247+
torch.cat(
248+
[v_cache[:, n_updates:, :], new_v_caches[i][:, :n_updates, :]], dim=1
249+
)
245250
for i, v_cache in enumerate(v_caches)
246251
]
247-
atten_mask[:, :, -pos - 1] = 0
252+
atten_mask[:, :, -pos - n_updates - ar_len : -pos - ar_len] = 0
253+
pos += n_updates
248254

249-
pos += 1
250255
return (atten_mask, pos, k_caches, v_caches)
251256

252257

@@ -266,69 +271,127 @@ def kv_inference(
266271
# TODO: change criteria & support batch inputs if necessary
267272
all_pos = torch.arange(0, max_seq_len, 1, dtype=torch.int32).unsqueeze(0)
268273

269-
token_list, result_logits = [], []
274+
prompt_token_list, generated_token_list, result_logits = [], [], []
270275

271276
if isinstance(prompt, str):
272277
# Llama2 tokenizer has no special tokens
273278
if isinstance(tokenizer, (SentencePieceTokenizer, HuggingFaceTokenizer)):
274-
token_list = tokenizer.encode(prompt, bos=True, eos=False)
279+
prompt_token_list = tokenizer.encode(prompt, bos=True, eos=False)
275280
elif isinstance(tokenizer, TiktokenTokenizer):
276-
token_list = tokenizer.encode(
281+
prompt_token_list = tokenizer.encode(
277282
prompt, bos=True, eos=False, allowed_special="all"
278283
)
279284
else:
280285
raise RuntimeError("Unknown tokenizer")
281286
else:
282-
token_list = prompt.flatten().tolist()
283-
pos = len(token_list) if len(token_list) < ar_len else ar_len
287+
prompt_token_list = prompt.flatten().tolist()
284288
dtype = torch.int64 if use_i64_token else torch.int32
285289

286290
with torch.no_grad():
287-
while token_list[-1] != tokenizer.eos_id and pos < max_seq_len:
288-
tmp_token_list = torch.tensor(
289-
token_list[pos - ar_len : pos], dtype=dtype
290-
).reshape(1, -1)
291-
tmp_pos = all_pos[:, pos - ar_len : pos]
292-
tmp_atten_mask = atten_mask
293-
if pos < ar_len:
294-
tmp_token_list = torch.cat(
295-
[
296-
torch.zeros((1, ar_len - pos), dtype=dtype),
297-
torch.tensor(token_list, dtype=dtype).reshape(1, -1),
298-
],
299-
dim=1,
300-
)
301-
tmp_pos = torch.cat(
302-
[
303-
torch.zeros((1, ar_len - pos), dtype=torch.int32),
304-
all_pos[:, :pos],
305-
],
306-
dim=1,
307-
)
308-
tmp_atten_mask = torch.cat(
309-
[
310-
torch.ones(1, ar_len, max_seq_len - pos) * -255.0,
311-
atten_mask[:, :, -pos:],
312-
],
313-
dim=-1,
314-
)
291+
# Phase 1: Prefill the prompt in ar_len chunks.
292+
num_prompt_tokens = len(prompt_token_list)
293+
prompt_processed_so_far = (
294+
0 # Tracks how many prompt tokens have been processed.
295+
)
296+
while prompt_processed_so_far < num_prompt_tokens:
297+
chunk_start_idx = prompt_processed_so_far
298+
# Take a chunk of prompt tokens, up to ar_len length.
299+
chunk_end_idx = min(num_prompt_tokens, prompt_processed_so_far + ar_len)
300+
actual_chunk_tokens = prompt_token_list[chunk_start_idx:chunk_end_idx]
301+
num_tokens_in_chunk = len(actual_chunk_tokens)
302+
303+
# Prepare tmp_token_list (padded with zeros).
304+
tmp_token_list = torch.zeros((1, ar_len), dtype=dtype)
305+
tmp_token_list[0, :num_tokens_in_chunk] = torch.tensor(
306+
actual_chunk_tokens, dtype=dtype
307+
)
308+
309+
# Prepare tmp_pos (padded with zeros).
310+
tmp_pos = torch.zeros((1, ar_len), dtype=torch.int32)
311+
tmp_pos[0, :num_tokens_in_chunk] = all_pos[
312+
0,
313+
prompt_processed_so_far : prompt_processed_so_far + num_tokens_in_chunk,
314+
]
315315

316+
# Run inference.
316317
logits, new_k_caches, new_v_caches = module(
317318
tmp_token_list,
318-
tmp_atten_mask,
319+
atten_mask,
319320
tmp_pos,
320321
*k_caches,
321322
*v_caches,
322323
)
323324
if collect_logits:
324-
result_logits.append(logits)
325-
atten_mask, pos, k_caches, v_caches = kv_updater(
326-
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
325+
result_logits.append(logits[:, :num_tokens_in_chunk])
326+
327+
# Update the prompt_processed_so_far, KV cache and attention mask.
328+
atten_mask, prompt_processed_so_far, k_caches, v_caches = kv_updater(
329+
ar_len,
330+
num_tokens_in_chunk,
331+
atten_mask,
332+
prompt_processed_so_far,
333+
k_caches,
334+
v_caches,
335+
new_k_caches,
336+
new_v_caches,
337+
)
338+
# Append the last run logits to the result_logits list.
339+
generated_token_list.append(
340+
torch.argmax(logits[:, num_tokens_in_chunk - 1], dim=-1).item()
341+
)
342+
343+
# Phase 2: Generate tokens until the EOS token is generated or max_seq_len is reached.
344+
# When run on wikitext for ppl evaluation, this while-loop is not expected to run.
345+
token_generated_so_far = 0
346+
# After prompt processing, the new one token is generated
347+
pos = prompt_processed_so_far + 1
348+
while generated_token_list[-1] != tokenizer.eos_id and pos < max_seq_len:
349+
num_generated_tokens = len(generated_token_list)
350+
chunk_start_idx = token_generated_so_far
351+
# Take a chunk of generated tokens, up to ar_len length.
352+
chunk_end_idx = min(num_generated_tokens, token_generated_so_far + ar_len)
353+
actual_chunk_tokens = generated_token_list[chunk_start_idx:chunk_end_idx]
354+
num_tokens_in_chunk = len(actual_chunk_tokens)
355+
356+
# Prepare tmp_token_list (padded with zeros).
357+
tmp_token_list = torch.zeros((1, ar_len), dtype=dtype)
358+
tmp_token_list[0, :num_tokens_in_chunk] = torch.tensor(
359+
actual_chunk_tokens, dtype=dtype
360+
)
361+
362+
# Prepare tmp_pos (padded with zeros).
363+
tmp_pos = torch.zeros((1, ar_len), dtype=torch.int32)
364+
tmp_pos[0, :num_tokens_in_chunk] = all_pos[
365+
0, pos : pos + num_tokens_in_chunk
366+
]
367+
368+
logits, new_k_caches, new_v_caches = module(
369+
tmp_token_list,
370+
atten_mask,
371+
tmp_pos,
372+
*k_caches,
373+
*v_caches,
327374
)
328-
if pos > len(token_list):
329-
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
375+
if collect_logits:
330376

331-
logging.info(f"kv inference result:\n{tokenizer.decode(token_list)}")
377+
result_logits.append(logits[:, :num_tokens_in_chunk])
378+
atten_mask, pos, k_caches, v_caches = kv_updater(
379+
ar_len,
380+
1,
381+
atten_mask,
382+
pos,
383+
k_caches,
384+
v_caches,
385+
new_k_caches,
386+
new_v_caches,
387+
)
388+
generated_token_list.append(
389+
torch.argmax(logits[:, num_tokens_in_chunk - 1], dim=-1).item()
390+
)
391+
token_generated_so_far += 1
392+
logging.info(
393+
f"kv inference result:\n{tokenizer.decode(prompt_token_list + generated_token_list)}"
394+
)
332395
if collect_logits:
333396
result_logits = torch.cat(result_logits, dim=1)
334397
return result_logits

0 commit comments

Comments
 (0)