Skip to content

Commit 18030f9

Browse files
authored
Qualcomm AI Engine Direct - Refactor calibration flow (#13150)
Summary: - Update calibration flow to enhance the speed of wikitext calibration cc: @haowhsu-quic , @winskuo-quic
1 parent 8e85857 commit 18030f9

File tree

1 file changed

+109
-53
lines changed

1 file changed

+109
-53
lines changed

examples/qualcomm/oss_scripts/llama/decoder_utils.py

Lines changed: 109 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -219,37 +219,42 @@ def post_process():
219219

220220

221221
def smart_mask_updater(
222-
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
222+
_, n_updates, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
223223
):
224-
# Update the KV cache input for the next inference when the position exceeds the autoregressive length.
225-
if pos >= ar_len:
224+
# ar_len is unused in smart mask
225+
max_cache_len = k_caches[0].size(-1)
226+
if pos + n_updates <= max_cache_len:
226227
for i, k_cache in enumerate(k_caches):
227-
k_cache[:, :, pos - ar_len] = new_k_caches[i][:, :, 0]
228+
k_cache[:, :, pos : pos + n_updates] = new_k_caches[i][:, :, :n_updates]
228229

229230
for i, v_cache in enumerate(v_caches):
230-
v_cache[:, pos - ar_len, :] = new_v_caches[i][:, 0, :]
231-
atten_mask[:, :, pos - ar_len] = 0
231+
v_cache[:, pos : pos + n_updates, :] = new_v_caches[i][:, :n_updates, :]
232+
atten_mask[:, :, pos : pos + n_updates] = 0
233+
pos += n_updates
232234

233-
pos += 1
234235
return (atten_mask, pos, k_caches, v_caches)
235236

236237

237238
def shift_pointer_updater(
238-
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
239+
ar_len, n_updates, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
239240
):
240-
# Update the KV cache input for the next inference when the position exceeds the autoregressive length.
241-
if pos >= ar_len:
241+
max_cache_len = k_caches[0].size(-1)
242+
if pos + n_updates <= max_cache_len:
242243
k_caches = [
243-
torch.cat([k_cache[:, :, 1:], new_k_caches[i][:, :, :1]], dim=-1)
244+
torch.cat(
245+
[k_cache[:, :, n_updates:], new_k_caches[i][:, :, :n_updates]], dim=-1
246+
)
244247
for i, k_cache in enumerate(k_caches)
245248
]
246249
v_caches = [
247-
torch.cat([v_cache[:, 1:, :], new_v_caches[i][:, :1, :]], dim=1)
250+
torch.cat(
251+
[v_cache[:, n_updates:, :], new_v_caches[i][:, :n_updates, :]], dim=1
252+
)
248253
for i, v_cache in enumerate(v_caches)
249254
]
250-
atten_mask[:, :, -pos - 1] = 0
255+
atten_mask[:, :, -pos - n_updates - ar_len : -pos - ar_len] = 0
256+
pos += n_updates
251257

252-
pos += 1
253258
return (atten_mask, pos, k_caches, v_caches)
254259

255260

@@ -269,70 +274,121 @@ def kv_inference(
269274
# TODO: change criteria & support batch inputs if necessary
270275
all_pos = torch.arange(0, max_seq_len, 1, dtype=torch.int32).unsqueeze(0)
271276

272-
token_list, result_logits = [], []
277+
prompt_token_list, total_token_list, result_logits = [], [], []
273278

274279
if isinstance(prompt, str):
275280
# Llama2 tokenizer has no special tokens
276281
if isinstance(tokenizer, (SentencePieceTokenizer, HuggingFaceTokenizer)):
277-
token_list = tokenizer.encode(prompt, bos=True, eos=False)
282+
prompt_token_list = tokenizer.encode(prompt, bos=True, eos=False)
278283
elif isinstance(tokenizer, TiktokenTokenizer):
279-
token_list = tokenizer.encode(
284+
prompt_token_list = tokenizer.encode(
280285
prompt, bos=True, eos=False, allowed_special="all"
281286
)
282287
else:
283288
raise RuntimeError("Unknown tokenizer")
284289
else:
285290
# pyre-ignore
286-
token_list = prompt.flatten().tolist()
287-
pos = len(token_list) if len(token_list) < ar_len else ar_len
291+
prompt_token_list = prompt.flatten().tolist()
292+
total_token_list = prompt_token_list
288293
dtype = torch.int64 if use_i64_token else torch.int32
289294

290295
with torch.no_grad():
291-
while token_list[-1] != tokenizer.eos_id and pos < max_seq_len:
292-
tmp_token_list = torch.tensor(
293-
token_list[pos - ar_len : pos], dtype=dtype
294-
).reshape(1, -1)
295-
tmp_pos = all_pos[:, pos - ar_len : pos]
296-
tmp_atten_mask = atten_mask
297-
if pos < ar_len:
298-
tmp_token_list = torch.cat(
299-
[
300-
torch.zeros((1, ar_len - pos), dtype=dtype),
301-
torch.tensor(token_list, dtype=dtype).reshape(1, -1),
302-
],
303-
dim=1,
304-
)
305-
tmp_pos = torch.cat(
306-
[
307-
torch.zeros((1, ar_len - pos), dtype=torch.int32),
308-
all_pos[:, :pos],
309-
],
310-
dim=1,
311-
)
312-
tmp_atten_mask = torch.cat(
313-
[
314-
torch.ones(1, ar_len, max_seq_len - pos) * -255.0,
315-
atten_mask[:, :, -pos:],
316-
],
317-
dim=-1,
318-
)
296+
# Phase 1: Prefill the prompt in ar_len chunks.
297+
num_prompt_tokens = len(prompt_token_list)
298+
pos = 0 # Tracks how many prompt tokens have been processed.
299+
while pos < num_prompt_tokens:
300+
chunk_start_idx = pos
301+
# Take a chunk of prompt tokens, up to ar_len length.
302+
chunk_end_idx = min(num_prompt_tokens, pos + ar_len)
303+
actual_chunk_tokens = prompt_token_list[chunk_start_idx:chunk_end_idx]
304+
num_tokens_in_chunk = len(actual_chunk_tokens)
305+
306+
# Prepare tmp_token_list (padded with zeros).
307+
tmp_token_list = torch.zeros((1, ar_len), dtype=dtype)
308+
tmp_token_list[0, :num_tokens_in_chunk] = torch.tensor(
309+
actual_chunk_tokens, dtype=dtype
310+
)
319311

312+
# Prepare tmp_pos (padded with zeros).
313+
tmp_pos = torch.zeros((1, ar_len), dtype=torch.int32)
314+
tmp_pos[0, :num_tokens_in_chunk] = all_pos[
315+
0,
316+
pos : pos + num_tokens_in_chunk,
317+
]
318+
319+
# Run inference.
320320
logits, new_k_caches, new_v_caches = module(
321321
tmp_token_list,
322-
tmp_atten_mask,
322+
atten_mask,
323323
tmp_pos,
324324
*k_caches,
325325
*v_caches,
326326
)
327327
if collect_logits:
328-
result_logits.append(logits)
328+
result_logits.append(logits[:, :num_tokens_in_chunk])
329+
330+
# Update the pos, KV cache and attention mask.
329331
atten_mask, pos, k_caches, v_caches = kv_updater(
330-
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
332+
ar_len,
333+
num_tokens_in_chunk,
334+
atten_mask,
335+
pos,
336+
k_caches,
337+
v_caches,
338+
new_k_caches,
339+
new_v_caches,
340+
)
341+
# Append the last run logits to the total_token_list.
342+
total_token_list.append(
343+
torch.argmax(logits[:, num_tokens_in_chunk - 1], dim=-1).item()
344+
)
345+
346+
# Phase 2: Generate tokens until the EOS token is generated or max_seq_len is reached.
347+
# When run on wikitext for ppl evaluation, this while-loop is not expected to run.
348+
max_cache_len = max_seq_len - ar_len
349+
num_tokens = len(total_token_list)
350+
while total_token_list[-1] != tokenizer.eos_id and num_tokens < max_seq_len:
351+
chunk_start_idx = min(pos, max_cache_len)
352+
# Take a chunk of generated tokens, up to ar_len length.
353+
chunk_end_idx = num_tokens
354+
actual_chunk_tokens = total_token_list[chunk_start_idx:chunk_end_idx]
355+
num_tokens_in_chunk = len(actual_chunk_tokens)
356+
357+
# Prepare tmp_token_list (padded with zeros).
358+
tmp_token_list = torch.zeros((1, ar_len), dtype=dtype)
359+
tmp_token_list[0, :num_tokens_in_chunk] = torch.tensor(
360+
actual_chunk_tokens, dtype=dtype
361+
)
362+
363+
# Prepare tmp_pos (padded with zeros).
364+
tmp_pos = torch.zeros((1, ar_len), dtype=torch.int32)
365+
tmp_pos[0, :num_tokens_in_chunk] = all_pos[0, chunk_start_idx:chunk_end_idx]
366+
367+
logits, new_k_caches, new_v_caches = module(
368+
tmp_token_list,
369+
atten_mask,
370+
tmp_pos,
371+
*k_caches,
372+
*v_caches,
331373
)
332-
if pos > len(token_list):
333-
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
374+
if collect_logits:
375+
result_logits.append(logits[:, :num_tokens_in_chunk])
334376

335-
logging.info(f"kv inference result:\n{tokenizer.decode(token_list)}")
377+
atten_mask, pos, k_caches, v_caches = kv_updater(
378+
ar_len,
379+
1,
380+
atten_mask,
381+
pos,
382+
k_caches,
383+
v_caches,
384+
new_k_caches,
385+
new_v_caches,
386+
)
387+
total_token_list.append(
388+
torch.argmax(logits[:, num_tokens_in_chunk - 1], dim=-1).item()
389+
)
390+
num_tokens = len(total_token_list)
391+
logging.info(f"kv inference result:\n{tokenizer.decode(total_token_list)}")
336392
if collect_logits:
337393
result_logits = torch.cat(result_logits, dim=1)
338394
return result_logits

0 commit comments

Comments
 (0)