@@ -211,7 +211,6 @@ struct server_task {
211211 static slot_params params_from_json_cmpl (
212212 const llama_context * ctx,
213213 const common_params & params_base,
214- const common_chat_template * tmpl,
215214 const json & data) {
216215 const llama_model * model = llama_get_model (ctx);
217216 const llama_vocab * vocab = llama_model_get_vocab (model);
@@ -330,30 +329,19 @@ struct server_task {
330329 }
331330 }
332331
333- if (tmpl && params_base.use_jinja ) {
334- common_chat_params chat_params;
335- chat_params.messages = json_value (data, " messages" , json::array ());
336- chat_params.tools = json_value (data, " tools" , json ());
337- chat_params.tool_choice = json_value (data, " tool_choice" , std::string (" auto" ));
338- chat_params.json_schema = json_value (data, " json_schema" , json ());
339- chat_params.parallel_tool_calls = json_value (data, " parallel_tool_calls" , false );
340- chat_params.stream = json_value (data, " stream" , false );
341-
342- auto chat_data = common_chat_init (*tmpl, chat_params);
343- params.chat_parser = std::move (chat_data.handler );
344- params.sampling .grammar = chat_data.grammar ;
345- for (const auto & stop : chat_data.additional_stops ) {
346- params.antiprompt .push_back (stop);
332+ if (!params_base.use_jinja ) {
333+ if (data.contains (" json_schema" ) && !data.at (" json_schema" ).is_null () && data.contains (" grammar" ) && !data.at (" grammar" ).is_null ()) {
334+ throw std::runtime_error (" Either \" json_schema\" or \" grammar\" can be specified, but not both" );
347335 }
348- for ( const auto & trigger : chat_data. grammar_triggers ) {
349- auto ids = common_tokenize (vocab, trigger. word , /* add_special= */ false , /* parse_special= */ true );
350- if (ids. size () == 1 ) {
351- LOG_INF ( " Grammar trigger token: %s (%d) \n " , trigger. word . c_str (), ids[ 0 ] );
352- params. sampling . grammar_trigger_tokens . push_back (ids[ 0 ]);
353- continue ;
336+ if (data. contains ( " json_schema " ) && !data. contains ( " grammar " ) ) {
337+ try {
338+ auto schema = json_value (data, " json_schema " , json::object ());
339+ params. sampling . grammar = json_schema_to_grammar (schema );
340+ } catch ( const std::exception & e) {
341+ throw std::runtime_error ( std::string ( " \" json_schema \" : " ) + e. what ()) ;
354342 }
355- LOG_INF ( " Grammar trigger word: %s \n " , trigger. word . c_str ());
356- params.sampling .grammar_trigger_words . push_back (trigger );
343+ } else {
344+ params.sampling .grammar = json_value (data, " grammar " , defaults. sampling . grammar );
357345 }
358346 }
359347
@@ -363,15 +351,13 @@ struct server_task {
363351 }
364352 if (data.contains (" json_schema" ) && !data.contains (" grammar" )) {
365353 try {
366- auto schema = json_value (data, " json_schema" , json::object ());
367- params.sampling .grammar = json_schema_to_grammar (schema);
354+ params.sampling .grammar = json_schema_to_grammar (json_value (data, " json_schema" , json::object ()));
368355 } catch (const std::exception & e) {
369356 throw std::runtime_error (std::string (" \" json_schema\" : " ) + e.what ());
370357 }
371358 } else {
372359 params.sampling .grammar = json_value (data, " grammar" , defaults.sampling .grammar );
373360 }
374- LOG_INF (" Grammar: %s\n " , params.sampling .grammar .c_str ());
375361
376362 {
377363 params.sampling .logit_bias .clear ();
@@ -2248,9 +2234,15 @@ struct server_context {
22482234 }
22492235
22502236 void send_partial_response (server_slot & slot, const completion_token_output & tkn) {
2251- auto opt_msg = slot.params .chat_parser ->parse_partial (tkn.text_to_send );
2252- if (!opt_msg) {
2253- return ;
2237+ common_chat_msg msg;
2238+ if (slot.params .chat_parser ) {
2239+ if (auto opt_msg = slot.params .chat_parser ->parse_partial (tkn.text_to_send )) {
2240+ msg = *opt_msg;
2241+ } else {
2242+ return ;
2243+ }
2244+ } else {
2245+ msg.content = tkn.text_to_send ;
22542246 }
22552247 auto res = std::make_unique<server_task_result_cmpl_partial>();
22562248
@@ -2267,7 +2259,7 @@ struct server_context {
22672259 res->oaicompat = slot.params .oaicompat ;
22682260 res->oaicompat_model = slot.params .oaicompat_model ;
22692261 res->oaicompat_cmpl_id = slot.params .oaicompat_cmpl_id ;
2270- res->oaicompat_chat_msg = *opt_msg ;
2262+ res->oaicompat_chat_msg = msg ;
22712263
22722264 // populate res.probs_output
22732265 if (slot.params .sampling .n_probs > 0 ) {
@@ -2308,7 +2300,11 @@ struct server_context {
23082300 res->oaicompat = slot.params .oaicompat ;
23092301 res->oaicompat_model = slot.params .oaicompat_model ;
23102302 res->oaicompat_cmpl_id = slot.params .oaicompat_cmpl_id ;
2311- res->oaicompat_chat_msg = slot.params .chat_parser ->parse_final (slot.generated_text );
2303+ res->oaicompat_chat_msg = slot.params .chat_parser ? slot.params .chat_parser ->parse_final (slot.generated_text ) : common_chat_msg {
2304+ /* .role = */ " assistant" ,
2305+ /* .content = */ slot.generated_text ,
2306+ /* .tool_calls = */ {}
2307+ };
23122308
23132309 // populate res.probs_output
23142310 if (slot.params .sampling .n_probs > 0 ) {
@@ -3773,7 +3769,7 @@ int main(int argc, char ** argv) {
37733769 std::function<bool ()> is_connection_closed,
37743770 httplib::Response & res,
37753771 oaicompat_type oaicompat,
3776- const common_chat_template * tmpl) {
3772+ const common_chat_template * tmpl = nullptr ) {
37773773 GGML_ASSERT (type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
37783774
37793775 if (ctx_server.params_base .embedding ) {
@@ -3785,21 +3781,29 @@ int main(int argc, char ** argv) {
37853781 std::vector<server_task> tasks;
37863782
37873783 try {
3788- fprintf (stderr, " PROMPT: %s\n " , data.at (" prompt" ).get <std::string>().c_str ());
3789- std::string prompt;
3784+ common_chat_data chat_data;
37903785 if (tmpl && ctx_server.params_base .use_jinja ) {
3791- auto chat_data = common_chat_init (*tmpl, {
3792- /* .messages = */ json_data (data, " messages" , json::array ()),
3793- /* .tools = */ json_data (data, " tools" , json ()),
3794- /
3786+ chat_data = common_chat_init (*tmpl, {
3787+ /* .messages = */ json_value (data, " messages" , json::array ()),
3788+ /* .tools = */ json_value (data, " tools" , json ()),
3789+ /* .tool_choice = */ json_value (data, " tool_choice" , std::string (" auto" )),
3790+ /* .json_schema = */ json_value (data, " json_schema" , json ()),
3791+ /* .parallel_tool_calls = */ json_value (data, " json_schema" , true ),
3792+ /* .stream = */ json_value (data, " json_schema" , false ),
3793+ /* .grammar = */ json_value (data, " grammar" , std::string (" " )),
37953794 });
3796-
3797- prompt = ctx_server.chat_templates .template_default ->render (data.at (" prompt" ).get <std::string>());
3795+ if (data.contains (" grammar" )) {
3796+ chat_data.grammar = data.at (" grammar" );
3797+ }
37983798 } else {
3799- prompt = data.at (" prompt" ).get <std::string>();
3799+ chat_data.prompt = data.at (" prompt" );
3800+ if (data.contains (" grammar" )) {
3801+ chat_data.grammar = data.at (" grammar" );
3802+ } else if (data.contains (" json_schema" )) {
3803+ chat_data.grammar = json_schema_to_grammar (data.at (" json_schema" ));
3804+ }
38003805 }
3801- task.params .chat_parser = common_chat_init ()
3802- std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts (ctx_server.vocab , data.at (" prompt" ), true , true );
3806+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts (ctx_server.vocab , chat_data.prompt , true , true );
38033807 tasks.reserve (tokenized_prompts.size ());
38043808 for (size_t i = 0 ; i < tokenized_prompts.size (); i++) {
38053809 server_task task = server_task (type);
@@ -3811,16 +3815,27 @@ int main(int argc, char ** argv) {
38113815 task.params = server_task::params_from_json_cmpl (
38123816 ctx_server.ctx ,
38133817 ctx_server.params_base ,
3814- nullptr ,
38153818 data);
38163819 task.id_selected_slot = json_value (data, " id_slot" , -1 );
38173820
38183821 // OAI-compat
38193822 task.params .oaicompat = oaicompat;
38203823 task.params .oaicompat_cmpl_id = completion_id;
3821- task.params .chat_parser = common_chat_init ()
3822- task.params .oaicompat_tools = json_value (data, " tools" , json ());
3823- task.params .oaicompat_tool_call_style = tool_call_style;
3824+ task.params .sampling .grammar = chat_data.grammar ;
3825+ for (const auto & trigger : chat_data.grammar_triggers ) {
3826+ auto ids = common_tokenize (ctx_server.vocab , trigger.word , /* add_special= */ false , /* parse_special= */ true );
3827+ if (ids.size () == 1 ) {
3828+ LOG_INF (" Grammar trigger token: %s (%d)\n " , trigger.word .c_str (), ids[0 ]);
3829+ task.params .sampling .grammar_trigger_tokens .push_back (ids[0 ]);
3830+ continue ;
3831+ }
3832+ LOG_INF (" Grammar trigger word: %s\n " , trigger.word .c_str ());
3833+ task.params .sampling .grammar_trigger_words .push_back (trigger);
3834+ }
3835+ task.params .antiprompt = chat_data.additional_stops ;
3836+ if (chat_data.parser ) {
3837+ task.params .chat_parser = i == tokenized_prompts.size () ? std::move (chat_data.parser ) : std::move (chat_data.parser ->clone ());
3838+ }
38243839 // oaicompat_model is already populated by params_from_json_cmpl
38253840
38263841 tasks.push_back (task);
@@ -4005,7 +4020,8 @@ int main(int argc, char ** argv) {
40054020 data,
40064021 req.is_connection_closed ,
40074022 res,
4008- OAICOMPAT_TYPE_CHAT);
4023+ OAICOMPAT_TYPE_CHAT,
4024+ &chat_template);
40094025 };
40104026
40114027 const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
0 commit comments