@@ -131,9 +131,22 @@ struct slot_params {
131131 lora.push_back ({{" id" , i}, {" scale" , this ->lora [i].scale }});
132132 }
133133
134- std::vector<std::string> grammar_trigger_words;
135- for (const auto & trigger : sampling.grammar_trigger_words ) {
136- grammar_trigger_words.push_back (trigger.word );
134+ auto grammar_triggers = json::array ();
135+ for (const auto & trigger : sampling.grammar_triggers ) {
136+ switch (trigger.type ) {
137+ case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
138+ grammar_triggers.push_back ({{" word" , std::get<std::string>(trigger.value )}});
139+ break ;
140+ case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
141+ grammar_triggers.push_back ({{" pattern" , std::get<std::string>(trigger.value )}});
142+ break ;
143+ case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
144+ grammar_triggers.push_back ({{" pattern_start" , std::get<std::string>(trigger.value )}});
145+ break ;
146+ case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
147+ grammar_triggers.push_back ({{" token" , std::get<llama_token>(trigger.value )}});
148+ break ;
149+ }
137150 }
138151
139152 return json {
@@ -170,8 +183,8 @@ struct slot_params {
170183 {" n_probs" , sampling.n_probs },
171184 {" min_keep" , sampling.min_keep },
172185 {" grammar" , sampling.grammar },
173- {" grammar_trigger_words " , grammar_trigger_words },
174- {" grammar_trigger_tokens " , sampling. grammar_trigger_tokens },
186+ {" grammar_lazy " , sampling. grammar_lazy },
187+ {" grammar_triggers " , grammar_triggers },
175188 {" preserved_tokens" , sampling.preserved_tokens },
176189 {" chat_format" , common_chat_format_name (oaicompat_chat_format)},
177190 {" samplers" , samplers},
@@ -356,24 +369,6 @@ struct server_task {
356369 }
357370
358371 {
359- const auto grammar_triggers = data.find (" grammar_triggers" );
360- if (grammar_triggers != data.end ()) {
361- for (const auto & t : *grammar_triggers) {
362- common_grammar_trigger trigger;
363- trigger.word = t.at (" word" );
364- trigger.at_start = t.at (" at_start" );
365-
366- auto ids = common_tokenize (vocab, trigger.word , /* add_special= */ false , /* parse_special= */ true );
367- if (ids.size () == 1 ) {
368- SRV_DBG (" Grammar trigger token: %d (`%s`)\n " , ids[0 ], trigger.word .c_str ());
369- params.sampling .grammar_trigger_tokens .push_back (ids[0 ]);
370- params.sampling .preserved_tokens .insert (ids[0 ]);
371- continue ;
372- }
373- SRV_DBG (" Grammar trigger word: `%s`\n " , trigger.word .c_str ());
374- params.sampling .grammar_trigger_words .push_back (trigger);
375- }
376- }
377372 const auto preserved_tokens = data.find (" preserved_tokens" );
378373 if (preserved_tokens != data.end ()) {
379374 for (const auto & t : *preserved_tokens) {
@@ -383,12 +378,48 @@ struct server_task {
383378 params.sampling .preserved_tokens .insert (ids[0 ]);
384379 } else {
385380 // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
386- SRV_WRN (" Not preserved because more than 1 token (wrong chat template override?): %s\n " , t.get <std::string>().c_str ());
381+ SRV_DBG (" Not preserved because more than 1 token: %s\n " , t.get <std::string>().c_str ());
382+ }
383+ }
384+ }
385+ const auto grammar_triggers = data.find (" grammar_triggers" );
386+ if (grammar_triggers != data.end ()) {
387+ for (const auto & t : *grammar_triggers) {
388+ auto type = static_cast <common_grammar_trigger_type>(t.at (" type" ));
389+ switch (type) {
390+ case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
391+ {
392+ const std::string & word = t.at (" value" );
393+ auto ids = common_tokenize (vocab, word, /* add_special= */ false , /* parse_special= */ true );
394+ if (ids.size () == 1 ) {
395+ auto token = ids[0 ];
396+ if (std::find (params.sampling .preserved_tokens .begin (), params.sampling .preserved_tokens .end (), token) == params.sampling .preserved_tokens .end ()) {
397+ throw std::runtime_error (" Grammar trigger word should be marked as preserved token: " + word);
398+ }
399+ SRV_DBG (" Grammar trigger token: %d (`%s`)\n " , token, word.c_str ());
400+ params.sampling .grammar_triggers .push_back ({COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN, token});
401+ } else {
402+ SRV_DBG (" Grammar trigger word: `%s`\n " , word.c_str ());
403+ params.sampling .grammar_triggers .push_back ({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
404+ }
405+ break ;
406+ }
407+ case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
408+ case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
409+ {
410+ const std::string & pattern = t.at (" value" );
411+ params.sampling .grammar_triggers .push_back ({type, pattern});
412+ break ;
413+ }
414+ case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
415+ throw std::runtime_error (" Unespected token trigger" );
416+ default :
417+ throw std::runtime_error (" Unknown trigger type" );
387418 }
388419 }
389420 }
390421 if (params.sampling .grammar_lazy ) {
391- GGML_ASSERT (params.sampling .grammar_trigger_tokens . size () > 0 || params. sampling . grammar_trigger_words .size () > 0 );
422+ GGML_ASSERT (params.sampling .grammar_triggers .size () > 0 );
392423 }
393424 }
394425
@@ -2045,7 +2076,7 @@ struct server_context {
20452076
20462077 if (slot.n_predict > 0 && slot.params .n_predict > slot.n_predict ) {
20472078 // Might be better to reject the request with a 400 ?
2048- SLT_WRN (slot, " n_predict = %d exceeds server configuration, setting to %d" , slot.params .n_predict , slot.n_predict );
2079+ SLT_WRN (slot, " n_predict = %d exceeds server configuration, setting to %d\n " , slot.params .n_predict , slot.n_predict );
20492080 slot.params .n_predict = slot.n_predict ;
20502081 }
20512082
0 commit comments