@@ -16,6 +16,7 @@ std::string common_chat_format_name(common_chat_format format) {
1616        case  COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return  " Functionary v3.2"  ;
1717        case  COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return  " Functionary v3.1 Llama 3.1"  ;
1818        case  COMMON_CHAT_FORMAT_HERMES_2_PRO: return  " Hermes 2 Pro"  ;
19+         case  COMMON_CHAT_FORMAT_COMMAND_R7B: return  " Command R7B"  ;
1920        default :
2021            throw  std::runtime_error (" Unknown chat format"  );
2122    }
@@ -317,6 +318,79 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
317318    return  parse_prefixed_json_tool_call_array (input, " [TOOL_CALLS]"  );
318319}
319320
321+ static  common_chat_params common_chat_params_init_command_r7b (const  common_chat_template & tmpl, const  struct  common_chat_inputs  & inputs) {
322+     common_chat_params data;
323+     data.grammar_lazy  = inputs.tool_choice  != " required"  ;
324+     data.grammar  = build_grammar ([&](const  common_grammar_builder & builder) {
325+         auto  schemas = json::array ();
326+         foreach_function (inputs.tools , [&](const  json & tool) {
327+             const  auto  & function = tool[" function"  ];
328+             schemas.push_back ({
329+                 {" type"  , " object"  },
330+                 {" properties"  , {
331+                     {" tool_call_id"  , {
332+                         {" type"  , " string"  },
333+                         //  Command-R's template expects an integer string.
334+                         {" pattern"  , " ^[0-9]{1,10}$"  },
335+                     }},
336+                     {" tool_name"  , {
337+                         {" type"  , " string"  },
338+                         {" const"  , function[" name"  ]},
339+                     }},
340+                     {" parameters"  , function[" parameters"  ]},
341+                 }},
342+                 {" required"  , json::array ({" tool_call_id"  , " tool_name"  , " parameters"  })},
343+             });
344+         });
345+         auto  schema = json {
346+             {" type"  , " array"  },
347+             {" items"  , schemas.size () == 1  ? schemas[0 ] : json {{" anyOf"  , schemas}}},
348+             {" minItems"  , 1 },
349+         };
350+         if  (!inputs.parallel_tool_calls ) {
351+             schema[" maxItems"  ] = 1 ;
352+         }
353+         builder.add_rule (" root"  , " \" <|START_ACTION|>\"  "   + builder.add_schema (" tool_calls"  , schema) + "  \" <|END_ACTION|>\" "  );
354+     }, grammar_options);
355+     data.grammar_triggers .push_back ({" <|START_ACTION|>"  , /*  .at_start = */   false });
356+     data.preserved_tokens  = {
357+         " <|START_RESPONSE|>"  ,
358+         " <|END_RESPONSE|>"  ,
359+         " <|START_THINKING|>"  ,
360+         " <|END_THINKING|>"  ,
361+         " <|END_ACTION|>"  ,
362+     };
363+     data.prompt  = tmpl.apply (inputs.messages , inputs.tools .empty () ? json () : inputs.tools , inputs.add_generation_prompt );
364+     data.format  = COMMON_CHAT_FORMAT_COMMAND_R7B;
365+     return  data;
366+ }
367+ static  common_chat_msg common_chat_parse_command_r7b (const  std::string & input) {
368+     static  std::regex response_regex (" <\\ |START_RESPONSE\\ |>(.*?)<\\ |END_RESPONSE\\ |>"  );
369+     static  std::regex thought_action_regex (" <\\ |START_THINKING\\ |>([\\ s\\ S\\ n\\ r]*?)<\\ |END_THINKING\\ |><\\ |START_ACTION\\ |>([\\ s\\ S\\ n\\ r]*?)<\\ |END_ACTION\\ |>"  );
370+     std::smatch match;
371+ 
372+     common_chat_msg result;
373+     result.role  = " assistant"  ;
374+     if  (std::regex_match (input, match, response_regex)) {
375+         result.content  = match[1 ].str ();
376+     } else  if  (std::regex_match (input, match, thought_action_regex)) {
377+         result.tool_plan  = match[1 ].str ();
378+         auto  actions_str = match[2 ].str ();
379+         auto  actions = json::parse (actions_str);
380+         for  (const  auto  & action : actions) {
381+             result.tool_calls .push_back ({
382+                 /*  .name = */        action[" tool_name"  ],
383+                 /*  .arguments = */   action[" parameters"  ].dump (),
384+                 /*  .id = */          action[" tool_call_id"  ],
385+             });
386+         }
387+     } else  {
388+         LOG_ERR (" Failed to parse command_r output"  );
389+         result.content  = input;
390+     }
391+     return  result;
392+ }
393+ 
320394static  void  expect_tool_parameters (const  std::string & name, const  json & parameters, const  std::vector<std::string> & expected_properties) {
321395    if  (!parameters.is_object () || !parameters.contains (" type"  ) || parameters[" type"  ] != " object"   || !parameters.contains (" properties"  ) || !parameters.contains (" required"  )) {
322396        throw  std::runtime_error (" Parameters of tool "   + name + "  must be an object w/ required properties"  );
@@ -462,6 +536,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
462536                " \" <|tool▁call▁begin|>function<|tool▁sep|>"   + name + " \\ n```json\\ n\"  "   + args_rule + "  \" ```<|tool▁call▁end|>\" "  ));
463537        });
464538        data.grammar_triggers .push_back ({" <|tool▁calls▁begin|>"  , /*  .at_start = */   false });
539+         data.preserved_tokens  = {
540+             " <|tool▁sep|>"  ,
541+             " <|tool▁call▁end|>"  ,
542+         };
465543        builder.add_rule (" root"  , " \" <|tool▁calls▁begin|>\"  ("   + string_join (tool_rules, "  | "  ) + " )"   + (inputs.parallel_tool_calls  ? " *"   : " "  ) + "  space"  );
466544    }, grammar_options);
467545    data.prompt  = tmpl.apply (inputs.messages , inputs.tools .empty () ? json () : inputs.tools , inputs.add_generation_prompt );
@@ -704,8 +782,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
704782        auto  tool_call = " \" <tool_call>\"  space "   + builder.add_rule (" tool_call"  , string_join (tool_rules, "  | "  )) + "  \" </tool_call>\"  space"  ;
705783        builder.add_rule (" root"  , inputs.parallel_tool_calls  ? " ("   + tool_call + " )+"   : tool_call);
706784        data.grammar_triggers .push_back ({" <tool_call>"  , /*  .at_start = */   false });
707-         //  Not really a trigger but need to print this special token to get a successful parse.
708-         data.grammar_triggers .push_back ({" </tool_call>"  , /*  .at_start = */   false });
785+         data.preserved_tokens  = { " </tool_call>"   };
709786    }, grammar_options);
710787
711788    data.prompt  = tmpl.apply (inputs.messages , inputs.tools .empty () ? json () : inputs.tools , inputs.add_generation_prompt );
@@ -822,6 +899,9 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
822899    if  (src.find (" [TOOL_CALLS]"  ) != std::string::npos) {
823900        return  common_chat_params_init_mistral_nemo (tmpl, inputs);
824901    }
902+     if  (src.find (" <|END_THINKING|><|START_ACTION|>"  ) != std::string::npos) {
903+         return  common_chat_params_init_command_r7b (tmpl, inputs);
904+     }
825905    return  common_chat_params_init_generic (tmpl, inputs);
826906}
827907
@@ -855,6 +935,8 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
855935            return  common_chat_parse_hermes_2_pro (input);
856936        case  COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
857937            return  common_chat_parse_firefunction_v2 (input);
938+         case  COMMON_CHAT_FORMAT_COMMAND_R7B:
939+             return  common_chat_parse_command_r7b (input);
858940        default :
859941            throw  std::runtime_error (" Unsupported format: "   + common_chat_format_name (format));
860942    }
0 commit comments