Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseLLMOutput;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.substitute;
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.SUMMARY_PROMPT_TEMPLATE;

import java.security.PrivilegedActionException;
import java.util.ArrayList;
Expand Down Expand Up @@ -125,9 +126,12 @@ public class MLChatAgentRunner implements MLAgentRunner {
public static final String INJECT_DATETIME_FIELD = "inject_datetime";
public static final String DATETIME_FORMAT_FIELD = "datetime_format";
public static final String SYSTEM_PROMPT_FIELD = "system_prompt";
public static final String SUMMARIZE_WHEN_MAX_ITERATION = "summarize_when_max_iteration";

private static final String DEFAULT_MAX_ITERATIONS = "10";
private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task";
private static final String MAX_ITERATIONS_SUMMARY_MESSAGE =
"Agent reached maximum iterations (%d) without completing the task. Here's a summary of the steps taken:\n\n%s";

private Client client;
private Settings settings;
Expand Down Expand Up @@ -322,6 +326,7 @@ private void runReAct(

StringBuilder scratchpadBuilder = new StringBuilder();
List<String> interactions = new CopyOnWriteArrayList<>();
List<String> executionSteps = new CopyOnWriteArrayList<>();

StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}");
AtomicReference<String> newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt));
Expand Down Expand Up @@ -380,6 +385,17 @@ private void runReAct(
lastActionInput.set(actionInput);
lastToolSelectionResponse.set(thoughtResponse);

// Record execution step for summary
if (thought != null && !"null".equals(thought) && !thought.trim().isEmpty()) {
executionSteps.add(String.format("Thought: %s", thought.trim()));
}
if (action != null && !"null".equals(action) && !action.trim().isEmpty()) {
String actionDesc = actionInput != null && !"null".equals(actionInput)
? String.format("Action: %s(%s)", action.trim(), actionInput.trim())
: String.format("Action: %s", action.trim());
executionSteps.add(actionDesc);
}

traceTensors
.add(
ModelTensors
Expand Down Expand Up @@ -414,7 +430,11 @@ private void runReAct(
additionalInfo,
lastThought,
maxIterations,
tools
tools,
tmpParameters,
executionSteps,
llm,
tenantId
);
return;
}
Expand Down Expand Up @@ -466,6 +486,10 @@ private void runReAct(
);
scratchpadBuilder.append(toolResponse).append("\n\n");

// Record tool result for summary
String outputSummary = outputToOutputString(filteredOutput);
executionSteps.add(String.format("Result: %s", outputSummary));

saveTraceData(
conversationIndexMemory,
"ReAct",
Expand Down Expand Up @@ -513,7 +537,11 @@ private void runReAct(
additionalInfo,
lastThought,
maxIterations,
tools
tools,
tmpParameters,
executionSteps,
llm,
tenantId
);
return;
}
Expand Down Expand Up @@ -873,6 +901,65 @@ private static void returnFinalResponse(
}

private void handleMaxIterationsReached(
String sessionId,
ActionListener<Object> listener,
String question,
String parentInteractionId,
boolean verbose,
boolean traceDisabled,
List<ModelTensors> traceTensors,
ConversationIndexMemory conversationIndexMemory,
AtomicInteger traceNumber,
Map<String, Object> additionalInfo,
AtomicReference<String> lastThought,
int maxIterations,
Map<String, Tool> tools,
Map<String, String> parameters,
List<String> executionSteps,
LLMSpec llmSpec,
String tenantId
) {
boolean shouldSummarize = Boolean.parseBoolean(parameters.getOrDefault(SUMMARIZE_WHEN_MAX_ITERATION, "false"));

if (shouldSummarize && !executionSteps.isEmpty()) {
generateLLMSummary(executionSteps, llmSpec, tenantId, ActionListener.wrap(summary -> {
String incompleteResponse = String.format(MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary);
sendFinalAnswer(
sessionId,
listener,
question,
parentInteractionId,
verbose,
traceDisabled,
traceTensors,
conversationIndexMemory,
traceNumber,
additionalInfo,
incompleteResponse
);
cleanUpResource(tools);
}, e -> { log.warn("Failed to generate LLM summary", e); }));
} else {
// Use traditional approach
sendTraditionalMaxIterationsResponse(
sessionId,
listener,
question,
parentInteractionId,
verbose,
traceDisabled,
traceTensors,
conversationIndexMemory,
traceNumber,
additionalInfo,
lastThought,
maxIterations,
tools
);
}
}

private void sendTraditionalMaxIterationsResponse(
String sessionId,
ActionListener<Object> listener,
String question,
Expand Down Expand Up @@ -906,6 +993,64 @@ private void handleMaxIterationsReached(
cleanUpResource(tools);
}

void generateLLMSummary(List<String> stepsSummary, LLMSpec llmSpec, String tenantId, ActionListener<String> listener) {
if (stepsSummary == null || stepsSummary.isEmpty()) {
listener.onFailure(new IllegalArgumentException("Steps summary cannot be null or empty"));
return;
}

try {
Map<String, String> summaryParams = new HashMap<>();
if (llmSpec.getParameters() != null) {
summaryParams.putAll(llmSpec.getParameters());
}
String summaryPrompt = String.format(SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepsSummary));
summaryParams.put("inputs", summaryPrompt);
summaryParams.put("prompt", summaryPrompt);
summaryParams.putIfAbsent("stop", gson.toJson(new String[] { "\n\n", "```" }));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need input here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse the static string public static final String PROMPT = "prompt"; we already have?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the inputs field, it’ll throw an error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the connector configuration? This might a little tricky, guess it may related to connector configuration.

And why we need stop here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the connector configuration, inputs is required. And you're right about stop: it's unnecessary, so I've removed it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameters may vary depending on the connector configuration. Can you add a new field system_prompt with empty value as well? The best way is to use the parameters used in previous LLM inference step and replace prompt with our summary prompt.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion. I've resolved it.


ActionRequest request = new MLPredictionTaskRequest(
llmSpec.getModelId(),
RemoteInferenceMLInput
.builder()
.algorithm(FunctionName.REMOTE)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(summaryParams).build())
.build(),
null,
tenantId
);
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> {
String summary = extractSummaryFromResponse(response);
if (summary != null) {
listener.onResponse(summary);
} else {
listener.onFailure(new RuntimeException("Empty or invalid LLM summary response"));
}
}, listener::onFailure));
} catch (Exception e) {
listener.onFailure(e);
}
}

public String extractSummaryFromResponse(MLTaskResponse response) {
try {
String outputString = outputToOutputString(response.getOutput());
if (outputString != null && !outputString.trim().isEmpty()) {
Map<String, Object> dataMap = gson.fromJson(outputString, Map.class);
if (dataMap.containsKey("response")) {
String summary = String.valueOf(dataMap.get("response"));
if (summary != null && !summary.trim().isEmpty() && !"null".equals(summary)) {
return summary.trim();
}
}
}
return null;
} catch (Exception e) {
log.warn("Failed to extract summary from response", e);
return null;
}
}

private void saveMessage(
ConversationIndexMemory memory,
String question,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,7 @@ public class PromptTemplate {
- Avoid making assumptions and relying on implicit knowledge.
- Your response must be self-contained and ready for the planner to use without modification. Never end with a question.
- Break complex searches into simpler queries when appropriate.""";

public static final String SUMMARY_PROMPT_TEMPLATE =
"Please provide a concise summary of the following agent execution steps. Focus on what the agent was trying to accomplish and what progress was made:\n\n%s\n\nPlease respond in the following JSON format:\n{\"response\": \"your summary here\"}";
}
Original file line number Diff line number Diff line change
Expand Up @@ -1118,4 +1118,132 @@ public void testConstructLLMParams_DefaultValues() {
Assert.assertTrue(result.containsKey(AgentUtils.RESPONSE_FORMAT_INSTRUCTION));
Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE));
}

@Test
public void testMaxIterationsWithSummaryEnabled() {
// Create LLM spec with max_iteration = 1 to simplify test
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build();
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.type(MLAgentType.CONVERSATIONAL.name())
.llm(llmSpec)
.memory(mlMemorySpec)
.tools(Arrays.asList(firstToolSpec))
.build();

// Reset and setup fresh mocks
Mockito.reset(client);
// First call: LLM response without final_answer to trigger max iterations
// Second call: Summary LLM response
Mockito
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to analyze the data", "action", FIRST_TOOL)))
.doAnswer(getLLMAnswer(ImmutableMap.of("response", "Summary: Analysis step was attempted")))
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
params.put(MLChatAgentRunner.SUMMARIZE_WHEN_MAX_ITERATION, "true");

mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Verify response is captured
verify(agentActionListener).onResponse(objectCaptor.capture());
Object capturedResponse = objectCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);

ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
assertEquals(1, agentOutput.size());

// Verify the response contains summary message
String response = (String) agentOutput.get(0).getDataAsMap().get("response");
assertTrue(
response.startsWith("Agent reached maximum iterations (1) without completing the task. Here's a summary of the steps taken:")
);
assertTrue(response.contains("Summary: Analysis step was attempted"));
}

@Test
public void testMaxIterationsWithSummaryDisabled() {
// Create LLM spec with max_iteration = 1 and summary disabled
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build();
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.type(MLAgentType.CONVERSATIONAL.name())
.llm(llmSpec)
.memory(mlMemorySpec)
.tools(Arrays.asList(firstToolSpec))
.build();

// Reset client mock for this test
Mockito.reset(client);
// Mock LLM response that doesn't contain final_answer
Mockito
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the tool", "action", FIRST_TOOL)))
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
params.put(MLChatAgentRunner.SUMMARIZE_WHEN_MAX_ITERATION, "false");

mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Verify response is captured
verify(agentActionListener).onResponse(objectCaptor.capture());
Object capturedResponse = objectCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);

ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
assertEquals(1, agentOutput.size());

// Verify the response contains traditional max iterations message
String response = (String) agentOutput.get(0).getDataAsMap().get("response");
assertEquals("Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the tool", response);
}

@Test
public void testExtractSummaryFromResponse() {
MLTaskResponse response = MLTaskResponse
.builder()
.output(
ModelTensorOutput
.builder()
.mlModelOutputs(
Arrays
.asList(
ModelTensors
.builder()
.mlModelTensors(
Arrays
.asList(
ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "Valid summary text")).build()
)
)
.build()
)
)
.build()
)
.build();

String result = mlChatAgentRunner.extractSummaryFromResponse(response);
assertEquals("Valid summary text", result);
}

@Test
public void testGenerateLLMSummaryWithNullSteps() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
ActionListener<String> listener = Mockito.mock(ActionListener.class);

mlChatAgentRunner.generateLLMSummary(null, llmSpec, "tenant", listener);

verify(listener).onFailure(any(IllegalArgumentException.class));
}
}
Loading