@@ -133,20 +133,7 @@ struct slot_params {
133133
134134 auto grammar_triggers = json::array ();
135135 for (const auto & trigger : sampling.grammar_triggers ) {
136- switch (trigger.type ) {
137- case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
138- grammar_triggers.push_back ({{" word" , trigger.value }});
139- break ;
140- case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
141- grammar_triggers.push_back ({{" pattern" , trigger.value }});
142- break ;
143- case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
144- grammar_triggers.push_back ({{" pattern_start" , trigger.value }});
145- break ;
146- case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
147- grammar_triggers.push_back ({{" token" , trigger.token }});
148- break ;
149- }
136+ grammar_triggers.push_back (trigger.to_json <json>());
150137 }
151138
152139 return json {
@@ -385,44 +372,31 @@ struct server_task {
385372 const auto grammar_triggers = data.find (" grammar_triggers" );
386373 if (grammar_triggers != data.end ()) {
387374 for (const auto & t : *grammar_triggers) {
388- auto type = static_cast <common_grammar_trigger_type>(t.at (" type" ));
389- switch (type) {
390- case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
391- {
392- const std::string & word = t.at (" value" );
393- auto ids = common_tokenize (vocab, word, /* add_special= */ false , /* parse_special= */ true );
394- if (ids.size () == 1 ) {
395- auto token = ids[0 ];
396- if (std::find (params.sampling .preserved_tokens .begin (), params.sampling .preserved_tokens .end (), (llama_token) token) == params.sampling .preserved_tokens .end ()) {
397- throw std::runtime_error (" Grammar trigger word should be marked as preserved token: " + word);
398- }
399- SRV_DBG (" Grammar trigger token: %d (`%s`)\n " , token, word.c_str ());
400- common_grammar_trigger trigger;
401- trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
402- trigger.value = token;
403- params.sampling .grammar_triggers .push_back (trigger);
404- } else {
405- SRV_DBG (" Grammar trigger word: `%s`\n " , word.c_str ());
406- params.sampling .grammar_triggers .push_back ({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
375+ auto ct = common_grammar_trigger::from_json (t);
376+ if (ct.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
377+ const auto & word = ct.value ;
378+ auto ids = common_tokenize (vocab, word, /* add_special= */ false , /* parse_special= */ true );
379+ if (ids.size () == 1 ) {
380+ auto token = ids[0 ];
381+ if (std::find (params.sampling .preserved_tokens .begin (), params.sampling .preserved_tokens .end (), (llama_token) token) == params.sampling .preserved_tokens .end ()) {
382+ throw std::runtime_error (" Grammar trigger word should be marked as preserved token: " + word);
407383 }
408- break ;
409- }
410- case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
411- case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
412- {
413- const std::string & pattern = t. at ( " value " );
414- params. sampling . grammar_triggers . push_back ({type, pattern} );
415- break ;
384+ SRV_DBG ( " Grammar trigger token: %d (`%s`) \n " , token, word. c_str ()) ;
385+ common_grammar_trigger trigger;
386+ trigger. type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
387+ trigger. value = (llama_token) token;
388+ params. sampling . grammar_triggers . push_back (trigger);
389+ } else {
390+ SRV_DBG ( " Grammar trigger word: `%s` \n " , word. c_str () );
391+ params. sampling . grammar_triggers . push_back ({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}) ;
416392 }
417- case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
418- throw std::runtime_error (" Unespected token trigger" );
419- default :
420- throw std::runtime_error (" Unknown trigger type" );
393+ } else {
394+ params.sampling .grammar_triggers .push_back (ct);
421395 }
422396 }
423397 }
424398 if (params.sampling .grammar_lazy ) {
425- GGML_ASSERT (params.sampling .grammar_triggers .size () > 0 );
399+ GGML_ASSERT (! params.sampling .grammar_triggers .empty () );
426400 }
427401 }
428402
0 commit comments