1111import static org .opensearch .ml .engine .algorithms .agent .MLChatAgentRunner .CONTEXT ;
1212import static org .opensearch .ml .engine .algorithms .agent .MLChatAgentRunner .EXAMPLES ;
1313import static org .opensearch .ml .engine .algorithms .agent .MLChatAgentRunner .OS_INDICES ;
14- import static org .opensearch .ml .engine .algorithms .agent .MLChatAgentRunner .PROMPT_PREFIX ;
15- import static org .opensearch .ml .engine .algorithms .agent .MLChatAgentRunner .PROMPT_SUFFIX ;
1614import static org .opensearch .ml .engine .algorithms .agent .MLChatAgentRunner .TOOL_DESCRIPTIONS ;
1715import static org .opensearch .ml .engine .algorithms .agent .MLChatAgentRunner .TOOL_NAMES ;
1816import static org .opensearch .ml .engine .memory .ConversationIndexMemory .LAST_N_INTERACTIONS ;
1917
2018import java .security .AccessController ;
2119import java .security .PrivilegedActionException ;
2220import java .security .PrivilegedExceptionAction ;
21+ import java .util .ArrayList ;
2322import java .util .HashMap ;
2423import java .util .List ;
2524import java .util .Map ;
2827import java .util .regex .Pattern ;
2928
3029import org .apache .commons .text .StringSubstitutor ;
30+ import org .opensearch .core .common .Strings ;
31+ import org .opensearch .ml .common .agent .MLAgent ;
3132import org .opensearch .ml .common .agent .MLToolSpec ;
3233import org .opensearch .ml .common .output .model .ModelTensor ;
3334import org .opensearch .ml .common .output .model .ModelTensorOutput ;
3435import org .opensearch .ml .common .spi .tools .Tool ;
3536
3637public class AgentUtils {
3738
39+ public static final String SELECTED_TOOLS = "selected_tools" ;
40+ public static final String PROMPT_PREFIX = "prompt.prefix" ;
41+ public static final String PROMPT_SUFFIX = "prompt.suffix" ;
42+ public static final String RESPONSE_FORMAT_INSTRUCTION = "prompt.format_instruction" ;
43+ public static final String TOOL_RESPONSE = "prompt.tool_response" ;
44+ public static final String PROMPT_CHAT_HISTORY_PREFIX = "prompt.chat_history_prefix" ;
45+ public static final String DISABLE_TRACE = "disable_trace" ;
46+ public static final String VERBOSE = "verbose" ;
47+
3848 public static String addExamplesToPrompt (Map <String , String > parameters , String prompt ) {
3949 Map <String , String > examplesMap = new HashMap <>();
4050 if (parameters .containsKey (EXAMPLES )) {
@@ -150,17 +160,43 @@ public static String addContextToPrompt(Map<String, String> parameters, String p
150160 return prompt ;
151161 }
152162
163+ public static List <String > MODEL_RESPONSE_PATTERNS = List
164+ .of ("\\ {\\ s*(\" (thought|action|action_input|final_answer)\" \\ s*:\\ s*\" .*?\" \\ s*,?\\ s*)+\\ }" );
165+
153166 public static String extractModelResponseJson (String text ) {
154- Pattern pattern = Pattern . compile ( "```json \\ s*([ \\ s \\ S]+?) \\ s*```" );
155- Matcher matcher = pattern . matcher ( text );
167+ return extractModelResponseJson ( text , null );
168+ }
156169
157- if (matcher .find ()) {
158- return matcher .group (1 );
170+ public static String extractModelResponseJson (String text , List <String > llmResponsePatterns ) {
171+ Pattern jsonBlockPattern = Pattern .compile ("```json\\ s*([\\ s\\ S]+?)\\ s*```" );
172+ Matcher jsonBlockMatcher = jsonBlockPattern .matcher (text );
173+
174+ if (jsonBlockMatcher .find ()) {
175+ return jsonBlockMatcher .group (1 );
159176 } else {
177+ String matchedPart = findMatchedPart (text , MODEL_RESPONSE_PATTERNS );
178+ if (matchedPart == null && llmResponsePatterns != null ) {
179+ // If no match is found, try additional patterns if provided
180+ matchedPart = findMatchedPart (text , llmResponsePatterns );
181+ }
182+ if (matchedPart != null ) {
183+ return matchedPart ;
184+ }
160185 throw new IllegalArgumentException ("Model output is invalid" );
161186 }
162187 }
163188
189+ public static String findMatchedPart (String text , List <String > patternList ) {
190+ for (String p : patternList ) {
191+ Pattern pattern = Pattern .compile (p );
192+ Matcher matcher = pattern .matcher (text );
193+ if (matcher .find ()) {
194+ return matcher .group ();
195+ }
196+ }
197+ return null ;
198+ }
199+
164200 public static String outputToOutputString (Object output ) throws PrivilegedActionException {
165201 String outputString ;
166202 if (output instanceof ModelTensorOutput ) {
@@ -179,16 +215,6 @@ public static String outputToOutputString(Object output) throws PrivilegedAction
179215 return outputString ;
180216 }
181217
182- public static String parseInputFromLLMReturn (Map <String , ?> retMap ) {
183- Object actionInput = retMap .get ("action_input" );
184- if (actionInput instanceof Map ) {
185- return gson .toJson (actionInput );
186- } else {
187- return String .valueOf (actionInput );
188- }
189-
190- }
191-
192218 public static int getMessageHistoryLimit (Map <String , String > params ) {
193219 String messageHistoryLimitStr = params .get (MESSAGE_HISTORY_LIMIT );
194220 return messageHistoryLimitStr != null ? Integer .parseInt (messageHistoryLimitStr ) : LAST_N_INTERACTIONS ;
@@ -197,4 +223,75 @@ public static int getMessageHistoryLimit(Map<String, String> params) {
197223 public static String getToolName (MLToolSpec toolSpec ) {
198224 return toolSpec .getName () != null ? toolSpec .getName () : toolSpec .getType ();
199225 }
226+
227+ public static List <MLToolSpec > getMlToolSpecs (MLAgent mlAgent , Map <String , String > params ) {
228+ String selectedToolsStr = params .get (SELECTED_TOOLS );
229+ List <MLToolSpec > toolSpecs = mlAgent .getTools ();
230+ if (!Strings .isEmpty (selectedToolsStr )) {
231+ List <String > selectedTools = gson .fromJson (selectedToolsStr , List .class );
232+ Map <String , MLToolSpec > toolNameSpecMap = new HashMap <>();
233+ for (MLToolSpec toolSpec : toolSpecs ) {
234+ toolNameSpecMap .put (getToolName (toolSpec ), toolSpec );
235+ }
236+ List <MLToolSpec > selectedToolSpecs = new ArrayList <>();
237+ for (String tool : selectedTools ) {
238+ if (toolNameSpecMap .containsKey (tool )) {
239+ selectedToolSpecs .add (toolNameSpecMap .get (tool ));
240+ }
241+ }
242+ toolSpecs = selectedToolSpecs ;
243+ }
244+ return toolSpecs ;
245+ }
246+
247+ public static void createTools (
248+ Map <String , Tool .Factory > toolFactories ,
249+ Map <String , String > params ,
250+ List <MLToolSpec > toolSpecs ,
251+ Map <String , Tool > tools ,
252+ Map <String , MLToolSpec > toolSpecMap
253+ ) {
254+ for (MLToolSpec toolSpec : toolSpecs ) {
255+ Tool tool = createTool (toolFactories , params , toolSpec );
256+ tools .put (tool .getName (), tool );
257+ toolSpecMap .put (tool .getName (), toolSpec );
258+ }
259+ }
260+
261+ public static Tool createTool (Map <String , Tool .Factory > toolFactories , Map <String , String > params , MLToolSpec toolSpec ) {
262+ if (!toolFactories .containsKey (toolSpec .getType ())) {
263+ throw new IllegalArgumentException ("Tool not found: " + toolSpec .getType ());
264+ }
265+ Map <String , String > executeParams = new HashMap <>();
266+ if (toolSpec .getParameters () != null ) {
267+ executeParams .putAll (toolSpec .getParameters ());
268+ }
269+ for (String key : params .keySet ()) {
270+ String toolNamePrefix = getToolName (toolSpec ) + "." ;
271+ if (key .startsWith (toolNamePrefix )) {
272+ executeParams .put (key .replace (toolNamePrefix , "" ), params .get (key ));
273+ }
274+ }
275+ Tool tool = toolFactories .get (toolSpec .getType ()).create (executeParams );
276+ String toolName = getToolName (toolSpec );
277+ tool .setName (toolName );
278+
279+ if (toolSpec .getDescription () != null ) {
280+ tool .setDescription (toolSpec .getDescription ());
281+ }
282+ if (params .containsKey (toolName + ".description" )) {
283+ tool .setDescription (params .get (toolName + ".description" ));
284+ }
285+
286+ return tool ;
287+ }
288+
289+ public static List <String > getToolNames (Map <String , Tool > tools ) {
290+ final List <String > inputTools = new ArrayList <>();
291+ for (Map .Entry <String , Tool > entry : tools .entrySet ()) {
292+ String toolName = entry .getValue ().getName ();
293+ inputTools .add (toolName );
294+ }
295+ return inputTools ;
296+ }
200297}
0 commit comments