Skip to content

Commit 4122b38

Browse files
authored
Ensure tool failures are returned back to the agent (opensearch-project#4052)
1 parent 974b418 commit 4122b38

File tree

18 files changed

+306
-125
lines changed

18 files changed

+306
-125
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ public class AgentTool implements Tool {
3535
public static final String TYPE = "AgentTool";
3636
private final Client client;
3737

38+
@Setter
3839
private String agentId;
3940
@Setter
4041
@Getter
@@ -51,30 +52,41 @@ public class AgentTool implements Tool {
5152
private Map<String, Object> attributes;
5253

5354
public AgentTool(Client client, String agentId) {
55+
if (agentId == null || agentId.isBlank()) {
56+
throw new IllegalArgumentException("Agent ID cannot be null or empty");
57+
}
58+
5459
this.client = client;
5560
this.agentId = agentId;
5661
}
5762

5863
@Override
5964
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
60-
Map<String, String> extractedParameters = ToolUtils.extractInputParameters(parameters, attributes);
61-
String tenantId = parameters.get(TENANT_ID_FIELD);
62-
AgentMLInput agentMLInput = AgentMLInput
63-
.AgentMLInputBuilder()
64-
.agentId(agentId)
65-
.tenantId(tenantId)
66-
.functionName(FunctionName.AGENT)
67-
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(extractedParameters).build())
68-
.build();
69-
ActionRequest request = new MLExecuteTaskRequest(FunctionName.AGENT, agentMLInput, false);
70-
client.execute(MLExecuteTaskAction.INSTANCE, request, ActionListener.wrap(r -> {
71-
ModelTensorOutput output = (ModelTensorOutput) r.getOutput();
72-
listener.onResponse((T) output);
73-
}, e -> {
74-
log.error("Failed to run agent " + agentId, e);
65+
try {
66+
if (agentId == null || agentId.isBlank()) {
67+
throw new IllegalArgumentException("Agent ID not registered in tool");
68+
}
69+
Map<String, String> extractedParameters = ToolUtils.extractInputParameters(parameters, attributes);
70+
String tenantId = parameters.get(TENANT_ID_FIELD);
71+
AgentMLInput agentMLInput = AgentMLInput
72+
.AgentMLInputBuilder()
73+
.agentId(agentId)
74+
.tenantId(tenantId)
75+
.functionName(FunctionName.AGENT)
76+
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(extractedParameters).build())
77+
.build();
78+
ActionRequest request = new MLExecuteTaskRequest(FunctionName.AGENT, agentMLInput, false);
79+
client.execute(MLExecuteTaskAction.INSTANCE, request, ActionListener.wrap(r -> {
80+
ModelTensorOutput output = (ModelTensorOutput) r.getOutput();
81+
listener.onResponse((T) output);
82+
}, e -> {
83+
log.error("Failed to run agent " + agentId, e);
84+
listener.onFailure(e);
85+
}));
86+
} catch (Exception e) {
87+
log.error("Failed to run AgentTool with agent: {}", agentId, e);
7588
listener.onFailure(e);
76-
}));
77-
89+
}
7890
}
7991

8092
@Override

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.List;
99
import java.util.Map;
1010

11+
import org.apache.commons.lang3.StringUtils;
1112
import org.opensearch.action.ActionRequest;
1213
import org.opensearch.core.action.ActionListener;
1314
import org.opensearch.ml.common.FunctionName;
@@ -57,10 +58,11 @@ public class ConnectorTool implements Tool {
5758
private String connectorId;
5859

5960
public ConnectorTool(Client client, String connectorId) {
60-
this.client = client;
61-
if (connectorId == null) {
62-
throw new IllegalArgumentException("connector_id can't be null");
61+
if (StringUtils.isBlank(connectorId)) {
62+
throw new IllegalArgumentException("Connector ID can't be null or empty");
6363
}
64+
65+
this.client = client;
6466
this.connectorId = connectorId;
6567

6668
outputParser = new Parser() {
@@ -74,23 +76,31 @@ public Object parse(Object o) {
7476

7577
@Override
7678
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
77-
Map<String, String> parameters = ToolUtils.extractInputParameters(originalParameters, attributes);
78-
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build();
79-
MLInput mlInput = RemoteInferenceMLInput.builder().algorithm(FunctionName.CONNECTOR).inputDataset(inputDataSet).build();
80-
ActionRequest request = new MLExecuteConnectorRequest(connectorId, mlInput);
81-
82-
client.execute(MLExecuteConnectorAction.INSTANCE, request, ActionListener.wrap(r -> {
83-
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput();
84-
modelTensorOutput.getMlModelOutputs();
85-
if (outputParser == null) {
86-
listener.onResponse((T) modelTensorOutput.getMlModelOutputs());
87-
} else {
88-
listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs()));
79+
try {
80+
if (connectorId.isBlank()) {
81+
throw new IllegalArgumentException("Connector is not registered in tool");
8982
}
90-
}, e -> {
91-
log.error("Failed to run model " + connectorId, e);
83+
Map<String, String> parameters = ToolUtils.extractInputParameters(originalParameters, attributes);
84+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build();
85+
MLInput mlInput = RemoteInferenceMLInput.builder().algorithm(FunctionName.CONNECTOR).inputDataset(inputDataSet).build();
86+
ActionRequest request = new MLExecuteConnectorRequest(connectorId, mlInput);
87+
88+
client.execute(MLExecuteConnectorAction.INSTANCE, request, ActionListener.wrap(r -> {
89+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput();
90+
modelTensorOutput.getMlModelOutputs();
91+
if (outputParser == null) {
92+
listener.onResponse((T) modelTensorOutput.getMlModelOutputs());
93+
} else {
94+
listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs()));
95+
}
96+
}, e -> {
97+
log.error("Failed to run model " + connectorId, e);
98+
listener.onFailure(e);
99+
}));
100+
} catch (Exception e) {
101+
log.error("Failed to run ConnectorTool with connector: {}", connectorId, e);
92102
listener.onFailure(e);
93-
}));
103+
}
94104
}
95105

96106
@Override
@@ -100,15 +110,12 @@ public String getType() {
100110

101111
@Override
102112
public boolean validate(Map<String, String> parameters) {
103-
if (parameters == null || parameters.size() == 0) {
104-
return false;
105-
}
106-
return true;
113+
return parameters != null && !parameters.isEmpty();
107114
}
108115

109116
public static class Factory implements Tool.Factory<ConnectorTool> {
110117
public static final String TYPE = "ConnectorTool";
111-
public static final String DEFAULT_DESCRIPTION = "This tool will invoke external service.";
118+
public static final String DEFAULT_DESCRIPTION = "Invokes external service. Required: 'connector_id'. Returns: service response.";
112119
private Client client;
113120
private static Factory INSTANCE;
114121

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,20 @@
3232

3333
import lombok.Getter;
3434
import lombok.Setter;
35+
import lombok.extern.log4j.Log4j2;
3536

37+
@Log4j2
3638
@ToolAnnotation(IndexMappingTool.TYPE)
3739
public class IndexMappingTool implements Tool {
3840
public static final String TYPE = "IndexMappingTool";
3941
public static final String STRICT_FIELD = "strict";
4042
private static final String DEFAULT_DESCRIPTION = String
4143
.join(
4244
" ",
43-
"This tool gets index mapping information from a certain index.",
44-
"It takes 1 required argument named 'index' which is a comma-delimited list of one or more indices to get mapping information from, which expands wildcards.",
45-
"It takes 1 optional argument named 'local' which means whether to return information from the local node only instead of the cluster manager node (Default is false).",
46-
"The tool returns a list of index mappings and settings for each index.",
47-
"The mappings are in JSON format under the key 'properties' which includes the field name as a key and a JSON object with field type under the key 'type'.",
48-
"The settings are in flattened map with 'index' as the top element and key-value pairs for each setting."
45+
"This tool returns index mappings and settings for specified indices.",
46+
"Required argument: 'index' - comma-delimited list of one or more indices (supports wildcards like 'my-index-*').",
47+
"Optional argument: 'local' - if true, returns info from local node only instead of cluster manager (default: false).",
48+
"Response format: For each index, 'mappings' contains field definitions under 'properties' (each field has a 'type'), and 'settings' contains configuration as a flattened key-value map."
4949
);
5050
public static final String DEFAULT_INPUT_SCHEMA = "{\"type\":\"object\",\""
5151
+ "properties\":{\"index\":{\"type\":\"array\",\"description\":\"OpenSearch index name list, separated by comma. "
@@ -170,6 +170,7 @@ public void onFailure(final Exception e) {
170170

171171
client.admin().indices().getIndex(getIndexRequest, internalListener);
172172
} catch (Exception e) {
173+
log.error("Failed to run IndexMappingTool", e);
173174
listener.onFailure(e);
174175
}
175176
}
@@ -181,7 +182,7 @@ public String getType() {
181182

182183
@Override
183184
public boolean validate(Map<String, String> parameters) {
184-
return parameters != null && parameters.containsKey("index");
185+
return parameters != null && !parameters.isEmpty() && parameters.containsKey("index");
185186
}
186187

187188
/**

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,9 @@ public class ListIndexTool implements Tool {
7575
public static final String DEFAULT_DESCRIPTION = String
7676
.join(
7777
" ",
78-
"This tool gets index information from the OpenSearch cluster.",
79-
"It takes 2 optional arguments named `indices` which is a comma-delimited list of one or more indices to get information from (default is an empty list meaning all indices),",
80-
"and `local` which means whether to return information from the local node only instead of the cluster manager node (default is false).",
81-
"The tool returns the indices information, including `health`, `status`, `index`, `uuid`, `pri`, `rep`, `docs.count`, `docs.deleted`, `store.size`, `pri.store. size `, `pri.store.size`, `pri.store`."
78+
"This tool returns information about indices in the OpenSearch cluster along with the index `health`, `status`, `index`, `uuid`, `pri`, `rep`, `docs.count`, `docs.deleted`, `store.size`, `pri.store. size `, `pri.store.size`, `pri.store`.",
79+
"Optional arguments: 1. `indices`, a comma-delimited list of one or more indices to get information from (default is an empty list meaning all indices). Use only valid index names.",
80+
"2. `local`, whether to return information from the local node only instead of the cluster manager node (Default is false)"
8281
);
8382
public static final String DEFAULT_INPUT_SCHEMA = "{\"type\":\"object\","
8483
+ "\"properties\":{\"indices\":{\"type\":\"array\",\"items\": {\"type\": \"string\"},"
@@ -181,6 +180,7 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
181180
internalListener
182181
);
183182
} catch (Exception e) {
183+
log.error("Failed to run ListIndexTool", e);
184184
listener.onFailure(e);
185185
}
186186
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ public class MLModelTool implements WithModelTool {
6767
private String responseField;
6868

6969
public MLModelTool(Client client, String modelId, String responseField) {
70+
if (modelId == null || modelId.isBlank()) {
71+
throw new IllegalArgumentException("Model ID can't be null or empty");
72+
}
73+
7074
this.client = client;
7175
this.modelId = modelId;
7276
this.responseField = responseField;
@@ -89,27 +93,32 @@ public MLModelTool(Client client, String modelId, String responseField) {
8993

9094
@Override
9195
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
92-
Map<String, String> parameters = ToolUtils.extractInputParameters(originalParameters, attributes);
93-
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build();
94-
String tenantId = null;
95-
if (parameters != null) {
96-
tenantId = parameters.get(TENANT_ID_FIELD);
97-
}
96+
try {
97+
Map<String, String> parameters = ToolUtils.extractInputParameters(originalParameters, attributes);
98+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build();
99+
String tenantId = null;
100+
if (parameters != null) {
101+
tenantId = parameters.get(TENANT_ID_FIELD);
102+
}
98103

99-
ActionRequest request = MLPredictionTaskRequest
100-
.builder()
101-
.modelId(modelId)
102-
.mlInput(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build())
103-
.tenantId(tenantId)
104-
.build();
105-
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(r -> {
106-
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput();
107-
modelTensorOutput.getMlModelOutputs();
108-
listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs()));
109-
}, e -> {
110-
log.error("Failed to run model {}", modelId, e);
104+
ActionRequest request = MLPredictionTaskRequest
105+
.builder()
106+
.modelId(modelId)
107+
.mlInput(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build())
108+
.tenantId(tenantId)
109+
.build();
110+
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(r -> {
111+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput();
112+
modelTensorOutput.getMlModelOutputs();
113+
listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs()));
114+
}, e -> {
115+
log.error("Failed to run model {}", modelId, e);
116+
listener.onFailure(e);
117+
}));
118+
} catch (Exception e) {
119+
log.error("Failed to run MLModelTool for model: {}", modelId, e);
111120
listener.onFailure(e);
112-
}));
121+
}
113122
}
114123

115124
@Override
@@ -134,10 +143,7 @@ public void setName(String s) {
134143

135144
@Override
136145
public boolean validate(Map<String, String> parameters) {
137-
if (parameters == null || parameters.size() == 0) {
138-
return false;
139-
}
140-
return true;
146+
return parameters != null && !parameters.isEmpty();
141147
}
142148

143149
public static class Factory implements WithModelTool.Factory<MLModelTool> {

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ public class SearchIndexTool implements Tool {
6262

6363
public static final String TYPE = "SearchIndexTool";
6464
private static final String DEFAULT_DESCRIPTION =
65-
"Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query. Only use this tool when both index name and DSL query is available.";
65+
"Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query. Only use this tool when both index name and DSL query is available. "
66+
+ "Returns documents matching the query in the provided index.";
6667

6768
public static final String DEFAULT_INPUT_SCHEMA = "{\"type\":\"object\","
6869
+ "\"properties\":{\"index\":{\"type\":\"string\",\"description\":\"OpenSearch index name. for example: index1\"},"
@@ -179,17 +180,20 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
179180
log.error("Invalid JSON input: {}", input, e);
180181
}
181182
}
183+
182184
if (StringUtils.isEmpty(index)) {
183185
index = parameters.get(INDEX_FIELD);
184186
}
187+
185188
if (StringUtils.isEmpty(query)) {
186189
query = parameters.get(QUERY_FIELD);
187190
}
191+
188192
if (StringUtils.isEmpty(index) || StringUtils.isEmpty(query)) {
189193
listener
190194
.onFailure(
191195
new IllegalArgumentException(
192-
"SearchIndexTool's two parameters: index and query are required and should in valid format!"
196+
"SearchIndexTool's two parameters: index and query are required and should be in valid format"
193197
)
194198
);
195199
return;
@@ -236,7 +240,7 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
236240
client.search(searchRequest, actionListener);
237241
}
238242
} catch (Exception e) {
239-
log.error("Failed to search index", e);
243+
log.error("Failed to run SearchIndexTool", e);
240244
listener.onFailure(e);
241245
}
242246
}

0 commit comments

Comments
 (0)