Skip to content

Commit 3f91832

Browse files
committed
talk-llama : sync llama.cpp
1 parent cff8868 commit 3f91832

17 files changed

+581
-231
lines changed

examples/talk-llama/llama-arch.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
179179
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
180180
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
181181
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
182+
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
182183
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
183184
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
184185
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
@@ -1023,6 +1024,9 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
10231024
{ LLM_TENSOR_OUTPUT, "output" },
10241025
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
10251026
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
1027+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1028+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1029+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
10261030
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
10271031
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
10281032
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
@@ -1443,10 +1447,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
14431447
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
14441448
};
14451449

1446-
LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {}
1450+
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
14471451

14481452
std::string LLM_KV::operator()(llm_kv kv) const {
1449-
return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
1453+
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
1454+
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
14501455
}
14511456

14521457
std::string LLM_TN_IMPL::str() const {

examples/talk-llama/llama-arch.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ enum llm_kv {
177177
LLM_KV_TOKENIZER_HF_JSON,
178178
LLM_KV_TOKENIZER_RWKV,
179179
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
180+
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
180181
LLM_KV_TOKENIZER_FIM_PRE_ID,
181182
LLM_KV_TOKENIZER_FIM_SUF_ID,
182183
LLM_KV_TOKENIZER_FIM_MID_ID,
@@ -335,9 +336,10 @@ enum llm_tensor_layer {
335336
};
336337

337338
struct LLM_KV {
338-
LLM_KV(llm_arch arch);
339+
LLM_KV(llm_arch arch, const char * suffix = nullptr);
339340

340341
llm_arch arch;
342+
const char * suffix;
341343

342344
std::string operator()(llm_kv kv) const;
343345
};

examples/talk-llama/llama-chat.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
5151
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
5252
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
5353
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 },
54+
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
5455
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
5556
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
5657
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
@@ -115,7 +116,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
115116
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
116117
return LLM_CHAT_TEMPLATE_PHI_3;
117118
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
118-
return LLM_CHAT_TEMPLATE_FALCON_3;
119+
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
119120
} else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
120121
return LLM_CHAT_TEMPLATE_ZEPHYR;
121122
} else if (tmpl_contains("bos_token + message['role']")) {
@@ -152,7 +153,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
152153
return LLM_CHAT_TEMPLATE_MINICPM;
153154
} else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
154155
return LLM_CHAT_TEMPLATE_DEEPSEEK_2;
155-
} else if (tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) {
156+
} else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) {
156157
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
157158
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
158159
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
@@ -440,6 +441,14 @@ int32_t llm_chat_apply_template(
440441
if (add_ass) {
441442
ss << "<|assistant|>";
442443
}
444+
} else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
445+
for (auto message : chat) {
446+
std::string role(message->role);
447+
ss << "<|" << role << "|>" << "\n" << message->content;
448+
}
449+
if (add_ass) {
450+
ss << "<|assistant|>";
451+
}
443452
} else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) {
444453
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
445454
for (auto message : chat) {

examples/talk-llama/llama-chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ enum llm_chat_template {
3131
LLM_CHAT_TEMPLATE_LLAMA_3,
3232
LLM_CHAT_TEMPLATE_CHATGML_3,
3333
LLM_CHAT_TEMPLATE_CHATGML_4,
34+
LLM_CHAT_TEMPLATE_GLMEDGE,
3435
LLM_CHAT_TEMPLATE_MINICPM,
3536
LLM_CHAT_TEMPLATE_EXAONE_3,
3637
LLM_CHAT_TEMPLATE_RWKV_WORLD,

examples/talk-llama/llama-grammar.cpp

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) {
560560
}
561561
}
562562
} catch (const std::exception & err) {
563-
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
563+
fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
564564
rules.clear();
565565
return false;
566566
}
@@ -960,10 +960,28 @@ struct llama_grammar * llama_grammar_init_impl(
960960
// Important: vec_rules has to be moved here, not copied, because stacks contains
961961
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
962962
// then the pointers would be invalidated when the local vec_rules goes out of scope.
963-
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
963+
return new llama_grammar {
964+
vocab,
965+
std::move(vec_rules),
966+
std::move(stacks),
967+
/* .partial_utf8 = */ {},
968+
/* .lazy =*/ false,
969+
/* .awaiting_trigger = */ false,
970+
/* .trigger_buffer = */ "",
971+
/* .trigger_tokens = */ {},
972+
/* .trigger_words = */ {},
973+
};
964974
}
965975

966-
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
976+
struct llama_grammar * llama_grammar_init_impl(
977+
const struct llama_vocab * vocab,
978+
const char * grammar_str,
979+
const char * grammar_root,
980+
bool lazy,
981+
const char ** trigger_words,
982+
size_t num_trigger_words,
983+
const llama_token * trigger_tokens,
984+
size_t num_trigger_tokens) {
967985
llama_grammar_parser parser;
968986

969987
// if there is a grammar, parse it
@@ -1035,10 +1053,31 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
10351053
}
10361054
} while (true);
10371055

1056+
std::vector<llama_token> vec_trigger_tokens;
1057+
std::vector<std::string> vec_trigger_words;
1058+
for (size_t i = 0; i < num_trigger_tokens; i++) {
1059+
GGML_ASSERT(trigger_tokens != nullptr);
1060+
vec_trigger_tokens.push_back(trigger_tokens[i]);
1061+
}
1062+
for (size_t i = 0; i < num_trigger_words; i++) {
1063+
GGML_ASSERT(trigger_words != nullptr);
1064+
vec_trigger_words.push_back(trigger_words[i]);
1065+
}
1066+
10381067
// Important: vec_rules has to be moved here, not copied, because stacks contains
10391068
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
10401069
// then the pointers would be invalidated when the local vec_rules goes out of scope.
1041-
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
1070+
return new llama_grammar {
1071+
vocab,
1072+
std::move(vec_rules),
1073+
std::move(stacks),
1074+
/* .partial_utf8 = */ {},
1075+
/* .lazy = */ lazy,
1076+
/* .awaiting_trigger = */ lazy,
1077+
/* .trigger_buffer = */ "",
1078+
std::move(vec_trigger_tokens),
1079+
std::move(vec_trigger_words),
1080+
};
10421081
}
10431082

10441083
void llama_grammar_free_impl(struct llama_grammar * grammar) {
@@ -1055,6 +1094,11 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
10551094
grammar.rules,
10561095
grammar.stacks,
10571096
grammar.partial_utf8,
1097+
grammar.lazy,
1098+
grammar.awaiting_trigger,
1099+
grammar.trigger_buffer,
1100+
grammar.trigger_tokens,
1101+
grammar.trigger_words,
10581102
};
10591103

10601104
// redirect elements in stacks to point to new rules
@@ -1076,6 +1120,10 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
10761120
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
10771121
GGML_ASSERT(grammar.vocab != nullptr);
10781122

1123+
if (grammar.awaiting_trigger) {
1124+
return;
1125+
}
1126+
10791127
bool allow_eog = false;
10801128
for (const auto & stack : grammar.stacks) {
10811129
if (stack.empty()) {
@@ -1115,6 +1163,34 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
11151163
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
11161164
GGML_ASSERT(grammar.vocab != nullptr);
11171165

1166+
const auto & piece = grammar.vocab->token_to_piece(token);
1167+
1168+
if (grammar.awaiting_trigger) {
1169+
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
1170+
grammar.awaiting_trigger = false;
1171+
grammar.trigger_buffer.clear();
1172+
llama_grammar_accept_str(grammar, piece);
1173+
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
1174+
return;
1175+
} else {
1176+
// TODO: consider a smarter incremental substring search algorithm (store last position to search from).
1177+
grammar.trigger_buffer += piece;
1178+
for (const auto & word : grammar.trigger_words) {
1179+
auto pos = grammar.trigger_buffer.find(word);
1180+
if (pos != std::string::npos) {
1181+
grammar.awaiting_trigger = false;
1182+
auto constrained_str = grammar.trigger_buffer.substr(pos);
1183+
grammar.trigger_buffer.clear();
1184+
llama_grammar_accept_str(grammar, constrained_str);
1185+
LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str());
1186+
return;
1187+
}
1188+
}
1189+
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n", token, piece.c_str(), grammar.trigger_buffer.c_str());
1190+
return;
1191+
}
1192+
}
1193+
11181194
if (grammar.vocab->is_eog(token)) {
11191195
for (const auto & stack : grammar.stacks) {
11201196
if (stack.empty()) {
@@ -1124,8 +1200,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
11241200
GGML_ABORT("fatal error");
11251201
}
11261202

1127-
const std::string & piece = grammar.vocab->token_to_piece(token);
1203+
llama_grammar_accept_str(grammar, piece);
1204+
}
11281205

1206+
void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
11291207
// Note terminating 0 in decoded string
11301208
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
11311209
const auto & code_points = decoded.first;
@@ -1135,5 +1213,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
11351213
}
11361214

11371215
grammar.partial_utf8 = decoded.second;
1138-
GGML_ASSERT(!grammar.stacks.empty());
1216+
if (grammar.stacks.empty()) {
1217+
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
1218+
}
11391219
}

examples/talk-llama/llama-grammar.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,15 @@ struct llama_grammar {
114114

115115
// buffer for partially generated UTF-8 sequence from accepted tokens
116116
llama_partial_utf8 partial_utf8;
117+
118+
// lazy grammars wait for trigger words or tokens before constraining the sampling.
119+
// we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
120+
// (useful e.g. for tool_choice=required)
121+
bool lazy = false;
122+
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
123+
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
124+
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
125+
std::vector<std::string> trigger_words;
117126
};
118127

119128
//
@@ -127,7 +136,15 @@ struct llama_grammar * llama_grammar_init_impl(
127136
size_t n_rules,
128137
size_t start_rule_index);
129138

130-
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
139+
struct llama_grammar * llama_grammar_init_impl(
140+
const struct llama_vocab * vocab,
141+
const char * grammar_str,
142+
const char * grammar_root,
143+
bool lazy,
144+
const char ** trigger_words,
145+
size_t num_trigger_words,
146+
const llama_token * trigger_tokens,
147+
size_t num_trigger_tokens);
131148

132149
void llama_grammar_free_impl(struct llama_grammar * grammar);
133150

@@ -141,3 +158,7 @@ void llama_grammar_apply_impl(
141158
void llama_grammar_accept_impl(
142159
struct llama_grammar & grammar,
143160
llama_token token);
161+
162+
void llama_grammar_accept_str(
163+
struct llama_grammar & grammar,
164+
const std::string & piece);

examples/talk-llama/llama-mmap.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <cstring>
88
#include <climits>
99
#include <stdexcept>
10+
#include <cerrno>
1011

1112
#ifdef __has_include
1213
#if __has_include(<unistd.h>)

0 commit comments

Comments
 (0)