Skip to content

Commit f97bbdd

Browse files
committed
fix to allow all EOGs to trigger a stop, occam's glm4 fix,
1 parent bd7a40f commit f97bbdd

File tree

6 files changed

+54
-22
lines changed

6 files changed

+54
-22
lines changed

gpttype_adapter.cpp

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -301,14 +301,29 @@ static int GetEosID(FileFormat file_format, int32_t n_vocab)
301301
}
302302
return eosID;
303303
}
304-
static int GetEotID(FileFormat file_format)
304+
305+
static std::vector<int> GetEogIDs(FileFormat file_format, int32_t n_vocab)
305306
{
307+
std::vector<int> alleogs;
308+
int eos = GetEosID(file_format, n_vocab);
306309
if(file_format == FileFormat::GGUF_GENERIC)
307310
{
308311
const llama_vocab * tmpvocab = llama_model_get_vocab(llama_get_model(llama_ctx_v4));
309-
return llama_vocab_eot(tmpvocab);
312+
int eot = llama_vocab_eot(tmpvocab);
313+
std::set<int> eogs = tmpvocab->get_eogs();
314+
if (eot >= 0) {
315+
eogs.insert(eot);
316+
}
317+
if (eos >= 0) {
318+
eogs.insert(eos);
319+
}
320+
alleogs = std::vector<int>(eogs.begin(), eogs.end());
321+
} else {
322+
if (eos >= 0) {
323+
alleogs.push_back(eos);
324+
}
310325
}
311-
return -1;
326+
return alleogs;
312327
}
313328

314329
static float LowestLogit(const std::vector<float> & logits)
@@ -1550,16 +1565,16 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
15501565
}
15511566
}
15521567

1553-
const llama_token eos = GetEosID(file_format,n_vocab);
1554-
const llama_token eot = GetEotID(file_format);
1568+
const std::vector<llama_token> eog_tokens = GetEogIDs(file_format,n_vocab);
15551569

15561570
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
15571571
std::vector<llama_grammar_candidate> candidates_grammar;
15581572

15591573
for (size_t i = 0; i < candidates->size; ++i) {
15601574
const llama_token id = candidates->data[i].id;
15611575
const std::string piece = FileFormatTokenizeID(id,file_format);
1562-
if (id == eos || (id==eot && id!=-1)) {
1576+
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), id) != eog_tokens.end();
1577+
if (found_eog) {
15631578
if (!allow_eos) {
15641579
candidates->data[i].logit = -INFINITY;
15651580
}
@@ -1711,7 +1726,9 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
17111726

17121727
static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct llama_grammar * grammar, llama_token token)
17131728
{
1714-
if (token == GetEosID(file_format,n_vocab) || (token!=-1 && token == GetEotID(file_format))) {
1729+
const std::vector<llama_token> eog_tokens = GetEogIDs(file_format,n_vocab);
1730+
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), token) != eog_tokens.end();
1731+
if (found_eog) {
17151732
for (const auto & stack : grammar->stacks) {
17161733
if (stack.empty()) {
17171734
return;
@@ -3827,8 +3844,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
38273844
}
38283845
}
38293846

3830-
unsigned int eosID = GetEosID(file_format, n_vocab);
3831-
unsigned int eotID = GetEotID(file_format);
3847+
const std::vector<llama_token> eog_tokens = GetEogIDs(file_format,n_vocab);
38323848
float * logitsPtr;
38333849
float lowestLogit = 0;
38343850
int btsize = banned_token_ids.size();
@@ -3886,13 +3902,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
38863902
if (!inputs.allow_eos_token && !inputs.bypass_eos_token)
38873903
{
38883904
// set the logit of the eos token to very low to avoid sampling it
3889-
if(eosID!=LLAMA_TOKEN_NULL)
3890-
{
3891-
logitsPtr[eosID] = lowestLogit;
3892-
}
3893-
if(eotID!=-1)
3905+
for(int i=0;i<eog_tokens.size();++i)
38943906
{
3895-
logitsPtr[eotID] = lowestLogit;
3907+
logitsPtr[eog_tokens[i]] = lowestLogit;
38963908
}
38973909
}
38983910
if(btsize>0)
@@ -3958,7 +3970,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
39583970
for (auto eid : embd)
39593971
{
39603972
std::string tokenizedstr = FileFormatTokenizeID(eid, file_format, inputs.render_special);
3961-
if(!inputs.render_special && (eid==eosID || (eid==eotID && eid!=-1) || VecContainsIntVal(special_stop_sequence,id))) //extra filter to avoid unwanted special tokens
3973+
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), eid) != eog_tokens.end();
3974+
if(!inputs.render_special && (found_eog || VecContainsIntVal(special_stop_sequence,id))) //extra filter to avoid unwanted special tokens
39623975
{
39633976
tokenizedstr = ""; //prevent render
39643977
}
@@ -4059,7 +4072,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
40594072

40604073
if(!early_abort)
40614074
{
4062-
if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1)))
4075+
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), id) != eog_tokens.end();
4076+
if(!inputs.bypass_eos_token && inputs.allow_eos_token && found_eog)
40634077
{
40644078
if(allow_regular_prints)
40654079
{

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <stdint.h>
1313
#include <stdio.h>
1414
#include <stdbool.h>
15+
#include <set>
1516

1617
#ifdef LLAMA_SHARED
1718
# if defined(_WIN32) && !defined(__MINGW32__)
@@ -941,6 +942,8 @@ extern "C" {
941942
LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
942943
LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
943944

945+
LLAMA_API std::set<int> llama_vocab_get_eogs(const struct llama_vocab * vocab);
946+
944947
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
945948
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
946949

koboldcpp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
dry_seq_break_max = 128
5353

5454
# global vars
55-
KcppVersion = "1.92"
55+
KcppVersion = "1.92.1"
5656
showdebug = True
5757
kcpp_instance = None #global running instance
5858
global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False}

src/llama-graph.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,6 +1287,10 @@ ggml_tensor * llm_graph_context::build_attn(
12871287

12881288
if (wo) {
12891289
cur = build_lora_mm(wo, cur);
1290+
if (arch == LLM_ARCH_GLM4) {
1291+
// GLM4 seems to have numerical issues with half-precision accumulators
1292+
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1293+
}
12901294
}
12911295

12921296
if (wo_b) {
@@ -1367,10 +1371,6 @@ ggml_tensor * llm_graph_context::build_attn(
13671371

13681372
if (wo) {
13691373
cur = build_lora_mm(wo, cur);
1370-
if (arch == LLM_ARCH_GLM4) {
1371-
// GLM4 seems to have numerical issues with half-precision accumulators
1372-
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1373-
}
13741374
}
13751375

13761376
if (wo_b) {

src/llama-vocab.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,6 +1538,7 @@ struct llama_vocab::impl {
15381538
bool is_user_defined(llama_token id) const;
15391539
bool is_unused (llama_token id) const;
15401540
bool is_eog (llama_token id) const;
1541+
std::set<int> get_eogs() const;
15411542

15421543
uint8_t token_to_byte(llama_token id) const;
15431544

@@ -2396,6 +2397,10 @@ bool llama_vocab::impl::is_eog(llama_token id) const {
23962397
return id != LLAMA_TOKEN_NULL && special_eog_ids.count(id) > 0;
23972398
}
23982399

2400+
std::set<int> llama_vocab::impl::get_eogs() const {
2401+
return special_eog_ids;
2402+
}
2403+
23992404
uint8_t llama_vocab::impl::token_to_byte(llama_token id) const {
24002405
GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
24012406
GGML_ASSERT(is_byte(id));
@@ -3121,6 +3126,10 @@ bool llama_vocab::is_eog(llama_token id) const {
31213126
return pimpl->is_eog(id);
31223127
}
31233128

3129+
std::set<int> llama_vocab::get_eogs() const {
3130+
return pimpl->get_eogs();
3131+
}
3132+
31243133
uint8_t llama_vocab::token_to_byte(llama_token id) const {
31253134
return pimpl->token_to_byte(id);
31263135
}
@@ -3431,6 +3440,11 @@ llama_token llama_vocab_eot(const struct llama_vocab * vocab) {
34313440
return vocab->token_eot();
34323441
}
34333442

3443+
std::set<int> llama_vocab_get_eogs(const struct llama_vocab * vocab)
3444+
{
3445+
return vocab->get_eogs();
3446+
}
3447+
34343448
// deprecated
34353449
llama_token llama_vocab_cls(const struct llama_vocab * vocab) {
34363450
return vocab->token_bos();

src/llama-vocab.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ struct llama_vocab {
4040
bool is_user_defined(llama_token id) const;
4141
bool is_unused (llama_token id) const;
4242
bool is_eog (llama_token id) const;
43+
std::set<int> get_eogs() const;
4344

4445
uint8_t token_to_byte(llama_token id) const;
4546
llama_token byte_to_token(uint8_t ch) const;

0 commit comments

Comments
 (0)