Skip to content

Commit 1cffe21

Browse files
fine tune prompt;refactor conversational agent code (#2094) (#2107)
* fine tune prompt;refactor conversational agent code Signed-off-by: Yaliang Wu <[email protected]> * put listener to last Signed-off-by: Yaliang Wu <[email protected]> * address comments Signed-off-by: Yaliang Wu <[email protected]> * check if selectedToolsStr is empty Signed-off-by: Yaliang Wu <[email protected]> --------- Signed-off-by: Yaliang Wu <[email protected]> (cherry picked from commit ad14420) Co-authored-by: Yaliang Wu <[email protected]>
1 parent cbd16d3 commit 1cffe21

File tree

12 files changed

+754
-559
lines changed

12 files changed

+754
-559
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 113 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,14 @@
1111
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT;
1212
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES;
1313
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES;
14-
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_PREFIX;
15-
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_SUFFIX;
1614
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;
1715
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;
1816
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS;
1917

2018
import java.security.AccessController;
2119
import java.security.PrivilegedActionException;
2220
import java.security.PrivilegedExceptionAction;
21+
import java.util.ArrayList;
2322
import java.util.HashMap;
2423
import java.util.List;
2524
import java.util.Map;
@@ -28,13 +27,24 @@
2827
import java.util.regex.Pattern;
2928

3029
import org.apache.commons.text.StringSubstitutor;
30+
import org.opensearch.core.common.Strings;
31+
import org.opensearch.ml.common.agent.MLAgent;
3132
import org.opensearch.ml.common.agent.MLToolSpec;
3233
import org.opensearch.ml.common.output.model.ModelTensor;
3334
import org.opensearch.ml.common.output.model.ModelTensorOutput;
3435
import org.opensearch.ml.common.spi.tools.Tool;
3536

3637
public class AgentUtils {
3738

39+
public static final String SELECTED_TOOLS = "selected_tools";
40+
public static final String PROMPT_PREFIX = "prompt.prefix";
41+
public static final String PROMPT_SUFFIX = "prompt.suffix";
42+
public static final String RESPONSE_FORMAT_INSTRUCTION = "prompt.format_instruction";
43+
public static final String TOOL_RESPONSE = "prompt.tool_response";
44+
public static final String PROMPT_CHAT_HISTORY_PREFIX = "prompt.chat_history_prefix";
45+
public static final String DISABLE_TRACE = "disable_trace";
46+
public static final String VERBOSE = "verbose";
47+
3848
public static String addExamplesToPrompt(Map<String, String> parameters, String prompt) {
3949
Map<String, String> examplesMap = new HashMap<>();
4050
if (parameters.containsKey(EXAMPLES)) {
@@ -150,17 +160,43 @@ public static String addContextToPrompt(Map<String, String> parameters, String p
150160
return prompt;
151161
}
152162

163+
public static List<String> MODEL_RESPONSE_PATTERNS = List
164+
.of("\\{\\s*(\"(thought|action|action_input|final_answer)\"\\s*:\\s*\".*?\"\\s*,?\\s*)+\\}");
165+
153166
public static String extractModelResponseJson(String text) {
154-
Pattern pattern = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```");
155-
Matcher matcher = pattern.matcher(text);
167+
return extractModelResponseJson(text, null);
168+
}
156169

157-
if (matcher.find()) {
158-
return matcher.group(1);
170+
public static String extractModelResponseJson(String text, List<String> llmResponsePatterns) {
171+
Pattern jsonBlockPattern = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```");
172+
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);
173+
174+
if (jsonBlockMatcher.find()) {
175+
return jsonBlockMatcher.group(1);
159176
} else {
177+
String matchedPart = findMatchedPart(text, MODEL_RESPONSE_PATTERNS);
178+
if (matchedPart == null && llmResponsePatterns != null) {
179+
// If no match is found, try additional patterns if provided
180+
matchedPart = findMatchedPart(text, llmResponsePatterns);
181+
}
182+
if (matchedPart != null) {
183+
return matchedPart;
184+
}
160185
throw new IllegalArgumentException("Model output is invalid");
161186
}
162187
}
163188

189+
public static String findMatchedPart(String text, List<String> patternList) {
190+
for (String p : patternList) {
191+
Pattern pattern = Pattern.compile(p);
192+
Matcher matcher = pattern.matcher(text);
193+
if (matcher.find()) {
194+
return matcher.group();
195+
}
196+
}
197+
return null;
198+
}
199+
164200
public static String outputToOutputString(Object output) throws PrivilegedActionException {
165201
String outputString;
166202
if (output instanceof ModelTensorOutput) {
@@ -179,16 +215,6 @@ public static String outputToOutputString(Object output) throws PrivilegedAction
179215
return outputString;
180216
}
181217

182-
public static String parseInputFromLLMReturn(Map<String, ?> retMap) {
183-
Object actionInput = retMap.get("action_input");
184-
if (actionInput instanceof Map) {
185-
return gson.toJson(actionInput);
186-
} else {
187-
return String.valueOf(actionInput);
188-
}
189-
190-
}
191-
192218
public static int getMessageHistoryLimit(Map<String, String> params) {
193219
String messageHistoryLimitStr = params.get(MESSAGE_HISTORY_LIMIT);
194220
return messageHistoryLimitStr != null ? Integer.parseInt(messageHistoryLimitStr) : LAST_N_INTERACTIONS;
@@ -197,4 +223,75 @@ public static int getMessageHistoryLimit(Map<String, String> params) {
197223
public static String getToolName(MLToolSpec toolSpec) {
198224
return toolSpec.getName() != null ? toolSpec.getName() : toolSpec.getType();
199225
}
226+
227+
public static List<MLToolSpec> getMlToolSpecs(MLAgent mlAgent, Map<String, String> params) {
228+
String selectedToolsStr = params.get(SELECTED_TOOLS);
229+
List<MLToolSpec> toolSpecs = mlAgent.getTools();
230+
if (!Strings.isEmpty(selectedToolsStr)) {
231+
List<String> selectedTools = gson.fromJson(selectedToolsStr, List.class);
232+
Map<String, MLToolSpec> toolNameSpecMap = new HashMap<>();
233+
for (MLToolSpec toolSpec : toolSpecs) {
234+
toolNameSpecMap.put(getToolName(toolSpec), toolSpec);
235+
}
236+
List<MLToolSpec> selectedToolSpecs = new ArrayList<>();
237+
for (String tool : selectedTools) {
238+
if (toolNameSpecMap.containsKey(tool)) {
239+
selectedToolSpecs.add(toolNameSpecMap.get(tool));
240+
}
241+
}
242+
toolSpecs = selectedToolSpecs;
243+
}
244+
return toolSpecs;
245+
}
246+
247+
public static void createTools(
248+
Map<String, Tool.Factory> toolFactories,
249+
Map<String, String> params,
250+
List<MLToolSpec> toolSpecs,
251+
Map<String, Tool> tools,
252+
Map<String, MLToolSpec> toolSpecMap
253+
) {
254+
for (MLToolSpec toolSpec : toolSpecs) {
255+
Tool tool = createTool(toolFactories, params, toolSpec);
256+
tools.put(tool.getName(), tool);
257+
toolSpecMap.put(tool.getName(), toolSpec);
258+
}
259+
}
260+
261+
public static Tool createTool(Map<String, Tool.Factory> toolFactories, Map<String, String> params, MLToolSpec toolSpec) {
262+
if (!toolFactories.containsKey(toolSpec.getType())) {
263+
throw new IllegalArgumentException("Tool not found: " + toolSpec.getType());
264+
}
265+
Map<String, String> executeParams = new HashMap<>();
266+
if (toolSpec.getParameters() != null) {
267+
executeParams.putAll(toolSpec.getParameters());
268+
}
269+
for (String key : params.keySet()) {
270+
String toolNamePrefix = getToolName(toolSpec) + ".";
271+
if (key.startsWith(toolNamePrefix)) {
272+
executeParams.put(key.replace(toolNamePrefix, ""), params.get(key));
273+
}
274+
}
275+
Tool tool = toolFactories.get(toolSpec.getType()).create(executeParams);
276+
String toolName = getToolName(toolSpec);
277+
tool.setName(toolName);
278+
279+
if (toolSpec.getDescription() != null) {
280+
tool.setDescription(toolSpec.getDescription());
281+
}
282+
if (params.containsKey(toolName + ".description")) {
283+
tool.setDescription(params.get(toolName + ".description"));
284+
}
285+
286+
return tool;
287+
}
288+
289+
public static List<String> getToolNames(Map<String, Tool> tools) {
290+
final List<String> inputTools = new ArrayList<>();
291+
for (Map.Entry<String, Tool> entry : tools.entrySet()) {
292+
String toolName = entry.getValue().getName();
293+
inputTools.add(toolName);
294+
}
295+
return inputTools;
296+
}
200297
}

0 commit comments

Comments
 (0)