1111import static org .opensearch .ml .common .utils .StringUtils .processTextDoc ;
1212import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .DISABLE_TRACE ;
1313import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .INTERACTIONS_PREFIX ;
14- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .LLM_FINISH_REASON_PATH ;
15- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .LLM_FINISH_REASON_TOOL_USE ;
16- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .LLM_RESPONSE_FILTER ;
17- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .NO_ESCAPE_PARAMS ;
1814import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .PROMPT_CHAT_HISTORY_PREFIX ;
1915import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .PROMPT_PREFIX ;
2016import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .PROMPT_SUFFIX ;
2117import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .RESPONSE_FORMAT_INSTRUCTION ;
22- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .TOOL_CALLS_PATH ;
23- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .TOOL_CALLS_TOOL_INPUT ;
24- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .TOOL_CALLS_TOOL_NAME ;
2518import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .TOOL_CALL_ID ;
26- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .TOOL_CALL_ID_PATH ;
2719import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .TOOL_RESPONSE ;
28- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .TOOL_TEMPLATE ;
20+ import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .TOOL_RESULT ;
2921import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .VERBOSE ;
3022import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .cleanUpResource ;
3123import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .constructToolParams ;
7971import org .opensearch .ml .common .transport .prediction .MLPredictionTaskRequest ;
8072import org .opensearch .ml .common .utils .StringUtils ;
8173import org .opensearch .ml .engine .encryptor .Encryptor ;
74+ import org .opensearch .ml .engine .function_calling .FunctionCalling ;
75+ import org .opensearch .ml .engine .function_calling .FunctionCallingFactory ;
76+ import org .opensearch .ml .engine .function_calling .LLMMessage ;
8277import org .opensearch .ml .engine .memory .ConversationIndexMemory ;
8378import org .opensearch .ml .engine .memory .ConversationIndexMessage ;
8479import org .opensearch .ml .engine .tools .MLModelTool ;
@@ -117,7 +112,6 @@ public class MLChatAgentRunner implements MLAgentRunner {
117112 public static final String FINAL_ANSWER = "final_answer" ;
118113 public static final String THOUGHT_RESPONSE = "thought_response" ;
119114 public static final String INTERACTIONS = "_interactions" ;
120- public static final String DEFAULT_NO_ESCAPE_PARAMS = "_chat_history,_tools,_interactions,tool_configs" ;
121115 public static final String INTERACTION_TEMPLATE_TOOL_RESPONSE = "interaction_template.tool_response" ;
122116 public static final String CHAT_HISTORY_QUESTION_TEMPLATE = "chat_history_template.user_question" ;
123117 public static final String CHAT_HISTORY_RESPONSE_TEMPLATE = "chat_history_template.ai_response" ;
@@ -170,116 +164,11 @@ public void run(MLAgent mlAgent, Map<String, String> inputParams, ActionListener
170164 params .putAll (inputParams );
171165
172166 String llmInterface = params .get (LLM_INTERFACE );
173- // todo: introduce function calling
174- // handle parameters based on llmInterface
175- if ("openai/v1/chat/completions" .equalsIgnoreCase (llmInterface )) {
176- if (!params .containsKey (NO_ESCAPE_PARAMS )) {
177- params .put (NO_ESCAPE_PARAMS , DEFAULT_NO_ESCAPE_PARAMS );
178- }
179- params .put (LLM_RESPONSE_FILTER , "$.choices[0].message.content" );
180-
181- params
182- .put (
183- TOOL_TEMPLATE ,
184- "{\" type\" : \" function\" , \" function\" : { \" name\" : \" ${tool.name}\" , \" description\" : \" ${tool.description}\" , \" parameters\" : ${tool.attributes.input_schema}, \" strict\" : ${tool.attributes.strict:-false} } }"
185- );
186- params .put (TOOL_CALLS_PATH , "$.choices[0].message.tool_calls" );
187- params .put (TOOL_CALLS_TOOL_NAME , "function.name" );
188- params .put (TOOL_CALLS_TOOL_INPUT , "function.arguments" );
189- params .put (TOOL_CALL_ID_PATH , "id" );
190- params .put ("tool_configs" , ", \" tools\" : [${parameters._tools:-}], \" parallel_tool_calls\" : false" );
191-
192- params .put ("tool_choice" , "auto" );
193- params .put ("parallel_tool_calls" , "false" );
194-
195- params .put ("interaction_template.assistant_tool_calls_path" , "$.choices[0].message" );
196- params
197- .put (
198- "interaction_template.tool_response" ,
199- "{ \" role\" : \" tool\" , \" tool_call_id\" : \" ${_interactions.tool_call_id}\" , \" content\" : \" ${_interactions.tool_response}\" }"
200- );
201-
202- params .put ("chat_history_template.user_question" , "{\" role\" : \" user\" ,\" content\" : \" ${_chat_history.message.question}\" }" );
203- params .put ("chat_history_template.ai_response" , "{\" role\" : \" assistant\" ,\" content\" : \" ${_chat_history.message.response}\" }" );
204-
205- params .put (LLM_FINISH_REASON_PATH , "$.choices[0].finish_reason" );
206- params .put (LLM_FINISH_REASON_TOOL_USE , "tool_calls" );
207- } else if ("bedrock/converse/claude" .equalsIgnoreCase (llmInterface )) {
208- if (!params .containsKey (NO_ESCAPE_PARAMS )) {
209- params .put (NO_ESCAPE_PARAMS , DEFAULT_NO_ESCAPE_PARAMS );
210- }
211- params .put (LLM_RESPONSE_FILTER , "$.output.message.content[0].text" );
212-
213- params
214- .put (
215- TOOL_TEMPLATE ,
216- "{\" toolSpec\" :{\" name\" :\" ${tool.name}\" ,\" description\" :\" ${tool.description}\" ,\" inputSchema\" : {\" json\" : ${tool.attributes.input_schema} } }}"
217- );
218- params .put (TOOL_CALLS_PATH , "$.output.message.content[*].toolUse" );
219- params .put (TOOL_CALLS_TOOL_NAME , "name" );
220- params .put (TOOL_CALLS_TOOL_INPUT , "input" );
221- params .put (TOOL_CALL_ID_PATH , "toolUseId" );
222- params .put ("tool_configs" , ", \" toolConfig\" : {\" tools\" : [${parameters._tools:-}]}" );
223-
224- params .put ("interaction_template.assistant_tool_calls_path" , "$.output.message" );
225- params
226- .put (
227- "interaction_template.tool_response" ,
228- "{\" role\" :\" user\" ,\" content\" :[{\" toolResult\" :{\" toolUseId\" :\" ${_interactions.tool_call_id}\" ,\" content\" :[{\" text\" :\" ${_interactions.tool_response}\" }]}}]}"
229- );
230-
231- params
232- .put (
233- "chat_history_template.user_question" ,
234- "{\" role\" :\" user\" ,\" content\" :[{\" text\" :\" ${_chat_history.message.question}\" }]}"
235- );
236- params
237- .put (
238- "chat_history_template.ai_response" ,
239- "{\" role\" :\" assistant\" ,\" content\" :[{\" text\" :\" ${_chat_history.message.response}\" }]}"
240- );
241-
242- params .put (LLM_FINISH_REASON_PATH , "$.stopReason" );
243- params .put (LLM_FINISH_REASON_TOOL_USE , "tool_use" );
244- } else if ("bedrock/converse/deepseek_r1" .equalsIgnoreCase (llmInterface )) {
245- if (!params .containsKey (NO_ESCAPE_PARAMS )) {
246- params .put (NO_ESCAPE_PARAMS , "_chat_history,_interactions" );
247- }
248- params .put (LLM_RESPONSE_FILTER , "$.output.message.content[0].text" );
249- params .put ("llm_final_response_post_filter" , "$.message.content[0].text" );
250-
251- params
252- .put (
253- TOOL_TEMPLATE ,
254- "{\" toolSpec\" :{\" name\" :\" ${tool.name}\" ,\" description\" :\" ${tool.description}\" ,\" inputSchema\" : {\" json\" : ${tool.attributes.input_schema} } }}"
255- );
256- params .put (TOOL_CALLS_PATH , "_llm_response.tool_calls" );
257- params .put (TOOL_CALLS_TOOL_NAME , "tool_name" );
258- params .put (TOOL_CALLS_TOOL_INPUT , "input" );
259- params .put (TOOL_CALL_ID_PATH , "id" );
260-
261- params .put ("interaction_template.assistant_tool_calls_path" , "$.output.message" );
262- params .put ("interaction_template.assistant_tool_calls_exclude_path" , "[ \" $.output.message.content[?(@.reasoningContent)]\" ]" );
263- params
264- .put (
265- "interaction_template.tool_response" ,
266- "{\" role\" :\" user\" ,\" content\" :[ {\" text\" :\" {\\ \" tool_call_id\\ \" :\\ \" ${_interactions.tool_call_id}\\ \" ,\\ \" tool_result\\ \" : \\ \" ${_interactions.tool_response}\\ \" \" } ]}"
267- );
268-
269- params
270- .put (
271- "chat_history_template.user_question" ,
272- "{\" role\" :\" user\" ,\" content\" :[{\" text\" :\" ${_chat_history.message.question}\" }]}"
273- );
274- params
275- .put (
276- "chat_history_template.ai_response" ,
277- "{\" role\" :\" assistant\" ,\" content\" :[{\" text\" :\" ${_chat_history.message.response}\" }]}"
278- );
279-
280- params .put (LLM_FINISH_REASON_PATH , "_llm_response.stop_reason" );
281- params .put (LLM_FINISH_REASON_TOOL_USE , "tool_use" );
167+ FunctionCalling functionCalling = FunctionCallingFactory .create (llmInterface );
168+ if (functionCalling != null ) {
169+ functionCalling .configure (params );
282170 }
171+
283172 String memoryType = mlAgent .getMemory ().getType ();
284173 String memoryId = params .get (MLAgentExecutor .MEMORY_ID );
285174 String appType = mlAgent .getAppType ();
@@ -347,23 +236,30 @@ public void run(MLAgent mlAgent, Map<String, String> inputParams, ActionListener
347236 }
348237 }
349238
350- runAgent (mlAgent , params , listener , memory , memory .getConversationId ());
239+ runAgent (mlAgent , params , listener , memory , memory .getConversationId (), functionCalling );
351240 }, e -> {
352241 log .error ("Failed to get chat history" , e );
353242 listener .onFailure (e );
354243 }), messageHistoryLimit );
355244 }, listener ::onFailure ));
356245 }
357246
358- private void runAgent (MLAgent mlAgent , Map <String , String > params , ActionListener <Object > listener , Memory memory , String sessionId ) {
247+ private void runAgent (
248+ MLAgent mlAgent ,
249+ Map <String , String > params ,
250+ ActionListener <Object > listener ,
251+ Memory memory ,
252+ String sessionId ,
253+ FunctionCalling functionCalling
254+ ) {
359255 List <MLToolSpec > toolSpecs = getMlToolSpecs (mlAgent , params );
360256
361257 // Create a common method to handle both success and failure cases
362258 Consumer <List <MLToolSpec >> processTools = (allToolSpecs ) -> {
363259 Map <String , Tool > tools = new HashMap <>();
364260 Map <String , MLToolSpec > toolSpecMap = new HashMap <>();
365261 createTools (toolFactories , params , allToolSpecs , tools , toolSpecMap , mlAgent );
366- runReAct (mlAgent .getLlm (), tools , toolSpecMap , params , memory , sessionId , mlAgent .getTenantId (), listener );
262+ runReAct (mlAgent .getLlm (), tools , toolSpecMap , params , memory , sessionId , mlAgent .getTenantId (), listener , functionCalling );
367263 };
368264
369265 // Fetch MCP tools and handle both success and failure cases
@@ -384,7 +280,8 @@ private void runReAct(
384280 Memory memory ,
385281 String sessionId ,
386282 String tenantId ,
387- ActionListener <Object > listener
283+ ActionListener <Object > listener ,
284+ FunctionCalling functionCalling
388285 ) {
389286 Map <String , String > tmpParameters = constructLLMParams (llm , parameters );
390287 String prompt = constructLLMPrompt (tools , tmpParameters );
@@ -437,7 +334,8 @@ private void runReAct(
437334 tmpModelTensorOutput ,
438335 llmResponsePatterns ,
439336 tools .keySet (),
440- interactions
337+ interactions ,
338+ functionCalling
441339 );
442340
443341 String thought = String .valueOf (modelOutput .get (THOUGHT ));
@@ -510,7 +408,8 @@ private void runReAct(
510408 actionInput ,
511409 toolParams ,
512410 interactions ,
513- toolCallId
411+ toolCallId ,
412+ functionCalling
514413 );
515414 } else {
516415 String res = String .format (Locale .ROOT , "Failed to run the tool %s which is unsupported." , action );
@@ -675,20 +574,28 @@ private static void runTool(
675574 String actionInput ,
676575 Map <String , String > toolParams ,
677576 List <String > interactions ,
678- String toolCallId
577+ String toolCallId ,
578+ FunctionCalling functionCalling
679579 ) {
680580 if (tools .get (action ).validate (toolParams )) {
681581 try {
682582 String finalAction = action ;
683583 ActionListener <Object > toolListener = ActionListener .wrap (r -> {
684- interactions
685- .add (
686- substitute (
687- tmpParameters .get (INTERACTION_TEMPLATE_TOOL_RESPONSE ),
688- Map .of (TOOL_CALL_ID , toolCallId , "tool_response" , processTextDoc (StringUtils .toJson (r ))),
689- INTERACTIONS_PREFIX
690- )
691- );
584+ if (functionCalling != null ) {
585+ List <Map <String , Object >> toolResults = List .of (Map .of (TOOL_CALL_ID , toolCallId , TOOL_RESULT , Map .of ("text" , r )));
586+ List <LLMMessage > llmMessages = functionCalling .supply (toolResults );
587+ // TODO: support multiple tool calls at the same time so that multiple LLMMessages can be generated here
588+ interactions .add (llmMessages .getFirst ().getResponse ());
589+ } else {
590+ interactions
591+ .add (
592+ substitute (
593+ tmpParameters .get (INTERACTION_TEMPLATE_TOOL_RESPONSE ),
594+ Map .of (TOOL_CALL_ID , toolCallId , "tool_response" , processTextDoc (StringUtils .toJson (r ))),
595+ INTERACTIONS_PREFIX
596+ )
597+ );
598+ }
692599 nextStepListener .onResponse (r );
693600 }, e -> {
694601 interactions
0 commit comments