Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 36 additions & 45 deletions examples/run/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ class ArgumentParser {

class LlamaData {
public:
llama_model_ptr model;
llama_sampler_ptr sampler;
llama_context_ptr context;
llama_cpp::model model;
llama_cpp::sampler sampler;
llama_cpp::context context;
std::vector<llama_chat_message> messages;

int init(const Options & opt) {
Expand All @@ -133,11 +133,11 @@ class LlamaData {

private:
// Initializes the model and returns a unique pointer to it
llama_model_ptr initialize_model(const std::string & model_path, const int ngl) {
llama_cpp::model initialize_model(const std::string & model_path, const int ngl) {
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = ngl;

llama_model_ptr model(llama_load_model_from_file(model_path.c_str(), model_params));
llama_cpp::model model(llama_cpp::load_model_from_file(model_path, model_params));
if (!model) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
}
Expand All @@ -146,12 +146,12 @@ class LlamaData {
}

// Initializes the context with the specified parameters
llama_context_ptr initialize_context(const llama_model_ptr & model, const int n_ctx) {
llama_cpp::context initialize_context(const llama_cpp::model & model, const int n_ctx) {
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = n_ctx;
ctx_params.n_batch = n_ctx;

llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params));
llama_cpp::context context(llama_cpp::new_context_with_model(model, ctx_params));
if (!context) {
fprintf(stderr, "%s: error: failed to create the llama_context\n", __func__);
}
Expand All @@ -160,8 +160,8 @@ class LlamaData {
}

// Initializes and configures the sampler
llama_sampler_ptr initialize_sampler() {
llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params()));
llama_cpp::sampler initialize_sampler() {
llama_cpp::sampler sampler(llama_cpp::sampler_chain_init(llama_sampler_chain_default_params()));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(0.8f));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
Expand All @@ -179,34 +179,20 @@ static void add_message(const char * role, const std::string & text, LlamaData &
owned_content.push_back(std::move(content));
}

// Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(const LlamaData & llama_data, std::vector<char> & formatted, const bool append) {
int result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(),
llama_data.messages.size(), append, formatted.data(), formatted.size());
if (result > static_cast<int>(formatted.size())) {
formatted.resize(result);
result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(),
llama_data.messages.size(), append, formatted.data(), formatted.size());
}

return result;
}

// Function to tokenize the prompt
static int tokenize_prompt(const llama_model_ptr & model, const std::string & prompt,
static int tokenize_prompt(const llama_cpp::model & model, const std::string & prompt,
std::vector<llama_token> & prompt_tokens) {
const int n_prompt_tokens = -llama_tokenize(model.get(), prompt.c_str(), prompt.size(), NULL, 0, true, true);
prompt_tokens.resize(n_prompt_tokens);
if (llama_tokenize(model.get(), prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
true) < 0) {
GGML_ABORT("failed to tokenize the prompt\n");
try {
prompt_tokens = llama_cpp::tokenize(model, prompt, false, true);
return prompt_tokens.size();
} catch (const std::exception & e) {
fprintf(stderr, "failed to tokenize the prompt: %s\n", e.what());
return -1;
}

return n_prompt_tokens;
}

// Check if we have enough space in the context to evaluate this batch
static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) {
static int check_context_size(const llama_cpp::context & ctx, const llama_batch & batch) {
const int n_ctx = llama_n_ctx(ctx.get());
const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get());
if (n_ctx_used + batch.n_tokens > n_ctx) {
Expand All @@ -219,15 +205,14 @@ static int check_context_size(const llama_context_ptr & ctx, const llama_batch &
}

// convert the token to a string
static int convert_token_to_string(const llama_model_ptr & model, const llama_token token_id, std::string & piece) {
char buf[256];
int n = llama_token_to_piece(model.get(), token_id, buf, sizeof(buf), 0, true);
if (n < 0) {
GGML_ABORT("failed to convert token to piece\n");
static int convert_token_to_string(const llama_cpp::model & model, const llama_token token_id, std::string & piece) {
try {
piece = llama_cpp::token_to_piece(model, token_id, 0, true);
return 0;
} catch (const std::exception & e) {
fprintf(stderr, "failed to convert token to piece: %s\n", e.what());
return -1;
}

piece = std::string(buf, n);
return 0;
}

static void print_word_and_concatenate_to_response(const std::string & piece, std::string & response) {
Expand Down Expand Up @@ -308,14 +293,20 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
// Helper function to apply the chat template and handle errors
static int apply_chat_template_with_error_handling(const LlamaData & llama_data, std::vector<char> & formatted,
const bool is_user_input, int & output_length) {
const int new_len = apply_chat_template(llama_data, formatted, is_user_input);
if (new_len < 0) {
fprintf(stderr, "failed to apply the chat template\n");
try {
std::string res = llama_cpp::chat_apply_template(
llama_data.model,
"",
llama_data.messages,
is_user_input);
output_length = res.size();
formatted.resize(output_length);
std::memcpy(formatted.data(), res.c_str(), output_length);
return output_length;
} catch (const std::exception & e) {
fprintf(stderr, "failed to apply chat template: %s\n", e.what());
return -1;
}

output_length = new_len;
return 0;
}

// Helper function to handle user input
Expand Down
47 changes: 44 additions & 3 deletions include/llama-cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
#endif

#include <memory>
#include <string>

#include "llama.h"

namespace llama_cpp {

struct llama_model_deleter {
void operator()(llama_model * model) { llama_free_model(model); }
};
Expand All @@ -20,6 +23,44 @@ struct llama_sampler_deleter {
void operator()(llama_sampler * sampler) { llama_sampler_free(sampler); }
};

typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;
typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr;
typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr;
typedef std::unique_ptr<llama_model, llama_model_deleter> model;
typedef std::unique_ptr<llama_context, llama_context_deleter> context;
typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> sampler;

inline model load_model_from_file(const std::string & path_model, llama_model_params params) {
return model(llama_load_model_from_file(path_model.c_str(), params));
}

inline context new_context_with_model(const model & model, llama_context_params params) {
return context(llama_new_context_with_model(model.get(), params));
}

inline sampler sampler_chain_init(llama_sampler_chain_params params) {
return sampler(llama_sampler_chain_init(params));
}

std::vector<llama_token> tokenize(
const llama_cpp::model & model,
const std::string & raw_text,
bool add_special,
bool parse_special = false);

std::string token_to_piece(
const llama_cpp::model & model,
llama_token token,
int32_t lstrip,
bool special);

std::string detokenize(
const llama_cpp::model & model,
const std::vector<llama_token> & tokens,
bool remove_special,
bool unparse_special);

std::string chat_apply_template(
const llama_cpp::model & model,
const std::string & tmpl,
const std::vector<llama_chat_message> & chat,
bool add_ass);

} // namespace llama_cpp
61 changes: 61 additions & 0 deletions src/llama.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "llama-cpp.h"
#include "llama-impl.h"
#include "llama-vocab.h"
#include "llama-sampling.h"
Expand Down Expand Up @@ -21818,6 +21819,14 @@ int32_t llama_tokenize(
return llama_tokenize_impl(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special);
}

std::vector<llama_token> llama_cpp::tokenize(
const llama_cpp::model & model,
const std::string & raw_text,
bool add_special,
bool parse_special) {
return llama_tokenize_internal(model->vocab, raw_text, add_special, parse_special);
}

int32_t llama_token_to_piece(
const struct llama_model * model,
llama_token token,
Expand All @@ -21828,6 +21837,23 @@ int32_t llama_token_to_piece(
return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
}

std::string llama_cpp::token_to_piece(
const llama_cpp::model & model,
llama_token token,
int32_t lstrip,
bool special) {
std::vector<char> buf(64);
int32_t n = llama_token_to_piece_impl(model->vocab, token, buf.data(), buf.size(), lstrip, special);
if (n > (int32_t) buf.size()) {
buf.resize(n);
llama_token_to_piece_impl(model->vocab, token, buf.data(), buf.size(), lstrip, special);
} else if (n < 0) {
// TODO: make special type of expection here
throw std::runtime_error("failed to convert token to piece");
}
return std::string(buf.data(), n);
}

int32_t llama_detokenize(
const struct llama_model * model,
const llama_token * tokens,
Expand All @@ -21839,6 +21865,23 @@ int32_t llama_detokenize(
return llama_detokenize_impl(model->vocab, tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
}

std::string llama_cpp::detokenize(
const llama_cpp::model & model,
const std::vector<llama_token> & tokens,
bool remove_special,
bool unparse_special) {
std::vector<char> buf(1024);
int32_t n = llama_detokenize_impl(model->vocab, tokens.data(), tokens.size(), buf.data(), buf.size(), remove_special, unparse_special);
if (n > (int32_t) buf.size()) {
buf.resize(n);
llama_detokenize_impl(model->vocab, tokens.data(), tokens.size(), buf.data(), buf.size(), remove_special, unparse_special);
} else if (n < 0) {
// TODO: make special type of expection here
throw std::runtime_error("failed to detokenize");
}
return std::string(buf.data(), n);
}

//
// chat templates
//
Expand Down Expand Up @@ -22172,6 +22215,24 @@ int32_t llama_chat_apply_template(
return res;
}

std::string llama_cpp::chat_apply_template(
const llama_cpp::model & model,
const std::string & tmpl,
const std::vector<llama_chat_message> & chat,
bool add_ass) {
std::vector<char> buf;
const char * tmpl_c = tmpl.empty() ? nullptr : tmpl.c_str();
int32_t n = llama_chat_apply_template(model.get(), tmpl_c, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
if (n > (int32_t) buf.size()) {
buf.resize(n);
llama_chat_apply_template(model.get(), tmpl_c, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
} else if (n < 0) {
// TODO: make special type of expection here
throw std::runtime_error("failed to format chat template");
}
return std::string(buf.data(), n);
}

//
// sampling
//
Expand Down
Loading