Skip to content

Commit 8c91e14

Browse files
ylwu-amzndhrubo-os
andauthored
Support output filter, unify tool parameter handling and improve Sear… (opensearch-project#4053)
* Support output filter, unify tool parameter handling and improve SearchIndexTool output parsing Signed-off-by: Yaliang Wu <[email protected]> * fix ut Signed-off-by: Yaliang Wu <[email protected]> * refactor tool parameter parsing logic Signed-off-by: Yaliang Wu <[email protected]> * fix tool parameter parsing Signed-off-by: Yaliang Wu <[email protected]> * fix comments Signed-off-by: Yaliang Wu <[email protected]> --------- Signed-off-by: Yaliang Wu <[email protected]> Co-authored-by: Dhrubo Saha <[email protected]>
1 parent 145f9f8 commit 8c91e14

File tree

21 files changed

+984
-346
lines changed

21 files changed

+984
-346
lines changed

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.common.utils;
77

8+
import static org.apache.commons.text.StringEscapeUtils.escapeJson;
89
import static org.opensearch.action.ValidateActions.addValidationError;
910

1011
import java.nio.ByteBuffer;
@@ -16,6 +17,7 @@
1617
import java.security.PrivilegedExceptionAction;
1718
import java.util.ArrayList;
1819
import java.util.Base64;
20+
import java.util.Collections;
1921
import java.util.HashMap;
2022
import java.util.HashSet;
2123
import java.util.List;
@@ -40,6 +42,7 @@
4042
import com.google.gson.JsonObject;
4143
import com.google.gson.JsonParser;
4244
import com.google.gson.JsonSyntaxException;
45+
import com.google.gson.reflect.TypeToken;
4346
import com.jayway.jsonpath.JsonPath;
4447
import com.jayway.jsonpath.PathNotFoundException;
4548
import com.networknt.schema.JsonSchema;
@@ -111,6 +114,30 @@ public static boolean isJson(String json) {
111114
}
112115
}
113116

117+
/**
118+
* Ensures that a string is properly JSON escaped.
119+
*
120+
* <p>This method examines the input string and determines whether it already represents
121+
* valid JSON content. If the input is valid JSON, it is returned unchanged. Otherwise,
122+
* the input is treated as a plain string and escaped according to JSON string literal
123+
* rules.</p>
124+
*
125+
* <p>Examples:</p>
126+
* <pre>
127+
* prepareJsonValue("hello") → "\"hello\""
128+
* prepareJsonValue("\"hello\"") → "\\\"hello\\\""
129+
* prepareJsonValue("{\"key\":123}") → {\"key\":123} (valid JSON object, unchanged)
130+
* </pre>
131+
* @param input
132+
* @return
133+
*/
134+
public static String prepareJsonValue(String input) {
135+
if (isJson(input)) {
136+
return input;
137+
}
138+
return escapeJson(input);
139+
}
140+
114141
public static String toUTF8(String rawString) {
115142
ByteBuffer buffer = StandardCharsets.UTF_8.encode(rawString);
116143

@@ -552,4 +579,22 @@ public static boolean matchesSafePattern(String value) {
552579
return SAFE_INPUT_PATTERN.matcher(value).matches();
553580
}
554581

582+
/**
583+
* Parses a JSON array string into a List of Strings.
584+
*
585+
* @param jsonArrayString JSON array string to parse (e.g., "[\"item1\", \"item2\"]")
586+
* @return List of strings parsed from the JSON array, or an empty list if the input is
587+
* null, empty, or invalid JSON
588+
*/
589+
public static List<String> parseStringArrayToList(String jsonArrayString) {
590+
if (jsonArrayString == null || jsonArrayString.trim().isEmpty()) {
591+
return Collections.emptyList();
592+
}
593+
try {
594+
return gson.fromJson(jsonArrayString, TypeToken.getParameterized(List.class, String.class).getType());
595+
} catch (JsonSyntaxException e) {
596+
log.error("Failed to parse JSON array string: {}", jsonArrayString, e);
597+
return Collections.emptyList();
598+
}
599+
}
555600
}

common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java

Lines changed: 99 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,13 @@
77

88
import static org.junit.Assert.assertEquals;
99
import static org.junit.Assert.assertFalse;
10+
import static org.junit.Assert.assertNotEquals;
1011
import static org.junit.Assert.assertNotNull;
1112
import static org.junit.Assert.assertNull;
13+
import static org.junit.Assert.assertSame;
1214
import static org.junit.Assert.assertThrows;
1315
import static org.junit.Assert.assertTrue;
14-
import static org.opensearch.ml.common.utils.StringUtils.TO_STRING_FUNCTION_NAME;
15-
import static org.opensearch.ml.common.utils.StringUtils.collectToStringPrefixes;
16-
import static org.opensearch.ml.common.utils.StringUtils.getJsonPath;
17-
import static org.opensearch.ml.common.utils.StringUtils.isValidJSONPath;
18-
import static org.opensearch.ml.common.utils.StringUtils.obtainFieldNameFromJsonPath;
19-
import static org.opensearch.ml.common.utils.StringUtils.parseParameters;
20-
import static org.opensearch.ml.common.utils.StringUtils.toJson;
16+
import static org.opensearch.ml.common.utils.StringUtils.*;
2117

2218
import java.io.IOException;
2319
import java.util.ArrayList;
@@ -190,7 +186,7 @@ public void addDefaultMethod_NoEscape() {
190186
public void addDefaultMethod_Escape() {
191187
String input = "return escape(\"abc\n123\");";
192188
String result = StringUtils.addDefaultMethod(input);
193-
Assert.assertNotEquals(input, result);
189+
assertNotEquals(input, result);
194190
assertTrue(result.startsWith(StringUtils.DEFAULT_ESCAPE_FUNCTION));
195191
}
196192

@@ -858,4 +854,99 @@ public void testValidateFields_InvalidCharacterSet() {
858854
assertTrue(exception.getMessage().contains("Field1"));
859855
}
860856

857+
@Test
858+
public void prepareJsonValue_returnsRawIfJson() {
859+
String json = "{\"key\": 123}";
860+
String result = StringUtils.prepareJsonValue(json);
861+
assertSame(json, result); // branch where isJson(input)==true
862+
}
863+
864+
@Test
865+
public void prepareJsonValue_escapesBadCharsOtherwise() {
866+
String input = "Tom & Jerry \"<script>";
867+
String escaped = StringUtils.prepareJsonValue(input);
868+
assertNotEquals(input, escaped);
869+
assertFalse(StringUtils.isJson(escaped));
870+
assertEquals("Tom & Jerry \\\"<script>", escaped);
871+
}
872+
873+
@Test
874+
public void testParseStringArrayToList_validJsonArray() {
875+
// Arrange
876+
String jsonArray = "[\"apple\", \"banana\", \"cherry\"]";
877+
878+
// Act
879+
List<String> result = parseStringArrayToList(jsonArray);
880+
881+
// Assert
882+
assertEquals(Arrays.asList("apple", "banana", "cherry"), result);
883+
}
884+
885+
@Test
886+
public void testParseStringArrayToList_emptyArray() {
887+
// Arrange
888+
String jsonArray = "[]";
889+
890+
// Act
891+
List<String> result = parseStringArrayToList(jsonArray);
892+
893+
// Assert
894+
assertTrue(result.isEmpty());
895+
}
896+
897+
@Test
898+
public void testParseStringArrayToList_withSpecialCharacters() {
899+
// Arrange
900+
String jsonArray = "[\"hello\", \"world!\", \"special: @#$%^&*()\"]";
901+
902+
// Act
903+
List<String> result = parseStringArrayToList(jsonArray);
904+
905+
// Assert
906+
assertEquals(Arrays.asList("hello", "world!", "special: @#$%^&*()"), result);
907+
}
908+
909+
@Test
910+
public void testParseStringArrayToList_withNullElement() {
911+
// Arrange
912+
String jsonArray = "[\"first\", null, \"third\"]";
913+
914+
// Act
915+
List<String> result = parseStringArrayToList(jsonArray);
916+
917+
// Assert
918+
assertEquals(3, result.size());
919+
assertEquals("first", result.get(0));
920+
assertNull(result.get(1));
921+
assertEquals("third", result.get(2));
922+
}
923+
924+
@Test
925+
public void testParseStringArrayToList_jsonWithTrailingComma() {
926+
// Arrange
927+
String jsonWithTrailingComma = "[\"apple\", \"banana\",]"; // Invalid trailing comma
928+
929+
List<String> result = parseStringArrayToList(jsonWithTrailingComma);
930+
931+
// Assert
932+
assertEquals(Arrays.asList("apple", "banana", null), result);
933+
assertEquals(3, result.size());
934+
}
935+
936+
@Test
937+
public void testParseStringArrayToList_nonArrayJson() {
938+
// Arrange
939+
String nonArrayJson = "{\"key\": \"value\"}";
940+
941+
// Act & Assert
942+
List<String> array = parseStringArrayToList(nonArrayJson);
943+
assertEquals(0, array.size());
944+
}
945+
946+
@Test
947+
public void testParseStringArrayToList_Null() {
948+
List<String> array = parseStringArrayToList(null);
949+
assertEquals(0, array.size());
950+
}
951+
861952
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 11 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD;
1111
import static org.opensearch.ml.common.CommonValue.MCP_CONNECTOR_ID_FIELD;
1212
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
13-
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
1413
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
1514
import static org.opensearch.ml.common.utils.StringUtils.gson;
1615
import static org.opensearch.ml.common.utils.StringUtils.isJson;
@@ -29,6 +28,7 @@
2928
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;
3029
import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.RESPONSE_FIELD;
3130
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS;
31+
import static org.opensearch.ml.engine.tools.ToolUtils.getToolName;
3232

3333
import java.io.IOException;
3434
import java.lang.reflect.Type;
@@ -81,6 +81,7 @@
8181
import org.opensearch.ml.engine.encryptor.Encryptor;
8282
import org.opensearch.ml.engine.function_calling.FunctionCalling;
8383
import org.opensearch.ml.engine.tools.McpSseTool;
84+
import org.opensearch.ml.engine.tools.ToolUtils;
8485
import org.opensearch.remote.metadata.client.GetDataObjectRequest;
8586
import org.opensearch.remote.metadata.client.SdkClient;
8687
import org.opensearch.remote.metadata.common.SdkClientUtils;
@@ -646,10 +647,6 @@ public static int getMessageHistoryLimit(Map<String, String> params) {
646647
return messageHistoryLimitStr != null ? Integer.parseInt(messageHistoryLimitStr) : LAST_N_INTERACTIONS;
647648
}
648649

649-
public static String getToolName(MLToolSpec toolSpec) {
650-
return toolSpec.getName() != null ? toolSpec.getName() : toolSpec.getType();
651-
}
652-
653650
public static List<MLToolSpec> getMlToolSpecs(MLAgent mlAgent, Map<String, String> params) {
654651
String selectedToolsStr = params.get(SELECTED_TOOLS);
655652
List<MLToolSpec> toolSpecs = new ArrayList<>();
@@ -841,7 +838,8 @@ public static void createTools(
841838
return;
842839
}
843840
for (MLToolSpec toolSpec : toolSpecs) {
844-
Tool tool = createTool(toolFactories, params, toolSpec, mlAgent.getTenantId());
841+
Map<String, String> toolParams = ToolUtils.buildToolParameters(params, toolSpec, mlAgent.getTenantId());
842+
Tool tool = ToolUtils.createTool(toolFactories, toolParams, toolSpec);
845843
tools.put(tool.getName(), tool);
846844
if (toolSpec.getAttributes() != null) {
847845
if (tool.getAttributes() == null) {
@@ -856,55 +854,6 @@ public static void createTools(
856854
}
857855
}
858856

859-
public static Tool createTool(
860-
Map<String, Tool.Factory> toolFactories,
861-
Map<String, String> params,
862-
MLToolSpec toolSpec,
863-
String tenantId
864-
) {
865-
if (!toolFactories.containsKey(toolSpec.getType())) {
866-
throw new IllegalArgumentException("Tool not found: " + toolSpec.getType());
867-
}
868-
Map<String, String> executeParams = new HashMap<>();
869-
if (toolSpec.getParameters() != null) {
870-
executeParams.putAll(toolSpec.getParameters());
871-
}
872-
executeParams.put(TENANT_ID_FIELD, tenantId);
873-
for (String key : params.keySet()) {
874-
String toolNamePrefix = getToolName(toolSpec) + ".";
875-
if (key.startsWith(toolNamePrefix)) {
876-
executeParams.put(key.replace(toolNamePrefix, ""), params.get(key));
877-
}
878-
}
879-
Map<String, Object> toolParams = new HashMap<>();
880-
toolParams.putAll(executeParams);
881-
Map<String, Object> runtimeResources = toolSpec.getRuntimeResources();
882-
if (runtimeResources != null) {
883-
toolParams.putAll(runtimeResources);
884-
}
885-
Tool tool = toolFactories.get(toolSpec.getType()).create(toolParams);
886-
String toolName = getToolName(toolSpec);
887-
tool.setName(toolName);
888-
889-
if (toolSpec.getDescription() != null) {
890-
tool.setDescription(toolSpec.getDescription());
891-
}
892-
if (params.containsKey(toolName + ".description")) {
893-
tool.setDescription(params.get(toolName + ".description"));
894-
}
895-
896-
return tool;
897-
}
898-
899-
public static List<String> getToolNames(Map<String, Tool> tools) {
900-
final List<String> inputTools = new ArrayList<>();
901-
for (Map.Entry<String, Tool> entry : tools.entrySet()) {
902-
String toolName = entry.getValue().getName();
903-
inputTools.add(toolName);
904-
}
905-
return inputTools;
906-
}
907-
908857
public static Map<String, String> constructToolParams(
909858
Map<String, Tool> tools,
910859
Map<String, MLToolSpec> toolSpecMap,
@@ -916,8 +865,15 @@ public static Map<String, String> constructToolParams(
916865
Map<String, String> toolParams = new HashMap<>();
917866
Map<String, String> toolSpecParams = toolSpecMap.get(action).getParameters();
918867
Map<String, String> toolSpecConfigMap = toolSpecMap.get(action).getConfigMap();
868+
MLToolSpec toolSpec = toolSpecMap.get(action);
919869
if (toolSpecParams != null) {
920870
toolParams.putAll(toolSpecParams);
871+
for (String key : toolSpecParams.keySet()) {
872+
String toolNamePrefix = getToolName(toolSpec) + ".";
873+
if (key.startsWith(toolNamePrefix)) {
874+
toolParams.put(key.replace(toolNamePrefix, ""), toolSpecParams.get(key));
875+
}
876+
}
921877
}
922878
if (toolSpecConfigMap != null) {
923879
toolParams.putAll(toolSpecConfigMap);

0 commit comments

Comments
 (0)