Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import static org.opensearch.ml.common.utils.ToolUtils.parseResponse;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTool;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMcpToolSpecs;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION;
Expand All @@ -26,6 +27,7 @@
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;

import org.opensearch.action.StepListener;
import org.opensearch.action.update.UpdateResponse;
Expand Down Expand Up @@ -158,38 +160,67 @@ private void runAgent(
String parentInteractionId
) {

StepListener<Object> firstStepListener = null;
Tool firstTool = null;
List<ModelTensor> flowAgentOutput = new ArrayList<>();
Map<String, String> firstToolExecuteParams = null;
StepListener<Object> previousStepListener = null;
Map<String, Object> additionalInfo = new ConcurrentHashMap<>();
List<MLToolSpec> toolSpecs = getMlToolSpecs(mlAgent, params);

if (toolSpecs == null || toolSpecs.isEmpty()) {
listener.onFailure(new IllegalArgumentException("no tool configured"));
return;
}
AtomicInteger traceNumber = new AtomicInteger(0);
if (memory != null) {
flowAgentOutput.add(ModelTensor.builder().name(MEMORY_ID).result(memoryId).build());
flowAgentOutput.add(ModelTensor.builder().name(PARENT_INTERACTION_ID_FIELD).result(parentInteractionId).build());
}
// Create a common method to handle both success and failure cases
Consumer<List<MLToolSpec>> processTools = (allToolSpecs) -> {
StepListener<Object> firstStepListener = null;
Tool firstTool = null;
List<ModelTensor> flowAgentOutput = new ArrayList<>();
Map<String, String> firstToolExecuteParams = null;
StepListener<Object> previousStepListener = null;
Map<String, Object> additionalInfo = new ConcurrentHashMap<>();
if (toolSpecs == null || toolSpecs.isEmpty()) {
listener.onFailure(new IllegalArgumentException("no tool configured"));
return;
}
AtomicInteger traceNumber = new AtomicInteger(0);
if (memory != null) {
flowAgentOutput.add(ModelTensor.builder().name(MEMORY_ID).result(memoryId).build());
flowAgentOutput.add(ModelTensor.builder().name(PARENT_INTERACTION_ID_FIELD).result(parentInteractionId).build());
}

MLMemorySpec memorySpec = mlAgent.getMemory();
for (int i = 0; i <= toolSpecs.size(); i++) {
if (i == 0) {
MLToolSpec toolSpec = toolSpecs.get(i);
firstToolExecuteParams = ToolUtils.buildToolParameters(params, toolSpec, mlAgent.getTenantId());
Tool tool = createTool(toolFactories, firstToolExecuteParams, toolSpec);
firstStepListener = new StepListener<>();
previousStepListener = firstStepListener;
firstTool = tool;
} else {
MLToolSpec previousToolSpec = toolSpecs.get(i - 1);
StepListener<Object> nextStepListener = new StepListener<>();
int finalI = i;
previousStepListener.whenComplete(output -> {
MLMemorySpec memorySpec = mlAgent.getMemory();
for (int i = 0; i <= toolSpecs.size(); i++) {
if (i == 0) {
MLToolSpec toolSpec = toolSpecs.get(i);
firstToolExecuteParams = ToolUtils.buildToolParameters(params, toolSpec, mlAgent.getTenantId());
Tool tool = createTool(toolFactories, firstToolExecuteParams, toolSpec);
firstStepListener = new StepListener<>();
previousStepListener = firstStepListener;
firstTool = tool;
} else {
MLToolSpec previousToolSpec = toolSpecs.get(i - 1);
StepListener<Object> nextStepListener = new StepListener<>();
int finalI = i;
previousStepListener.whenComplete(output -> {
processOutput(
params,
listener,
memory,
memoryId,
parentInteractionId,
toolSpecs,
flowAgentOutput,
additionalInfo,
traceNumber,
memorySpec,
previousToolSpec,
finalI,
output,
mlAgent.getTenantId(),
nextStepListener
);
}, e -> {
log.error("Failed to run flow agent", e);
listener.onFailure(e);
});
previousStepListener = nextStepListener;
}
}
if (toolSpecs.size() == 1) {
firstTool.run(firstToolExecuteParams, ActionListener.wrap(output -> {
MLToolSpec toolSpec = toolSpecs.get(0);
processOutput(
params,
listener,
Expand All @@ -201,43 +232,27 @@ private void runAgent(
additionalInfo,
traceNumber,
memorySpec,
previousToolSpec,
finalI,
toolSpec,
1,
output,
mlAgent.getTenantId(),
nextStepListener
null
);
}, e -> {
log.error("Failed to run flow agent", e);
listener.onFailure(e);
});
previousStepListener = nextStepListener;
}, e -> { listener.onFailure(e); }));
} else {
firstTool.run(firstToolExecuteParams, firstStepListener);
}
}
if (toolSpecs.size() == 1) {
firstTool.run(firstToolExecuteParams, ActionListener.wrap(output -> {
MLToolSpec toolSpec = toolSpecs.get(0);
processOutput(
params,
listener,
memory,
memoryId,
parentInteractionId,
toolSpecs,
flowAgentOutput,
additionalInfo,
traceNumber,
memorySpec,
toolSpec,
1,
output,
mlAgent.getTenantId(),
null
);
}, e -> { listener.onFailure(e); }));
} else {
firstTool.run(firstToolExecuteParams, firstStepListener);
}
};

// Fetch MCP tools and handle both success and failure cases
getMcpToolSpecs(mlAgent, client, sdkClient, encryptor, ActionListener.wrap(mcpTools -> {
toolSpecs.addAll(mcpTools);
processTools.accept(toolSpecs);
}, e -> {
log.warn("Failed to get MCP tools, continuing with base tools only", e);
processTools.accept(toolSpecs);
}));

}

@SuppressWarnings("removal")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
import static org.opensearch.ml.common.utils.ToolUtils.getToolName;
import static org.opensearch.ml.common.utils.ToolUtils.parseResponse;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTool;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMcpToolSpecs;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;

import org.opensearch.action.StepListener;
import org.opensearch.action.update.UpdateResponse;
Expand Down Expand Up @@ -149,12 +151,24 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
}

MLToolSpec toolSpec = toolSpecs.get(finalI);
Map<String, String> executeParams = ToolUtils.buildToolParameters(params, toolSpec, mlAgent.getTenantId());
Tool tool = createTool(toolFactories, executeParams, toolSpec);
if (finalI < toolSpecs.size()) {
tool.run(executeParams, nextStepListener);
}

// Create a common method to handle both success and failure cases
Consumer<List<MLToolSpec>> processTools = (allToolSpecs) -> {
Map<String, String> executeParams = ToolUtils.buildToolParameters(params, toolSpec, mlAgent.getTenantId());
Tool tool = createTool(toolFactories, executeParams, toolSpec);
if (finalI < toolSpecs.size()) {
tool.run(executeParams, nextStepListener);
}
};

// Fetch MCP tools and handle both success and failure cases
getMcpToolSpecs(mlAgent, client, sdkClient, encryptor, ActionListener.wrap(mcpTools -> {
toolSpecs.addAll(mcpTools);
processTools.accept(toolSpecs);
}, e -> {
log.warn("Failed to get MCP tools, continuing with base tools only", e);
processTools.accept(toolSpecs);
}));
}, e -> {
log.error("Failed to run flow agent", e);
listener.onFailure(e);
Expand Down
Loading