From af5ea63fcff650eb98c75c778782e8b501b02f9b Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Mon, 22 Sep 2025 14:51:51 +0800 Subject: [PATCH 01/13] init max step summary Signed-off-by: Jiaru Jiang --- build.gradle | 2 + .../algorithms/agent/MLChatAgentRunner.java | 149 +++++++++++++++++- .../algorithms/agent/PromptTemplate.java | 3 + .../agent/MLChatAgentRunnerTest.java | 89 +++++++++++ 4 files changed, 241 insertions(+), 2 deletions(-) diff --git a/build.gradle b/build.gradle index de007ef60e..7508ddd28b 100644 --- a/build.gradle +++ b/build.gradle @@ -45,6 +45,7 @@ buildscript { configurations.all { resolutionStrategy { force("org.eclipse.platform:org.eclipse.core.runtime:3.29.0") // for spotless transitive dependency CVE (for 3.26.100) + force("com.google.errorprone:error_prone_annotations:2.18.0") } } } @@ -96,6 +97,7 @@ subprojects { configurations.all { // Force spotless depending on newer version of guava due to CVE-2023-2976. Remove after spotless upgrades. resolutionStrategy.force "com.google.guava:guava:32.1.3-jre" + resolutionStrategy.force "com.google.errorprone:error_prone_annotations:2.18.0" resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0' } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index f785ffb3ba..b0781f50c9 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -34,6 +34,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseLLMOutput; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.substitute; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX; +import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.SUMMARY_PROMPT_TEMPLATE; import java.security.PrivilegedActionException; import java.util.ArrayList; @@ -125,9 +126,12 @@ public class MLChatAgentRunner implements MLAgentRunner { public static final String INJECT_DATETIME_FIELD = "inject_datetime"; public static final String DATETIME_FORMAT_FIELD = "datetime_format"; public static final String SYSTEM_PROMPT_FIELD = "system_prompt"; + public static final String SUMMARIZE_WHEN_MAX_ITERATION = "summarize_when_max_iteration"; private static final String DEFAULT_MAX_ITERATIONS = "10"; private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task"; + private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = + "Agent reached maximum iterations (%d) without completing the task. Here's a summary of the steps taken:\n\n%s"; private Client client; private Settings settings; @@ -322,6 +326,7 @@ private void runReAct( StringBuilder scratchpadBuilder = new StringBuilder(); List interactions = new CopyOnWriteArrayList<>(); + List executionSteps = new CopyOnWriteArrayList<>(); StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); @@ -380,6 +385,17 @@ private void runReAct( lastActionInput.set(actionInput); lastToolSelectionResponse.set(thoughtResponse); + // Record execution step for summary + if (thought != null && !"null".equals(thought) && !thought.trim().isEmpty()) { + executionSteps.add(String.format("Thought: %s", thought.trim())); + } + if (action != null && !"null".equals(action) && !action.trim().isEmpty()) { + String actionDesc = actionInput != null && !"null".equals(actionInput) + ? String.format("Action: %s(%s)", action.trim(), actionInput.trim()) + : String.format("Action: %s", action.trim()); + executionSteps.add(actionDesc); + } + traceTensors .add( ModelTensors @@ -414,7 +430,11 @@ private void runReAct( additionalInfo, lastThought, maxIterations, - tools + tools, + tmpParameters, + executionSteps, + llm, + tenantId ); return; } @@ -466,6 +486,10 @@ private void runReAct( ); scratchpadBuilder.append(toolResponse).append("\n\n"); + // Record tool result for summary + String outputSummary = outputToOutputString(filteredOutput); + executionSteps.add(String.format("Result: %s", outputSummary)); + saveTraceData( conversationIndexMemory, "ReAct", @@ -513,7 +537,11 @@ private void runReAct( additionalInfo, lastThought, maxIterations, - tools + tools, + tmpParameters, + executionSteps, + llm, + tenantId ); return; } @@ -873,6 +901,65 @@ private static void returnFinalResponse( } private void handleMaxIterationsReached( + String sessionId, + ActionListener listener, + String question, + String parentInteractionId, + boolean verbose, + boolean traceDisabled, + List traceTensors, + ConversationIndexMemory conversationIndexMemory, + AtomicInteger traceNumber, + Map additionalInfo, + AtomicReference lastThought, + int maxIterations, + Map tools, + Map parameters, + List executionSteps, + LLMSpec llmSpec, + String tenantId + ) { + boolean shouldSummarize = Boolean.parseBoolean(parameters.getOrDefault(SUMMARIZE_WHEN_MAX_ITERATION, "false")); + + if (shouldSummarize && !executionSteps.isEmpty()) { + generateLLMSummary(executionSteps, llmSpec, tenantId, ActionListener.wrap(summary -> { + String incompleteResponse = String.format(MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); + sendFinalAnswer( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + conversationIndexMemory, + traceNumber, + additionalInfo, + incompleteResponse + ); + cleanUpResource(tools); + }, e -> { log.warn("Failed to generate LLM summary", e); })); + } else { + // Use traditional approach + sendTraditionalMaxIterationsResponse( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + conversationIndexMemory, + traceNumber, + additionalInfo, + lastThought, + maxIterations, + tools + ); + } + } + + private void sendTraditionalMaxIterationsResponse( String sessionId, ActionListener listener, String question, @@ -906,6 +993,64 @@ private void handleMaxIterationsReached( cleanUpResource(tools); } + void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenantId, ActionListener listener) { + if (stepsSummary == null || stepsSummary.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Steps summary cannot be null or empty")); + return; + } + + try { + Map summaryParams = new HashMap<>(); + if (llmSpec.getParameters() != null) { + summaryParams.putAll(llmSpec.getParameters()); + } + String summaryPrompt = String.format(SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepsSummary)); + summaryParams.put("inputs", summaryPrompt); + summaryParams.put("prompt", summaryPrompt); + summaryParams.putIfAbsent("stop", gson.toJson(new String[] { "\n\n", "```" })); + + ActionRequest request = new MLPredictionTaskRequest( + llmSpec.getModelId(), + RemoteInferenceMLInput + .builder() + .algorithm(FunctionName.REMOTE) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(summaryParams).build()) + .build(), + null, + tenantId + ); + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> { + String summary = extractSummaryFromResponse(response); + if (summary != null) { + listener.onResponse(summary); + } else { + listener.onFailure(new RuntimeException("Empty or invalid LLM summary response")); + } + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private String extractSummaryFromResponse(MLTaskResponse response) { + try { + String outputString = outputToOutputString(response.getOutput()); + if (outputString != null && !outputString.trim().isEmpty()) { + Map dataMap = gson.fromJson(outputString, Map.class); + if (dataMap.containsKey("response")) { + String summary = String.valueOf(dataMap.get("response")); + if (summary != null && !summary.trim().isEmpty() && !"null".equals(summary)) { + return summary.trim(); + } + } + } + return null; + } catch (Exception e) { + log.warn("Failed to extract summary from response", e); + return null; + } + } + private void saveMessage( ConversationIndexMemory memory, String question, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java index 67b29c3557..f1b1147a7d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java @@ -140,4 +140,7 @@ public class PromptTemplate { - Avoid making assumptions and relying on implicit knowledge. - Your response must be self-contained and ready for the planner to use without modification. Never end with a question. - Break complex searches into simpler queries when appropriate."""; + + public static final String SUMMARY_PROMPT_TEMPLATE = + "Please provide a concise summary of the following agent execution steps. Focus on what the agent was trying to accomplish and what progress was made:\n\n%s\n\nPlease respond in the following JSON format:\n{\"response\": \"your summary here\"}"; } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index bae59994c5..f50c5e8efd 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1118,4 +1118,93 @@ public void testConstructLLMParams_DefaultValues() { Assert.assertTrue(result.containsKey(AgentUtils.RESPONSE_FORMAT_INSTRUCTION)); Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE)); } + + @Test + public void testMaxIterationsWithSummaryEnabled() { + // Create LLM spec with max_iteration = 1 to simplify test + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec)) + .build(); + + // Reset and setup fresh mocks + Mockito.reset(client); + // First call: LLM response without final_answer to trigger max iterations + // Second call: Summary LLM response + Mockito + .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to analyze the data", "action", FIRST_TOOL))) + .doAnswer(getLLMAnswer(ImmutableMap.of("response", "Summary: Analysis step was attempted"))) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + params.put(MLChatAgentRunner.SUMMARIZE_WHEN_MAX_ITERATION, "true"); + + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify response is captured + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object capturedResponse = objectCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); + assertEquals(1, agentOutput.size()); + + // Verify the response contains summary message + String response = (String) agentOutput.get(0).getDataAsMap().get("response"); + assertTrue( + response.startsWith("Agent reached maximum iterations (1) without completing the task. Here's a summary of the steps taken:") + ); + assertTrue(response.contains("Summary: Analysis step was attempted")); + } + + @Test + public void testMaxIterationsWithSummaryDisabled() { + // Create LLM spec with max_iteration = 1 and summary disabled + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec)) + .build(); + + // Reset client mock for this test + Mockito.reset(client); + // Mock LLM response that doesn't contain final_answer + Mockito + .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the tool", "action", FIRST_TOOL))) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + params.put(MLChatAgentRunner.SUMMARIZE_WHEN_MAX_ITERATION, "false"); + + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify response is captured + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object capturedResponse = objectCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); + assertEquals(1, agentOutput.size()); + + // Verify the response contains traditional max iterations message + String response = (String) agentOutput.get(0).getDataAsMap().get("response"); + assertEquals("Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the tool", response); + } } From 51dbde86f0f4f581edc42a1dc5188a36bdc89d5a Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Mon, 22 Sep 2025 15:01:30 +0800 Subject: [PATCH 02/13] fix:recover build.gradle Signed-off-by: Jiaru Jiang --- build.gradle | 2 -- 1 file changed, 2 deletions(-) diff --git a/build.gradle b/build.gradle index 7508ddd28b..de007ef60e 100644 --- a/build.gradle +++ b/build.gradle @@ -45,7 +45,6 @@ buildscript { configurations.all { resolutionStrategy { force("org.eclipse.platform:org.eclipse.core.runtime:3.29.0") // for spotless transitive dependency CVE (for 3.26.100) - force("com.google.errorprone:error_prone_annotations:2.18.0") } } } @@ -97,7 +96,6 @@ subprojects { configurations.all { // Force spotless depending on newer version of guava due to CVE-2023-2976. Remove after spotless upgrades. resolutionStrategy.force "com.google.guava:guava:32.1.3-jre" - resolutionStrategy.force "com.google.errorprone:error_prone_annotations:2.18.0" resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0' } From 86e65a10f07f4119c4af29c56c05ae0152a2fa6f Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 10:12:14 +0800 Subject: [PATCH 03/13] add:increase test coverage Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 2 +- .../agent/MLChatAgentRunnerTest.java | 28 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index b0781f50c9..b425088961 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -1032,7 +1032,7 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenan } } - private String extractSummaryFromResponse(MLTaskResponse response) { + public String extractSummaryFromResponse(MLTaskResponse response) { try { String outputString = outputToOutputString(response.getOutput()); if (outputString != null && !outputString.trim().isEmpty()) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index f50c5e8efd..05f153d66a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1207,4 +1207,32 @@ public void testMaxIterationsWithSummaryDisabled() { String response = (String) agentOutput.get(0).getDataAsMap().get("response"); assertEquals("Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the tool", response); } + + @Test + public void testExtractSummaryFromResponse() { + MLTaskResponse response = MLTaskResponse.builder() + .output(ModelTensorOutput.builder() + .mlModelOutputs(Arrays.asList( + ModelTensors.builder() + .mlModelTensors(Arrays.asList( + ModelTensor.builder() + .dataAsMap(ImmutableMap.of("response", "Valid summary text")) + .build())) + .build())) + .build()) + .build(); + + String result = mlChatAgentRunner.extractSummaryFromResponse(response); + assertEquals("Valid summary text", result); + } + + @Test + public void testGenerateLLMSummaryWithNullSteps() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + ActionListener listener = Mockito.mock(ActionListener.class); + + mlChatAgentRunner.generateLLMSummary(null, llmSpec, "tenant", listener); + + verify(listener).onFailure(any(IllegalArgumentException.class)); + } } From 24c90ec37f968155188d34d2f50ef6379fed0153 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 10:14:24 +0800 Subject: [PATCH 04/13] fix:spotlessApply Signed-off-by: Jiaru Jiang --- .../agent/MLChatAgentRunnerTest.java | 37 ++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 05f153d66a..275454f6fa 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1210,18 +1210,29 @@ public void testMaxIterationsWithSummaryDisabled() { @Test public void testExtractSummaryFromResponse() { - MLTaskResponse response = MLTaskResponse.builder() - .output(ModelTensorOutput.builder() - .mlModelOutputs(Arrays.asList( - ModelTensors.builder() - .mlModelTensors(Arrays.asList( - ModelTensor.builder() - .dataAsMap(ImmutableMap.of("response", "Valid summary text")) - .build())) - .build())) - .build()) + MLTaskResponse response = MLTaskResponse + .builder() + .output( + ModelTensorOutput + .builder() + .mlModelOutputs( + Arrays + .asList( + ModelTensors + .builder() + .mlModelTensors( + Arrays + .asList( + ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "Valid summary text")).build() + ) + ) + .build() + ) + ) + .build() + ) .build(); - + String result = mlChatAgentRunner.extractSummaryFromResponse(response); assertEquals("Valid summary text", result); } @@ -1230,9 +1241,9 @@ public void testExtractSummaryFromResponse() { public void testGenerateLLMSummaryWithNullSteps() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); ActionListener listener = Mockito.mock(ActionListener.class); - + mlChatAgentRunner.generateLLMSummary(null, llmSpec, "tenant", listener); - + verify(listener).onFailure(any(IllegalArgumentException.class)); } } From f20ae6672c5d282174918a9abea6942604bc9dab Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 15:01:35 +0800 Subject: [PATCH 05/13] fix:use traceTensor Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 37 ++++++------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index b425088961..ba6722c9a0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -130,8 +130,7 @@ public class MLChatAgentRunner implements MLAgentRunner { private static final String DEFAULT_MAX_ITERATIONS = "10"; private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task"; - private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = - "Agent reached maximum iterations (%d) without completing the task. Here's a summary of the steps taken:\n\n%s"; + private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = MAX_ITERATIONS_MESSAGE + ". Here's a summary of the steps taken:\n\n%s"; private Client client; private Settings settings; @@ -326,8 +325,6 @@ private void runReAct( StringBuilder scratchpadBuilder = new StringBuilder(); List interactions = new CopyOnWriteArrayList<>(); - List executionSteps = new CopyOnWriteArrayList<>(); - StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); tmpParameters.put(PROMPT, newPrompt.get()); @@ -385,17 +382,6 @@ private void runReAct( lastActionInput.set(actionInput); lastToolSelectionResponse.set(thoughtResponse); - // Record execution step for summary - if (thought != null && !"null".equals(thought) && !thought.trim().isEmpty()) { - executionSteps.add(String.format("Thought: %s", thought.trim())); - } - if (action != null && !"null".equals(action) && !action.trim().isEmpty()) { - String actionDesc = actionInput != null && !"null".equals(actionInput) - ? String.format("Action: %s(%s)", action.trim(), actionInput.trim()) - : String.format("Action: %s", action.trim()); - executionSteps.add(actionDesc); - } - traceTensors .add( ModelTensors @@ -432,7 +418,6 @@ private void runReAct( maxIterations, tools, tmpParameters, - executionSteps, llm, tenantId ); @@ -486,10 +471,6 @@ private void runReAct( ); scratchpadBuilder.append(toolResponse).append("\n\n"); - // Record tool result for summary - String outputSummary = outputToOutputString(filteredOutput); - executionSteps.add(String.format("Result: %s", outputSummary)); - saveTraceData( conversationIndexMemory, "ReAct", @@ -539,7 +520,6 @@ private void runReAct( maxIterations, tools, tmpParameters, - executionSteps, llm, tenantId ); @@ -915,14 +895,13 @@ private void handleMaxIterationsReached( int maxIterations, Map tools, Map parameters, - List executionSteps, LLMSpec llmSpec, String tenantId ) { boolean shouldSummarize = Boolean.parseBoolean(parameters.getOrDefault(SUMMARIZE_WHEN_MAX_ITERATION, "false")); - if (shouldSummarize && !executionSteps.isEmpty()) { - generateLLMSummary(executionSteps, llmSpec, tenantId, ActionListener.wrap(summary -> { + if (shouldSummarize && !traceTensors.isEmpty()) { + generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { String incompleteResponse = String.format(MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); sendFinalAnswer( sessionId, @@ -993,7 +972,7 @@ private void sendTraditionalMaxIterationsResponse( cleanUpResource(tools); } - void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenantId, ActionListener listener) { + void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenantId, ActionListener listener) { if (stepsSummary == null || stepsSummary.isEmpty()) { listener.onFailure(new IllegalArgumentException("Steps summary cannot be null or empty")); return; @@ -1004,7 +983,13 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenan if (llmSpec.getParameters() != null) { summaryParams.putAll(llmSpec.getParameters()); } - String summaryPrompt = String.format(SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepsSummary)); + + // Convert ModelTensors to strings before joining + List stepStrings = new ArrayList<>(); + for (ModelTensors tensor : stepsSummary) { + stepStrings.add(outputToOutputString(tensor)); + } + String summaryPrompt = String.format(SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put("inputs", summaryPrompt); summaryParams.put("prompt", summaryPrompt); summaryParams.putIfAbsent("stop", gson.toJson(new String[] { "\n\n", "```" })); From 6b57c8eb39574535c7d8e8a85b8692fce7dcb254 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 15:13:38 +0800 Subject: [PATCH 06/13] fix:String.format() Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index ba6722c9a0..bcde6f05d5 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -902,7 +902,7 @@ private void handleMaxIterationsReached( if (shouldSummarize && !traceTensors.isEmpty()) { generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { - String incompleteResponse = String.format(MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); + String incompleteResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); sendFinalAnswer( sessionId, listener, From e0aa50f9ff568c124da5436441e0e73341b600db Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 15:53:50 +0800 Subject: [PATCH 07/13] fix:String.format() Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index bcde6f05d5..f06d40ce7c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -989,7 +989,7 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String for (ModelTensors tensor : stepsSummary) { stepStrings.add(outputToOutputString(tensor)); } - String summaryPrompt = String.format(SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); + String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put("inputs", summaryPrompt); summaryParams.put("prompt", summaryPrompt); summaryParams.putIfAbsent("stop", gson.toJson(new String[] { "\n\n", "```" })); From eeeea3fe0ce069215a22f142029bd7988571ef89 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 16:17:45 +0800 Subject: [PATCH 08/13] fix:reuse sendTraditionalMaxIterationsResponse method Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index f06d40ce7c..55e542038e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -902,8 +902,9 @@ private void handleMaxIterationsReached( if (shouldSummarize && !traceTensors.isEmpty()) { generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { - String incompleteResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); - sendFinalAnswer( + String summaryResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); + AtomicReference summaryThought = new AtomicReference<>(summaryResponse); + sendTraditionalMaxIterationsResponse( sessionId, listener, question, @@ -914,12 +915,12 @@ private void handleMaxIterationsReached( conversationIndexMemory, traceNumber, additionalInfo, - incompleteResponse + summaryThought, + 0, // 不使用 maxIterations 格式化,直接使用 summaryResponse + tools ); - cleanUpResource(tools); }, e -> { log.warn("Failed to generate LLM summary", e); })); } else { - // Use traditional approach sendTraditionalMaxIterationsResponse( sessionId, listener, @@ -953,9 +954,16 @@ private void sendTraditionalMaxIterationsResponse( int maxIterations, Map tools ) { - String incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) - ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) - : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); + String incompleteResponse; + if (maxIterations == 0) { + // 直接使用 lastThought 中的完整消息(用于摘要情况) + incompleteResponse = lastThought.get(); + } else { + // 传统格式化(用于普通情况) + incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) + ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) + : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); + } sendFinalAnswer( sessionId, listener, @@ -991,7 +999,7 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String } String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put("inputs", summaryPrompt); - summaryParams.put("prompt", summaryPrompt); + summaryParams.put(PROMPT, summaryPrompt); summaryParams.putIfAbsent("stop", gson.toJson(new String[] { "\n\n", "```" })); ActionRequest request = new MLPredictionTaskRequest( From 865d8ae0df61be139620921c6b8aae82f3a8f4e2 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 16:22:39 +0800 Subject: [PATCH 09/13] fix:remove useless comment Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 55e542038e..333573bca7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -916,7 +916,7 @@ private void handleMaxIterationsReached( traceNumber, additionalInfo, summaryThought, - 0, // 不使用 maxIterations 格式化,直接使用 summaryResponse + 0, tools ); }, e -> { log.warn("Failed to generate LLM summary", e); })); @@ -956,10 +956,8 @@ private void sendTraditionalMaxIterationsResponse( ) { String incompleteResponse; if (maxIterations == 0) { - // 直接使用 lastThought 中的完整消息(用于摘要情况) incompleteResponse = lastThought.get(); } else { - // 传统格式化(用于普通情况) incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); From bff9bb5f1deef52e234c700932e4fc50a0e6fd30 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Fri, 10 Oct 2025 13:39:03 +0800 Subject: [PATCH 10/13] fix: delete stop Signed-off-by: Jiaru Jiang --- .../opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java | 1 - 1 file changed, 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 333573bca7..733d3771ca 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -998,7 +998,6 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put("inputs", summaryPrompt); summaryParams.put(PROMPT, summaryPrompt); - summaryParams.putIfAbsent("stop", gson.toJson(new String[] { "\n\n", "```" })); ActionRequest request = new MLPredictionTaskRequest( llmSpec.getModelId(), From 750904ec329539216d15fb9b3cbbfbb86e047680 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Fri, 10 Oct 2025 15:06:51 +0800 Subject: [PATCH 11/13] fix: refactor Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 33 ++++++++----------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 733d3771ca..71857bf8b7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -903,7 +903,6 @@ private void handleMaxIterationsReached( if (shouldSummarize && !traceTensors.isEmpty()) { generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { String summaryResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); - AtomicReference summaryThought = new AtomicReference<>(summaryResponse); sendTraditionalMaxIterationsResponse( sessionId, listener, @@ -915,12 +914,18 @@ private void handleMaxIterationsReached( conversationIndexMemory, traceNumber, additionalInfo, - summaryThought, - 0, + summaryResponse, tools ); - }, e -> { log.warn("Failed to generate LLM summary", e); })); + }, e -> { + log.error("Failed to generate LLM summary", e); + listener.onFailure(e); + cleanUpResource(tools); + })); } else { + String response = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) + ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) + : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); sendTraditionalMaxIterationsResponse( sessionId, listener, @@ -932,8 +937,7 @@ private void handleMaxIterationsReached( conversationIndexMemory, traceNumber, additionalInfo, - lastThought, - maxIterations, + response, tools ); } @@ -950,18 +954,9 @@ private void sendTraditionalMaxIterationsResponse( ConversationIndexMemory conversationIndexMemory, AtomicInteger traceNumber, Map additionalInfo, - AtomicReference lastThought, - int maxIterations, + String response, Map tools ) { - String incompleteResponse; - if (maxIterations == 0) { - incompleteResponse = lastThought.get(); - } else { - incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) - ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) - : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); - } sendFinalAnswer( sessionId, listener, @@ -973,7 +968,7 @@ private void sendTraditionalMaxIterationsResponse( conversationIndexMemory, traceNumber, additionalInfo, - incompleteResponse + response ); cleanUpResource(tools); } @@ -1036,8 +1031,8 @@ public String extractSummaryFromResponse(MLTaskResponse response) { } return null; } catch (Exception e) { - log.warn("Failed to extract summary from response", e); - return null; + log.error("Failed to extract summary from response", e); + throw new RuntimeException("Failed to extract summary from response", e); } } From 2ddad4c0dae8095fad07de30ccaaa76c83fe10a6 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Fri, 10 Oct 2025 15:55:20 +0800 Subject: [PATCH 12/13] fix: json serialization Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 24 ++++++++++++++----- .../agent/MLChatAgentRunnerTest.java | 17 +++++++++++-- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 71857bf8b7..e84ca65564 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -988,7 +988,15 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String // Convert ModelTensors to strings before joining List stepStrings = new ArrayList<>(); for (ModelTensors tensor : stepsSummary) { - stepStrings.add(outputToOutputString(tensor)); + if (tensor != null && tensor.getMlModelTensors() != null) { + for (ModelTensor modelTensor : tensor.getMlModelTensors()) { + if (modelTensor.getResult() != null) { + stepStrings.add(modelTensor.getResult()); + } else if (modelTensor.getDataAsMap() != null && modelTensor.getDataAsMap().containsKey("response")) { + stepStrings.add(String.valueOf(modelTensor.getDataAsMap().get("response"))); + } + } + } } String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put("inputs", summaryPrompt); @@ -1021,12 +1029,16 @@ public String extractSummaryFromResponse(MLTaskResponse response) { try { String outputString = outputToOutputString(response.getOutput()); if (outputString != null && !outputString.trim().isEmpty()) { - Map dataMap = gson.fromJson(outputString, Map.class); - if (dataMap.containsKey("response")) { - String summary = String.valueOf(dataMap.get("response")); - if (summary != null && !summary.trim().isEmpty() && !"null".equals(summary)) { - return summary.trim(); + try { + Map dataMap = gson.fromJson(outputString, Map.class); + if (dataMap.containsKey("response")) { + String summary = String.valueOf(dataMap.get("response")); + if (summary != null && !summary.trim().isEmpty() && !"null".equals(summary)) { + return summary.trim(); + } } + } catch (Exception jsonException) { + return outputString.trim(); } } return null; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 275454f6fa..c3c1063ab3 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1135,11 +1135,24 @@ public void testMaxIterationsWithSummaryEnabled() { // Reset and setup fresh mocks Mockito.reset(client); + Mockito.reset(firstTool); + when(firstTool.getName()).thenReturn(FIRST_TOOL); + when(firstTool.validate(Mockito.anyMap())).thenReturn(true); + Mockito.doAnswer(generateToolResponse("First tool response")).when(firstTool).run(Mockito.anyMap(), any()); + // First call: LLM response without final_answer to trigger max iterations - // Second call: Summary LLM response + // Second call: Summary LLM response with result field instead of dataAsMap Mockito .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to analyze the data", "action", FIRST_TOOL))) - .doAnswer(getLLMAnswer(ImmutableMap.of("response", "Summary: Analysis step was attempted"))) + .doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder().result("Summary: Analysis step was attempted").build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + MLTaskResponse mlTaskResponse = MLTaskResponse.builder().output(mlModelTensorOutput).build(); + listener.onResponse(mlTaskResponse); + return null; + }) .when(client) .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); From 8de0f73b27efe53ade6abee20b9bc119de376865 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Mon, 13 Oct 2025 14:00:22 +0800 Subject: [PATCH 13/13] fix: parameter Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index e84ca65564..526648940b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -999,8 +999,8 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String } } String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); - summaryParams.put("inputs", summaryPrompt); summaryParams.put(PROMPT, summaryPrompt); + summaryParams.putIfAbsent(SYSTEM_PROMPT_FIELD, ""); ActionRequest request = new MLPredictionTaskRequest( llmSpec.getModelId(),