diff --git a/common/src/main/java/org/opensearch/ml/common/MLAgentType.java b/common/src/main/java/org/opensearch/ml/common/MLAgentType.java index 2dd2614634..04a4b72014 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLAgentType.java +++ b/common/src/main/java/org/opensearch/ml/common/MLAgentType.java @@ -20,7 +20,7 @@ public static MLAgentType from(String value) { try { return MLAgentType.valueOf(value.toUpperCase(Locale.ROOT)); } catch (Exception e) { - throw new IllegalArgumentException("Wrong Agent type"); + throw new IllegalArgumentException(value + " is not a valid Agent Type"); } } } diff --git a/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java b/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java new file mode 100644 index 0000000000..31939ce1ca --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common; + +import java.util.Locale; + +public enum MLMemoryType { + CONVERSATION_INDEX, + AGENTIC_MEMORY; + + public static MLMemoryType from(String value) { + if (value != null) { + try { + return MLMemoryType.valueOf(value.toUpperCase(Locale.ROOT)); + } catch (Exception e) { + throw new IllegalArgumentException("Wrong Memory type"); + } + } + return null; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index b66a23f11e..ec73d73856 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -15,7 +15,6 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -113,7 +112,7 @@ private void validate() { String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH) ); } - validateMLAgentType(type); + MLAgentType.from(type); if (type.equalsIgnoreCase(MLAgentType.CONVERSATIONAL.toString()) && llm == null) { throw new IllegalArgumentException("We need model information for the conversational agent type"); } @@ -130,19 +129,6 @@ private void validate() { } } - private void validateMLAgentType(String agentType) { - if (type == null) { - throw new IllegalArgumentException("Agent type can't be null"); - } else { - try { - MLAgentType.valueOf(agentType.toUpperCase(Locale.ROOT)); // Use toUpperCase() to allow case-insensitive matching - } catch (IllegalArgumentException e) { - // The typeStr does not match any MLAgentType, so throw a new exception with a clearer message. - throw new IllegalArgumentException(agentType + " is not a valid Agent Type"); - } - } - } - public MLAgent(StreamInput input) throws IOException { Version streamInputVersion = input.getVersion(); name = input.readString(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java index 9a0d6002fd..e85b3f4bdc 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java @@ -26,6 +26,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; @@ -383,9 +384,7 @@ private void validate() { String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH) ); } - if (memoryType != null && !memoryType.equals("conversation_index")) { - throw new IllegalArgumentException(String.format("Invalid memory type: %s", memoryType)); - } + MLMemoryType.from(memoryType); if (tools != null) { Set toolNames = new HashSet<>(); for (MLToolSpec toolSpec : tools) { diff --git a/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java b/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java index ee15ca95fd..05f37c4992 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java @@ -44,14 +44,14 @@ public void testFromWithMixedCase() { public void testFromWithInvalidType() { // This should throw an IllegalArgumentException exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Wrong Agent type"); + exceptionRule.expectMessage(" is not a valid Agent Type"); MLAgentType.from("INVALID_TYPE"); } @Test public void testFromWithEmptyString() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Wrong Agent type"); + exceptionRule.expectMessage(" is not a valid Agent Type"); // This should also throw an IllegalArgumentException MLAgentType.from(""); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java index 72eb035279..084f95d137 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java @@ -94,7 +94,7 @@ public void testValidationWithInvalidMemoryType() { IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> { MLAgentUpdateInput.builder().agentId("test-agent-id").name("test-agent").memoryType("invalid_type").build(); }); - assertEquals("Invalid memory type: invalid_type", e.getMessage()); + assertEquals("Wrong Memory type", e.getMessage()); } @Test diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapter.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapter.java new file mode 100644 index 0000000000..6bf685bd7f --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapter.java @@ -0,0 +1,775 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.memorycontainer.MemoryType; +import org.opensearch.ml.common.memorycontainer.PayloadType; +import org.opensearch.ml.common.transport.memorycontainer.MLMemoryContainerGetAction; +import org.opensearch.ml.common.transport.memorycontainer.MLMemoryContainerGetRequest; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesAction; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesInput; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesRequest; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesAction; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesInput; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesRequest; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryAction; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryInput; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryRequest; +import org.opensearch.ml.common.transport.memorycontainer.memory.MessageInput; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +/** + * Adapter for Agentic Memory system to work with MLChatAgentRunner. + * + *

This adapter provides a bridge between the ML Chat Agent system and the Agentic Memory + * infrastructure, enabling intelligent conversation management and context retention.

+ * + *

Memory Types Handled:

+ * + * + *

Key Features:

+ * + * + *

Usage Example:

+ *
{@code
+ * AgenticMemoryAdapter adapter = new AgenticMemoryAdapter(
+ *     client, "memory-container-id", "session-123", "user-456"
+ * );
+ * 
+ * // Retrieve conversation messages
+ * adapter.getMessages(ActionListener.wrap(
+ *     messages -> processMessages(messages),
+ *     error -> handleError(error)
+ * ));
+ * 
+ * // Save trace data
+ * adapter.saveTraceData("search_tool", "query", "results", 
+ *     "parent-id", 1, "search", listener);
+ * }
+ * + * @see ChatMemoryAdapter + * @see MLChatAgentRunner + */ +@Log4j2 +public class AgenticMemoryAdapter implements ChatMemoryAdapter { + private final Client client; + private final String memoryContainerId; + private final String sessionId; + private final String ownerId; + + /** + * Creates a new AgenticMemoryAdapter instance. + * + * @param client OpenSearch client for executing memory operations + * @param memoryContainerId Unique identifier for the memory container + * @param sessionId Session identifier for conversation context + * @param ownerId Owner/user identifier for access control + * @throws IllegalArgumentException if any required parameter is null + */ + public AgenticMemoryAdapter(Client client, String memoryContainerId, String sessionId, String ownerId) { + if (client == null) { + throw new IllegalArgumentException("Client cannot be null"); + } + if (memoryContainerId == null || memoryContainerId.trim().isEmpty()) { + throw new IllegalArgumentException("Memory container ID cannot be null or empty"); + } + if (sessionId == null || sessionId.trim().isEmpty()) { + throw new IllegalArgumentException("Session ID cannot be null or empty"); + } + if (ownerId == null || ownerId.trim().isEmpty()) { + throw new IllegalArgumentException("Owner ID cannot be null or empty"); + } + + this.client = client; + this.memoryContainerId = memoryContainerId; + this.sessionId = sessionId; + this.ownerId = ownerId; + } + + @Override + public void getMessages(ActionListener> listener) { + // Query both WORKING memory (recent conversations) and LONG_TERM memory + // (extracted facts) + // This provides both conversation history and learned context + + List allChatMessages = new ArrayList<>(); + AtomicInteger pendingRequests = new AtomicInteger(2); + + // 1. Get recent conversation history from WORKING memory + SearchSourceBuilder workingSearchBuilder = new SearchSourceBuilder() + .query( + QueryBuilders + .boolQuery() + .must(QueryBuilders.termQuery("namespace.session_id", sessionId)) + .must(QueryBuilders.termQuery("namespace.user_id", ownerId)) + ) + .sort("created_time", SortOrder.DESC) + .size(50); // Limit recent conversation history + + MLSearchMemoriesRequest workingRequest = MLSearchMemoriesRequest + .builder() + .mlSearchMemoriesInput( + MLSearchMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .memoryType(MemoryType.WORKING) + .searchSourceBuilder(workingSearchBuilder) + .build() + ) + .build(); + + client.execute(MLSearchMemoriesAction.INSTANCE, workingRequest, ActionListener.wrap(workingResponse -> { + synchronized (allChatMessages) { + allChatMessages.addAll(parseAgenticMemoryResponse(workingResponse)); + if (pendingRequests.decrementAndGet() == 0) { + // Sort all chat messages by timestamp and return + allChatMessages.sort((a, b) -> a.getTimestamp().compareTo(b.getTimestamp())); + listener.onResponse(allChatMessages); + } + } + }, listener::onFailure)); + + // 2. Get relevant context from LONG_TERM memory (extracted facts) + SearchSourceBuilder longTermSearchBuilder = new SearchSourceBuilder() + .query( + QueryBuilders + .boolQuery() + .must(QueryBuilders.termQuery("namespace.session_id", sessionId)) + .must(QueryBuilders.termQuery("namespace.user_id", ownerId)) + .should(QueryBuilders.termQuery("strategy_type", "SUMMARY")) + .should(QueryBuilders.termQuery("strategy_type", "SEMANTIC")) + ) + .sort("created_time", SortOrder.DESC) + .size(10); // Limit context facts + + MLSearchMemoriesRequest longTermRequest = MLSearchMemoriesRequest + .builder() + .mlSearchMemoriesInput( + MLSearchMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .memoryType(MemoryType.LONG_TERM) + .searchSourceBuilder(longTermSearchBuilder) + .build() + ) + .build(); + + client.execute(MLSearchMemoriesAction.INSTANCE, longTermRequest, ActionListener.wrap(longTermResponse -> { + synchronized (allChatMessages) { + allChatMessages.addAll(parseAgenticMemoryResponse(longTermResponse)); + if (pendingRequests.decrementAndGet() == 0) { + // Sort all chat messages by timestamp and return + allChatMessages.sort((a, b) -> a.getTimestamp().compareTo(b.getTimestamp())); + listener.onResponse(allChatMessages); + } + } + }, e -> { + // If long-term memory fails, still return working memory results + log.warn("Failed to retrieve long-term memory, continuing with working memory only", e); + synchronized (allChatMessages) { + if (pendingRequests.decrementAndGet() == 0) { + listener.onResponse(allChatMessages); + } + } + })); + } + + @Override + public String getConversationId() { + return sessionId; + } + + @Override + public String getMemoryContainerId() { + return memoryContainerId; + } + + @Override + public void saveInteraction( + String question, + String assistantResponse, + String parentId, + Integer traceNum, + String action, + ActionListener listener + ) { + if (listener == null) { + throw new IllegalArgumentException("ActionListener cannot be null"); + } + final String finalQuestion = question != null ? question : ""; + final String finalAssistantResponse = assistantResponse != null ? assistantResponse : ""; + + log + .info( + "AgenticMemoryAdapter.saveInteraction: Called with parentId: {}, action: {}, hasResponse: {}", + parentId, + action, + !finalAssistantResponse.isEmpty() + ); + + // If parentId is provided and we have a response, update the existing + // interaction + if (parentId != null && !finalAssistantResponse.isEmpty()) { + log.info("AgenticMemoryAdapter.saveInteraction: Updating existing interaction {} with final response", parentId); + + // Update the existing interaction with the complete conversation + Map updateFields = new HashMap<>(); + updateFields.put("response", finalAssistantResponse); + updateFields.put("input", finalQuestion); + + updateInteraction(parentId, updateFields, ActionListener.wrap(res -> { + log.info("AgenticMemoryAdapter.saveInteraction: Successfully updated interaction {}", parentId); + listener.onResponse(parentId); // Return the same interaction ID + }, ex -> { + log + .error( + "AgenticMemoryAdapter.saveInteraction: Failed to update interaction {}, falling back to create new", + parentId, + ex + ); + // Fallback to creating new interaction if update fails + createNewInteraction(finalQuestion, finalAssistantResponse, parentId, traceNum, action, listener); + })); + } else { + // Create new interaction (root interaction or when no parentId) + log.info("AgenticMemoryAdapter.saveInteraction: Creating new interaction - parentId: {}, action: {}", parentId, action); + createNewInteraction(finalQuestion, finalAssistantResponse, parentId, traceNum, action, listener); + } + } + + private void createNewInteraction( + String question, + String assistantResponse, + String parentId, + Integer traceNum, + String action, + ActionListener listener + ) { + List messages = Arrays + .asList( + MessageInput.builder().role("user").content(createTextContent(question)).build(), + MessageInput.builder().role("assistant").content(createTextContent(assistantResponse)).build() + ); + + // Create namespace map with proper String types + Map namespaceMap = new java.util.HashMap<>(); + namespaceMap.put("session_id", sessionId); + namespaceMap.put("user_id", ownerId); + + Map metadataMap = new java.util.HashMap<>(); + if (traceNum != null) { + metadataMap.put("trace_num", traceNum.toString()); + } + if (action != null) { + metadataMap.put("action", action); + } + if (parentId != null) { + metadataMap.put("parent_id", parentId); + } + + // Check if memory container has LLM ID configured to determine infer value + hasLlmIdConfigured(ActionListener.wrap(hasLlmId -> { + MLAddMemoriesInput input = MLAddMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .messages(messages) + .namespace(namespaceMap) + .metadata(metadataMap) + .ownerId(ownerId) + .infer(hasLlmId) // Use dynamic infer based on LLM ID presence + .build(); + + MLAddMemoriesRequest request = new MLAddMemoriesRequest(input); + + client.execute(MLAddMemoriesAction.INSTANCE, request, ActionListener.wrap(addResponse -> { + log + .info( + "AgenticMemoryAdapter.createNewInteraction: Created interaction with ID: {}, sessionId: {}, action: {}, infer: {}", + addResponse.getWorkingMemoryId(), + addResponse.getSessionId(), + action, + hasLlmId + ); + listener.onResponse(addResponse.getWorkingMemoryId()); + }, listener::onFailure)); + }, ex -> { + log.warn("Failed to check LLM ID configuration for interaction, proceeding with infer=false", ex); + // Fallback to infer=false if we can't determine LLM ID status + MLAddMemoriesInput input = MLAddMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .messages(messages) + .namespace(namespaceMap) + .metadata(metadataMap) + .ownerId(ownerId) + .infer(false) + .build(); + + MLAddMemoriesRequest request = new MLAddMemoriesRequest(input); + + client.execute(MLAddMemoriesAction.INSTANCE, request, ActionListener.wrap(addResponse -> { + log + .info( + "AgenticMemoryAdapter.createNewInteraction: Created interaction with ID: {}, sessionId: {}, action: {}, infer: false (fallback)", + addResponse.getWorkingMemoryId(), + addResponse.getSessionId(), + action + ); + listener.onResponse(addResponse.getWorkingMemoryId()); + }, listener::onFailure)); + })); + } + + /** + * Save trace data as structured tool invocation information in working memory. + * + *

This method stores detailed information about tool executions, including inputs, + * outputs, and contextual metadata. The data is stored with appropriate tags and + * namespace information for later retrieval and analysis.

+ * + *

Important: This method always uses {@code infer=false} to prevent + * LLM-based long-term memory extraction from tool traces. Tool execution data is already + * structured and queryable, and extracting facts from intermediate steps would create + * fragmented, duplicate long-term memories. Semantic extraction happens only on final + * conversation interactions via {@link #saveInteraction}.

+ * + * @param toolName Name of the tool that was executed (required, non-empty) + * @param toolInput Input parameters passed to the tool (nullable, defaults to empty string) + * @param toolOutput Output/response from the tool execution (nullable, defaults to empty string) + * @param parentMemoryId Parent memory ID to associate this trace with (nullable) + * @param traceNum Trace sequence number for ordering (nullable) + * @param action Action/origin identifier for categorization (nullable) + * @param listener ActionListener to handle the response with the created memory ID + * @throws IllegalArgumentException if toolName is null/empty or listener is null + * @see #saveInteraction for conversational data that triggers long-term memory extraction + */ + @Override + public void saveTraceData( + String toolName, + String toolInput, + String toolOutput, + String parentMemoryId, + Integer traceNum, + String action, + ActionListener listener + ) { + if (toolName == null || toolName.trim().isEmpty()) { + throw new IllegalArgumentException("Tool name cannot be null or empty"); + } + if (listener == null) { + throw new IllegalArgumentException("ActionListener cannot be null"); + } + final String finalToolName = toolName; + + // Create tool invocation structured data + Map toolInvocation = new HashMap<>(); + toolInvocation.put("tool_name", finalToolName); + toolInvocation.put("tool_input", toolInput != null ? toolInput : ""); + toolInvocation.put("tool_output", toolOutput != null ? toolOutput : ""); + + Map structuredData = new HashMap<>(); + structuredData.put("tool_invocations", List.of(toolInvocation)); + + // Create namespace map + Map namespaceMap = new HashMap<>(); + namespaceMap.put("session_id", sessionId); + namespaceMap.put("user_id", ownerId); + + // Create metadata map + Map metadataMap = new HashMap<>(); + metadataMap.put("status", "checkpoint"); + if (traceNum != null) { + metadataMap.put("trace_num", traceNum.toString()); + } + if (action != null) { + metadataMap.put("action", action); + } + if (parentMemoryId != null) { + metadataMap.put("parent_memory_id", parentMemoryId); + } + + // Create tags map with trace-specific information + Map tagsMap = new HashMap<>(); + tagsMap.put("data_type", "trace"); + + if (action != null) { + tagsMap.put("topic", action); + } + + /* + * IMPORTANT: Tool trace data uses infer=false to prevent long-term memory extraction + * + * Rationale: + * 1. Tool traces are intermediate execution steps, not final user-facing content + * 2. Running LLM inference on tool traces would create fragmented, low-quality long-term memories + * 3. Multiple tool executions in a single conversation would generate redundant/duplicate facts + * 4. Tool trace data is already structured (tool_name, tool_input, tool_output) and queryable + * 5. Final conversation interactions (saveInteraction) will trigger proper semantic extraction + * + * Example problem if infer=true: + * User: "What's the weather in Seattle?" + * - Tool trace saved → LLM extracts: "User queried Seattle" (incomplete context) + * - Final response saved → LLM extracts: "User asked about Seattle weather" (complete context) + * Result: Duplicate/conflicting long-term memories + * + * By setting infer=false for tool traces: + * - Tool execution data remains queryable via structured data search + * - Long-term memory extraction happens only on final, contextually complete interactions + * - Cleaner, more accurate long-term memory without duplication + * - Reduced LLM inference costs and processing overhead + */ + executeTraceDataSave(structuredData, namespaceMap, metadataMap, tagsMap, false, finalToolName, action, listener); + } + + /** + * Execute the actual trace data save operation. + * + *

Note: The infer parameter is kept for potential future use cases where selective + * inference on tool traces might be needed, but currently always receives false to + * prevent duplicate long-term memory extraction.

+ * + * @param structuredData The structured data containing tool invocation information + * @param namespaceMap The namespace mapping for the memory + * @param metadataMap The metadata for the memory entry + * @param tagsMap The tags for the memory entry + * @param infer Whether to enable inference processing (currently always false for tool traces) + * @param toolName The name of the tool (for logging) + * @param action The action identifier (for logging) + * @param listener ActionListener to handle the response + */ + private void executeTraceDataSave( + Map structuredData, + Map namespaceMap, + Map metadataMap, + Map tagsMap, + boolean infer, + String toolName, + String action, + ActionListener listener + ) { + try { + MLAddMemoriesInput input = MLAddMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .structuredData(structuredData) + .namespace(namespaceMap) + .metadata(metadataMap) + .tags(tagsMap) + .ownerId(ownerId) + .payloadType(PayloadType.DATA) + .infer(infer) + .build(); + + MLAddMemoriesRequest request = new MLAddMemoriesRequest(input); + + client.execute(MLAddMemoriesAction.INSTANCE, request, ActionListener.wrap(addResponse -> { + log + .info( + "AgenticMemoryAdapter.saveTraceData: Successfully saved trace data with ID: {}, toolName: {}, action: {}, infer: {}", + addResponse.getWorkingMemoryId(), + toolName, + action, + infer + ); + listener.onResponse(addResponse.getWorkingMemoryId()); + }, ex -> { + log + .error( + "AgenticMemoryAdapter.saveTraceData: Failed to save trace data for tool: {}, action: {}, infer: {}. Error: {}", + toolName, + action, + infer, + ex.getMessage(), + ex + ); + listener.onFailure(ex); + })); + } catch (Exception e) { + log + .error( + "AgenticMemoryAdapter.saveTraceData: Exception while building trace data save request for tool: {}, action: {}", + toolName, + action, + e + ); + listener.onFailure(e); + } + } + + /** + * Check if the memory container has an LLM ID configured for inference + * @param callback ActionListener to handle the result (true if LLM ID exists, false otherwise) + */ + private void hasLlmIdConfigured(ActionListener callback) { + MLMemoryContainerGetRequest getRequest = MLMemoryContainerGetRequest.builder().memoryContainerId(memoryContainerId).build(); + + client.execute(MLMemoryContainerGetAction.INSTANCE, getRequest, ActionListener.wrap(response -> { + boolean hasLlmId = response.getMlMemoryContainer().getConfiguration().getLlmId() != null; + log.info("Memory container {} has LLM ID configured: {}", memoryContainerId, hasLlmId); + callback.onResponse(hasLlmId); + }, ex -> { + log + .warn( + "Failed to get memory container {} configuration, defaulting infer to false. Error: {}", + memoryContainerId, + ex.getMessage(), + ex + ); + callback.onResponse(false); + })); + } + + private List> createTextContent(String text) { + return List.of(Map.of("type", "text", "text", text)); + } + + private List parseAgenticMemoryResponse(SearchResponse response) { + List chatMessages = new ArrayList<>(); + + for (SearchHit hit : response.getHits().getHits()) { + Map source = hit.getSourceAsMap(); + + // Parse working memory documents (conversational format) + if ("conversational".equals(source.get("payload_type"))) { + @SuppressWarnings("unchecked") + List> messages = (List>) source.get("messages"); + if (messages != null && messages.size() >= 2) { + // Extract user question and assistant response + String question = extractMessageText(messages.get(0)); // user message + String assistantResponse = extractMessageText(messages.get(1)); // assistant message + + if (question != null && assistantResponse != null) { + // Add user message + ChatMessage userMessage = ChatMessage + .builder() + .id(hit.getId() + "_user") + .timestamp(Instant.ofEpochMilli((Long) source.get("created_time"))) + .sessionId(getSessionIdFromNamespace(source)) + .role("user") + .content(question) + .contentType("text") + .origin("agentic_memory_working") + .metadata( + Map + .of( + "payload_type", + source.get("payload_type"), + "memory_container_id", + source.get("memory_container_id"), + "namespace", + source.get("namespace"), + "tags", + source.get("tags") + ) + ) + .build(); + chatMessages.add(userMessage); + + // Add assistant message + ChatMessage assistantMessage = ChatMessage + .builder() + .id(hit.getId() + "_assistant") + .timestamp(Instant.ofEpochMilli((Long) source.get("created_time"))) + .sessionId(getSessionIdFromNamespace(source)) + .role("assistant") + .content(assistantResponse) + .contentType("text") + .origin("agentic_memory_working") + .metadata( + Map + .of( + "payload_type", + source.get("payload_type"), + "memory_container_id", + source.get("memory_container_id"), + "namespace", + source.get("namespace"), + "tags", + source.get("tags") + ) + ) + .build(); + chatMessages.add(assistantMessage); + } + } + } + // Parse long-term memory documents (extracted facts) + else if (source.containsKey("memory") && source.containsKey("strategy_type")) { + String memory = (String) source.get("memory"); + String strategyType = (String) source.get("strategy_type"); + + // Convert extracted facts to chat message format for context + ChatMessage contextMessage = ChatMessage + .builder() + .id(hit.getId()) + .timestamp(Instant.ofEpochMilli((Long) source.get("created_time"))) + .sessionId(sessionId) // Use current session + .role("system") // System context message + .content("Context (" + strategyType + "): " + memory) // The extracted fact with context + .contentType("context") + .origin("agentic_memory_longterm") + .metadata( + Map + .of( + "strategy_type", + strategyType, + "strategy_id", + source.get("strategy_id"), + "memory_container_id", + source.get("memory_container_id") + ) + ) + .build(); + chatMessages.add(contextMessage); + } + } + + // Sort by timestamp to maintain chronological order + chatMessages.sort((a, b) -> a.getTimestamp().compareTo(b.getTimestamp())); + + return chatMessages; + } + + private String extractMessageText(Map message) { + if (message == null) + return null; + + @SuppressWarnings("unchecked") + List> content = (List>) message.get("content"); + if (content != null && !content.isEmpty()) { + Map firstContent = content.get(0); + return (String) firstContent.get("text"); + } + return null; + } + + private String getSessionIdFromNamespace(Map source) { + @SuppressWarnings("unchecked") + Map namespace = (Map) source.get("namespace"); + return namespace != null ? (String) namespace.get("session_id") : null; + } + + @Override + public void updateInteraction(String interactionId, java.util.Map updateFields, ActionListener listener) { + if (listener == null) { + throw new IllegalArgumentException("ActionListener cannot be null"); + } + if (interactionId == null || interactionId.trim().isEmpty()) { + listener.onFailure(new IllegalArgumentException("Interaction ID is required and cannot be empty")); + return; + } + if (updateFields == null || updateFields.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Update fields are required and cannot be empty")); + return; + } + + try { + log + .info( + "AgenticMemoryAdapter.updateInteraction: CALLED - Updating interaction {} with fields: {} in memory container: {}", + interactionId, + updateFields.keySet(), + memoryContainerId + ); + + // Convert updateFields to the format expected by memory container API + Map updateContent = new java.util.HashMap<>(); + + // Handle the response field - this is the main field we need to update + if (updateFields.containsKey("response")) { + String response = (String) updateFields.get("response"); + String question = (String) updateFields.getOrDefault("input", ""); + + // For working memory updates, we need to provide the complete messages array + // with both user question and assistant response + List> messages = Arrays + .asList( + Map.of("role", "user", "content", createTextContent(question)), + Map.of("role", "assistant", "content", createTextContent(response)) + ); + + updateContent.put("messages", messages); + + log + .debug( + "AgenticMemoryAdapter.updateInteraction: Updating messages for interaction {} with question: '{}' and response length: {}", + interactionId, + question.length() > 50 ? question.substring(0, 50) + "..." : question, + response.length() + ); + } + + // Handle other fields that might be updated + if (updateFields.containsKey("additional_info")) { + updateContent.put("additional_info", updateFields.get("additional_info")); + } + + MLUpdateMemoryInput input = MLUpdateMemoryInput.builder().updateContent(updateContent).build(); + + MLUpdateMemoryRequest request = MLUpdateMemoryRequest + .builder() + .memoryContainerId(memoryContainerId) + .memoryType(MemoryType.WORKING) // We're updating working memory + .memoryId(interactionId) + .mlUpdateMemoryInput(input) + .build(); + + client.execute(MLUpdateMemoryAction.INSTANCE, request, ActionListener.wrap(updateResponse -> { + log + .debug( + "AgenticMemoryAdapter.updateInteraction: Successfully updated interaction {} in memory container: {}", + interactionId, + memoryContainerId + ); + listener.onResponse(null); + }, ex -> { + log + .error( + "AgenticMemoryAdapter.updateInteraction: Failed to update interaction {} in memory container {}", + interactionId, + memoryContainerId, + ex + ); + listener.onFailure(ex); + })); + + } catch (Exception e) { + log + .error( + "AgenticMemoryAdapter.updateInteraction: Exception while updating interaction {} in memory container {}", + interactionId, + memoryContainerId, + e + ); + listener.onFailure(e); + } + } + +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatHistoryTemplateEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatHistoryTemplateEngine.java new file mode 100644 index 0000000000..80743ba3c5 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatHistoryTemplateEngine.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import java.util.List; +import java.util.Map; + +/** + * Enhanced template system for ChatMessage-based memory types. + * Supports flexible templating with role-based formatting and metadata access. + */ +public interface ChatHistoryTemplateEngine { + /** + * Build chat history from ChatMessage list using template + * @param messages List of ChatMessage objects + * @param template Template string with placeholders + * @param context Additional context variables + * @return Formatted chat history string + */ + String buildChatHistory(List messages, String template, Map context); + + /** + * Get default template for basic chat history formatting + * @return Default template string + */ + default String getDefaultTemplate() { + return "{{#each messages}}{{role}}: {{content}}\n{{/each}}"; + } + + /** + * Get role-based template with enhanced formatting + * @return Role-based template string + */ + default String getRoleBasedTemplate() { + return """ + {{#each messages}} + {{#if (eq role 'user')}} + Human: {{content}} + {{else if (eq role 'assistant')}} + Assistant: {{content}} + {{else if (eq role 'system')}} + System: {{content}} + {{else if (eq role 'tool')}} + Tool Result: {{content}} + {{/if}} + {{#if metadata.confidence}} + (Confidence: {{metadata.confidence}}) + {{/if}} + {{/each}} + """; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapter.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapter.java new file mode 100644 index 0000000000..88e952c806 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapter.java @@ -0,0 +1,124 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import java.util.List; + +import org.opensearch.core.action.ActionListener; + +/** + * Common interface for modern memory types supporting ChatMessage-based interactions. + * + *

This interface provides a unified abstraction for different memory backend implementations, + * enabling consistent interaction patterns across various memory storage systems. It supports + * both conversation management and detailed trace data storage for comprehensive agent behavior + * tracking.

+ * + *

Supported Memory Types:

+ *
    + *
  • Agentic Memory - Local cluster-based intelligent memory system
  • + *
  • Remote Agentic Memory - Distributed agentic memory implementation
  • + *
  • Bedrock AgentCore Memory - AWS Bedrock agent memory integration
  • + *
  • Future memory types - Extensible for additional implementations
  • + *
+ * + *

Core Capabilities:

+ *
    + *
  • Message retrieval in standardized ChatMessage format
  • + *
  • Conversation and session management
  • + *
  • Interaction persistence with metadata support
  • + *
  • Tool execution trace data storage
  • + *
  • Dynamic interaction updates
  • + *
+ * + *

Note: ConversationIndex uses a separate legacy pipeline for backward compatibility + * and is not part of this modern interface hierarchy.

+ * + * @see ChatMessage + * @see AgenticMemoryAdapter + */ +public interface ChatMemoryAdapter { + /** + * Retrieve conversation messages in ChatMessage format + * @param listener ActionListener to handle the response + */ + void getMessages(ActionListener> listener); + + /** + * Get the conversation/session identifier + * @return conversation ID or session ID + */ + String getConversationId(); + + /** + * This is the main memory container ID used to identify the memory container + * in the memory management system. + * @return + */ + String getMemoryContainerId(); + + /** + * Save interaction to memory (optional implementation) + * @param question User question + * @param response AI response + * @param parentId Parent interaction ID + * @param traceNum Trace number + * @param action Action performed + * @param listener ActionListener to handle the response + */ + default void saveInteraction( + String question, + String response, + String parentId, + Integer traceNum, + String action, + ActionListener listener + ) { + listener.onFailure(new UnsupportedOperationException("Save not implemented")); + } + + /** + * Update existing interaction with additional information + * @param interactionId Interaction ID to update + * @param updateFields Fields to update (e.g., final answer, additional info) + * @param listener ActionListener to handle the response + */ + default void updateInteraction(String interactionId, java.util.Map updateFields, ActionListener listener) { + listener.onFailure(new UnsupportedOperationException("Update interaction not implemented")); + } + + /** + * Save trace data as tool invocation data in working memory. + * + *

This method provides a standardized way to store detailed information about + * tool executions, enabling comprehensive tracking and analysis of agent behavior. + * Implementations should store this data in a structured format that supports + * later retrieval and analysis.

+ * + *

Default implementation throws UnsupportedOperationException. Memory adapters + * that support trace data storage should override this method.

+ * + * @param toolName Name of the tool that was executed (required) + * @param toolInput Input parameters passed to the tool (may be null) + * @param toolOutput Output/response from the tool execution (may be null) + * @param parentMemoryId Parent memory ID to associate this trace with (may be null) + * @param traceNum Trace sequence number for ordering (may be null) + * @param action Action/origin identifier for categorization (may be null) + * @param listener ActionListener to handle the response with created trace ID + * @throws UnsupportedOperationException if the implementation doesn't support trace data storage + */ + default void saveTraceData( + String toolName, + String toolInput, + String toolOutput, + String parentMemoryId, + Integer traceNum, + String action, + ActionListener listener + ) { + listener.onFailure(new UnsupportedOperationException("Save trace data not implemented")); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMessage.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMessage.java new file mode 100644 index 0000000000..31dd72604d --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMessage.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import java.time.Instant; +import java.util.Map; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; + +/** + * Enhanced memory message for chat agents - designed for extensibility. + * Supports multiple memory types: Agentic, Remote Agentic, Bedrock AgentCore, etc. + * + * Design Philosophy: + * - Text-first with rich metadata (hybrid approach) + * - Extensible for future multimodal content + * - Memory-type agnostic interface + * - Role-based message support + */ +@Builder +@AllArgsConstructor +@Getter +public class ChatMessage { + private String id; + private Instant timestamp; + private String sessionId; + private String role; // "user", "assistant", "system", "tool" + private String content; // Primary text content + private String contentType; // "text", "image", "tool_result", etc. (metadata) + private String origin; // "agentic_memory", "remote_agentic", "bedrock_agentcore", etc. + private Map metadata; // Rich content details and memory-specific data +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 1594506cf4..4b44c55738 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -27,6 +27,7 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.UUID; import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchException; @@ -46,6 +47,7 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLAgentType; +import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; @@ -95,6 +97,7 @@ public class MLAgentExecutor implements Executable, SettingsChangeListener { public static final String MEMORY_ID = "memory_id"; + public static final String MEMORY_CONTAINER_ID = "memory_container_id"; public static final String QUESTION = "question"; public static final String PARENT_INTERACTION_ID = "parent_interaction_id"; public static final String REGENERATE_INTERACTION_ID = "regenerate_interaction_id"; @@ -174,194 +177,211 @@ public void execute(Input input, ActionListener listener, TransportChann if (MLIndicesHandler.doesMultiTenantIndexExist(clusterService, mlFeatureEnabledSetting.isMultiTenancyEnabled(), ML_AGENT_INDEX)) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - sdkClient - .getDataObjectAsync(getDataObjectRequest, client.threadPool().executor("opensearch_ml_general")) - .whenComplete((response, throwable) -> { - context.restore(); - log.debug("Completed Get Agent Request, Agent id:{}", agentId); - if (throwable != null) { - Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); - if (ExceptionsHelper.unwrap(cause, IndexNotFoundException.class) != null) { - log.error("Failed to get Agent index", cause); - listener.onFailure(new OpenSearchStatusException("Failed to get agent index", RestStatus.NOT_FOUND)); - } else { - log.error("Failed to get ML Agent {}", agentId, cause); - listener.onFailure(cause); - } + sdkClient.getDataObjectAsync(getDataObjectRequest).whenComplete((response, throwable) -> { + context.restore(); + log.debug("Completed Get Agent Request, Agent id:{}", agentId); + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + if (ExceptionsHelper.unwrap(cause, IndexNotFoundException.class) != null) { + log.error("Failed to get Agent index", cause); + listener.onFailure(new OpenSearchStatusException("Failed to get agent index", RestStatus.NOT_FOUND)); } else { - try { - GetResponse getAgentResponse = response.parser() == null - ? null - : GetResponse.fromXContent(response.parser()); - if (getAgentResponse != null && getAgentResponse.isExists()) { - try ( - XContentParser parser = jsonXContent - .createParser( - xContentRegistry, - LoggingDeprecationHandler.INSTANCE, - getAgentResponse.getSourceAsString() - ) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLAgent mlAgent = MLAgent.parse(parser); - if (isMultiTenancyEnabled && !Objects.equals(tenantId, mlAgent.getTenantId())) { - listener - .onFailure( - new OpenSearchStatusException( - "You don't have permission to access this resource", - RestStatus.FORBIDDEN - ) - ); - } - MLMemorySpec memorySpec = mlAgent.getMemory(); - String memoryId = inputDataSet.getParameters().get(MEMORY_ID); - String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); - String regenerateInteractionId = inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID); - String appType = mlAgent.getAppType(); - String question = inputDataSet.getParameters().get(QUESTION); - - if (parentInteractionId != null && regenerateInteractionId != null) { - throw new IllegalArgumentException( - "Provide either `parent_interaction_id` to update an existing interaction, or `regenerate_interaction_id` to create a new one." + log.error("Failed to get ML Agent {}", agentId, cause); + listener.onFailure(cause); + } + } else { + try { + GetResponse getAgentResponse = response.parser() == null ? null : GetResponse.fromXContent(response.parser()); + if (getAgentResponse != null && getAgentResponse.isExists()) { + try ( + XContentParser parser = jsonXContent + .createParser( + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, + getAgentResponse.getSourceAsString() + ) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLAgent mlAgent = MLAgent.parse(parser); + if (isMultiTenancyEnabled && !Objects.equals(tenantId, mlAgent.getTenantId())) { + listener + .onFailure( + new OpenSearchStatusException( + "You don't have permission to access this resource", + RestStatus.FORBIDDEN + ) ); - } + } + MLMemorySpec memorySpec = mlAgent.getMemory(); + String memoryId; + if (Objects.equals(mlAgent.getMemory().getType(), MLMemoryType.CONVERSATION_INDEX.name())) { + memoryId = inputDataSet.getParameters().get(MEMORY_ID); + } else { + memoryId = inputDataSet.getParameters().get(MEMORY_CONTAINER_ID); + } - MLTask mlTask = MLTask - .builder() - .taskType(MLTaskType.AGENT_EXECUTION) - .functionName(FunctionName.AGENT) - .state(MLTaskState.CREATED) - .workerNodes(ImmutableList.of(clusterService.localNode().getId())) - .createTime(Instant.now()) - .lastUpdateTime(Instant.now()) - .async(false) - .tenantId(tenantId) - .build(); - - if (memoryId == null && regenerateInteractionId != null) { - throw new IllegalArgumentException("A memory ID must be provided to regenerate."); - } - if (memorySpec != null - && memorySpec.getType() != null - && memoryFactoryMap.containsKey(memorySpec.getType()) - && (memoryId == null || parentInteractionId == null)) { - ConversationIndexMemory.Factory conversationIndexMemoryFactory = - (ConversationIndexMemory.Factory) memoryFactoryMap.get(memorySpec.getType()); - conversationIndexMemoryFactory - .create(question, memoryId, appType, ActionListener.wrap(memory -> { - inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId()); - // get question for regenerate - if (regenerateInteractionId != null) { - log.info("Regenerate for existing interaction {}", regenerateInteractionId); - client - .execute( - GetInteractionAction.INSTANCE, - new GetInteractionRequest(regenerateInteractionId), - ActionListener.wrap(interactionRes -> { - inputDataSet - .getParameters() - .putIfAbsent(QUESTION, interactionRes.getInteraction().getInput()); - saveRootInteractionAndExecute( - listener, - memory, - inputDataSet, - mlTask, - isAsync, - outputs, - modelTensors, - mlAgent, - channel - ); - }, e -> { - log.error("Failed to get existing interaction for regeneration", e); - listener.onFailure(e); - }) + String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); + String regenerateInteractionId = inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID); + String appType = mlAgent.getAppType(); + String question = inputDataSet.getParameters().get(QUESTION); + + if (parentInteractionId != null && regenerateInteractionId != null) { + throw new IllegalArgumentException( + "Provide either `parent_interaction_id` to update an existing interaction, or `regenerate_interaction_id` to create a new one." + ); + } + + MLTask mlTask = MLTask + .builder() + .taskType(MLTaskType.AGENT_EXECUTION) + .functionName(FunctionName.AGENT) + .state(MLTaskState.CREATED) + .workerNodes(ImmutableList.of(clusterService.localNode().getId())) + .createTime(Instant.now()) + .lastUpdateTime(Instant.now()) + .async(false) + .tenantId(tenantId) + .build(); + + if (memoryId == null && regenerateInteractionId != null) { + throw new IllegalArgumentException("A memory ID must be provided to regenerate."); + } + + // NEW: Handle AGENTIC_MEMORY type before ConversationIndex logic + if (memorySpec != null && "AGENTIC_MEMORY".equals(memorySpec.getType())) { + log.debug("Detected AGENTIC_MEMORY type - routing to agentic memory handler"); + handleAgenticMemory( + inputDataSet, + mlTask, + isAsync, + outputs, + modelTensors, + mlAgent, + listener, + channel + ); + } + // EXISTING: ConversationIndex logic remains unchanged + else if (memorySpec != null + && memorySpec.getType() != null + && memoryFactoryMap.containsKey(memorySpec.getType()) + && (memoryId == null || parentInteractionId == null)) { + ConversationIndexMemory.Factory conversationIndexMemoryFactory = + (ConversationIndexMemory.Factory) memoryFactoryMap.get(memorySpec.getType()); + conversationIndexMemoryFactory.create(question, memoryId, appType, ActionListener.wrap(memory -> { + inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId()); + // get question for regenerate + if (regenerateInteractionId != null) { + log.info("Regenerate for existing interaction {}", regenerateInteractionId); + client + .execute( + GetInteractionAction.INSTANCE, + new GetInteractionRequest(regenerateInteractionId), + ActionListener.wrap(interactionRes -> { + inputDataSet + .getParameters() + .putIfAbsent(QUESTION, interactionRes.getInteraction().getInput()); + saveRootInteractionAndExecute( + listener, + memory, + inputDataSet, + mlTask, + isAsync, + outputs, + modelTensors, + mlAgent, + channel ); - } else { - saveRootInteractionAndExecute( - listener, - memory, - inputDataSet, - mlTask, - isAsync, - outputs, - modelTensors, - mlAgent, - channel - ); - } - }, ex -> { - log.error("Failed to read conversation memory", ex); - listener.onFailure(ex); - })); - } else { - // For existing conversations, create memory instance using factory - if (memorySpec != null && memorySpec.getType() != null) { - ConversationIndexMemory.Factory factory = (ConversationIndexMemory.Factory) memoryFactoryMap - .get(memorySpec.getType()); - if (factory != null) { - // memoryId exists, so create returns an object with existing memory, therefore name can - // be null - factory - .create( - null, - memoryId, - appType, - ActionListener - .wrap( - createdMemory -> executeAgent( - inputDataSet, - mlTask, - isAsync, - memoryId, - mlAgent, - outputs, - modelTensors, - listener, - createdMemory, - channel - ), - ex -> { - log.error("Failed to find memory with memory_id: {}", memoryId, ex); - listener.onFailure(ex); - } - ) - ); - return; - } + }, e -> { + log.error("Failed to get existing interaction for regeneration", e); + listener.onFailure(e); + }) + ); + } else { + saveRootInteractionAndExecute( + listener, + memory, + inputDataSet, + mlTask, + isAsync, + outputs, + modelTensors, + mlAgent, + channel + ); + } + }, ex -> { + log.error("Failed to read conversation memory", ex); + listener.onFailure(ex); + })); + } else { + // For existing conversations, create memory instance using factory + if (memorySpec != null && memorySpec.getType() != null) { + ConversationIndexMemory.Factory factory = (ConversationIndexMemory.Factory) memoryFactoryMap + .get(memorySpec.getType()); + if (factory != null) { + // memoryId exists, so create returns an object with existing memory, therefore name can + // be null + factory + .create( + null, + memoryId, + appType, + ActionListener + .wrap( + createdMemory -> executeAgent( + inputDataSet, + mlTask, + isAsync, + memoryId, + mlAgent, + outputs, + modelTensors, + listener, + createdMemory, + channel + ), + ex -> { + log.error("Failed to find memory with memory_id: {}", memoryId, ex); + listener.onFailure(ex); + } + ) + ); + return; } - executeAgent( - inputDataSet, - mlTask, - isAsync, - memoryId, - mlAgent, - outputs, - modelTensors, - listener, - null, - channel - ); } - } catch (Exception e) { - log.error("Failed to parse ml agent {}", agentId, e); - listener.onFailure(e); - } - } else { - listener - .onFailure( - new OpenSearchStatusException( - "Failed to find agent with the provided agent id: " + agentId, - RestStatus.NOT_FOUND - ) + executeAgent( + inputDataSet, + mlTask, + isAsync, + memoryId, + mlAgent, + outputs, + modelTensors, + listener, + null, + channel ); + } + } catch (Exception e) { + log.error("Failed to parse ml agent {}", agentId, e); + listener.onFailure(e); } - } catch (Exception e) { - log.error("Failed to get agent", e); - listener.onFailure(e); + } else { + listener + .onFailure( + new OpenSearchStatusException( + "Failed to find agent with the provided agent id: " + agentId, + RestStatus.NOT_FOUND + ) + ); } + } catch (Exception e) { + log.error("Failed to get agent", e); + listener.onFailure(e); } - }); + } + }); } } else { listener.onFailure(new ResourceNotFoundException("Agent index not found")); @@ -456,7 +476,7 @@ private void executeAgent( List outputs, List modelTensors, ActionListener listener, - ConversationIndexMemory memory, + Object memory, // Accept both ConversationIndexMemory and AgenticMemoryAdapter TransportChannel channel ) { String mcpConnectorConfigJSON = (mlAgent.getParameters() != null) ? mlAgent.getParameters().get(MCP_CONNECTORS_FIELD) : null; @@ -472,12 +492,23 @@ private void executeAgent( // If async is true, index ML task and return the taskID. Also add memoryID to the task if it exists if (isAsync) { Map agentResponse = new HashMap<>(); - if (memoryId != null && !memoryId.isEmpty()) { - agentResponse.put(MEMORY_ID, memoryId); - } - if (parentInteractionId != null && !parentInteractionId.isEmpty()) { - agentResponse.put(PARENT_INTERACTION_ID, parentInteractionId); + // Handle different memory types for response + if (memory instanceof AgenticMemoryAdapter) { + AgenticMemoryAdapter adapter = (AgenticMemoryAdapter) memory; + agentResponse.put(MEMORY_ID, adapter.getMemoryContainerId()); // memory_container_id + agentResponse.put("session_id", adapter.getConversationId()); // session_id + if (parentInteractionId != null && !parentInteractionId.isEmpty()) { + agentResponse.put(PARENT_INTERACTION_ID, parentInteractionId); // actual interaction ID + } + } else { + // ConversationIndex behavior (unchanged) + if (memoryId != null && !memoryId.isEmpty()) { + agentResponse.put(MEMORY_ID, memoryId); + } + if (parentInteractionId != null && !parentInteractionId.isEmpty()) { + agentResponse.put(PARENT_INTERACTION_ID, parentInteractionId); + } } mlTask.setResponse(agentResponse); mlTask.setAsync(true); @@ -535,7 +566,7 @@ private ActionListener createAgentActionListener( List modelTensors, String agentType, String parentInteractionId, - ConversationIndexMemory memory + Object memory // Accept both ConversationIndexMemory and AgenticMemoryAdapter ) { return ActionListener.wrap(output -> { if (output != null) { @@ -556,7 +587,7 @@ private ActionListener createAsyncTaskUpdater( List outputs, List modelTensors, String parentInteractionId, - ConversationIndexMemory memory + Object memory // Accept both ConversationIndexMemory and AgenticMemoryAdapter ) { String taskId = mlTask.getTaskId(); Map agentResponse = new HashMap<>(); @@ -583,6 +614,7 @@ private ActionListener createAsyncTaskUpdater( e -> log.error("Failed to update ML task {} with agent execution results", taskId) ) ); + }, ex -> { agentResponse.put(ERROR_MESSAGE, ex.getMessage()); @@ -711,23 +743,259 @@ public void indexMLTask(MLTask mlTask, ActionListener listener) { } } - private void updateInteractionWithFailure(String interactionId, ConversationIndexMemory memory, String errorMessage) { - if (interactionId != null && memory != null) { - String failureMessage = "Agent execution failed: " + errorMessage; - Map updateContent = new HashMap<>(); - updateContent.put(RESPONSE_FIELD, failureMessage); + /** + * Handle agentic memory type requests + */ + private void handleAgenticMemory( + RemoteInferenceInputDataSet inputDataSet, + MLTask mlTask, + boolean isAsync, + List outputs, + List modelTensors, + MLAgent mlAgent, + ActionListener listener, + TransportChannel channel + ) { + // Extract parameters + String memoryContainerId = inputDataSet.getParameters().get("memory_container_id"); + String sessionId = inputDataSet.getParameters().get("session_id"); + String ownerId = inputDataSet.getParameters().get("owner_id"); + String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); + + log.debug("MLAgentExecutor: Processing AGENTIC_MEMORY request with parameters: {}", inputDataSet.getParameters().keySet()); + log + .debug( + "Extracted agentic memory parameters - memoryContainerId: {}, sessionId: {}, ownerId: {}, parentInteractionId: {}", + memoryContainerId != null ? "present" : "null", + sessionId != null ? "present" : "null", + ownerId != null ? "present" : "null", + parentInteractionId != null ? "present" : "null" + ); + + // Parameter validation + if (memoryContainerId == null) { + log + .error( + "AGENTIC_MEMORY validation failed: memory_container_id is null. Available params: {}", + inputDataSet.getParameters().keySet() + ); + listener.onFailure(new IllegalArgumentException("memory_container_id is required for agentic memory")); + return; + } + + if (ownerId == null) { + log.error("AGENTIC_MEMORY validation failed: owner_id is null. Available params: {}", inputDataSet.getParameters().keySet()); + listener.onFailure(new IllegalArgumentException("owner_id is required for agentic memory")); + return; + } + + log.debug("AGENTIC_MEMORY parameter validation successful - memoryContainerId: {}, ownerId: {}", memoryContainerId, ownerId); + + // Session management (same pattern as ConversationIndex) + boolean isNewConversation = Strings.isEmpty(sessionId) || parentInteractionId == null; + log + .debug( + "Conversation type determination - sessionId: {}, parentInteractionId: {}, isNewConversation: {}", + sessionId != null ? "present" : "null", + parentInteractionId != null ? "present" : "null", + isNewConversation + ); + + if (isNewConversation) { + if (Strings.isEmpty(sessionId)) { + sessionId = UUID.randomUUID().toString(); // NEW conversation + log.debug("Generated new agentic memory session: {}", sessionId); + } else { + log.debug("Using provided session ID for new conversation: {}", sessionId); + } + } else { + log + .debug( + "Continuing existing agentic memory conversation - sessionId: {}, parentInteractionId: {}", + sessionId, + parentInteractionId + ); + } - memory - .getMemoryManager() - .updateInteraction( + // Create AgenticMemoryAdapter + log + .debug( + "Creating AgenticMemoryAdapter with parameters - memoryContainerId: {}, sessionId: {}, ownerId: {}", + memoryContainerId, + sessionId, + ownerId + ); + try { + AgenticMemoryAdapter adapter = new AgenticMemoryAdapter(client, memoryContainerId, sessionId, ownerId); + log + .debug( + "AgenticMemoryAdapter created successfully - memoryContainerId: {}, sessionId: {}, conversationId: {}", + memoryContainerId, + sessionId, + adapter.getConversationId() + ); + + // Route to appropriate execution path + if (isNewConversation) { + // NEW conversation: create root interaction first + log + .debug( + "Execution path: NEW conversation - routing to saveRootInteractionAndExecuteAgentic for sessionId: {}", + sessionId + ); + saveRootInteractionAndExecuteAgentic( + listener, + adapter, + inputDataSet, + mlTask, + isAsync, + outputs, + modelTensors, + mlAgent, + channel + ); + } else { + // EXISTING conversation: execute directly + log + .debug( + "Execution path: EXISTING conversation - routing to executeAgent for sessionId: {}, parentInteractionId: {}", + sessionId, + parentInteractionId + ); + executeAgent( + inputDataSet, + mlTask, + isAsync, + adapter.getMemoryContainerId(), + mlAgent, + outputs, + modelTensors, + listener, + adapter, + channel + ); + } + } catch (Exception ex) { + log + .error( + "AgenticMemoryAdapter creation failed - memoryContainerId: {}, sessionId: {}, ownerId: {}, error: {}", + memoryContainerId, + sessionId, + ownerId, + ex.getMessage(), + ex + ); + listener.onFailure(ex); + } + } + + /** + * Create root interaction for new agentic memory conversations (mirrors ConversationIndex pattern for tool tracing support) + */ + private void saveRootInteractionAndExecuteAgentic( + ActionListener listener, + AgenticMemoryAdapter adapter, + RemoteInferenceInputDataSet inputDataSet, + MLTask mlTask, + boolean isAsync, + List outputs, + List modelTensors, + MLAgent mlAgent, + TransportChannel channel + ) { + String question = inputDataSet.getParameters().get(QUESTION); + + log + .debug( + "Creating root interaction for agentic memory - memoryContainerId: {}, sessionId: {}, question: {}", + adapter.getMemoryContainerId(), + adapter.getConversationId(), + question != null ? "present" : "null" + ); + + // Create root interaction with empty response (same pattern as ConversationIndex) + // This enables tool tracing and proper interaction updating + adapter.saveInteraction(question, "", null, 0, "ROOT", ActionListener.wrap(interactionId -> { + log + .info( + "Root interaction created successfully for agentic memory - interactionId: {}, memoryContainerId: {}, sessionId: {}", interactionId, - updateContent, - ActionListener - .wrap( - res -> log.info("Updated interaction {} with failure message", interactionId), - e -> log.warn("Failed to update interaction {} with failure message", interactionId, e) - ) + adapter.getMemoryContainerId(), + adapter.getConversationId() + ); + inputDataSet.getParameters().put(PARENT_INTERACTION_ID, interactionId); + + log + .debug( + "Proceeding to executeAgent with root interaction - interactionId: {}, sessionId: {}", + interactionId, + adapter.getConversationId() + ); + + executeAgent( + inputDataSet, + mlTask, + isAsync, + adapter.getMemoryContainerId(), // Use memory_container_id as memoryId for agentic memory + mlAgent, + outputs, + modelTensors, + listener, + adapter, + channel + ); + }, ex -> { + log + .error( + "Root interaction creation failed for agentic memory - memoryContainerId: {}, sessionId: {}, error: {}", + adapter.getMemoryContainerId(), + adapter.getConversationId(), + ex.getMessage(), + ex ); + listener.onFailure(ex); + })); + } + + private void updateInteractionWithFailure(String interactionId, Object memory, String errorMessage) { + if (interactionId != null && memory != null) { + if (memory instanceof ConversationIndexMemory) { + // Existing ConversationIndex error handling + ConversationIndexMemory conversationMemory = (ConversationIndexMemory) memory; + String failureMessage = "Agent execution failed: " + errorMessage; + Map updateContent = new HashMap<>(); + updateContent.put(RESPONSE_FIELD, failureMessage); + + conversationMemory + .getMemoryManager() + .updateInteraction( + interactionId, + updateContent, + ActionListener + .wrap( + res -> log.info("Updated interaction {} with failure message", interactionId), + e -> log.warn("Failed to update interaction {} with failure message", interactionId, e) + ) + ); + } else if (memory instanceof AgenticMemoryAdapter) { + // New agentic memory error handling + AgenticMemoryAdapter adapter = (AgenticMemoryAdapter) memory; + Map updateFields = new HashMap<>(); + updateFields.put("error", errorMessage); + + adapter + .updateInteraction( + interactionId, + updateFields, + ActionListener + .wrap( + res -> log.info("Updated agentic memory interaction {} with failure message", interactionId), + e -> log.warn("Failed to update agentic memory interaction {} with failure message", interactionId, e) + ) + ); + } else { + log.warn("Unknown memory type for error handling: {}", memory.getClass().getSimpleName()); + } } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 7e1a4050bd..f22e295062 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -43,6 +43,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicInteger; @@ -57,6 +58,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; @@ -76,8 +78,6 @@ import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.engine.tools.MLModelTool; -import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; -import org.opensearch.ml.repackage.com.google.common.collect.Lists; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.transport.TransportChannel; import org.opensearch.transport.client.Client; @@ -177,78 +177,60 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener functionCalling.configure(params); } - String memoryType = mlAgent.getMemory().getType(); - String memoryId = params.get(MLAgentExecutor.MEMORY_ID); - String appType = mlAgent.getAppType(); - String title = params.get(MLAgentExecutor.QUESTION); String chatHistoryPrefix = params.getOrDefault(PROMPT_CHAT_HISTORY_PREFIX, CHAT_HISTORY_PREFIX); String chatHistoryQuestionTemplate = params.get(CHAT_HISTORY_QUESTION_TEMPLATE); String chatHistoryResponseTemplate = params.get(CHAT_HISTORY_RESPONSE_TEMPLATE); int messageHistoryLimit = getMessageHistoryLimit(params); - ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType); - conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.wrap(memory -> { - // TODO: call runAgent directly if messageHistoryLimit == 0 - memory.getMessages(ActionListener.>wrap(r -> { - List messageList = new ArrayList<>(); - for (Interaction next : r) { - String question = next.getInput(); - String response = next.getResponse(); - // As we store the conversation with empty response first and then update when have final answer, - // filter out those in-flight requests when run in parallel - if (Strings.isNullOrEmpty(response)) { - continue; - } - messageList - .add( - ConversationIndexMessage - .conversationIndexMessageBuilder() - .sessionId(memory.getConversationId()) - .question(question) - .response(response) - .build() - ); - } - if (!messageList.isEmpty()) { - if (chatHistoryQuestionTemplate == null) { - StringBuilder chatHistoryBuilder = new StringBuilder(); - chatHistoryBuilder.append(chatHistoryPrefix); - for (Message message : messageList) { - chatHistoryBuilder.append(message.toString()).append("\n"); - } - params.put(CHAT_HISTORY, chatHistoryBuilder.toString()); - - // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate - inputParams.put(CHAT_HISTORY, chatHistoryBuilder.toString()); - } else { - List chatHistory = new ArrayList<>(); - for (Message message : messageList) { - Map messageParams = new HashMap<>(); - messageParams.put("question", processTextDoc(((ConversationIndexMessage) message).getQuestion())); - - StringSubstitutor substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}"); - String chatQuestionMessage = substitutor.replace(chatHistoryQuestionTemplate); - chatHistory.add(chatQuestionMessage); - - messageParams.clear(); - messageParams.put("response", processTextDoc(((ConversationIndexMessage) message).getResponse())); - substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}"); - String chatResponseMessage = substitutor.replace(chatHistoryResponseTemplate); - chatHistory.add(chatResponseMessage); - } - params.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); - params.put(NEW_CHAT_HISTORY, String.join(", ", chatHistory) + ", "); - - // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate - inputParams.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); - } - } + createMemoryAdapter(mlAgent, params, ActionListener.wrap(memoryOrAdapter -> { + log.debug("createMemoryAdapter callback: memoryOrAdapter type = {}", memoryOrAdapter.getClass().getSimpleName()); + + if (memoryOrAdapter instanceof ConversationIndexMemory) { + // Existing ConversationIndex flow - zero changes + ConversationIndexMemory memory = (ConversationIndexMemory) memoryOrAdapter; + memory.getMessages(ActionListener.>wrap(r -> { + processLegacyInteractions( + r, + memory.getConversationId(), + memory, + mlAgent, + params, + inputParams, + chatHistoryPrefix, + chatHistoryQuestionTemplate, + chatHistoryResponseTemplate, + functionCalling, + listener + ); + }, e -> { + log.error("Failed to get chat history", e); + listener.onFailure(e); + }), messageHistoryLimit); + + } else if (memoryOrAdapter instanceof ChatMemoryAdapter) { + // Modern Pipeline - NEW ChatMessage processing + log.debug("Routing to modern ChatMemoryAdapter pipeline"); + ChatMemoryAdapter adapter = (ChatMemoryAdapter) memoryOrAdapter; + adapter.getMessages(ActionListener.wrap(chatMessages -> { + // Use NEW ChatMessage-based processing (no conversion to Interaction) + processModernChatMessages( + chatMessages, + adapter.getConversationId(), + adapter, // Add ChatMemoryAdapter parameter + mlAgent, + params, + inputParams, + functionCalling, + listener + ); + }, e -> { + log.error("Failed to get chat history from modern memory adapter", e); + listener.onFailure(e); + })); - runAgent(mlAgent, params, listener, memory, memory.getConversationId(), functionCalling); - }, e -> { - log.error("Failed to get chat history", e); - listener.onFailure(e); - }), messageHistoryLimit); + } else { + listener.onFailure(new IllegalArgumentException("Unsupported memory type: " + memoryOrAdapter.getClass())); + } }, listener::onFailure)); } @@ -256,7 +238,7 @@ private void runAgent( MLAgent mlAgent, Map params, ActionListener listener, - Memory memory, + Object memoryOrSessionId, // Can be Memory object or String sessionId String sessionId, FunctionCalling functionCalling ) { @@ -267,7 +249,71 @@ private void runAgent( Map tools = new HashMap<>(); Map toolSpecMap = new HashMap<>(); createTools(toolFactories, params, allToolSpecs, tools, toolSpecMap, mlAgent); - runReAct(mlAgent.getLlm(), tools, toolSpecMap, params, memory, sessionId, mlAgent.getTenantId(), listener, functionCalling); + + // Route to correct runReAct method based on memory type + if (memoryOrSessionId instanceof Memory) { + // Legacy ConversationIndex path + Memory actualMemory = (Memory) memoryOrSessionId; + runReAct( + mlAgent.getLlm(), + tools, + toolSpecMap, + params, + actualMemory, + sessionId, + mlAgent.getTenantId(), + listener, + functionCalling + ); + } else { + // Modern agentic memory path - create ChatMemoryAdapter + String memoryContainerId = params.get("memory_container_id"); + String ownerId = params.get("owner_id"); + + log + .debug( + "Agentic memory path: memoryContainerId={}, ownerId={}, sessionId={}, allParams={}", + memoryContainerId, + ownerId, + sessionId, + params.keySet() + ); + + if (memoryContainerId != null && ownerId != null) { + AgenticMemoryAdapter chatMemoryAdapter = new AgenticMemoryAdapter(client, memoryContainerId, sessionId, ownerId); + runReAct( + mlAgent.getLlm(), + tools, + toolSpecMap, + params, + chatMemoryAdapter, + sessionId, + mlAgent.getTenantId(), + listener, + functionCalling + ); + } else { + // Missing required parameters for agentic memory + log + .error( + "Agentic memory requested but missing required parameters. " + + "memory_container_id: {}, owner_id: {}, available params: {}", + memoryContainerId, + ownerId, + params.keySet() + ); + listener + .onFailure( + new IllegalArgumentException( + "Agentic memory requires 'memory_container_id' and 'owner_id' parameters. " + + "Provided: memory_container_id=" + + memoryContainerId + + ", owner_id=" + + ownerId + ) + ); + } + } }; // Fetch MCP tools and handle both success and failure cases @@ -387,17 +433,32 @@ private void runReAct( .build() ); - saveTraceData( - conversationIndexMemory, - memory.getType(), - question, - thoughtResponse, - sessionId, - traceDisabled, - parentInteractionId, - traceNumber, - "LLM" - ); + // Save trace data using appropriate memory adapter + if (memory instanceof ConversationIndexMemory) { + saveTraceData( + (ConversationIndexMemory) memory, + memory.getType(), + question, + thoughtResponse, + sessionId, + traceDisabled, + parentInteractionId, + traceNumber, + "LLM" + ); + } else if (memory instanceof ChatMemoryAdapter) { + saveTraceData( + (ChatMemoryAdapter) memory, + memory.getType(), + question, + thoughtResponse, + sessionId, + traceDisabled, + parentInteractionId, + traceNumber, + "LLM" + ); + } if (nextStepListener == null) { handleMaxIterationsReached( @@ -466,17 +527,32 @@ private void runReAct( ); scratchpadBuilder.append(toolResponse).append("\n\n"); - saveTraceData( - conversationIndexMemory, - "ReAct", - lastActionInput.get(), - outputToOutputString(filteredOutput), - sessionId, - traceDisabled, - parentInteractionId, - traceNumber, - lastAction.get() - ); + // Save trace data using appropriate memory adapter + if (memory instanceof ConversationIndexMemory) { + saveTraceData( + (ConversationIndexMemory) memory, + "ReAct", + lastActionInput.get(), + outputToOutputString(filteredOutput), + sessionId, + traceDisabled, + parentInteractionId, + traceNumber, + lastAction.get() + ); + } else if (memory instanceof ChatMemoryAdapter) { + saveTraceData( + (ChatMemoryAdapter) memory, + "ReAct", + lastActionInput.get(), + outputToOutputString(filteredOutput), + sessionId, + traceDisabled, + parentInteractionId, + traceNumber, + lastAction.get() + ); + } StringSubstitutor substitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder), "${parameters.", "}"); newPrompt.set(substitutor.replace(finalPrompt)); @@ -581,7 +657,7 @@ private static void addToolOutputToAddtionalInfo( List list = (List) additionalInfo.get(toolOutputKey); list.add(outputString); } else { - additionalInfo.put(toolOutputKey, Lists.newArrayList(outputString)); + additionalInfo.put(toolOutputKey, new ArrayList<>(Collections.singletonList(outputString))); } } } @@ -705,6 +781,45 @@ public static void saveTraceData( } } + /** + * Overloaded saveTraceData method for ChatMemoryAdapter + */ + public static void saveTraceData( + ChatMemoryAdapter chatMemoryAdapter, + String memoryType, + String question, + String thoughtResponse, + String sessionId, + boolean traceDisabled, + String parentInteractionId, + AtomicInteger traceNumber, + String origin + ) { + if (chatMemoryAdapter != null && !traceDisabled) { + // Save trace data as tool invocation data in working memory + chatMemoryAdapter + .saveTraceData( + origin, // toolName (LLM, ReAct, etc.) + question, // toolInput + thoughtResponse, // toolOutput + parentInteractionId, // parentMemoryId + traceNumber.addAndGet(1), // traceNum + origin, // action + ActionListener + .wrap( + r -> log + .debug( + "Successfully saved trace data via ChatMemoryAdapter for session: {}, origin: {}", + sessionId, + origin + ), + e -> log + .warn("Failed to save trace data via ChatMemoryAdapter for session: {}, origin: {}", sessionId, origin, e) + ) + ); + } + } + private void sendFinalAnswer( String sessionId, ActionListener listener, @@ -759,6 +874,51 @@ private void sendFinalAnswer( } } + /** + * Overloaded sendFinalAnswer method for modern ChatMemoryAdapter pipeline + */ + private void sendFinalAnswer( + String sessionId, + ActionListener listener, + String question, + String parentInteractionId, + boolean verbose, + boolean traceDisabled, + List cotModelTensors, + ChatMemoryAdapter chatMemoryAdapter, // Modern parameter + AtomicInteger traceNumber, + Map additionalInfo, + String finalAnswer + ) { + // Send completion chunk for streaming + streamingWrapper.sendCompletionChunk(sessionId, parentInteractionId); + + if (chatMemoryAdapter != null) { + String copyOfFinalAnswer = finalAnswer; + ActionListener saveTraceListener = ActionListener.wrap(r -> { + // For ChatMemoryAdapter, we don't have separate updateInteraction + // The saveInteraction method handles the complete saving + streamingWrapper + .sendFinalResponse( + sessionId, + listener, + parentInteractionId, + verbose, + cotModelTensors, + additionalInfo, + copyOfFinalAnswer + ); + }, listener::onFailure); + + // Use ChatMemoryAdapter's saveInteraction method + chatMemoryAdapter + .saveInteraction(question, finalAnswer, parentInteractionId, traceNumber.addAndGet(1), "LLM", saveTraceListener); + } else { + streamingWrapper + .sendFinalResponse(sessionId, listener, parentInteractionId, verbose, cotModelTensors, additionalInfo, finalAnswer); + } + } + public static List createModelTensors(String sessionId, String parentInteractionId) { List cotModelTensors = new ArrayList<>(); @@ -863,7 +1023,7 @@ public static void returnFinalResponse( ModelTensor .builder() .name("response") - .dataAsMap(ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)) + .dataAsMap(Map.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)) .build() ) ); @@ -908,6 +1068,305 @@ private void handleMaxIterationsReached( cleanUpResource(tools); } + /** + * Overloaded handleMaxIterationsReached method for ChatMemoryAdapter + */ + private void handleMaxIterationsReached( + String sessionId, + ActionListener listener, + String question, + String parentInteractionId, + boolean verbose, + boolean traceDisabled, + List traceTensors, + ChatMemoryAdapter chatMemoryAdapter, // Modern parameter + AtomicInteger traceNumber, + Map additionalInfo, + AtomicReference lastThought, + int maxIterations, + Map tools + ) { + String incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) + ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) + : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); + sendFinalAnswer( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + chatMemoryAdapter, // Use ChatMemoryAdapter instead of ConversationIndexMemory + traceNumber, + additionalInfo, + incompleteResponse + ); + cleanUpResource(tools); + } + + /** + * Complete runReAct method for modern ChatMemoryAdapter pipeline + * This method handles the new memory types (agentic, remote, bedrock, etc.) + * + * Full implementation with complete ReAct loop, tool execution, trace saving, and streaming. + */ + private void runReAct( + LLMSpec llm, + Map tools, + Map toolSpecMap, + Map parameters, + ChatMemoryAdapter chatMemoryAdapter, // Modern parameter + String sessionId, + String tenantId, + ActionListener listener, + FunctionCalling functionCalling + ) { + Map tmpParameters = constructLLMParams(llm, parameters); + String prompt = constructLLMPrompt(tools, tmpParameters); + tmpParameters.put(PROMPT, prompt); + final String finalPrompt = prompt; + + String question = tmpParameters.get(MLAgentExecutor.QUESTION); + String parentInteractionId = tmpParameters.get(MLAgentExecutor.PARENT_INTERACTION_ID); + boolean verbose = Boolean.parseBoolean(tmpParameters.getOrDefault(VERBOSE, "false")); + boolean traceDisabled = tmpParameters.containsKey(DISABLE_TRACE) && Boolean.parseBoolean(tmpParameters.get(DISABLE_TRACE)); + + // Trace number + AtomicInteger traceNumber = new AtomicInteger(0); + + AtomicReference> lastLlmListener = new AtomicReference<>(); + AtomicReference lastThought = new AtomicReference<>(); + AtomicReference lastAction = new AtomicReference<>(); + AtomicReference lastActionInput = new AtomicReference<>(); + AtomicReference lastToolSelectionResponse = new AtomicReference<>(); + Map additionalInfo = new ConcurrentHashMap<>(); + Map lastToolParams = new ConcurrentHashMap<>(); + + StepListener firstListener = new StepListener(); + lastLlmListener.set(firstListener); + StepListener lastStepListener = firstListener; + + StringBuilder scratchpadBuilder = new StringBuilder(); + List interactions = new CopyOnWriteArrayList<>(); + + StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); + AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); + tmpParameters.put(PROMPT, newPrompt.get()); + List traceTensors = createModelTensors(sessionId, parentInteractionId); + int maxIterations = Integer.parseInt(tmpParameters.getOrDefault(MAX_ITERATION, DEFAULT_MAX_ITERATIONS)); + + for (int i = 0; i < maxIterations; i++) { + int finalI = i; + StepListener nextStepListener = (i == maxIterations - 1) ? null : new StepListener<>(); + + lastStepListener.whenComplete(output -> { + StringBuilder sessionMsgAnswerBuilder = new StringBuilder(); + if (finalI % 2 == 0) { + MLTaskResponse llmResponse = (MLTaskResponse) output; + ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput(); + List llmResponsePatterns = gson.fromJson(tmpParameters.get("llm_response_pattern"), List.class); + Map modelOutput = parseLLMOutput( + parameters, + tmpModelTensorOutput, + llmResponsePatterns, + tools.keySet(), + interactions, + functionCalling + ); + + streamingWrapper.fixInteractionRole(interactions); + String thought = String.valueOf(modelOutput.get(THOUGHT)); + String toolCallId = String.valueOf(modelOutput.get("tool_call_id")); + String action = String.valueOf(modelOutput.get(ACTION)); + String actionInput = String.valueOf(modelOutput.get(ACTION_INPUT)); + String thoughtResponse = modelOutput.get(THOUGHT_RESPONSE); + String finalAnswer = modelOutput.get(FINAL_ANSWER); + + if (finalAnswer != null) { + finalAnswer = finalAnswer.trim(); + sendFinalAnswer( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + chatMemoryAdapter, // Use ChatMemoryAdapter instead of ConversationIndexMemory + traceNumber, + additionalInfo, + finalAnswer + ); + cleanUpResource(tools); + return; + } + + sessionMsgAnswerBuilder.append(thought); + lastThought.set(thought); + lastAction.set(action); + lastActionInput.set(actionInput); + lastToolSelectionResponse.set(thoughtResponse); + + traceTensors + .add( + ModelTensors + .builder() + .mlModelTensors(List.of(ModelTensor.builder().name("response").result(thoughtResponse).build())) + .build() + ); + + // Save trace data using ChatMemoryAdapter + saveTraceData( + chatMemoryAdapter, + "ChatMemoryAdapter", // Memory type for modern pipeline + question, + thoughtResponse, + sessionId, + traceDisabled, + parentInteractionId, + traceNumber, + "LLM" + ); + + if (nextStepListener == null) { + handleMaxIterationsReached( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + chatMemoryAdapter, // Use ChatMemoryAdapter + traceNumber, + additionalInfo, + lastThought, + maxIterations, + tools + ); + return; + } + + if (tools.containsKey(action)) { + Map toolParams = constructToolParams( + tools, + toolSpecMap, + question, + lastActionInput, + action, + actionInput + ); + lastToolParams.clear(); + lastToolParams.putAll(toolParams); + runTool( + tools, + toolSpecMap, + tmpParameters, + (ActionListener) nextStepListener, + action, + actionInput, + toolParams, + interactions, + toolCallId, + functionCalling + ); + + } else { + String res = String.format(Locale.ROOT, "Failed to run the tool %s which is unsupported.", action); + StringSubstitutor substitutor = new StringSubstitutor( + Map.of(SCRATCHPAD, scratchpadBuilder.toString()), + "${parameters.", + "}" + ); + newPrompt.set(substitutor.replace(finalPrompt)); + tmpParameters.put(PROMPT, newPrompt.get()); + ((ActionListener) nextStepListener).onResponse(res); + } + } else { + Object filteredOutput = filterToolOutput(lastToolParams, output); + addToolOutputToAddtionalInfo(toolSpecMap, lastAction, additionalInfo, filteredOutput); + + String toolResponse = constructToolResponse( + tmpParameters, + lastAction, + lastActionInput, + lastToolSelectionResponse, + filteredOutput + ); + scratchpadBuilder.append(toolResponse).append("\n\n"); + + // Save trace data for tool response using ChatMemoryAdapter + saveTraceData( + chatMemoryAdapter, + "ReAct", + lastActionInput.get(), + outputToOutputString(filteredOutput), + sessionId, + traceDisabled, + parentInteractionId, + traceNumber, + lastAction.get() + ); + + StringSubstitutor substitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder), "${parameters.", "}"); + newPrompt.set(substitutor.replace(finalPrompt)); + tmpParameters.put(PROMPT, newPrompt.get()); + if (!interactions.isEmpty()) { + tmpParameters.put(INTERACTIONS, ", " + String.join(", ", interactions)); + } + + sessionMsgAnswerBuilder.append(outputToOutputString(filteredOutput)); + streamingWrapper.sendToolResponse(outputToOutputString(output), sessionId, parentInteractionId); + traceTensors + .add( + ModelTensors + .builder() + .mlModelTensors( + Collections + .singletonList( + ModelTensor.builder().name("response").result(sessionMsgAnswerBuilder.toString()).build() + ) + ) + .build() + ); + + if (finalI == maxIterations - 1) { + handleMaxIterationsReached( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + chatMemoryAdapter, // Use ChatMemoryAdapter + traceNumber, + additionalInfo, + lastThought, + maxIterations, + tools + ); + return; + } + + if (nextStepListener != null) { + ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); + streamingWrapper.executeRequest(request, (ActionListener) nextStepListener); + } + } + }, listener::onFailure); + + if (i == 0) { + ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); + streamingWrapper.executeRequest(request, firstListener); + } + if (nextStepListener != null) { + lastStepListener = nextStepListener; + } + } + } + private void saveMessage( ConversationIndexMemory memory, String question, @@ -933,4 +1392,171 @@ private void saveMessage( memory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM", listener); } } + + /** + * Process modern ChatMessage format and build chat history using enhanced templates + */ + private void processModernChatMessages( + List chatMessages, + String sessionId, + ChatMemoryAdapter chatMemoryAdapter, // Add ChatMemoryAdapter parameter + MLAgent mlAgent, + Map params, + Map inputParams, + FunctionCalling functionCalling, + ActionListener listener + ) { + // Use new enhanced template system for ChatMessage + SimpleChatHistoryTemplateEngine templateEngine = new SimpleChatHistoryTemplateEngine(); + + // Filter out empty content messages (in-flight requests) + List validMessages = chatMessages + .stream() + .filter(msg -> msg.getContent() != null && !msg.getContent().trim().isEmpty()) + .toList(); + + if (!validMessages.isEmpty()) { + // Build chat history using enhanced template system + String chatHistory = templateEngine.buildSimpleChatHistory(validMessages); + params.put(CHAT_HISTORY, chatHistory); + inputParams.put(CHAT_HISTORY, chatHistory); + } + + // Run agent with modern processing (no Memory object needed) + runAgent(mlAgent, params, listener, sessionId, sessionId, functionCalling); + } + + /** + * Process legacy interactions (ConversationIndex) and build chat history, then run the agent + */ + private void processLegacyInteractions( + List interactions, + String sessionId, + ConversationIndexMemory memory, + MLAgent mlAgent, + Map params, + Map inputParams, + String chatHistoryPrefix, + String chatHistoryQuestionTemplate, + String chatHistoryResponseTemplate, + FunctionCalling functionCalling, + ActionListener listener + ) { + List messageList = new ArrayList<>(); + for (Interaction next : interactions) { + String question = next.getInput(); + String response = next.getResponse(); + // As we store the conversation with empty response first and then update when have final answer, + // filter out those in-flight requests when run in parallel + if (Strings.isNullOrEmpty(response)) { + continue; + } + messageList + .add( + ConversationIndexMessage + .conversationIndexMessageBuilder() + .sessionId(sessionId) + .question(question) + .response(response) + .build() + ); + } + + if (!messageList.isEmpty()) { + if (chatHistoryQuestionTemplate == null) { + StringBuilder chatHistoryBuilder = new StringBuilder(); + chatHistoryBuilder.append(chatHistoryPrefix); + for (Message message : messageList) { + chatHistoryBuilder.append(message.toString()).append("\n"); + } + params.put(CHAT_HISTORY, chatHistoryBuilder.toString()); + + // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate + inputParams.put(CHAT_HISTORY, chatHistoryBuilder.toString()); + } else { + List chatHistory = new ArrayList<>(); + for (Message message : messageList) { + Map messageParams = new HashMap<>(); + messageParams.put("question", processTextDoc(((ConversationIndexMessage) message).getQuestion())); + + StringSubstitutor substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}"); + String chatQuestionMessage = substitutor.replace(chatHistoryQuestionTemplate); + chatHistory.add(chatQuestionMessage); + + messageParams.clear(); + messageParams.put("response", processTextDoc(((ConversationIndexMessage) message).getResponse())); + substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}"); + String chatResponseMessage = substitutor.replace(chatHistoryResponseTemplate); + chatHistory.add(chatResponseMessage); + } + params.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); + params.put(NEW_CHAT_HISTORY, String.join(", ", chatHistory) + ", "); + + // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate + inputParams.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); + } + } + + runAgent(mlAgent, params, listener, memory != null ? memory : sessionId, sessionId, functionCalling); + } + + /** + * Create appropriate memory adapter based on memory type + */ + private void createMemoryAdapter(MLAgent mlAgent, Map params, ActionListener listener) { + String memoryType = mlAgent.getMemory().getType(); + MLMemoryType type = MLMemoryType.from(memoryType); + + log.debug("MLChatAgentRunner.createMemoryAdapter: memoryType={}, params={}", memoryType, params.keySet()); + + switch (type) { + case CONVERSATION_INDEX: + // Keep existing flow - no adapter needed (zero risk approach) + ConversationIndexMemory.Factory factory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType); + String title = params.get(MLAgentExecutor.QUESTION); + String memoryId = params.get(MLAgentExecutor.MEMORY_ID); + String appType = mlAgent.getAppType(); + + factory.create(title, memoryId, appType, ActionListener.wrap(conversationMemory -> { + // Return ConversationIndexMemory directly - no conversion needed + listener.onResponse(conversationMemory); + }, listener::onFailure)); + break; + + case AGENTIC_MEMORY: + // New agentic memory path + String memoryContainerId = params.get("memory_container_id"); + String sessionId = params.get("session_id"); + String ownerId = params.get("owner_id"); // From user context + + log.debug("AGENTIC_MEMORY path: memoryContainerId={}, sessionId={}, ownerId={}", memoryContainerId, sessionId, ownerId); + + // Validate required parameters + if (memoryContainerId == null) { + log.error("AGENTIC_MEMORY validation failed: memory_container_id is null. Available params: {}", params.keySet()); + listener.onFailure(new IllegalArgumentException("memory_container_id is required for agentic memory")); + return; + } + + // Session management: same pattern as ConversationIndex + if (Strings.isEmpty(sessionId)) { + // CREATE NEW: Generate new session ID if not provided + sessionId = UUID.randomUUID().toString(); + log.debug("Created new agentic memory session: {}", sessionId); + } + // USE EXISTING: If sessionId provided, use it directly + + AgenticMemoryAdapter adapter = new AgenticMemoryAdapter(client, memoryContainerId, sessionId, ownerId); + log.debug("Created AgenticMemoryAdapter successfully: memoryContainerId={}, sessionId={}", memoryContainerId, sessionId); + listener.onResponse(adapter); + break; + + default: + // Future memory types will be added here: + // - REMOTE_AGENTIC_MEMORY: RemoteAgenticMemoryAdapter (similar format, different location) + // - BEDROCK_AGENTCORE: BedrockAgentCoreMemoryAdapter (format adapted in adapter) + // All future types will use modern ChatMessage pipeline + listener.onFailure(new IllegalArgumentException("Unsupported memory type: " + memoryType)); + } + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/SimpleChatHistoryTemplateEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/SimpleChatHistoryTemplateEngine.java new file mode 100644 index 0000000000..33399208e4 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/SimpleChatHistoryTemplateEngine.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import java.util.List; +import java.util.Map; + +/** + * Simple implementation of ChatHistoryTemplateEngine. + * Provides basic template functionality for ChatMessage formatting. + * + * This is a simplified implementation that supports: + * - Role-based message formatting + * - Basic placeholder replacement + * - Content type awareness + * + * Future versions can implement more advanced template engines (Handlebars, etc.) + */ +public class SimpleChatHistoryTemplateEngine implements ChatHistoryTemplateEngine { + + @Override + public String buildChatHistory(List messages, String template, Map context) { + if (messages == null || messages.isEmpty()) { + return ""; + } + + // For now, use a simple approach - build chat history with role-based formatting + StringBuilder chatHistory = new StringBuilder(); + + for (ChatMessage message : messages) { + String formattedMessage = formatMessage(message); + chatHistory.append(formattedMessage).append("\n"); + } + + return chatHistory.toString().trim(); + } + + /** + * Format a single ChatMessage based on its role and content type + */ + private String formatMessage(ChatMessage message) { + String role = message.getRole(); + String content = message.getContent(); + String contentType = message.getContentType(); + + // Role-based formatting + String prefix = switch (role) { + case "user" -> "Human: "; + case "assistant" -> "Assistant: "; + case "system" -> "System: "; + case "tool" -> "Tool Result: "; + default -> role + ": "; + }; + + // Content type awareness + String formattedContent = content; + if ("image".equals(contentType)) { + formattedContent = "[Image: " + content + "]"; + } else if ("tool_result".equals(contentType)) { + Map metadata = message.getMetadata(); + if (metadata != null && metadata.containsKey("tool_name")) { + formattedContent = "Tool " + metadata.get("tool_name") + ": " + content; + } + } else if ("context".equals(contentType)) { + // Context messages (like from long-term memory) get special formatting + formattedContent = "[Context] " + content; + } + + return prefix + formattedContent; + } + + /** + * Build chat history using default simple formatting + */ + public String buildSimpleChatHistory(List messages) { + return buildChatHistory(messages, getDefaultTemplate(), Map.of()); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMemoryAdapter.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMemoryAdapter.java new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMessage.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMessage.java new file mode 100644 index 0000000000..0111642129 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMessage.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import java.util.Map; + +/** + * Interface for chat messages in the unified memory system. + * Provides a common abstraction for messages across different memory implementations. + */ +public interface ChatMessage { + /** + * Get the role of the message sender + * @return role such as "user", "assistant", "system" + */ + String getRole(); + + /** + * Get the content of the message + * @return message content + */ + String getContent(); + + /** + * Get additional metadata associated with the message + * @return metadata map + */ + Map getMetadata(); +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapterTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapterTest.java new file mode 100644 index 0000000000..f69f7d71e2 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapterTest.java @@ -0,0 +1,167 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.core.action.ActionListener; +import org.opensearch.transport.client.Client; + +/** + * Unit tests for AgenticMemoryAdapter. + */ +public class AgenticMemoryAdapterTest { + + @Mock + private Client client; + + private AgenticMemoryAdapter adapter; + private final String memoryContainerId = "test-memory-container"; + private final String sessionId = "test-session"; + private final String ownerId = "test-owner"; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + adapter = new AgenticMemoryAdapter(client, memoryContainerId, sessionId, ownerId); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorWithNullClient() { + new AgenticMemoryAdapter(null, memoryContainerId, sessionId, ownerId); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorWithNullMemoryContainerId() { + new AgenticMemoryAdapter(client, null, sessionId, ownerId); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorWithEmptyMemoryContainerId() { + new AgenticMemoryAdapter(client, "", sessionId, ownerId); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorWithNullSessionId() { + new AgenticMemoryAdapter(client, memoryContainerId, null, ownerId); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorWithEmptySessionId() { + new AgenticMemoryAdapter(client, memoryContainerId, "", ownerId); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorWithNullOwnerId() { + new AgenticMemoryAdapter(client, memoryContainerId, sessionId, null); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorWithEmptyOwnerId() { + new AgenticMemoryAdapter(client, memoryContainerId, sessionId, ""); + } + + @Test + public void testGetConversationId() { + assertEquals(sessionId, adapter.getConversationId()); + } + + @Test + public void testGetMemoryContainerId() { + assertEquals(memoryContainerId, adapter.getMemoryContainerId()); + } + + @Test(expected = IllegalArgumentException.class) + public void testSaveTraceDataWithNullToolName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + adapter.saveTraceData(null, "input", "output", "parent-id", 1, "action", listener); + } + + @Test(expected = IllegalArgumentException.class) + public void testSaveTraceDataWithEmptyToolName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + adapter.saveTraceData("", "input", "output", "parent-id", 1, "action", listener); + } + + @Test(expected = IllegalArgumentException.class) + public void testSaveTraceDataWithNullListener() { + adapter.saveTraceData("tool", "input", "output", "parent-id", 1, "action", null); + } + + @Test(expected = IllegalArgumentException.class) + public void testSaveInteractionWithNullListener() { + adapter.saveInteraction("question", "response", null, 1, "action", null); + } + + @Test + public void testUpdateInteractionWithNullInteractionId() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + Map updateFields = new HashMap<>(); + updateFields.put("response", "updated response"); + + adapter.updateInteraction(null, updateFields, listener); + + // Verify that onFailure was called with IllegalArgumentException + verify(listener).onFailure(any(IllegalArgumentException.class)); + } + + @Test + public void testUpdateInteractionWithEmptyInteractionId() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + Map updateFields = new HashMap<>(); + updateFields.put("response", "updated response"); + + adapter.updateInteraction("", updateFields, listener); + + // Verify that onFailure was called with IllegalArgumentException + verify(listener).onFailure(any(IllegalArgumentException.class)); + } + + @Test + public void testUpdateInteractionWithNullUpdateFields() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + adapter.updateInteraction("interaction-id", null, listener); + + // Verify that onFailure was called with IllegalArgumentException + verify(listener).onFailure(any(IllegalArgumentException.class)); + } + + @Test + public void testUpdateInteractionWithEmptyUpdateFields() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + Map updateFields = new HashMap<>(); + + adapter.updateInteraction("interaction-id", updateFields, listener); + + // Verify that onFailure was called with IllegalArgumentException + verify(listener).onFailure(any(IllegalArgumentException.class)); + } + + @Test(expected = IllegalArgumentException.class) + public void testUpdateInteractionWithNullListener() { + Map updateFields = new HashMap<>(); + updateFields.put("response", "updated response"); + + adapter.updateInteraction("interaction-id", updateFields, null); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapterTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapterTest.java new file mode 100644 index 0000000000..990598dd42 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapterTest.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; + +import java.util.List; + +import org.junit.Test; +import org.opensearch.core.action.ActionListener; + +/** + * Unit tests for ChatMemoryAdapter interface default methods. + */ +public class ChatMemoryAdapterTest { + + /** + * Test implementation of ChatMemoryAdapter for testing default methods + */ + private static class TestChatMemoryAdapter implements ChatMemoryAdapter { + @Override + public void getMessages(ActionListener> listener) { + // Test implementation - not used in these tests + } + + @Override + public String getConversationId() { + return "test-conversation-id"; + } + + @Override + public String getMemoryContainerId() { + return "test-memory-container-id"; + } + } + + @Test + public void testSaveInteractionDefaultImplementation() { + TestChatMemoryAdapter adapter = new TestChatMemoryAdapter(); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + // Test that default implementation throws UnsupportedOperationException + adapter.saveInteraction("question", "response", "parentId", 1, "action", listener); + + // Verify that onFailure was called with UnsupportedOperationException + org.mockito.Mockito + .verify(listener) + .onFailure( + org.mockito.ArgumentMatchers + .argThat( + exception -> exception instanceof UnsupportedOperationException + && "Save not implemented".equals(exception.getMessage()) + ) + ); + } + + @Test + public void testUpdateInteractionDefaultImplementation() { + TestChatMemoryAdapter adapter = new TestChatMemoryAdapter(); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + // Test that default implementation throws UnsupportedOperationException + adapter.updateInteraction("interactionId", java.util.Map.of("key", "value"), listener); + + // Verify that onFailure was called with UnsupportedOperationException + org.mockito.Mockito + .verify(listener) + .onFailure( + org.mockito.ArgumentMatchers + .argThat( + exception -> exception instanceof UnsupportedOperationException + && "Update interaction not implemented".equals(exception.getMessage()) + ) + ); + } + + @Test + public void testSaveTraceDataDefaultImplementation() { + TestChatMemoryAdapter adapter = new TestChatMemoryAdapter(); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + // Test that default implementation throws UnsupportedOperationException + adapter.saveTraceData("toolName", "input", "output", "parentId", 1, "action", listener); + + // Verify that onFailure was called with UnsupportedOperationException + org.mockito.Mockito + .verify(listener) + .onFailure( + org.mockito.ArgumentMatchers + .argThat( + exception -> exception instanceof UnsupportedOperationException + && "Save trace data not implemented".equals(exception.getMessage()) + ) + ); + } + + @Test + public void testGetConversationId() { + TestChatMemoryAdapter adapter = new TestChatMemoryAdapter(); + assertEquals("test-conversation-id", adapter.getConversationId()); + } + + @Test + public void testGetMemoryContainerId() { + TestChatMemoryAdapter adapter = new TestChatMemoryAdapter(); + assertEquals("test-memory-container-id", adapter.getMemoryContainerId()); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index f6c3e3618e..57c472a4c4 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1171,4 +1171,136 @@ public void testConstructLLMParams_DefaultValues() { Assert.assertTrue(result.containsKey(AgentUtils.RESPONSE_FORMAT_INSTRUCTION)); Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE)); } + + @Test + public void testCreateMemoryAdapter_ConversationIndex() { + // Test that ConversationIndex memory type returns ConversationIndexMemory + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + MLMemorySpec memorySpec = MLMemorySpec.builder().type("conversation_index").build(); + MLAgent mlAgent = MLAgent + .builder() + .name("test_agent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(memorySpec) + .build(); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.QUESTION, "test question"); + params.put(MLAgentExecutor.MEMORY_ID, "test_memory_id"); + + // Mock the memory factory + when(memoryMap.get("conversation_index")).thenReturn(memoryFactory); + + // Create a mock ConversationIndexMemory + org.opensearch.ml.engine.memory.ConversationIndexMemory mockMemory = Mockito + .mock(org.opensearch.ml.engine.memory.ConversationIndexMemory.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockMemory); + return null; + }).when(memoryFactory).create(anyString(), anyString(), anyString(), any()); + + // Test the createMemoryAdapter method + ActionListener testListener = new ActionListener() { + @Override + public void onResponse(Object result) { + // Verify that we get back a ConversationIndexMemory + assertTrue("Expected ConversationIndexMemory", result instanceof org.opensearch.ml.engine.memory.ConversationIndexMemory); + assertEquals("Memory should be the mocked instance", mockMemory, result); + } + + @Override + public void onFailure(Exception e) { + Assert.fail("Should not fail: " + e.getMessage()); + } + }; + + // This would normally be a private method call, but for testing we can verify the logic + // by checking that the correct memory type handling works through the public run method + // The actual test would need to be done through integration testing + } + + @Test + public void testCreateMemoryAdapter_AgenticMemory() { + // Test that agentic memory type returns AgenticMemoryAdapter + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + MLMemorySpec memorySpec = MLMemorySpec.builder().type("agentic_memory").build(); + MLAgent mlAgent = MLAgent + .builder() + .name("test_agent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(memorySpec) + .build(); + + Map params = new HashMap<>(); + params.put("memory_container_id", "test_container_id"); + params.put("session_id", "test_session_id"); + params.put("owner_id", "test_owner_id"); + + // This test verifies that the agentic memory path would be taken + // Full integration testing would require mocking the agentic memory services + assertNotNull("MLAgent should be created successfully", mlAgent); + assertEquals("Memory type should be agentic_memory", "agentic_memory", mlAgent.getMemory().getType()); + } + + @Test + public void testEnhancedChatMessage() { + // Test the enhanced ChatMessage format + ChatMessage userMessage = ChatMessage + .builder() + .id("msg_1") + .timestamp(java.time.Instant.now()) + .sessionId("session_123") + .role("user") + .content("Hello, how are you?") + .contentType("text") + .origin("agentic_memory") + .metadata(Map.of("confidence", 0.95)) + .build(); + + ChatMessage assistantMessage = ChatMessage + .builder() + .id("msg_2") + .timestamp(java.time.Instant.now()) + .sessionId("session_123") + .role("assistant") + .content("I'm doing well, thank you!") + .contentType("text") + .origin("agentic_memory") + .metadata(Map.of("confidence", 0.98)) + .build(); + + // Verify the enhanced ChatMessage structure + assertEquals("user", userMessage.getRole()); + assertEquals("text", userMessage.getContentType()); + assertEquals("agentic_memory", userMessage.getOrigin()); + assertNotNull(userMessage.getMetadata()); + assertEquals(0.95, userMessage.getMetadata().get("confidence")); + + assertEquals("assistant", assistantMessage.getRole()); + assertEquals("I'm doing well, thank you!", assistantMessage.getContent()); + } + + @Test + public void testSimpleChatHistoryTemplateEngine() { + // Test the new template engine + SimpleChatHistoryTemplateEngine templateEngine = new SimpleChatHistoryTemplateEngine(); + + List messages = List + .of( + ChatMessage.builder().role("user").content("What's the weather?").contentType("text").build(), + ChatMessage.builder().role("assistant").content("It's sunny today!").contentType("text").build(), + ChatMessage.builder().role("system").content("Weather data retrieved from API").contentType("context").build() + ); + + String chatHistory = templateEngine.buildSimpleChatHistory(messages); + + assertNotNull("Chat history should not be null", chatHistory); + assertTrue("Should contain user message", chatHistory.contains("Human: What's the weather?")); + assertTrue("Should contain assistant message", chatHistory.contains("Assistant: It's sunny today!")); + assertTrue("Should contain system context", chatHistory.contains("[Context] Weather data retrieved from API")); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java index 4c3f6217af..4e819f103b 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java @@ -379,8 +379,9 @@ public SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSourceBuil public SearchSourceBuilder addOwnerIdFilter(User user, SearchSourceBuilder searchSourceBuilder) { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - boolQueryBuilder.should(QueryBuilders.termsQuery(OWNER_ID_FIELD, user.getName())); - + if (user != null) { + boolQueryBuilder.should(QueryBuilders.termsQuery(OWNER_ID_FIELD, user.getName())); + } return applyFilterToSearchSource(searchSourceBuilder, boolQueryBuilder); }