Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions examples/run/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -729,11 +729,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {

// Function to tokenize the prompt
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
std::vector<llama_token> & prompt_tokens) {
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data,
const bool is_first) {
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
prompt_tokens.resize(n_prompt_tokens);
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
true) < 0) {
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(),
llama_get_kv_cache_used_cells(llama_data.context.get()) == 0, true) < 0) {
Comment on lines +736 to +737
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(),
llama_get_kv_cache_used_cells(llama_data.context.get()) == 0, true) < 0) {
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(),
is_first, true) < 0) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the simple-chat example, the first call to llama_tokenize uses this parameter:

is_first

and the second uses:

llama_get_kv_cache_used_cells(llama_data.context.get()) == 0

Copy link
Member

@ggerganov ggerganov Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I didn't even notice. Maybe this way it is better than introducing the is_first parameter.

Copy link
Collaborator Author

@ericcurtin ericcurtin Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not pretending I know the details and just copying simple-chat :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed this variant to both simple-chat and run - seems much simpler than the is_first solution: #11311

printe("failed to tokenize the prompt\n");
return -1;
}
Expand Down Expand Up @@ -774,11 +775,11 @@ static void print_word_and_concatenate_to_response(const std::string & piece, st
}

// helper function to evaluate a prompt and generate a response
static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) {
static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response, const bool is_first) {
const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());

std::vector<llama_token> tokens;
if (tokenize_prompt(vocab, prompt, tokens) < 0) {
if (tokenize_prompt(vocab, prompt, tokens, llama_data, is_first) < 0) {
return 1;
}

Expand Down Expand Up @@ -852,13 +853,13 @@ static int read_user_input(std::string & user_input) {

// Function to generate a response based on the prompt
static int generate_response(LlamaData & llama_data, const std::string & prompt, std::string & response,
const bool stdout_a_terminal) {
const bool stdout_a_terminal, const int prev_len) {
// Set response color
if (stdout_a_terminal) {
printf("\033[33m");
}

if (generate(llama_data, prompt, response)) {
if (generate(llama_data, prompt, response, prev_len == 0)) {
printe("failed to generate response\n");
return 1;
}
Expand Down Expand Up @@ -948,7 +949,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {

std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
std::string response;
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
if (generate_response(llama_data, prompt, response, stdout_a_terminal, prev_len)) {
return 1;
}

Expand Down
Loading