@@ -301,14 +301,29 @@ static int GetEosID(FileFormat file_format, int32_t n_vocab)
301301 }
302302 return eosID;
303303}
304- static int GetEotID (FileFormat file_format)
304+
305+ static std::vector<int > GetEogIDs (FileFormat file_format, int32_t n_vocab)
305306{
307+ std::vector<int > alleogs;
308+ int eos = GetEosID (file_format, n_vocab);
306309 if (file_format == FileFormat::GGUF_GENERIC)
307310 {
308311 const llama_vocab * tmpvocab = llama_model_get_vocab (llama_get_model (llama_ctx_v4));
309- return llama_vocab_eot (tmpvocab);
312+ int eot = llama_vocab_eot (tmpvocab);
313+ std::set<int > eogs = tmpvocab->get_eogs ();
314+ if (eot >= 0 ) {
315+ eogs.insert (eot);
316+ }
317+ if (eos >= 0 ) {
318+ eogs.insert (eos);
319+ }
320+ alleogs = std::vector<int >(eogs.begin (), eogs.end ());
321+ } else {
322+ if (eos >= 0 ) {
323+ alleogs.push_back (eos);
324+ }
310325 }
311- return - 1 ;
326+ return alleogs ;
312327}
313328
314329static float LowestLogit (const std::vector<float > & logits)
@@ -1550,16 +1565,16 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
15501565 }
15511566 }
15521567
1553- const llama_token eos = GetEosID (file_format,n_vocab);
1554- const llama_token eot = GetEotID (file_format);
1568+ const std::vector<llama_token> eog_tokens = GetEogIDs (file_format,n_vocab);
15551569
15561570 std::vector<std::pair<std::vector<uint32_t >, llama_partial_utf8>> candidates_decoded;
15571571 std::vector<llama_grammar_candidate> candidates_grammar;
15581572
15591573 for (size_t i = 0 ; i < candidates->size ; ++i) {
15601574 const llama_token id = candidates->data [i].id ;
15611575 const std::string piece = FileFormatTokenizeID (id,file_format);
1562- if (id == eos || (id==eot && id!=-1 )) {
1576+ bool found_eog = std::find (eog_tokens.begin (), eog_tokens.end (), id) != eog_tokens.end ();
1577+ if (found_eog) {
15631578 if (!allow_eos) {
15641579 candidates->data [i].logit = -INFINITY;
15651580 }
@@ -1711,7 +1726,9 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
17111726
17121727static void grammar_accept_token (FileFormat file_format, int32_t n_vocab, struct llama_grammar * grammar, llama_token token)
17131728{
1714- if (token == GetEosID (file_format,n_vocab) || (token!=-1 && token == GetEotID (file_format))) {
1729+ const std::vector<llama_token> eog_tokens = GetEogIDs (file_format,n_vocab);
1730+ bool found_eog = std::find (eog_tokens.begin (), eog_tokens.end (), token) != eog_tokens.end ();
1731+ if (found_eog) {
17151732 for (const auto & stack : grammar->stacks ) {
17161733 if (stack.empty ()) {
17171734 return ;
@@ -3827,8 +3844,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
38273844 }
38283845 }
38293846
3830- unsigned int eosID = GetEosID (file_format, n_vocab);
3831- unsigned int eotID = GetEotID (file_format);
3847+ const std::vector<llama_token> eog_tokens = GetEogIDs (file_format,n_vocab);
38323848 float * logitsPtr;
38333849 float lowestLogit = 0 ;
38343850 int btsize = banned_token_ids.size ();
@@ -3886,13 +3902,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
38863902 if (!inputs.allow_eos_token && !inputs.bypass_eos_token )
38873903 {
38883904 // set the logit of the eos token to very low to avoid sampling it
3889- if (eosID!=LLAMA_TOKEN_NULL)
3890- {
3891- logitsPtr[eosID] = lowestLogit;
3892- }
3893- if (eotID!=-1 )
3905+ for (int i=0 ;i<eog_tokens.size ();++i)
38943906 {
3895- logitsPtr[eotID ] = lowestLogit;
3907+ logitsPtr[eog_tokens[i] ] = lowestLogit;
38963908 }
38973909 }
38983910 if (btsize>0 )
@@ -3958,7 +3970,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
39583970 for (auto eid : embd)
39593971 {
39603972 std::string tokenizedstr = FileFormatTokenizeID (eid, file_format, inputs.render_special );
3961- if (!inputs.render_special && (eid==eosID || (eid==eotID && eid!=-1 ) || VecContainsIntVal (special_stop_sequence,id))) // extra filter to avoid unwanted special tokens
3973+ bool found_eog = std::find (eog_tokens.begin (), eog_tokens.end (), eid) != eog_tokens.end ();
3974+ if (!inputs.render_special && (found_eog || VecContainsIntVal (special_stop_sequence,id))) // extra filter to avoid unwanted special tokens
39623975 {
39633976 tokenizedstr = " " ; // prevent render
39643977 }
@@ -4059,7 +4072,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
40594072
40604073 if (!early_abort)
40614074 {
4062- if (!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1 )))
4075+ bool found_eog = std::find (eog_tokens.begin (), eog_tokens.end (), id) != eog_tokens.end ();
4076+ if (!inputs.bypass_eos_token && inputs.allow_eos_token && found_eog)
40634077 {
40644078 if (allow_regular_prints)
40654079 {
0 commit comments