Skip to content

Commit 8ef3528

Browse files
authored
move tool utils to common (opensearch-project#4081)
* move tool utils to common Signed-off-by: Yaliang Wu <[email protected]> * avoid adding spi dependency to common Signed-off-by: Yaliang Wu <[email protected]> * add javadoc to ToolUtils Signed-off-by: Yaliang Wu <[email protected]> --------- Signed-off-by: Yaliang Wu <[email protected]>
1 parent 878bbcc commit 8ef3528

File tree

17 files changed

+268
-179
lines changed

17 files changed

+268
-179
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ToolUtils.java renamed to common/src/main/java/org/opensearch/ml/common/utils/ToolUtils.java

Lines changed: 80 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.engine.tools;
6+
package org.opensearch.ml.common.utils;
77

88
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
99
import static org.opensearch.ml.common.utils.StringUtils.gson;
1010

11-
import java.util.ArrayList;
1211
import java.util.HashMap;
1312
import java.util.List;
1413
import java.util.Map;
@@ -19,8 +18,6 @@
1918
import org.opensearch.ml.common.output.model.ModelTensor;
2019
import org.opensearch.ml.common.output.model.ModelTensorOutput;
2120
import org.opensearch.ml.common.output.model.ModelTensors;
22-
import org.opensearch.ml.common.spi.tools.Tool;
23-
import org.opensearch.ml.common.utils.StringUtils;
2421

2522
import com.google.gson.reflect.TypeToken;
2623
import com.jayway.jsonpath.JsonPath;
@@ -38,6 +35,19 @@ public class ToolUtils {
3835
public static final String TOOL_OUTPUT_FILTERS_FIELD = "output_filter";
3936
public static final String TOOL_REQUIRED_PARAMS = "required_parameters";
4037

38+
/**
39+
* Extracts required parameters based on tool attributes specification.
40+
* <p>
41+
* The method performs the following:
42+
* <ul>
43+
* <li>If required parameters are specified in attributes, only those parameters are extracted</li>
44+
* <li>If no required parameters are specified, all parameters are returned</li>
45+
* </ul>
46+
*
47+
* @param parameters The input parameters map to extract from
48+
* @param attributes The attributes map containing required parameter specifications
49+
* @return Map containing only the required parameters
50+
*/
4151
public static Map<String, String> extractRequiredParameters(Map<String, String> parameters, Map<String, ?> attributes) {
4252
Map<String, String> extractedParameters = new HashMap<>();
4353
if (parameters == null) {
@@ -56,6 +66,26 @@ public static Map<String, String> extractRequiredParameters(Map<String, String>
5666
return extractedParameters;
5767
}
5868

69+
/**
70+
* Extracts and processes input parameters, including handling "input" parameter.
71+
* <p>
72+
* The method performs the following steps:
73+
* <ol>
74+
* <li>Extracts required parameters based on tool attributes specification</li>
75+
* <li>If an "input" parameter exists:
76+
* <ul>
77+
* <li>Substitutes any parameter placeholders</li>
78+
* <li>Parses it as a JSON map</li>
79+
* <li>Merges the parsed values with other parameters</li>
80+
* </ul>
81+
* </li>
82+
* </ol>
83+
*
84+
* @param parameters The raw input parameters
85+
* @param attributes The tool attributes containing parameter specifications
86+
* @return Map of processed input parameters
87+
* @throws IllegalArgumentException if input JSON parsing fails
88+
*/
5989
public static Map<String, String> extractInputParameters(Map<String, String> parameters, Map<String, ?> attributes) {
6090
Map<String, String> extractedParameters = ToolUtils.extractRequiredParameters(parameters, attributes);
6191
if (extractedParameters.containsKey("input")) {
@@ -73,6 +103,22 @@ public static Map<String, String> extractInputParameters(Map<String, String> par
73103
return extractedParameters;
74104
}
75105

106+
/**
107+
* Builds the final parameter map for tool execution.
108+
* <p>
109+
* The method performs the following steps:
110+
* <ol>
111+
* <li>Combines tool specification parameters with input parameters</li>
112+
* <li>Processes tool-specific parameter prefixes</li>
113+
* <li>Applies configuration overrides from tool specification</li>
114+
* <li>Adds tenant identification</li>
115+
* </ol>
116+
*
117+
* @param parameters The input parameters to process
118+
* @param toolSpec The tool specification containing default parameters and configuration
119+
* @param tenantId The identifier for the tenant
120+
* @return Map of processed parameters ready for tool execution
121+
*/
76122
public static Map<String, String> buildToolParameters(Map<String, String> parameters, MLToolSpec toolSpec, String tenantId) {
77123
Map<String, String> executeParams = new HashMap<>();
78124
if (toolSpec.getParameters() != null) {
@@ -102,30 +148,14 @@ public static Map<String, String> buildToolParameters(Map<String, String> parame
102148
return executeParams;
103149
}
104150

105-
public static Tool createTool(Map<String, Tool.Factory> toolFactories, Map<String, String> executeParams, MLToolSpec toolSpec) {
106-
if (!toolFactories.containsKey(toolSpec.getType())) {
107-
throw new IllegalArgumentException("Tool not found: " + toolSpec.getType());
108-
}
109-
Map<String, Object> toolParams = new HashMap<>();
110-
toolParams.putAll(executeParams);
111-
Map<String, Object> runtimeResources = toolSpec.getRuntimeResources();
112-
if (runtimeResources != null) {
113-
toolParams.putAll(runtimeResources);
114-
}
115-
Tool tool = toolFactories.get(toolSpec.getType()).create(toolParams);
116-
String toolName = getToolName(toolSpec);
117-
tool.setName(toolName);
118-
119-
if (toolSpec.getDescription() != null) {
120-
tool.setDescription(toolSpec.getDescription());
121-
}
122-
if (executeParams.containsKey(toolName + ".description")) {
123-
tool.setDescription(executeParams.get(toolName + ".description"));
124-
}
125-
126-
return tool;
127-
}
128-
151+
/**
152+
* Filters tool output based on specified output filters in tool parameters.
153+
* Uses JSONPath expressions to extract specific portions of the response.
154+
*
155+
* @param toolParams The tool parameters containing output filter specifications
156+
* @param response The raw tool response to filter
157+
* @return Filtered output if successful, original response if filtering fails
158+
*/
129159
public static Object filterToolOutput(Map<String, String> toolParams, Object response) {
130160
if (toolParams != null && toolParams.containsKey(TOOL_OUTPUT_FILTERS_FIELD)) {
131161
try {
@@ -142,6 +172,20 @@ public static Object filterToolOutput(Map<String, String> toolParams, Object res
142172
return response;
143173
}
144174

175+
/**
176+
* Parses different types of tool responses into a JSON string representation.
177+
* <p>
178+
* Handles the following special cases:
179+
* <ul>
180+
* <li>ModelTensors - converts to XContent JSON representation</li>
181+
* <li>ModelTensor - converts to XContent JSON representation</li>
182+
* <li>ModelTensorOutput - converts to XContent JSON representation</li>
183+
* <li>Other types - converts to generic JSON string</li>
184+
* </ul>
185+
*
186+
* @param output The tool output object to parse
187+
* @return JSON string representation of the output
188+
*/
145189
public static String parseResponse(Object output) {
146190
try {
147191
if (output instanceof List && !((List) output).isEmpty() && ((List) output).get(0) instanceof ModelTensors) {
@@ -159,16 +203,15 @@ public static String parseResponse(Object output) {
159203
}
160204
}
161205

162-
public static List<String> getToolNames(Map<String, Tool> tools) {
163-
final List<String> inputTools = new ArrayList<>();
164-
for (Map.Entry<String, Tool> entry : tools.entrySet()) {
165-
String toolName = entry.getValue().getName();
166-
inputTools.add(toolName);
167-
}
168-
return inputTools;
169-
}
170-
206+
/**
207+
* Gets the tool name from a tool specification.
208+
* Returns the specified name if available, otherwise returns the tool type.
209+
*
210+
* @param toolSpec The tool specification
211+
* @return The name of the tool
212+
*/
171213
public static String getToolName(MLToolSpec toolSpec) {
172214
return toolSpec.getName() != null ? toolSpec.getName() : toolSpec.getType();
173215
}
216+
174217
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ToolUtilsTest.java renamed to common/src/test/java/org/opensearch/ml/common/utils/ToolUtilsTest.java

Lines changed: 7 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,14 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.engine.tools;
7-
8-
import static org.junit.Assert.*;
9-
import static org.mockito.ArgumentMatchers.any;
10-
import static org.mockito.ArgumentMatchers.argThat;
11-
import static org.mockito.Mockito.mock;
12-
import static org.mockito.Mockito.verify;
13-
import static org.mockito.Mockito.when;
6+
package org.opensearch.ml.common.utils;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertFalse;
10+
import static org.junit.Assert.assertTrue;
1411
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
15-
import static org.opensearch.ml.engine.tools.ToolUtils.TOOL_REQUIRED_PARAMS;
16-
import static org.opensearch.ml.engine.tools.ToolUtils.filterToolOutput;
12+
import static org.opensearch.ml.common.utils.ToolUtils.TOOL_REQUIRED_PARAMS;
13+
import static org.opensearch.ml.common.utils.ToolUtils.filterToolOutput;
1714

1815
import java.util.ArrayList;
1916
import java.util.HashMap;
@@ -22,54 +19,9 @@
2219

2320
import org.junit.Test;
2421
import org.opensearch.ml.common.agent.MLToolSpec;
25-
import org.opensearch.ml.common.spi.tools.Tool;
2622

2723
public class ToolUtilsTest {
2824

29-
@Test
30-
public void testCreateTool_Success() {
31-
Map<String, Tool.Factory> toolFactories = new HashMap<>();
32-
Tool.Factory factory = mock(Tool.Factory.class);
33-
Tool mockTool = mock(Tool.class);
34-
when(factory.create(any())).thenReturn(mockTool);
35-
toolFactories.put("test_tool", factory);
36-
37-
MLToolSpec toolSpec = MLToolSpec
38-
.builder()
39-
.type("test_tool")
40-
.name("TestTool")
41-
.description("Original description")
42-
.parameters(Map.of("param1", "value1"))
43-
.runtimeResources(Map.of("resource1", "value2"))
44-
.build();
45-
46-
Map<String, String> params = new HashMap<>();
47-
params.put("TestTool.param2", "value3");
48-
params.put("TestTool.description", "Custom description");
49-
50-
Map<String, String> toolParameters = ToolUtils.buildToolParameters(params, toolSpec, "test_tenant");
51-
ToolUtils.createTool(toolFactories, toolParameters, toolSpec);
52-
53-
verify(factory).create(argThat(toolParamsMap -> {
54-
Map<String, Object> toolParams = (Map<String, Object>) toolParamsMap;
55-
return toolParams.get("param1").equals("value1")
56-
&& toolParams.get("param2").equals("value3")
57-
&& toolParams.get("resource1").equals("value2")
58-
&& toolParams.get(TENANT_ID_FIELD).equals("test_tenant");
59-
}));
60-
61-
verify(mockTool).setName("TestTool");
62-
verify(mockTool).setDescription("Custom description");
63-
}
64-
65-
@Test
66-
public void testCreateTool_ToolNotFound() {
67-
Map<String, Tool.Factory> toolFactories = new HashMap<>();
68-
MLToolSpec toolSpec = MLToolSpec.builder().type("non_existent_tool").name("TestTool").build();
69-
70-
assertThrows(IllegalArgumentException.class, () -> ToolUtils.createTool(toolFactories, new HashMap<>(), toolSpec));
71-
}
72-
7325
@Test
7426
public void testExtractRequiredParameters_WithRequiredParameters() {
7527
Map<String, String> parameters = new HashMap<>();
@@ -219,70 +171,6 @@ public void testBuildToolParameters_WithNullToolSpecParameters() {
219171
assertEquals("test_tenant", result.get(TENANT_ID_FIELD));
220172
}
221173

222-
@Test
223-
public void testCreateTool_WithDescription() {
224-
Map<String, Tool.Factory> toolFactories = new HashMap<>();
225-
Tool.Factory factory = mock(Tool.Factory.class);
226-
Tool mockTool = mock(Tool.class);
227-
when(factory.create(any())).thenReturn(mockTool);
228-
toolFactories.put("test_tool", factory);
229-
230-
MLToolSpec toolSpec = MLToolSpec.builder().type("test_tool").name("TestTool").description("Tool description").build();
231-
232-
Map<String, String> params = new HashMap<>();
233-
234-
Tool result = ToolUtils.createTool(toolFactories, params, toolSpec);
235-
236-
verify(mockTool).setName("TestTool");
237-
verify(mockTool).setDescription("Tool description");
238-
assertEquals(mockTool, result);
239-
}
240-
241-
@Test
242-
public void testCreateTool_WithRuntimeResources() {
243-
Map<String, Tool.Factory> toolFactories = new HashMap<>();
244-
Tool.Factory factory = mock(Tool.Factory.class);
245-
Tool mockTool = mock(Tool.class);
246-
when(factory.create(any())).thenReturn(mockTool);
247-
toolFactories.put("test_tool", factory);
248-
249-
Map<String, Object> runtimeResources = new HashMap<>();
250-
runtimeResources.put("resource1", "value1");
251-
runtimeResources.put("resource2", 42);
252-
253-
MLToolSpec toolSpec = MLToolSpec.builder().type("test_tool").name("TestTool").runtimeResources(runtimeResources).build();
254-
255-
Map<String, String> params = new HashMap<>();
256-
params.put("param1", "value1");
257-
258-
ToolUtils.createTool(toolFactories, params, toolSpec);
259-
260-
verify(factory).create(argThat(toolParamsMap -> {
261-
Map<String, Object> toolParams = (Map<String, Object>) toolParamsMap;
262-
return toolParams.get("param1").equals("value1")
263-
&& toolParams.get("resource1").equals("value1")
264-
&& toolParams.get("resource2").equals(42);
265-
}));
266-
}
267-
268-
@Test
269-
public void testCreateTool_WithNullRuntimeResources() {
270-
Map<String, Tool.Factory> toolFactories = new HashMap<>();
271-
Tool.Factory factory = mock(Tool.Factory.class);
272-
Tool mockTool = mock(Tool.class);
273-
when(factory.create(any())).thenReturn(mockTool);
274-
toolFactories.put("test_tool", factory);
275-
276-
MLToolSpec toolSpec = MLToolSpec.builder().type("test_tool").name("TestTool").runtimeResources(null).build();
277-
278-
Map<String, String> params = new HashMap<>();
279-
params.put("param1", "value1");
280-
281-
ToolUtils.createTool(toolFactories, params, toolSpec);
282-
283-
verify(factory).create(argThat(toolParamsMap -> ((Map<String, Object>) toolParamsMap).get("param1").equals("value1")));
284-
}
285-
286174
@Test
287175
public void testFilterToolOutput_NoFiltering() {
288176
// Create a simple object

0 commit comments

Comments
 (0)