From 13142960ff2bcd78f42576c7e61d9f4c6a610735 Mon Sep 17 00:00:00 2001 From: Abdul Muneer Kolarkunnu Date: Wed, 15 Oct 2025 16:15:23 +0530 Subject: [PATCH] [FEATURE] Support MCP for Flow and Conversational Flow Agent Resolves #3807 Signed-off-by: Abdul Muneer Kolarkunnu --- .../MLConversationalFlowAgentRunner.java | 139 ++++++++++-------- .../algorithms/agent/MLFlowAgentRunner.java | 24 ++- 2 files changed, 96 insertions(+), 67 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java index 54d847b929..f821b0cab2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java @@ -16,6 +16,7 @@ import static org.opensearch.ml.common.utils.ToolUtils.parseResponse; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTool; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMcpToolSpecs; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; @@ -26,6 +27,7 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; import org.opensearch.action.StepListener; import org.opensearch.action.update.UpdateResponse; @@ -158,38 +160,67 @@ private void runAgent( String parentInteractionId ) { - StepListener firstStepListener = null; - Tool firstTool = null; - List flowAgentOutput = new ArrayList<>(); - Map firstToolExecuteParams = null; - StepListener previousStepListener = null; - Map additionalInfo = new ConcurrentHashMap<>(); List toolSpecs = getMlToolSpecs(mlAgent, params); - if (toolSpecs == null || toolSpecs.isEmpty()) { - listener.onFailure(new IllegalArgumentException("no tool configured")); - return; - } - AtomicInteger traceNumber = new AtomicInteger(0); - if (memory != null) { - flowAgentOutput.add(ModelTensor.builder().name(MEMORY_ID).result(memoryId).build()); - flowAgentOutput.add(ModelTensor.builder().name(PARENT_INTERACTION_ID_FIELD).result(parentInteractionId).build()); - } + // Create a common method to handle both success and failure cases + Consumer> processTools = (allToolSpecs) -> { + StepListener firstStepListener = null; + Tool firstTool = null; + List flowAgentOutput = new ArrayList<>(); + Map firstToolExecuteParams = null; + StepListener previousStepListener = null; + Map additionalInfo = new ConcurrentHashMap<>(); + if (toolSpecs == null || toolSpecs.isEmpty()) { + listener.onFailure(new IllegalArgumentException("no tool configured")); + return; + } + AtomicInteger traceNumber = new AtomicInteger(0); + if (memory != null) { + flowAgentOutput.add(ModelTensor.builder().name(MEMORY_ID).result(memoryId).build()); + flowAgentOutput.add(ModelTensor.builder().name(PARENT_INTERACTION_ID_FIELD).result(parentInteractionId).build()); + } - MLMemorySpec memorySpec = mlAgent.getMemory(); - for (int i = 0; i <= toolSpecs.size(); i++) { - if (i == 0) { - MLToolSpec toolSpec = toolSpecs.get(i); - firstToolExecuteParams = ToolUtils.buildToolParameters(params, toolSpec, mlAgent.getTenantId()); - Tool tool = createTool(toolFactories, firstToolExecuteParams, toolSpec); - firstStepListener = new StepListener<>(); - previousStepListener = firstStepListener; - firstTool = tool; - } else { - MLToolSpec previousToolSpec = toolSpecs.get(i - 1); - StepListener nextStepListener = new StepListener<>(); - int finalI = i; - previousStepListener.whenComplete(output -> { + MLMemorySpec memorySpec = mlAgent.getMemory(); + for (int i = 0; i <= toolSpecs.size(); i++) { + if (i == 0) { + MLToolSpec toolSpec = toolSpecs.get(i); + firstToolExecuteParams = ToolUtils.buildToolParameters(params, toolSpec, mlAgent.getTenantId()); + Tool tool = createTool(toolFactories, firstToolExecuteParams, toolSpec); + firstStepListener = new StepListener<>(); + previousStepListener = firstStepListener; + firstTool = tool; + } else { + MLToolSpec previousToolSpec = toolSpecs.get(i - 1); + StepListener nextStepListener = new StepListener<>(); + int finalI = i; + previousStepListener.whenComplete(output -> { + processOutput( + params, + listener, + memory, + memoryId, + parentInteractionId, + toolSpecs, + flowAgentOutput, + additionalInfo, + traceNumber, + memorySpec, + previousToolSpec, + finalI, + output, + mlAgent.getTenantId(), + nextStepListener + ); + }, e -> { + log.error("Failed to run flow agent", e); + listener.onFailure(e); + }); + previousStepListener = nextStepListener; + } + } + if (toolSpecs.size() == 1) { + firstTool.run(firstToolExecuteParams, ActionListener.wrap(output -> { + MLToolSpec toolSpec = toolSpecs.get(0); processOutput( params, listener, @@ -201,43 +232,27 @@ private void runAgent( additionalInfo, traceNumber, memorySpec, - previousToolSpec, - finalI, + toolSpec, + 1, output, mlAgent.getTenantId(), - nextStepListener + null ); - }, e -> { - log.error("Failed to run flow agent", e); - listener.onFailure(e); - }); - previousStepListener = nextStepListener; + }, e -> { listener.onFailure(e); })); + } else { + firstTool.run(firstToolExecuteParams, firstStepListener); } - } - if (toolSpecs.size() == 1) { - firstTool.run(firstToolExecuteParams, ActionListener.wrap(output -> { - MLToolSpec toolSpec = toolSpecs.get(0); - processOutput( - params, - listener, - memory, - memoryId, - parentInteractionId, - toolSpecs, - flowAgentOutput, - additionalInfo, - traceNumber, - memorySpec, - toolSpec, - 1, - output, - mlAgent.getTenantId(), - null - ); - }, e -> { listener.onFailure(e); })); - } else { - firstTool.run(firstToolExecuteParams, firstStepListener); - } + }; + + // Fetch MCP tools and handle both success and failure cases + getMcpToolSpecs(mlAgent, client, sdkClient, encryptor, ActionListener.wrap(mcpTools -> { + toolSpecs.addAll(mcpTools); + processTools.accept(toolSpecs); + }, e -> { + log.warn("Failed to get MCP tools, continuing with base tools only", e); + processTools.accept(toolSpecs); + })); + } @SuppressWarnings("removal") diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java index 30725a8c47..6a13958e8b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java @@ -11,12 +11,14 @@ import static org.opensearch.ml.common.utils.ToolUtils.getToolName; import static org.opensearch.ml.common.utils.ToolUtils.parseResponse; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTool; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMcpToolSpecs; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; import org.opensearch.action.StepListener; import org.opensearch.action.update.UpdateResponse; @@ -149,12 +151,24 @@ public void run(MLAgent mlAgent, Map params, ActionListener executeParams = ToolUtils.buildToolParameters(params, toolSpec, mlAgent.getTenantId()); - Tool tool = createTool(toolFactories, executeParams, toolSpec); - if (finalI < toolSpecs.size()) { - tool.run(executeParams, nextStepListener); - } + // Create a common method to handle both success and failure cases + Consumer> processTools = (allToolSpecs) -> { + Map executeParams = ToolUtils.buildToolParameters(params, toolSpec, mlAgent.getTenantId()); + Tool tool = createTool(toolFactories, executeParams, toolSpec); + if (finalI < toolSpecs.size()) { + tool.run(executeParams, nextStepListener); + } + }; + + // Fetch MCP tools and handle both success and failure cases + getMcpToolSpecs(mlAgent, client, sdkClient, encryptor, ActionListener.wrap(mcpTools -> { + toolSpecs.addAll(mcpTools); + processTools.accept(toolSpecs); + }, e -> { + log.warn("Failed to get MCP tools, continuing with base tools only", e); + processTools.accept(toolSpecs); + })); }, e -> { log.error("Failed to run flow agent", e); listener.onFailure(e);