Skip to content

Commit 66302c8

Browse files
pass all parameters including chat_history to run tools (#2714) (#2723)
Signed-off-by: Jing Zhang <[email protected]> (cherry picked from commit cf9ed90) Co-authored-by: Jing Zhang <[email protected]>
1 parent 831fc95 commit 66302c8

File tree

2 files changed

+59
-4
lines changed

2 files changed

+59
-4
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,10 @@ private static void runTool(
473473
llmToolTmpParameters.put(MLAgentExecutor.QUESTION, actionInput);
474474
tools.get(action).run(llmToolTmpParameters, toolListener); // run tool
475475
} else {
476-
tools.get(action).run(toolParams, toolListener); // run tool
476+
Map<String, String> parameters = new HashMap<>();
477+
parameters.putAll(tmpParameters);
478+
parameters.putAll(toolParams);
479+
tools.get(action).run(parameters, toolListener); // run tool
477480
}
478481
} catch (Exception e) {
479482
nextStepListener

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ public class MLChatAgentRunnerTest {
121121
@Captor
122122
private ArgumentCaptor<ActionListener<UpdateResponse>> mlMemoryManagerCapture;
123123
@Captor
124-
private ArgumentCaptor<Map<String, String>> ToolParamsCapture;
124+
private ArgumentCaptor<Map<String, String>> toolParamsCapture;
125125

126126
@Before
127127
@SuppressWarnings("unchecked")
@@ -706,7 +706,7 @@ public void testToolParameters() {
706706
// Verify the size of parameters passed in the tool run method.
707707
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
708708
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
709-
assertEquals(3, ((Map) argumentCaptor.getValue()).size());
709+
assertEquals(14, ((Map) argumentCaptor.getValue()).size());
710710

711711
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
712712
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
@@ -734,7 +734,7 @@ public void testToolUseOriginalInput() {
734734
// Verify the size of parameters passed in the tool run method.
735735
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
736736
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
737-
assertEquals(3, ((Map) argumentCaptor.getValue()).size());
737+
assertEquals(15, ((Map) argumentCaptor.getValue()).size());
738738
assertEquals("raw input", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));
739739

740740
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
@@ -767,6 +767,58 @@ public void testSaveLastTraceFailure() {
767767
Mockito.verify(agentActionListener).onFailure(any(IllegalArgumentException.class));
768768
}
769769

770+
@Test
771+
public void testToolExecutionWithChatHistoryParameter() {
772+
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build();
773+
MLToolSpec firstToolSpec = MLToolSpec
774+
.builder()
775+
.name(FIRST_TOOL)
776+
.parameters(Map.of("firsttoolspec", "firsttoolspec"))
777+
.description("first tool spec")
778+
.type(FIRST_TOOL)
779+
.includeOutputInAgentResponse(false)
780+
.build();
781+
MLToolSpec secondToolSpec = MLToolSpec
782+
.builder()
783+
.name(SECOND_TOOL)
784+
.parameters(Map.of("secondtoolspec", "secondtoolspec"))
785+
.description("second tool spec")
786+
.type(SECOND_TOOL)
787+
.includeOutputInAgentResponse(true)
788+
.build();
789+
final MLAgent mlAgent = MLAgent
790+
.builder()
791+
.name("TestAgent")
792+
.type(MLAgentType.CONVERSATIONAL.name())
793+
.memory(mlMemorySpec)
794+
.llm(llmSpec)
795+
.description("mlagent description")
796+
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
797+
.build();
798+
799+
doAnswer(invocation -> {
800+
ActionListener<List<Interaction>> listener = invocation.getArgument(0);
801+
List<Interaction> interactionList = generateInteractions(2);
802+
Interaction inProgressInteraction = Interaction.builder().id("interaction-99").input("input-99").response(null).build();
803+
interactionList.add(inProgressInteraction);
804+
listener.onResponse(interactionList);
805+
return null;
806+
}).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture());
807+
808+
doAnswer(generateToolResponse("First tool response"))
809+
.when(firstTool)
810+
.run(toolParamsCapture.capture(), toolListenerCaptor.capture());
811+
812+
HashMap<String, String> params = new HashMap<>();
813+
params.put(MESSAGE_HISTORY_LIMIT, "5");
814+
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
815+
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
816+
String chatHistory = params.get(MLChatAgentRunner.CHAT_HISTORY);
817+
Assert.assertFalse(chatHistory.contains("input-99"));
818+
Assert.assertEquals(5, messageHistoryLimitCapture.getValue().intValue());
819+
Assert.assertTrue(toolParamsCapture.getValue().containsKey(MLChatAgentRunner.CHAT_HISTORY));
820+
}
821+
770822
// Helper methods to create MLAgent and parameters
771823
private MLAgent createMLAgentWithTools() {
772824
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();

0 commit comments

Comments
 (0)