Skip to content

Commit 829bcbc

Browse files
committed
Qualcomm AI Engine Direct - Refactor calibration flow
Summary: - Update calibration flow to enhance the speed of wikitext calibration
1 parent 6bc312a commit 829bcbc

File tree

1 file changed

+109
-53
lines changed

1 file changed

+109
-53
lines changed

examples/qualcomm/oss_scripts/llama/decoder_utils.py

Lines changed: 109 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -216,37 +216,42 @@ def post_process():
216216

217217

218218
def smart_mask_updater(
219-
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
219+
_, n_updates, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
220220
):
221-
# Update the KV cache input for the next inference when the position exceeds the autoregressive length.
222-
if pos >= ar_len:
221+
# ar_len is unused in smart mask
222+
max_cache_len = k_caches[0].size(-1)
223+
if pos + n_updates <= max_cache_len:
223224
for i, k_cache in enumerate(k_caches):
224-
k_cache[:, :, pos - ar_len] = new_k_caches[i][:, :, 0]
225+
k_cache[:, :, pos : pos + n_updates] = new_k_caches[i][:, :, :n_updates]
225226

226227
for i, v_cache in enumerate(v_caches):
227-
v_cache[:, pos - ar_len, :] = new_v_caches[i][:, 0, :]
228-
atten_mask[:, :, pos - ar_len] = 0
228+
v_cache[:, pos : pos + n_updates, :] = new_v_caches[i][:, :n_updates, :]
229+
atten_mask[:, :, pos : pos + n_updates] = 0
230+
pos += n_updates
229231

230-
pos += 1
231232
return (atten_mask, pos, k_caches, v_caches)
232233

233234

234235
def shift_pointer_updater(
235-
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
236+
ar_len, n_updates, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
236237
):
237-
# Update the KV cache input for the next inference when the position exceeds the autoregressive length.
238-
if pos >= ar_len:
238+
max_cache_len = k_caches[0].size(-1)
239+
if pos + n_updates <= max_cache_len:
239240
k_caches = [
240-
torch.cat([k_cache[:, :, 1:], new_k_caches[i][:, :, :1]], dim=-1)
241+
torch.cat(
242+
[k_cache[:, :, n_updates:], new_k_caches[i][:, :, :n_updates]], dim=-1
243+
)
241244
for i, k_cache in enumerate(k_caches)
242245
]
243246
v_caches = [
244-
torch.cat([v_cache[:, 1:, :], new_v_caches[i][:, :1, :]], dim=1)
247+
torch.cat(
248+
[v_cache[:, n_updates:, :], new_v_caches[i][:, :n_updates, :]], dim=1
249+
)
245250
for i, v_cache in enumerate(v_caches)
246251
]
247-
atten_mask[:, :, -pos - 1] = 0
252+
atten_mask[:, :, -pos - n_updates - ar_len : -pos - ar_len] = 0
253+
pos += n_updates
248254

249-
pos += 1
250255
return (atten_mask, pos, k_caches, v_caches)
251256

252257

@@ -266,69 +271,120 @@ def kv_inference(
266271
# TODO: change criteria & support batch inputs if necessary
267272
all_pos = torch.arange(0, max_seq_len, 1, dtype=torch.int32).unsqueeze(0)
268273

269-
token_list, result_logits = [], []
274+
prompt_token_list, total_token_list, result_logits = [], [], []
270275

271276
if isinstance(prompt, str):
272277
# Llama2 tokenizer has no special tokens
273278
if isinstance(tokenizer, (SentencePieceTokenizer, HuggingFaceTokenizer)):
274-
token_list = tokenizer.encode(prompt, bos=True, eos=False)
279+
prompt_token_list = tokenizer.encode(prompt, bos=True, eos=False)
275280
elif isinstance(tokenizer, TiktokenTokenizer):
276-
token_list = tokenizer.encode(
281+
prompt_token_list = tokenizer.encode(
277282
prompt, bos=True, eos=False, allowed_special="all"
278283
)
279284
else:
280285
raise RuntimeError("Unknown tokenizer")
281286
else:
282-
token_list = prompt.flatten().tolist()
283-
pos = len(token_list) if len(token_list) < ar_len else ar_len
287+
prompt_token_list = prompt.flatten().tolist()
288+
total_token_list = prompt_token_list
284289
dtype = torch.int64 if use_i64_token else torch.int32
285290

286291
with torch.no_grad():
287-
while token_list[-1] != tokenizer.eos_id and pos < max_seq_len:
288-
tmp_token_list = torch.tensor(
289-
token_list[pos - ar_len : pos], dtype=dtype
290-
).reshape(1, -1)
291-
tmp_pos = all_pos[:, pos - ar_len : pos]
292-
tmp_atten_mask = atten_mask
293-
if pos < ar_len:
294-
tmp_token_list = torch.cat(
295-
[
296-
torch.zeros((1, ar_len - pos), dtype=dtype),
297-
torch.tensor(token_list, dtype=dtype).reshape(1, -1),
298-
],
299-
dim=1,
300-
)
301-
tmp_pos = torch.cat(
302-
[
303-
torch.zeros((1, ar_len - pos), dtype=torch.int32),
304-
all_pos[:, :pos],
305-
],
306-
dim=1,
307-
)
308-
tmp_atten_mask = torch.cat(
309-
[
310-
torch.ones(1, ar_len, max_seq_len - pos) * -255.0,
311-
atten_mask[:, :, -pos:],
312-
],
313-
dim=-1,
314-
)
292+
# Phase 1: Prefill the prompt in ar_len chunks.
293+
num_prompt_tokens = len(prompt_token_list)
294+
pos = 0 # Tracks how many prompt tokens have been processed.
295+
while pos < num_prompt_tokens:
296+
chunk_start_idx = pos
297+
# Take a chunk of prompt tokens, up to ar_len length.
298+
chunk_end_idx = min(num_prompt_tokens, pos + ar_len)
299+
actual_chunk_tokens = prompt_token_list[chunk_start_idx:chunk_end_idx]
300+
num_tokens_in_chunk = len(actual_chunk_tokens)
301+
302+
# Prepare tmp_token_list (padded with zeros).
303+
tmp_token_list = torch.zeros((1, ar_len), dtype=dtype)
304+
tmp_token_list[0, :num_tokens_in_chunk] = torch.tensor(
305+
actual_chunk_tokens, dtype=dtype
306+
)
315307

308+
# Prepare tmp_pos (padded with zeros).
309+
tmp_pos = torch.zeros((1, ar_len), dtype=torch.int32)
310+
tmp_pos[0, :num_tokens_in_chunk] = all_pos[
311+
0,
312+
pos : pos + num_tokens_in_chunk,
313+
]
314+
315+
# Run inference.
316316
logits, new_k_caches, new_v_caches = module(
317317
tmp_token_list,
318-
tmp_atten_mask,
318+
atten_mask,
319319
tmp_pos,
320320
*k_caches,
321321
*v_caches,
322322
)
323323
if collect_logits:
324-
result_logits.append(logits)
324+
result_logits.append(logits[:, :num_tokens_in_chunk])
325+
326+
# Update the pos, KV cache and attention mask.
325327
atten_mask, pos, k_caches, v_caches = kv_updater(
326-
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
328+
ar_len,
329+
num_tokens_in_chunk,
330+
atten_mask,
331+
pos,
332+
k_caches,
333+
v_caches,
334+
new_k_caches,
335+
new_v_caches,
336+
)
337+
# Append the last run logits to the total_token_list.
338+
total_token_list.append(
339+
torch.argmax(logits[:, num_tokens_in_chunk - 1], dim=-1).item()
340+
)
341+
342+
# Phase 2: Generate tokens until the EOS token is generated or max_seq_len is reached.
343+
# When run on wikitext for ppl evaluation, this while-loop is not expected to run.
344+
max_cache_len = max_seq_len - ar_len
345+
num_tokens = len(total_token_list)
346+
while total_token_list[-1] != tokenizer.eos_id and num_tokens < max_seq_len:
347+
chunk_start_idx = min(pos, max_cache_len)
348+
# Take a chunk of generated tokens, up to ar_len length.
349+
chunk_end_idx = num_tokens
350+
actual_chunk_tokens = total_token_list[chunk_start_idx:chunk_end_idx]
351+
num_tokens_in_chunk = len(actual_chunk_tokens)
352+
353+
# Prepare tmp_token_list (padded with zeros).
354+
tmp_token_list = torch.zeros((1, ar_len), dtype=dtype)
355+
tmp_token_list[0, :num_tokens_in_chunk] = torch.tensor(
356+
actual_chunk_tokens, dtype=dtype
357+
)
358+
359+
# Prepare tmp_pos (padded with zeros).
360+
tmp_pos = torch.zeros((1, ar_len), dtype=torch.int32)
361+
tmp_pos[0, :num_tokens_in_chunk] = all_pos[0, chunk_start_idx:chunk_end_idx]
362+
363+
logits, new_k_caches, new_v_caches = module(
364+
tmp_token_list,
365+
atten_mask,
366+
tmp_pos,
367+
*k_caches,
368+
*v_caches,
327369
)
328-
if pos > len(token_list):
329-
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
370+
if collect_logits:
371+
result_logits.append(logits[:, :num_tokens_in_chunk])
330372

331-
logging.info(f"kv inference result:\n{tokenizer.decode(token_list)}")
373+
atten_mask, pos, k_caches, v_caches = kv_updater(
374+
ar_len,
375+
1,
376+
atten_mask,
377+
pos,
378+
k_caches,
379+
v_caches,
380+
new_k_caches,
381+
new_v_caches,
382+
)
383+
total_token_list.append(
384+
torch.argmax(logits[:, num_tokens_in_chunk - 1], dim=-1).item()
385+
)
386+
num_tokens = len(total_token_list)
387+
logging.info(f"kv inference result:\n{tokenizer.decode(total_token_list)}")
332388
if collect_logits:
333389
result_logits = torch.cat(result_logits, dim=1)
334390
return result_logits

0 commit comments

Comments
 (0)