@@ -121,7 +121,7 @@ public class MLChatAgentRunnerTest {
121121 @ Captor
122122 private ArgumentCaptor <ActionListener <UpdateResponse >> mlMemoryManagerCapture ;
123123 @ Captor
124- private ArgumentCaptor <Map <String , String >> ToolParamsCapture ;
124+ private ArgumentCaptor <Map <String , String >> toolParamsCapture ;
125125
126126 @ Before
127127 @ SuppressWarnings ("unchecked" )
@@ -706,7 +706,7 @@ public void testToolParameters() {
706706 // Verify the size of parameters passed in the tool run method.
707707 ArgumentCaptor argumentCaptor = ArgumentCaptor .forClass (Map .class );
708708 verify (firstTool ).run ((Map <String , String >) argumentCaptor .capture (), any ());
709- assertEquals (3 , ((Map ) argumentCaptor .getValue ()).size ());
709+ assertEquals (14 , ((Map ) argumentCaptor .getValue ()).size ());
710710
711711 Mockito .verify (agentActionListener ).onResponse (objectCaptor .capture ());
712712 ModelTensorOutput modelTensorOutput = (ModelTensorOutput ) objectCaptor .getValue ();
@@ -734,7 +734,7 @@ public void testToolUseOriginalInput() {
734734 // Verify the size of parameters passed in the tool run method.
735735 ArgumentCaptor argumentCaptor = ArgumentCaptor .forClass (Map .class );
736736 verify (firstTool ).run ((Map <String , String >) argumentCaptor .capture (), any ());
737- assertEquals (3 , ((Map ) argumentCaptor .getValue ()).size ());
737+ assertEquals (15 , ((Map ) argumentCaptor .getValue ()).size ());
738738 assertEquals ("raw input" , ((Map <?, ?>) argumentCaptor .getValue ()).get ("input" ));
739739
740740 Mockito .verify (agentActionListener ).onResponse (objectCaptor .capture ());
@@ -767,6 +767,58 @@ public void testSaveLastTraceFailure() {
767767 Mockito .verify (agentActionListener ).onFailure (any (IllegalArgumentException .class ));
768768 }
769769
770+ @ Test
771+ public void testToolExecutionWithChatHistoryParameter () {
772+ LLMSpec llmSpec = LLMSpec .builder ().modelId ("MODEL_ID" ).parameters (Map .of ("max_iteration" , "1" )).build ();
773+ MLToolSpec firstToolSpec = MLToolSpec
774+ .builder ()
775+ .name (FIRST_TOOL )
776+ .parameters (Map .of ("firsttoolspec" , "firsttoolspec" ))
777+ .description ("first tool spec" )
778+ .type (FIRST_TOOL )
779+ .includeOutputInAgentResponse (false )
780+ .build ();
781+ MLToolSpec secondToolSpec = MLToolSpec
782+ .builder ()
783+ .name (SECOND_TOOL )
784+ .parameters (Map .of ("secondtoolspec" , "secondtoolspec" ))
785+ .description ("second tool spec" )
786+ .type (SECOND_TOOL )
787+ .includeOutputInAgentResponse (true )
788+ .build ();
789+ final MLAgent mlAgent = MLAgent
790+ .builder ()
791+ .name ("TestAgent" )
792+ .type (MLAgentType .CONVERSATIONAL .name ())
793+ .memory (mlMemorySpec )
794+ .llm (llmSpec )
795+ .description ("mlagent description" )
796+ .tools (Arrays .asList (firstToolSpec , secondToolSpec ))
797+ .build ();
798+
799+ doAnswer (invocation -> {
800+ ActionListener <List <Interaction >> listener = invocation .getArgument (0 );
801+ List <Interaction > interactionList = generateInteractions (2 );
802+ Interaction inProgressInteraction = Interaction .builder ().id ("interaction-99" ).input ("input-99" ).response (null ).build ();
803+ interactionList .add (inProgressInteraction );
804+ listener .onResponse (interactionList );
805+ return null ;
806+ }).when (conversationIndexMemory ).getMessages (memoryInteractionCapture .capture (), messageHistoryLimitCapture .capture ());
807+
808+ doAnswer (generateToolResponse ("First tool response" ))
809+ .when (firstTool )
810+ .run (toolParamsCapture .capture (), toolListenerCaptor .capture ());
811+
812+ HashMap <String , String > params = new HashMap <>();
813+ params .put (MESSAGE_HISTORY_LIMIT , "5" );
814+ mlChatAgentRunner .run (mlAgent , params , agentActionListener );
815+ Mockito .verify (agentActionListener ).onResponse (objectCaptor .capture ());
816+ String chatHistory = params .get (MLChatAgentRunner .CHAT_HISTORY );
817+ Assert .assertFalse (chatHistory .contains ("input-99" ));
818+ Assert .assertEquals (5 , messageHistoryLimitCapture .getValue ().intValue ());
819+ Assert .assertTrue (toolParamsCapture .getValue ().containsKey (MLChatAgentRunner .CHAT_HISTORY ));
820+ }
821+
770822 // Helper methods to create MLAgent and parameters
771823 private MLAgent createMLAgentWithTools () {
772824 LLMSpec llmSpec = LLMSpec .builder ().modelId ("MODEL_ID" ).build ();
0 commit comments