@@ -169,18 +169,18 @@ struct delta_data {
169169};
170170
171171static delta_data init_delta (const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
172- common_chat_inputs params ;
173- params .parallel_tool_calls = true ;
174- params .messages = json::array ();
175- params .messages .push_back (user_message);
176- params .tools = tools;
177- auto prefix_data = common_chat_params_init (tmpl, params );
178- params .messages .push_back (delta_message);
179- params .add_generation_prompt = false ;
180- auto full_data = common_chat_params_init (tmpl, params );
181-
182- std::string prefix = prefix_data .prompt ;
183- std::string full = full_data .prompt ;
172+ common_chat_inputs inputs ;
173+ inputs .parallel_tool_calls = true ;
174+ inputs .messages = json::array ();
175+ inputs .messages .push_back (user_message);
176+ inputs .tools = tools;
177+ auto params_prefix = common_chat_params_init (tmpl, inputs );
178+ inputs .messages .push_back (delta_message);
179+ inputs .add_generation_prompt = false ;
180+ auto params_full = common_chat_params_init (tmpl, inputs );
181+
182+ std::string prefix = params_prefix .prompt ;
183+ std::string full = params_full .prompt ;
184184
185185 // Check full starts with prefix
186186 if (full.find (prefix) != 0 ) {
@@ -203,7 +203,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
203203 break ;
204204 }
205205 }
206- return {delta, full_data .grammar , full_data .parser };
206+ return {delta, params_full .grammar , params_full .parser };
207207}
208208
209209/*
@@ -220,12 +220,6 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
220220 };
221221
222222 for (const auto & tool_choice : json ({" auto" , " required" })) {
223- common_chat_inputs params;
224- params.tool_choice = tool_choice;
225- params.parallel_tool_calls = true ;
226- params.messages = json {user_message, test_message};
227- params.tools = tools;
228-
229223 auto data = init_delta (tmpl, end_tokens, user_message, test_message, tools);
230224 if (!expected_delta.empty ()) {
231225 assert_equals (expected_delta, data.delta );
@@ -309,17 +303,18 @@ static void test_template_output_parsers() {
309303 tools_params.tools .push_back (special_function_tool);
310304
311305 auto describe = [](const common_chat_template & tmpl, const common_chat_inputs & params) {
312- auto data = common_chat_params_init (tmpl, params);
313- return data.format ;
306+ return common_chat_params_init (tmpl, params).format ;
314307 };
315308
316309 {
317- const common_chat_template tmpl (read_file (" models/templates/google-gemma-2-2b-it.jinja" ), " <s>" , " </s>" );
310+ const common_chat_template tmpl (read_file (
311+ " models/templates/google-gemma-2-2b-it.jinja" ), " <s>" , " </s>" );
318312 std::vector<std::string> end_tokens { " <end_of_turn>" };
319313
320314 assert_equals (std::string (" content-only" ), describe (tmpl, no_tools_params));
321315 assert_equals (std::string (" generic tool calls" ), describe (tmpl, tools_params));
322- assert_equals (std::string (" generic tool calls" ), describe (common_chat_template (read_file (" models/templates/microsoft-Phi-3.5-mini-instruct.jinja" ), " <s>" , " </s>" ), tools_params));
316+ assert_equals (std::string (" generic tool calls" ), describe (common_chat_template (read_file (
317+ " models/templates/microsoft-Phi-3.5-mini-instruct.jinja" ), " <s>" , " </s>" ), tools_params));
323318
324319 // Generic tool calls doesn't generate / parse content-only messages symmetrically.
325320 assert_msg_equals (msg_from_json (text_message), common_chat_params_init (tmpl, tools_params).parser (
@@ -340,7 +335,8 @@ static void test_template_output_parsers() {
340335 " }" );
341336 }
342337 {
343- const common_chat_template tmpl (read_file (" models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja" ), " <s>" , " </s>" );
338+ const common_chat_template tmpl (read_file (
339+ " models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja" ), " <s>" , " </s>" );
344340 std::vector<std::string> end_tokens { " </s>" };
345341
346342 assert_equals (std::string (" mistral nemo tool calls" ), describe (tmpl, tools_params));
@@ -351,12 +347,15 @@ static void test_template_output_parsers() {
351347 /* skip_grammar_test= */ true );
352348 }
353349 {
354- const common_chat_template tmpl (read_file (" models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja" ), " <s>" , " </s>" );
350+ const common_chat_template tmpl (read_file (
351+ " models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja" ), " <s>" , " </s>" );
355352 std::vector<std::string> end_tokens { " <|im_end|>" };
356353
357354 assert_equals (std::string (" hermes 2 pro tool calls" ), describe (tmpl, tools_params));
358- assert_equals (std::string (" hermes 2 pro tool calls" ), describe (common_chat_template (read_file (" models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja" ), " <s>" , " </s>" ), tools_params));
359- assert_equals (std::string (" hermes 2 pro tool calls" ), describe (common_chat_template (read_file (" models/templates/Qwen-Qwen2.5-7B-Instruct.jinja" ), " <s>" , " </s>" ), tools_params));
355+ assert_equals (std::string (" hermes 2 pro tool calls" ), describe (common_chat_template (read_file (
356+ " models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja" ), " <s>" , " </s>" ), tools_params));
357+ assert_equals (std::string (" hermes 2 pro tool calls" ), describe (common_chat_template (read_file (
358+ " models/templates/Qwen-Qwen2.5-7B-Instruct.jinja" ), " <s>" , " </s>" ), tools_params));
360359
361360 test_template (tmpl, end_tokens, text_message, tools, " Hello, world!" , /* skip_grammar_test= */ true );
362361 test_template (tmpl, end_tokens, tool_call_message, tools,
@@ -369,11 +368,13 @@ static void test_template_output_parsers() {
369368 " </tool_call>" );
370369 }
371370 {
372- const common_chat_template tmpl (read_file (" models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja" ), " <s>" , " </s>" );
371+ const common_chat_template tmpl (read_file (
372+ " models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja" ), " <s>" , " </s>" );
373373 std::vector<std::string> end_tokens { " <|eom_id|>" , " <|eot_id|>" };
374374
375375 assert_equals (std::string (" llama 3.x tool calls (w/ builtin tools)" ), describe (tmpl, tools_params));
376- assert_equals (std::string (" llama 3.x tool calls (w/ builtin tools)" ), describe (common_chat_template (read_file (" models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja" ), " <s>" , " </s>" ), tools_params));
376+ assert_equals (std::string (" llama 3.x tool calls (w/ builtin tools)" ), describe (common_chat_template (read_file (
377+ " models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja" ), " <s>" , " </s>" ), tools_params));
377378
378379 // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
379380 test_template (tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
@@ -384,7 +385,8 @@ static void test_template_output_parsers() {
384385 " {\" name\" : \" special_function\" , \" parameters\" : {\" arg1\" : 1}}" );
385386 }
386387 {
387- const common_chat_template tmpl (read_file (" models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja" ), " <s>" , " </s>" );
388+ const common_chat_template tmpl (read_file (
389+ " models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja" ), " <s>" , " </s>" );
388390 std::vector<std::string> end_tokens { " <|eom_id|>" , " <|eot_id|>" };
389391
390392 assert_equals (std::string (" llama 3.x tool calls" ), describe (tmpl, tools_params));
@@ -395,7 +397,8 @@ static void test_template_output_parsers() {
395397 " {\" name\" : \" special_function\" , \" parameters\" : {\" arg1\" : 1}}" );
396398 }
397399 {
398- const common_chat_template tmpl (read_file (" models/templates/meetkai-functionary-medium-v3.1.jinja" ), " <s>" , " </s>" );
400+ const common_chat_template tmpl (read_file (
401+ " models/templates/meetkai-functionary-medium-v3.1.jinja" ), " <s>" , " </s>" );
399402 std::vector<std::string> end_tokens { " <|eom_id|>" , " <|eot_id|>" };
400403
401404 assert_equals (std::string (" functionary v3.1 llama 3.1 tool calls" ), describe (tmpl, tools_params));
@@ -406,7 +409,8 @@ static void test_template_output_parsers() {
406409 " <function=special_function>{\" arg1\" : 1}</function>" );
407410 }
408411 {
409- const common_chat_template tmpl (read_file (" models/templates/meetkai-functionary-medium-v3.2.jinja" ), " <s>" , " </s>" );
412+ const common_chat_template tmpl (read_file (
413+ " models/templates/meetkai-functionary-medium-v3.2.jinja" ), " <s>" , " </s>" );
410414 std::vector<std::string> end_tokens { " <|eom_id|>" , " <|eot_id|>" };
411415
412416 assert_equals (std::string (" functionary v3.2 content-only" ), describe (tmpl, no_tools_params));
@@ -420,7 +424,8 @@ static void test_template_output_parsers() {
420424 " {\" arg1\" : 1}" );
421425 }
422426 {
423- const common_chat_template tmpl (read_file (" models/templates/fireworks-ai-llama-3-firefunction-v2.jinja" ), " <s>" , " </s>" );
427+ const common_chat_template tmpl (read_file (
428+ " models/templates/fireworks-ai-llama-3-firefunction-v2.jinja" ), " <s>" , " </s>" );
424429 std::vector<std::string> end_tokens { " <|eot_id|>" };
425430
426431 assert_equals (std::string (" firefunction v2 tool calls" ), describe (tmpl, tools_params));
@@ -431,7 +436,8 @@ static void test_template_output_parsers() {
431436 " functools[{\" name\" : \" special_function\" , \" arguments\" : {\" arg1\" : 1}}]" );
432437 }
433438 {
434- const common_chat_template tmpl (read_file (" models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja" ), " <s>" , " </s>" );
439+ const common_chat_template tmpl (read_file (
440+ " models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja" ), " <s>" , " </s>" );
435441 std::vector<std::string> end_tokens { " <|end▁of▁sentence|>" };
436442
437443 assert_equals (std::string (" deepseek r1 tool calls" ), describe (tmpl, tools_params));
0 commit comments