Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 109 additions & 53 deletions examples/qualcomm/oss_scripts/llama/decoder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,37 +219,42 @@ def post_process():


def smart_mask_updater(
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
_, n_updates, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
):
# Update the KV cache input for the next inference when the position exceeds the autoregressive length.
if pos >= ar_len:
# ar_len is unused in smart mask
max_cache_len = k_caches[0].size(-1)
if pos + n_updates <= max_cache_len:
for i, k_cache in enumerate(k_caches):
k_cache[:, :, pos - ar_len] = new_k_caches[i][:, :, 0]
k_cache[:, :, pos : pos + n_updates] = new_k_caches[i][:, :, :n_updates]

for i, v_cache in enumerate(v_caches):
v_cache[:, pos - ar_len, :] = new_v_caches[i][:, 0, :]
atten_mask[:, :, pos - ar_len] = 0
v_cache[:, pos : pos + n_updates, :] = new_v_caches[i][:, :n_updates, :]
atten_mask[:, :, pos : pos + n_updates] = 0
pos += n_updates

pos += 1
return (atten_mask, pos, k_caches, v_caches)


def shift_pointer_updater(
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
ar_len, n_updates, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
):
# Update the KV cache input for the next inference when the position exceeds the autoregressive length.
if pos >= ar_len:
max_cache_len = k_caches[0].size(-1)
if pos + n_updates <= max_cache_len:
k_caches = [
torch.cat([k_cache[:, :, 1:], new_k_caches[i][:, :, :1]], dim=-1)
torch.cat(
[k_cache[:, :, n_updates:], new_k_caches[i][:, :, :n_updates]], dim=-1
)
for i, k_cache in enumerate(k_caches)
]
v_caches = [
torch.cat([v_cache[:, 1:, :], new_v_caches[i][:, :1, :]], dim=1)
torch.cat(
[v_cache[:, n_updates:, :], new_v_caches[i][:, :n_updates, :]], dim=1
)
for i, v_cache in enumerate(v_caches)
]
atten_mask[:, :, -pos - 1] = 0
atten_mask[:, :, -pos - n_updates - ar_len : -pos - ar_len] = 0
pos += n_updates

pos += 1
return (atten_mask, pos, k_caches, v_caches)


Expand All @@ -269,70 +274,121 @@ def kv_inference(
# TODO: change criteria & support batch inputs if necessary
all_pos = torch.arange(0, max_seq_len, 1, dtype=torch.int32).unsqueeze(0)

token_list, result_logits = [], []
prompt_token_list, total_token_list, result_logits = [], [], []

if isinstance(prompt, str):
# Llama2 tokenizer has no special tokens
if isinstance(tokenizer, (SentencePieceTokenizer, HuggingFaceTokenizer)):
token_list = tokenizer.encode(prompt, bos=True, eos=False)
prompt_token_list = tokenizer.encode(prompt, bos=True, eos=False)
elif isinstance(tokenizer, TiktokenTokenizer):
token_list = tokenizer.encode(
prompt_token_list = tokenizer.encode(
prompt, bos=True, eos=False, allowed_special="all"
)
else:
raise RuntimeError("Unknown tokenizer")
else:
# pyre-ignore
token_list = prompt.flatten().tolist()
pos = len(token_list) if len(token_list) < ar_len else ar_len
prompt_token_list = prompt.flatten().tolist()
total_token_list = prompt_token_list
dtype = torch.int64 if use_i64_token else torch.int32

with torch.no_grad():
while token_list[-1] != tokenizer.eos_id and pos < max_seq_len:
tmp_token_list = torch.tensor(
token_list[pos - ar_len : pos], dtype=dtype
).reshape(1, -1)
tmp_pos = all_pos[:, pos - ar_len : pos]
tmp_atten_mask = atten_mask
if pos < ar_len:
tmp_token_list = torch.cat(
[
torch.zeros((1, ar_len - pos), dtype=dtype),
torch.tensor(token_list, dtype=dtype).reshape(1, -1),
],
dim=1,
)
tmp_pos = torch.cat(
[
torch.zeros((1, ar_len - pos), dtype=torch.int32),
all_pos[:, :pos],
],
dim=1,
)
tmp_atten_mask = torch.cat(
[
torch.ones(1, ar_len, max_seq_len - pos) * -255.0,
atten_mask[:, :, -pos:],
],
dim=-1,
)
# Phase 1: Prefill the prompt in ar_len chunks.
num_prompt_tokens = len(prompt_token_list)
pos = 0 # Tracks how many prompt tokens have been processed.
while pos < num_prompt_tokens:
chunk_start_idx = pos
# Take a chunk of prompt tokens, up to ar_len length.
chunk_end_idx = min(num_prompt_tokens, pos + ar_len)
actual_chunk_tokens = prompt_token_list[chunk_start_idx:chunk_end_idx]
num_tokens_in_chunk = len(actual_chunk_tokens)

# Prepare tmp_token_list (padded with zeros).
tmp_token_list = torch.zeros((1, ar_len), dtype=dtype)
tmp_token_list[0, :num_tokens_in_chunk] = torch.tensor(
actual_chunk_tokens, dtype=dtype
)

# Prepare tmp_pos (padded with zeros).
tmp_pos = torch.zeros((1, ar_len), dtype=torch.int32)
tmp_pos[0, :num_tokens_in_chunk] = all_pos[
0,
pos : pos + num_tokens_in_chunk,
]

# Run inference.
logits, new_k_caches, new_v_caches = module(
tmp_token_list,
tmp_atten_mask,
atten_mask,
tmp_pos,
*k_caches,
*v_caches,
)
if collect_logits:
result_logits.append(logits)
result_logits.append(logits[:, :num_tokens_in_chunk])

# Update the pos, KV cache and attention mask.
atten_mask, pos, k_caches, v_caches = kv_updater(
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
ar_len,
num_tokens_in_chunk,
atten_mask,
pos,
k_caches,
v_caches,
new_k_caches,
new_v_caches,
)
# Append the last run logits to the total_token_list.
total_token_list.append(
torch.argmax(logits[:, num_tokens_in_chunk - 1], dim=-1).item()
)

# Phase 2: Generate tokens until the EOS token is generated or max_seq_len is reached.
# When run on wikitext for ppl evaluation, this while-loop is not expected to run.
max_cache_len = max_seq_len - ar_len
num_tokens = len(total_token_list)
while total_token_list[-1] != tokenizer.eos_id and num_tokens < max_seq_len:
chunk_start_idx = min(pos, max_cache_len)
# Take a chunk of generated tokens, up to ar_len length.
chunk_end_idx = num_tokens
actual_chunk_tokens = total_token_list[chunk_start_idx:chunk_end_idx]
num_tokens_in_chunk = len(actual_chunk_tokens)

# Prepare tmp_token_list (padded with zeros).
tmp_token_list = torch.zeros((1, ar_len), dtype=dtype)
tmp_token_list[0, :num_tokens_in_chunk] = torch.tensor(
actual_chunk_tokens, dtype=dtype
)

# Prepare tmp_pos (padded with zeros).
tmp_pos = torch.zeros((1, ar_len), dtype=torch.int32)
tmp_pos[0, :num_tokens_in_chunk] = all_pos[0, chunk_start_idx:chunk_end_idx]

logits, new_k_caches, new_v_caches = module(
tmp_token_list,
atten_mask,
tmp_pos,
*k_caches,
*v_caches,
)
if pos > len(token_list):
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
if collect_logits:
result_logits.append(logits[:, :num_tokens_in_chunk])

logging.info(f"kv inference result:\n{tokenizer.decode(token_list)}")
atten_mask, pos, k_caches, v_caches = kv_updater(
ar_len,
1,
atten_mask,
pos,
k_caches,
v_caches,
new_k_caches,
new_v_caches,
)
total_token_list.append(
torch.argmax(logits[:, num_tokens_in_chunk - 1], dim=-1).item()
)
num_tokens = len(total_token_list)
logging.info(f"kv inference result:\n{tokenizer.decode(total_token_list)}")
if collect_logits:
result_logits = torch.cat(result_logits, dim=1)
return result_logits
Expand Down
Loading