| 
5 | 5 | 
 
  | 
6 | 6 | package org.opensearch.ml.engine.algorithms.agent;  | 
7 | 7 | 
 
  | 
 | 8 | +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;  | 
8 | 9 | import static org.opensearch.ml.common.utils.StringUtils.gson;  | 
 | 10 | +import static org.opensearch.ml.common.utils.StringUtils.isJson;  | 
 | 11 | +import static org.opensearch.ml.common.utils.StringUtils.toJson;  | 
9 | 12 | import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.MESSAGE_HISTORY_LIMIT;  | 
 | 13 | +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION;  | 
 | 14 | +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION_INPUT;  | 
10 | 15 | import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY;  | 
11 | 16 | import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT;  | 
12 | 17 | import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES;  | 
 | 18 | +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.FINAL_ANSWER;  | 
13 | 19 | import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES;  | 
 | 20 | +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT;  | 
 | 21 | +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT_RESPONSE;  | 
14 | 22 | import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;  | 
15 | 23 | import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;  | 
16 | 24 | import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS;  | 
 | 
19 | 27 | import java.security.PrivilegedActionException;  | 
20 | 28 | import java.security.PrivilegedExceptionAction;  | 
21 | 29 | import java.util.ArrayList;  | 
 | 30 | +import java.util.Collection;  | 
22 | 31 | import java.util.HashMap;  | 
23 | 32 | import java.util.List;  | 
 | 33 | +import java.util.Locale;  | 
24 | 34 | import java.util.Map;  | 
25 | 35 | import java.util.Optional;  | 
 | 36 | +import java.util.Set;  | 
26 | 37 | import java.util.regex.Matcher;  | 
27 | 38 | import java.util.regex.Pattern;  | 
28 | 39 | 
 
  | 
 | 
33 | 44 | import org.opensearch.ml.common.output.model.ModelTensor;  | 
34 | 45 | import org.opensearch.ml.common.output.model.ModelTensorOutput;  | 
35 | 46 | import org.opensearch.ml.common.spi.tools.Tool;  | 
 | 47 | +import org.opensearch.ml.common.utils.StringUtils;  | 
36 | 48 | 
 
  | 
 | 49 | +import lombok.extern.log4j.Log4j2;  | 
 | 50 | + | 
 | 51 | +@Log4j2  | 
37 | 52 | public class AgentUtils {  | 
38 | 53 | 
 
  | 
39 | 54 |     public static final String SELECTED_TOOLS = "selected_tools";  | 
@@ -167,23 +182,166 @@ public static String extractModelResponseJson(String text) {  | 
167 | 182 |         return extractModelResponseJson(text, null);  | 
168 | 183 |     }  | 
169 | 184 | 
 
  | 
170 |  | -    public static String extractModelResponseJson(String text, List<String> llmResponsePatterns) {  | 
171 |  | -        Pattern jsonBlockPattern = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```");  | 
172 |  | -        Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);  | 
173 |  | - | 
174 |  | -        if (jsonBlockMatcher.find()) {  | 
175 |  | -            return jsonBlockMatcher.group(1);  | 
 | 185 | +    public static Map<String, String> parseLLMOutput(  | 
 | 186 | +        ModelTensorOutput tmpModelTensorOutput,  | 
 | 187 | +        List<String> llmResponsePatterns,  | 
 | 188 | +        Set<String> inputTools  | 
 | 189 | +    ) {  | 
 | 190 | +        Map<String, String> modelOutput = new HashMap<>();  | 
 | 191 | +        Map<String, ?> dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();  | 
 | 192 | +        if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) {  | 
 | 193 | +            String llmReasoningResponse = (String) dataAsMap.get("response");  | 
 | 194 | +            String thoughtResponse = null;  | 
 | 195 | +            try {  | 
 | 196 | +                thoughtResponse = extractModelResponseJson(llmReasoningResponse, llmResponsePatterns);  | 
 | 197 | +                modelOutput.put(THOUGHT_RESPONSE, thoughtResponse);  | 
 | 198 | +            } catch (IllegalArgumentException e) {  | 
 | 199 | +                modelOutput.put(THOUGHT_RESPONSE, llmReasoningResponse);  | 
 | 200 | +                thoughtResponse = llmReasoningResponse;  | 
 | 201 | +            }  | 
 | 202 | +            parseThoughtResponse(modelOutput, thoughtResponse);  | 
176 | 203 |         } else {  | 
177 |  | -            String matchedPart = findMatchedPart(text, MODEL_RESPONSE_PATTERNS);  | 
178 |  | -            if (matchedPart == null && llmResponsePatterns != null) {  | 
179 |  | -                // If no match is found, try additional patterns if provided  | 
180 |  | -                matchedPart = findMatchedPart(text, llmResponsePatterns);  | 
 | 204 | +            extractParams(modelOutput, dataAsMap, THOUGHT);  | 
 | 205 | +            extractParams(modelOutput, dataAsMap, ACTION);  | 
 | 206 | +            extractParams(modelOutput, dataAsMap, ACTION_INPUT);  | 
 | 207 | +            extractParams(modelOutput, dataAsMap, FINAL_ANSWER);  | 
 | 208 | +            try {  | 
 | 209 | +                modelOutput.put(THOUGHT_RESPONSE, StringUtils.toJson(dataAsMap));  | 
 | 210 | +            } catch (Exception e) {  | 
 | 211 | +                log.warn("Failed to parse model response", e);  | 
 | 212 | +            }  | 
 | 213 | +        }  | 
 | 214 | +        String action = modelOutput.get(ACTION);  | 
 | 215 | +        if (action != null) {  | 
 | 216 | +            String matchedTool = getMatchedTool(inputTools, action);  | 
 | 217 | +            if (matchedTool != null) {  | 
 | 218 | +                modelOutput.put(ACTION, matchedTool);  | 
 | 219 | +            } else {  | 
 | 220 | +                modelOutput.remove(ACTION);  | 
 | 221 | +            }  | 
 | 222 | +        }  | 
 | 223 | +        if (!modelOutput.containsKey(ACTION) && !modelOutput.containsKey(FINAL_ANSWER)) {  | 
 | 224 | +            modelOutput.put(FINAL_ANSWER, modelOutput.get(THOUGHT_RESPONSE));  | 
 | 225 | +        }  | 
 | 226 | +        return modelOutput;  | 
 | 227 | +    }  | 
 | 228 | + | 
 | 229 | +    public static String getMatchedTool(Collection<String> tools, String action) {  | 
 | 230 | +        for (String tool : tools) {  | 
 | 231 | +            if (action.toLowerCase(Locale.ROOT).contains(tool.toLowerCase(Locale.ROOT))) {  | 
 | 232 | +                return tool;  | 
181 | 233 |             }  | 
 | 234 | +        }  | 
 | 235 | +        return null;  | 
 | 236 | +    }  | 
 | 237 | + | 
 | 238 | +    public static void extractParams(Map<String, String> modelOutput, Map<String, ?> dataAsMap, String paramName) {  | 
 | 239 | +        if (dataAsMap.containsKey(paramName)) {  | 
 | 240 | +            modelOutput.put(paramName, toJson(dataAsMap.get(paramName)));  | 
 | 241 | +        }  | 
 | 242 | +    }  | 
 | 243 | + | 
 | 244 | +    public static String extractModelResponseJson(String text, List<String> llmResponsePatterns) {  | 
 | 245 | +        if (text.contains("```json")) {  | 
 | 246 | +            text = text.substring(text.indexOf("```json") + "```json".length());  | 
 | 247 | +            if (text.contains("```")) {  | 
 | 248 | +                text = text.substring(0, text.lastIndexOf("```"));  | 
 | 249 | +            }  | 
 | 250 | +        }  | 
 | 251 | +        text = text.trim();  | 
 | 252 | +        if (isJson(text)) {  | 
 | 253 | +            return text;  | 
 | 254 | +        }  | 
 | 255 | +        String matchedPart = null;  | 
 | 256 | +        if (llmResponsePatterns != null) {  | 
 | 257 | +            matchedPart = findMatchedPart(text, llmResponsePatterns);  | 
182 | 258 |             if (matchedPart != null) {  | 
183 | 259 |                 return matchedPart;  | 
184 | 260 |             }  | 
185 |  | -            throw new IllegalArgumentException("Model output is invalid");  | 
186 | 261 |         }  | 
 | 262 | +        matchedPart = findMatchedPart(text, MODEL_RESPONSE_PATTERNS);  | 
 | 263 | +        if (matchedPart != null) {  | 
 | 264 | +            return matchedPart;  | 
 | 265 | +        }  | 
 | 266 | +        throw new IllegalArgumentException("Model output is invalid");  | 
 | 267 | +    }  | 
 | 268 | + | 
 | 269 | +    public static void parseThoughtResponse(Map<String, String> modelOutput, String thoughtResponse) {  | 
 | 270 | +        if (thoughtResponse != null) {  | 
 | 271 | +            if (isJson(thoughtResponse)) {  | 
 | 272 | +                modelOutput.putAll(getParameterMap(gson.fromJson(thoughtResponse, Map.class)));  | 
 | 273 | +            } else {// sometimes LLM return invalid json response  | 
 | 274 | +                String thought = extractThought(thoughtResponse);  | 
 | 275 | +                String action = extractAction(thoughtResponse);  | 
 | 276 | +                String actionInput = extractActionInput(thoughtResponse);  | 
 | 277 | +                String finalAnswer = extractFinalAnswer(thoughtResponse);  | 
 | 278 | +                if (thought != null) {  | 
 | 279 | +                    modelOutput.put(THOUGHT, thought);  | 
 | 280 | +                }  | 
 | 281 | +                if (action != null) {  | 
 | 282 | +                    modelOutput.put(ACTION, action);  | 
 | 283 | +                }  | 
 | 284 | +                if (actionInput != null) {  | 
 | 285 | +                    modelOutput.put(ACTION_INPUT, actionInput);  | 
 | 286 | +                }  | 
 | 287 | +                if (finalAnswer != null) {  | 
 | 288 | +                    modelOutput.put(FINAL_ANSWER, finalAnswer);  | 
 | 289 | +                }  | 
 | 290 | +            }  | 
 | 291 | +        }  | 
 | 292 | +    }  | 
 | 293 | + | 
 | 294 | +    public static String extractFinalAnswer(String text) {  | 
 | 295 | +        String result = null;  | 
 | 296 | +        if (text.contains("\"final_answer\"")) {  | 
 | 297 | +            String pattern = "\"final_answer\"\\s*:\\s*\"(.*?)$";  | 
 | 298 | +            Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL);  | 
 | 299 | +            Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);  | 
 | 300 | +            if (jsonBlockMatcher.find()) {  | 
 | 301 | +                result = jsonBlockMatcher.group(1);  | 
 | 302 | +            }  | 
 | 303 | +        }  | 
 | 304 | +        return result;  | 
 | 305 | +    }  | 
 | 306 | + | 
 | 307 | +    public static String extractThought(String text) {  | 
 | 308 | +        String result = null;  | 
 | 309 | +        if (text.contains("\"thought\"")) {  | 
 | 310 | +            String pattern = "\"thought\"\\s*:\\s*\"(.*?)\"\\s*,\\s*[\"final_answer\"|\"action\"]";  | 
 | 311 | +            Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL);  | 
 | 312 | +            Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);  | 
 | 313 | +            if (jsonBlockMatcher.find()) {  | 
 | 314 | +                result = jsonBlockMatcher.group(1);  | 
 | 315 | +            }  | 
 | 316 | +        }  | 
 | 317 | +        return result;  | 
 | 318 | +    }  | 
 | 319 | + | 
 | 320 | +    public static String extractAction(String text) {  | 
 | 321 | +        String result = null;  | 
 | 322 | +        if (text.contains("\"action\"")) {  | 
 | 323 | +            String pattern = "\"action\"\\s*:\\s*\"(.*?)(?:\"action_input\"|$)";  | 
 | 324 | +            Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL);  | 
 | 325 | +            Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);  | 
 | 326 | +            if (jsonBlockMatcher.find()) {  | 
 | 327 | +                result = jsonBlockMatcher.group(1);  | 
 | 328 | +            }  | 
 | 329 | +        }  | 
 | 330 | +        return result;  | 
 | 331 | +    }  | 
 | 332 | + | 
 | 333 | +    public static String extractActionInput(String text) {  | 
 | 334 | +        String result = null;  | 
 | 335 | +        if (text.contains("\"action_input\"")) {  | 
 | 336 | +            String pattern = "\"action_input\"\\s*:\\s*\"((?:[^\\\"]|\\\")*)\"";  | 
 | 337 | +            Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL); // Add Pattern.DOTALL to match across newlines  | 
 | 338 | +            Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);  | 
 | 339 | +            if (jsonBlockMatcher.find()) {  | 
 | 340 | +                result = jsonBlockMatcher.group(1);  | 
 | 341 | +                result = result.replace("\\\"", "\"");  | 
 | 342 | +            }  | 
 | 343 | +        }  | 
 | 344 | +        return result;  | 
187 | 345 |     }  | 
188 | 346 | 
 
  | 
189 | 347 |     public static String findMatchedPart(String text, List<String> patternList) {  | 
 | 
0 commit comments