diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index f785ffb3ba..526648940b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -34,6 +34,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseLLMOutput; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.substitute; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX; +import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.SUMMARY_PROMPT_TEMPLATE; import java.security.PrivilegedActionException; import java.util.ArrayList; @@ -125,9 +126,11 @@ public class MLChatAgentRunner implements MLAgentRunner { public static final String INJECT_DATETIME_FIELD = "inject_datetime"; public static final String DATETIME_FORMAT_FIELD = "datetime_format"; public static final String SYSTEM_PROMPT_FIELD = "system_prompt"; + public static final String SUMMARIZE_WHEN_MAX_ITERATION = "summarize_when_max_iteration"; private static final String DEFAULT_MAX_ITERATIONS = "10"; private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task"; + private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = MAX_ITERATIONS_MESSAGE + ". Here's a summary of the steps taken:\n\n%s"; private Client client; private Settings settings; @@ -322,7 +325,6 @@ private void runReAct( StringBuilder scratchpadBuilder = new StringBuilder(); List interactions = new CopyOnWriteArrayList<>(); - StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); tmpParameters.put(PROMPT, newPrompt.get()); @@ -414,7 +416,10 @@ private void runReAct( additionalInfo, lastThought, maxIterations, - tools + tools, + tmpParameters, + llm, + tenantId ); return; } @@ -513,7 +518,10 @@ private void runReAct( additionalInfo, lastThought, maxIterations, - tools + tools, + tmpParameters, + llm, + tenantId ); return; } @@ -885,11 +893,70 @@ private void handleMaxIterationsReached( Map additionalInfo, AtomicReference lastThought, int maxIterations, + Map tools, + Map parameters, + LLMSpec llmSpec, + String tenantId + ) { + boolean shouldSummarize = Boolean.parseBoolean(parameters.getOrDefault(SUMMARIZE_WHEN_MAX_ITERATION, "false")); + + if (shouldSummarize && !traceTensors.isEmpty()) { + generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { + String summaryResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); + sendTraditionalMaxIterationsResponse( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + conversationIndexMemory, + traceNumber, + additionalInfo, + summaryResponse, + tools + ); + }, e -> { + log.error("Failed to generate LLM summary", e); + listener.onFailure(e); + cleanUpResource(tools); + })); + } else { + String response = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) + ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) + : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); + sendTraditionalMaxIterationsResponse( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + conversationIndexMemory, + traceNumber, + additionalInfo, + response, + tools + ); + } + } + + private void sendTraditionalMaxIterationsResponse( + String sessionId, + ActionListener listener, + String question, + String parentInteractionId, + boolean verbose, + boolean traceDisabled, + List traceTensors, + ConversationIndexMemory conversationIndexMemory, + AtomicInteger traceNumber, + Map additionalInfo, + String response, Map tools ) { - String incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) - ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) - : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); sendFinalAnswer( sessionId, listener, @@ -901,11 +968,86 @@ private void handleMaxIterationsReached( conversationIndexMemory, traceNumber, additionalInfo, - incompleteResponse + response ); cleanUpResource(tools); } + void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenantId, ActionListener listener) { + if (stepsSummary == null || stepsSummary.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Steps summary cannot be null or empty")); + return; + } + + try { + Map summaryParams = new HashMap<>(); + if (llmSpec.getParameters() != null) { + summaryParams.putAll(llmSpec.getParameters()); + } + + // Convert ModelTensors to strings before joining + List stepStrings = new ArrayList<>(); + for (ModelTensors tensor : stepsSummary) { + if (tensor != null && tensor.getMlModelTensors() != null) { + for (ModelTensor modelTensor : tensor.getMlModelTensors()) { + if (modelTensor.getResult() != null) { + stepStrings.add(modelTensor.getResult()); + } else if (modelTensor.getDataAsMap() != null && modelTensor.getDataAsMap().containsKey("response")) { + stepStrings.add(String.valueOf(modelTensor.getDataAsMap().get("response"))); + } + } + } + } + String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); + summaryParams.put(PROMPT, summaryPrompt); + summaryParams.putIfAbsent(SYSTEM_PROMPT_FIELD, ""); + + ActionRequest request = new MLPredictionTaskRequest( + llmSpec.getModelId(), + RemoteInferenceMLInput + .builder() + .algorithm(FunctionName.REMOTE) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(summaryParams).build()) + .build(), + null, + tenantId + ); + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> { + String summary = extractSummaryFromResponse(response); + if (summary != null) { + listener.onResponse(summary); + } else { + listener.onFailure(new RuntimeException("Empty or invalid LLM summary response")); + } + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + } + + public String extractSummaryFromResponse(MLTaskResponse response) { + try { + String outputString = outputToOutputString(response.getOutput()); + if (outputString != null && !outputString.trim().isEmpty()) { + try { + Map dataMap = gson.fromJson(outputString, Map.class); + if (dataMap.containsKey("response")) { + String summary = String.valueOf(dataMap.get("response")); + if (summary != null && !summary.trim().isEmpty() && !"null".equals(summary)) { + return summary.trim(); + } + } + } catch (Exception jsonException) { + return outputString.trim(); + } + } + return null; + } catch (Exception e) { + log.error("Failed to extract summary from response", e); + throw new RuntimeException("Failed to extract summary from response", e); + } + } + private void saveMessage( ConversationIndexMemory memory, String question, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java index 67b29c3557..f1b1147a7d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java @@ -140,4 +140,7 @@ public class PromptTemplate { - Avoid making assumptions and relying on implicit knowledge. - Your response must be self-contained and ready for the planner to use without modification. Never end with a question. - Break complex searches into simpler queries when appropriate."""; + + public static final String SUMMARY_PROMPT_TEMPLATE = + "Please provide a concise summary of the following agent execution steps. Focus on what the agent was trying to accomplish and what progress was made:\n\n%s\n\nPlease respond in the following JSON format:\n{\"response\": \"your summary here\"}"; } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index bae59994c5..c3c1063ab3 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1118,4 +1118,145 @@ public void testConstructLLMParams_DefaultValues() { Assert.assertTrue(result.containsKey(AgentUtils.RESPONSE_FORMAT_INSTRUCTION)); Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE)); } + + @Test + public void testMaxIterationsWithSummaryEnabled() { + // Create LLM spec with max_iteration = 1 to simplify test + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec)) + .build(); + + // Reset and setup fresh mocks + Mockito.reset(client); + Mockito.reset(firstTool); + when(firstTool.getName()).thenReturn(FIRST_TOOL); + when(firstTool.validate(Mockito.anyMap())).thenReturn(true); + Mockito.doAnswer(generateToolResponse("First tool response")).when(firstTool).run(Mockito.anyMap(), any()); + + // First call: LLM response without final_answer to trigger max iterations + // Second call: Summary LLM response with result field instead of dataAsMap + Mockito + .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to analyze the data", "action", FIRST_TOOL))) + .doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder().result("Summary: Analysis step was attempted").build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + MLTaskResponse mlTaskResponse = MLTaskResponse.builder().output(mlModelTensorOutput).build(); + listener.onResponse(mlTaskResponse); + return null; + }) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + params.put(MLChatAgentRunner.SUMMARIZE_WHEN_MAX_ITERATION, "true"); + + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify response is captured + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object capturedResponse = objectCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); + assertEquals(1, agentOutput.size()); + + // Verify the response contains summary message + String response = (String) agentOutput.get(0).getDataAsMap().get("response"); + assertTrue( + response.startsWith("Agent reached maximum iterations (1) without completing the task. Here's a summary of the steps taken:") + ); + assertTrue(response.contains("Summary: Analysis step was attempted")); + } + + @Test + public void testMaxIterationsWithSummaryDisabled() { + // Create LLM spec with max_iteration = 1 and summary disabled + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec)) + .build(); + + // Reset client mock for this test + Mockito.reset(client); + // Mock LLM response that doesn't contain final_answer + Mockito + .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the tool", "action", FIRST_TOOL))) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + params.put(MLChatAgentRunner.SUMMARIZE_WHEN_MAX_ITERATION, "false"); + + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify response is captured + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object capturedResponse = objectCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); + assertEquals(1, agentOutput.size()); + + // Verify the response contains traditional max iterations message + String response = (String) agentOutput.get(0).getDataAsMap().get("response"); + assertEquals("Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the tool", response); + } + + @Test + public void testExtractSummaryFromResponse() { + MLTaskResponse response = MLTaskResponse + .builder() + .output( + ModelTensorOutput + .builder() + .mlModelOutputs( + Arrays + .asList( + ModelTensors + .builder() + .mlModelTensors( + Arrays + .asList( + ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "Valid summary text")).build() + ) + ) + .build() + ) + ) + .build() + ) + .build(); + + String result = mlChatAgentRunner.extractSummaryFromResponse(response); + assertEquals("Valid summary text", result); + } + + @Test + public void testGenerateLLMSummaryWithNullSteps() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + ActionListener listener = Mockito.mock(ActionListener.class); + + mlChatAgentRunner.generateLLMSummary(null, llmSpec, "tenant", listener); + + verify(listener).onFailure(any(IllegalArgumentException.class)); + } }