diff --git a/examples/qualcomm/oss_scripts/llama/decoder_utils.py b/examples/qualcomm/oss_scripts/llama/decoder_utils.py index 87a1e313dd7..8bfc0d135c0 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_utils.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_utils.py @@ -219,37 +219,42 @@ def post_process(): def smart_mask_updater( - ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches + _, n_updates, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches ): - # Update the KV cache input for the next inference when the position exceeds the autoregressive length. - if pos >= ar_len: + # ar_len is unused in smart mask + max_cache_len = k_caches[0].size(-1) + if pos + n_updates <= max_cache_len: for i, k_cache in enumerate(k_caches): - k_cache[:, :, pos - ar_len] = new_k_caches[i][:, :, 0] + k_cache[:, :, pos : pos + n_updates] = new_k_caches[i][:, :, :n_updates] for i, v_cache in enumerate(v_caches): - v_cache[:, pos - ar_len, :] = new_v_caches[i][:, 0, :] - atten_mask[:, :, pos - ar_len] = 0 + v_cache[:, pos : pos + n_updates, :] = new_v_caches[i][:, :n_updates, :] + atten_mask[:, :, pos : pos + n_updates] = 0 + pos += n_updates - pos += 1 return (atten_mask, pos, k_caches, v_caches) def shift_pointer_updater( - ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches + ar_len, n_updates, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches ): - # Update the KV cache input for the next inference when the position exceeds the autoregressive length. - if pos >= ar_len: + max_cache_len = k_caches[0].size(-1) + if pos + n_updates <= max_cache_len: k_caches = [ - torch.cat([k_cache[:, :, 1:], new_k_caches[i][:, :, :1]], dim=-1) + torch.cat( + [k_cache[:, :, n_updates:], new_k_caches[i][:, :, :n_updates]], dim=-1 + ) for i, k_cache in enumerate(k_caches) ] v_caches = [ - torch.cat([v_cache[:, 1:, :], new_v_caches[i][:, :1, :]], dim=1) + torch.cat( + [v_cache[:, n_updates:, :], new_v_caches[i][:, :n_updates, :]], dim=1 + ) for i, v_cache in enumerate(v_caches) ] - atten_mask[:, :, -pos - 1] = 0 + atten_mask[:, :, -pos - n_updates - ar_len : -pos - ar_len] = 0 + pos += n_updates - pos += 1 return (atten_mask, pos, k_caches, v_caches) @@ -269,70 +274,121 @@ def kv_inference( # TODO: change criteria & support batch inputs if necessary all_pos = torch.arange(0, max_seq_len, 1, dtype=torch.int32).unsqueeze(0) - token_list, result_logits = [], [] + prompt_token_list, total_token_list, result_logits = [], [], [] if isinstance(prompt, str): # Llama2 tokenizer has no special tokens if isinstance(tokenizer, (SentencePieceTokenizer, HuggingFaceTokenizer)): - token_list = tokenizer.encode(prompt, bos=True, eos=False) + prompt_token_list = tokenizer.encode(prompt, bos=True, eos=False) elif isinstance(tokenizer, TiktokenTokenizer): - token_list = tokenizer.encode( + prompt_token_list = tokenizer.encode( prompt, bos=True, eos=False, allowed_special="all" ) else: raise RuntimeError("Unknown tokenizer") else: # pyre-ignore - token_list = prompt.flatten().tolist() - pos = len(token_list) if len(token_list) < ar_len else ar_len + prompt_token_list = prompt.flatten().tolist() + total_token_list = prompt_token_list dtype = torch.int64 if use_i64_token else torch.int32 with torch.no_grad(): - while token_list[-1] != tokenizer.eos_id and pos < max_seq_len: - tmp_token_list = torch.tensor( - token_list[pos - ar_len : pos], dtype=dtype - ).reshape(1, -1) - tmp_pos = all_pos[:, pos - ar_len : pos] - tmp_atten_mask = atten_mask - if pos < ar_len: - tmp_token_list = torch.cat( - [ - torch.zeros((1, ar_len - pos), dtype=dtype), - torch.tensor(token_list, dtype=dtype).reshape(1, -1), - ], - dim=1, - ) - tmp_pos = torch.cat( - [ - torch.zeros((1, ar_len - pos), dtype=torch.int32), - all_pos[:, :pos], - ], - dim=1, - ) - tmp_atten_mask = torch.cat( - [ - torch.ones(1, ar_len, max_seq_len - pos) * -255.0, - atten_mask[:, :, -pos:], - ], - dim=-1, - ) + # Phase 1: Prefill the prompt in ar_len chunks. + num_prompt_tokens = len(prompt_token_list) + pos = 0 # Tracks how many prompt tokens have been processed. + while pos < num_prompt_tokens: + chunk_start_idx = pos + # Take a chunk of prompt tokens, up to ar_len length. + chunk_end_idx = min(num_prompt_tokens, pos + ar_len) + actual_chunk_tokens = prompt_token_list[chunk_start_idx:chunk_end_idx] + num_tokens_in_chunk = len(actual_chunk_tokens) + + # Prepare tmp_token_list (padded with zeros). + tmp_token_list = torch.zeros((1, ar_len), dtype=dtype) + tmp_token_list[0, :num_tokens_in_chunk] = torch.tensor( + actual_chunk_tokens, dtype=dtype + ) + # Prepare tmp_pos (padded with zeros). + tmp_pos = torch.zeros((1, ar_len), dtype=torch.int32) + tmp_pos[0, :num_tokens_in_chunk] = all_pos[ + 0, + pos : pos + num_tokens_in_chunk, + ] + + # Run inference. logits, new_k_caches, new_v_caches = module( tmp_token_list, - tmp_atten_mask, + atten_mask, tmp_pos, *k_caches, *v_caches, ) if collect_logits: - result_logits.append(logits) + result_logits.append(logits[:, :num_tokens_in_chunk]) + + # Update the pos, KV cache and attention mask. atten_mask, pos, k_caches, v_caches = kv_updater( - ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches + ar_len, + num_tokens_in_chunk, + atten_mask, + pos, + k_caches, + v_caches, + new_k_caches, + new_v_caches, + ) + # Append the last run logits to the total_token_list. + total_token_list.append( + torch.argmax(logits[:, num_tokens_in_chunk - 1], dim=-1).item() + ) + + # Phase 2: Generate tokens until the EOS token is generated or max_seq_len is reached. + # When run on wikitext for ppl evaluation, this while-loop is not expected to run. + max_cache_len = max_seq_len - ar_len + num_tokens = len(total_token_list) + while total_token_list[-1] != tokenizer.eos_id and num_tokens < max_seq_len: + chunk_start_idx = min(pos, max_cache_len) + # Take a chunk of generated tokens, up to ar_len length. + chunk_end_idx = num_tokens + actual_chunk_tokens = total_token_list[chunk_start_idx:chunk_end_idx] + num_tokens_in_chunk = len(actual_chunk_tokens) + + # Prepare tmp_token_list (padded with zeros). + tmp_token_list = torch.zeros((1, ar_len), dtype=dtype) + tmp_token_list[0, :num_tokens_in_chunk] = torch.tensor( + actual_chunk_tokens, dtype=dtype + ) + + # Prepare tmp_pos (padded with zeros). + tmp_pos = torch.zeros((1, ar_len), dtype=torch.int32) + tmp_pos[0, :num_tokens_in_chunk] = all_pos[0, chunk_start_idx:chunk_end_idx] + + logits, new_k_caches, new_v_caches = module( + tmp_token_list, + atten_mask, + tmp_pos, + *k_caches, + *v_caches, ) - if pos > len(token_list): - token_list.append(torch.argmax(logits[:, -1], dim=-1).item()) + if collect_logits: + result_logits.append(logits[:, :num_tokens_in_chunk]) - logging.info(f"kv inference result:\n{tokenizer.decode(token_list)}") + atten_mask, pos, k_caches, v_caches = kv_updater( + ar_len, + 1, + atten_mask, + pos, + k_caches, + v_caches, + new_k_caches, + new_v_caches, + ) + total_token_list.append( + torch.argmax(logits[:, num_tokens_in_chunk - 1], dim=-1).item() + ) + num_tokens = len(total_token_list) + logging.info(f"kv inference result:\n{tokenizer.decode(total_token_list)}") if collect_logits: result_logits = torch.cat(result_logits, dim=1) return result_logits