Skip to content

Commit 78861a3

Browse files
author
ochafik
committed
Wire LLM_KV_TOKENIZER_CHAT_TEMPLATE_N in llama_model_chat_template
1 parent cb72cf1 commit 78861a3

File tree

7 files changed

+17
-23
lines changed

7 files changed

+17
-23
lines changed

common/common.cpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1822,17 +1822,6 @@ std::string common_chat_format_example(const struct llama_model * model,
18221822
return common_chat_apply_template(model, tmpl, msgs, true);
18231823
}
18241824

1825-
static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) {
1826-
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
1827-
if (tlen > 0) {
1828-
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
1829-
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
1830-
return std::string(curr_tmpl_buf.data(), tlen);
1831-
}
1832-
}
1833-
return "";
1834-
}
1835-
18361825
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
18371826
{
18381827
auto vocab = llama_model_get_vocab(model);
@@ -1841,9 +1830,8 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model *
18411830
std::string default_template_src = chat_template_override;
18421831
std::string tool_use_template_src = chat_template_override;
18431832
if (chat_template_override.empty()) {
1844-
// TODO:
1845-
default_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template");
1846-
tool_use_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use");
1833+
default_template_src = llama_model_chat_template(model, /* name */ nullptr);
1834+
tool_use_template_src = llama_model_chat_template(model, /* name */ "tool_use");
18471835
}
18481836
if (default_template_src.empty() || default_template_src == "chatml") {
18491837
if (!tool_use_template_src.empty()) {

examples/run/run.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -713,11 +713,11 @@ static void add_message(const char * role, const std::string & text, LlamaData &
713713
// Function to apply the chat template and resize `formatted` if needed
714714
static int apply_chat_template(LlamaData & llama_data, const bool append) {
715715
int result = llama_chat_apply_template(
716-
llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
716+
llama_model_chat_template(llama_data.model.get(), /* name */ nullptr), llama_data.messages.data(), llama_data.messages.size(), append,
717717
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
718718
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
719719
llama_data.fmtted.resize(result);
720-
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
720+
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get(), /* name */ nullptr), llama_data.messages.data(),
721721
llama_data.messages.size(), append, llama_data.fmtted.data(),
722722
llama_data.fmtted.size());
723723
}

examples/simple-chat/simple-chat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
161161
break;
162162
}
163163

164-
const char * tmpl = llama_model_chat_template(model);
164+
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
165165

166166
// add the user input to the message list and format it
167167
messages.push_back({"user", strdup(user.c_str())});

include/llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ extern "C" {
503503
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
504504

505505
// Get the default chat template. Returns nullptr if not available
506-
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model);
506+
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name);
507507

508508
// Returns the total number of parameters in the model
509509
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);

src/llama-arch.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
179179
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
180180
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
181181
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
182+
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
182183
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
183184
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
184185
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
@@ -1443,10 +1444,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
14431444
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
14441445
};
14451446

1446-
LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {}
1447+
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
14471448

14481449
std::string LLM_KV::operator()(llm_kv kv) const {
1449-
return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
1450+
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
1451+
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
14501452
}
14511453

14521454
std::string LLM_TN_IMPL::str() const {

src/llama-arch.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ enum llm_kv {
177177
LLM_KV_TOKENIZER_HF_JSON,
178178
LLM_KV_TOKENIZER_RWKV,
179179
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
180+
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
180181
LLM_KV_TOKENIZER_FIM_PRE_ID,
181182
LLM_KV_TOKENIZER_FIM_SUF_ID,
182183
LLM_KV_TOKENIZER_FIM_MID_ID,
@@ -335,9 +336,10 @@ enum llm_tensor_layer {
335336
};
336337

337338
struct LLM_KV {
338-
LLM_KV(llm_arch arch);
339+
LLM_KV(llm_arch arch, const char * suffix = nullptr);
339340

340341
llm_arch arch;
342+
const char * suffix;
341343

342344
std::string operator()(llm_kv kv) const;
343345
};

src/llama-model.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3912,8 +3912,10 @@ uint64_t llama_model_size(const struct llama_model * model) {
39123912
return model->size();
39133913
}
39143914

3915-
const char * llama_model_chat_template(const struct llama_model * model) {
3916-
const auto & it = model->gguf_kv.find(LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE));
3915+
const char * llama_model_chat_template(const struct llama_model * model, const char * name) {
3916+
const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
3917+
: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
3918+
const auto & it = model->gguf_kv.find(key);
39173919
if (it == model->gguf_kv.end()) {
39183920
return nullptr;
39193921
}

0 commit comments

Comments
 (0)