Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/qualcomm/oss_scripts/llama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ list(
${CMAKE_CURRENT_LIST_DIR}/runner/token_generator.h
${CMAKE_CURRENT_LIST_DIR}/runner/imem_alloc.h
${CMAKE_CURRENT_LIST_DIR}/runner/client_mem.h
${CMAKE_CURRENT_LIST_DIR}/runner/lhd_token_generator.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/lhd_token_generator.h
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.h
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.cpp
Expand Down
75 changes: 69 additions & 6 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import getpass
import json
import logging
import math
import os
import subprocess
import sys
Expand Down Expand Up @@ -90,6 +91,12 @@
logging.getLogger().setLevel(logging.INFO)


def next_power_of_two(n):
if n == 0:
return 1
return 2 ** math.ceil(math.log2(n))


def smart_mask_updater(
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
):
Expand Down Expand Up @@ -531,6 +538,28 @@ def compile(args, pte_filename, tokenizer):
use_i64_token=use_i64_token,
)
)
elif args.model_mode == "lookahead":
llama_instance_list.append(
LlamaModel(
kv_config,
# To get better performance, we round up to the nearest power of 2.
ar_len=next_power_of_two(
(args.window + args.gcap) * (args.ngram - 1)
),
output_new_cache_only=True,
output_cache=True,
use_i64_token=use_i64_token,
)
)
llama_instance_list.append(
LlamaModel(
prefill_config,
ar_len=args.prefill_ar_len,
output_new_cache_only=True,
output_cache=True,
use_i64_token=use_i64_token,
)
)
else:
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")

Expand Down Expand Up @@ -630,8 +659,8 @@ def permute(w, heads):
tokenizer=tokenizer,
custom_annotations=custom_annotations,
)
# If hybrid mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
if i == 0 and args.model_mode == "hybrid":
# If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
if i == 0 and args.model_mode in ["hybrid", "lookahead"]:
output_indices = 0
for node in llama_instance.llama_graph_module.graph.nodes:
if node.op == "output":
Expand Down Expand Up @@ -673,7 +702,7 @@ def permute(w, heads):
shared_buffer=args.shared_buffer,
)
quant_attrs = llama_instance_list[0].get_quant_attrs()
elif args.model_mode == "hybrid":
elif args.model_mode in ["hybrid", "lookahead"]:
sample_inputs_list = [
llama_instace.inputs for llama_instace in llama_instance_list
]
Expand Down Expand Up @@ -759,6 +788,8 @@ def inference(args, pte_filename, runtime_tokenizer_path, pre_gen_pte=""):
eval_mode = 0
elif args.model_mode == "hybrid":
eval_mode = 1
elif args.model_mode == "lookahead":
eval_mode = 2
else:
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")

Expand Down Expand Up @@ -832,6 +863,9 @@ def post_process():
"--output_path outputs/outputs.txt",
f"--performance_output_path {performance_output_path}",
f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}",
f"--window {args.window}",
f"--gcap {args.gcap}",
f"--ngram {args.ngram}",
runner_args,
]
)
Expand Down Expand Up @@ -971,9 +1005,9 @@ def _build_parser():

parser.add_argument(
"--model_mode",
help="Export and inference kv mode or hybrid mode",
help="Export and inference kv mode, hybrid mode, or lookahead decoding mode",
default="kv",
choices=["kv", "hybrid"],
choices=["kv", "hybrid", "lookahead"],
type=str,
)

Expand All @@ -986,7 +1020,7 @@ def _build_parser():

parser.add_argument(
"--prefill_ar_len",
help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid mode.",
help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid and lookahead mode.",
default=32,
type=int,
)
Expand All @@ -1007,6 +1041,27 @@ def _build_parser():
help="Fallback to cpu embedding operator and type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '4,32'.",
)

parser.add_argument(
"--ngram",
help="Represents the size of the n-grams used in the lookahead process.",
default=5,
type=int,
)

parser.add_argument(
"--window",
help="Determines how many future tokens the algorithm attempts to predict in each step.",
default=8,
type=int,
)

parser.add_argument(
"--gcap",
help="Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities.",
default=8,
type=int,
)

parser.add_argument("-v", "--verbose", action="store_true")

return parser
Expand All @@ -1023,6 +1078,14 @@ def export_llama(args) -> None:
args.max_seq_len >= args.prefill_ar_len
), "Please ensure max_seq_len is >= prefill_ar_len"
pte_filename = "hybrid_llama_qnn"
elif args.model_mode == "lookahead":
assert (
args.max_seq_len >= args.prefill_ar_len
), "Please ensure max_seq_len is >= prefill_ar_len"
assert args.max_seq_len > next_power_of_two(
(args.window + args.gcap) * (args.ngram - 1)
), "Please ensure max_seq_len is > next_power_of_two((args.window + args.gcap) * (args.ngram - 1))"
pte_filename = "lookahead_llama_qnn"
else:
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")

Expand Down
23 changes: 19 additions & 4 deletions examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,24 @@ DEFINE_int32(
DEFINE_int32(
eval_mode,
0,
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv)");
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv) / 2: Lookahead Decoding");
DEFINE_string(
kv_updater,
"How to update kv cache. Choose between SmartMask and ShiftPointer",
"SmartMask");
"SmartMask",
"How to update kv cache. Choose between SmartMask and ShiftPointer");
DEFINE_int32(num_iters, 1, "total num of iterations to run.");
DEFINE_int32(
ngram,
0,
"[Lookahead Decoding] Represents the size of the n-grams used in the lookahead process.");
DEFINE_int32(
window,
0,
"[Lookahead Decoding] Determines how many future tokens the algorithm attempts to predict in each step.");
DEFINE_int32(
gcap,
0,
"[Lookahead Decoding] Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities.");

std::vector<std::string> CollectPrompts(int argc, char** argv) {
// Collect all prompts from command line, example usage:
Expand Down Expand Up @@ -111,7 +123,10 @@ int main(int argc, char** argv) {
FLAGS_performance_output_path.c_str(),
FLAGS_temperature,
FLAGS_eval_mode,
FLAGS_kv_updater);
FLAGS_kv_updater,
FLAGS_ngram,
FLAGS_window,
FLAGS_gcap);
auto llama_version = runner.get_llama_version();
std::vector<char> buf;
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
Expand Down
85 changes: 66 additions & 19 deletions examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void KVManager::init_attention_mask(
int32_t ar_len,
int32_t n_past) {
ET_CHECK_MSG(
attention_map.size() == ar_len,
attention_map.size() <= ar_len,
"The size of attention_map (%zu) doesn't match with ar_len (%d)",
attention_map.size(),
ar_len);
Expand Down Expand Up @@ -197,9 +197,11 @@ void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) {
? 0
: metadata_.max_cache_len - (metadata_.context_len - cur_ar_len_);
v_cache_[layer][head].buffer = single_layer_v_cache +
head * single_head_size_in + cache_gap * metadata_.head_dim;
v_cache_[layer][head].output_buffer =
single_layer_v_cache + (head + 1) * single_head_size_in;
head * metadata_.head_dim * metadata_.context_len +
cache_gap * metadata_.head_dim;
v_cache_[layer][head].output_buffer = single_layer_v_cache +
head * metadata_.head_dim * metadata_.context_len +
single_head_size_in;
}
}
break;
Expand Down Expand Up @@ -311,21 +313,29 @@ bool KVManager::update_cache_tensor(
return updated;
}

void KVManager::update_cache(int32_t ar_len, int32_t n_past, int32_t n_update) {
void KVManager::update_cache(
int32_t ar_len,
int32_t n_past,
int32_t n_update,
const std::vector<bool>& selected) {
ET_CHECK_MSG(
cur_ar_len_ == ar_len,
"Current AR length (%d) is not matched with target AR length (%d). Please rearrange cache first.",
cur_ar_len_,
ar_len);
for (int layer = 0; layer < metadata_.num_layers; ++layer) {
for (int head = 0; head < metadata_.num_heads; ++head) {
update_key(k_cache_[layer][head], n_past, n_update);
update_value(v_cache_[layer][head], n_past, n_update);
update_key(k_cache_[layer][head], n_past, n_update, selected);
update_value(v_cache_[layer][head], n_past, n_update, selected);
}
}
}

void KVManager::update_key(KVCache& k_cache, int32_t n_past, int32_t n_update) {
void KVManager::update_key(
KVCache& k_cache,
int32_t n_past,
int32_t n_update,
const std::vector<bool>& selected) {
uint8_t* write_ptr = k_cache.buffer;
uint8_t* read_ptr = k_cache.output_buffer;
const int32_t copy_size = n_update * sizeof(uint8_t);
Expand All @@ -340,22 +350,35 @@ void KVManager::update_key(KVCache& k_cache, int32_t n_past, int32_t n_update) {
write_ptr += iter_size + past_size;
if (kv_updater_ == KVManagerMode::SMART_MASK)
write_ptr += past_size;

for (int i = 0; i < n_iter; ++i) {
std::memcpy(write_ptr, read_ptr, copy_size);
write_ptr += iter_size;
read_ptr += out_size;
if (selected.empty()) {
for (int i = 0; i < n_iter; ++i) {
std::memcpy(write_ptr, read_ptr, copy_size);
write_ptr += iter_size;
read_ptr += out_size;
}
} else {
std::vector<int32_t> true_indices(n_update);
for (int i = 0, j = 0; i < selected.size() && j < n_update; ++i) {
if (selected[i]) {
true_indices[j++] = i;
}
}
for (int i = 0; i < n_iter; ++i) {
auto wp = write_ptr, rp = read_ptr;
for (auto ind : true_indices) {
*wp++ = rp[ind];
}
write_ptr += iter_size;
read_ptr += out_size;
}
}
}

void KVManager::update_value(
KVCache& v_cache,
int32_t n_past,
int32_t n_update) {
// Value cache doesn't need to copy for SHIFT_POINTER mode
if (kv_updater_ == KVManagerMode::SHIFT_POINTER)
return;

int32_t n_update,
const std::vector<bool>& selected) {
uint8_t* write_ptr = v_cache.buffer;
uint8_t* read_ptr = v_cache.output_buffer;
const int32_t copy_size = n_update * metadata_.head_dim * sizeof(uint8_t);
Expand All @@ -364,7 +387,31 @@ void KVManager::update_value(
if (kv_updater_ == KVManagerMode::SMART_MASK)
write_ptr += past_size;

std::memcpy(write_ptr, read_ptr, copy_size);
// Update the value cache for lookahead decoding in SHIFT_POINTER mode
if (kv_updater_ == KVManagerMode::SHIFT_POINTER) {
read_ptr += past_size;
write_ptr = read_ptr;
}

if (selected.empty()) {
// In general, value cache doesn't need to copy for SHIFT_POINTER mode
if (kv_updater_ == KVManagerMode::SHIFT_POINTER)
return;
std::memcpy(write_ptr, read_ptr, copy_size);
} else {
int32_t update_times = n_update;
auto wp = write_ptr, rp = read_ptr;
for (auto sel : selected) {
if (sel) {
std::memcpy(wp, rp, metadata_.head_dim * sizeof(uint8_t));
wp += metadata_.head_dim;
update_times--;
if (update_times == 0)
break;
}
rp += metadata_.head_dim;
}
}
}

} // namespace example
19 changes: 16 additions & 3 deletions examples/qualcomm/oss_scripts/llama/runner/kv_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,13 @@ class KVManager {
* @param ar_len Length of input tokens.
* @param n_past Number of past elements in the cache.
* @param n_update Number of elements to be updated.
* @param selected Indicate which position to be updated
*/
void update_cache(int32_t ar_len, int32_t n_past, int32_t n_update);
void update_cache(
int32_t ar_len,
int32_t n_past,
int32_t n_update,
const std::vector<bool>& selected);

const std::vector<std::vector<KVCache>>& get_k_cache_() const {
return k_cache_;
Expand All @@ -138,8 +143,16 @@ class KVManager {
// Helper functions to rearrange and update key and value caches
void rearrange_key(KVCache& k_cache, int32_t ar_len_dst);
void rearrange_value(KVCache& v_cache, int32_t ar_len_dst);
void update_key(KVCache& k_cache, int32_t n_past, int32_t n_update);
void update_value(KVCache& v_cache, int32_t n_past, int32_t n_update);
void update_key(
KVCache& k_cache,
int32_t n_past,
int32_t n_update,
const std::vector<bool>& selected);
void update_value(
KVCache& v_cache,
int32_t n_past,
int32_t n_update,
const std::vector<bool>& selected);
KVManagerMode kv_updater_;

// metadata
Expand Down
Loading
Loading