Skip to content

Commit d52e518

Browse files
add config field in MLToolSpec for static parameters (#2977) (#3133)
* add config field in MLToolSpec for static parameters Signed-off-by: Jing Zhang <[email protected]> * add version control Signed-off-by: Jing Zhang <[email protected]> * address comments I Signed-off-by: Jing Zhang <[email protected]> * address commits II Signed-off-by: Jing Zhang <[email protected]> * address comments III Signed-off-by: Jing Zhang <[email protected]> --------- Signed-off-by: Jing Zhang <[email protected]> (cherry picked from commit 9ed0040) Co-authored-by: Jing Zhang <[email protected]>
1 parent cf70562 commit d52e518

File tree

11 files changed

+341
-22
lines changed

11 files changed

+341
-22
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,4 +581,5 @@ public class CommonValue {
581581
public static final Version VERSION_2_15_0 = Version.fromString("2.15.0");
582582
public static final Version VERSION_2_16_0 = Version.fromString("2.16.0");
583583
public static final Version VERSION_2_17_0 = Version.fromString("2.17.0");
584+
public static final Version VERSION_2_18_0 = Version.fromString("2.18.0");
584585
}

common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
import java.io.IOException;
1212
import java.util.Map;
1313

14+
import org.opensearch.Version;
1415
import org.opensearch.core.common.io.stream.StreamInput;
1516
import org.opensearch.core.common.io.stream.StreamOutput;
1617
import org.opensearch.core.xcontent.ToXContentObject;
1718
import org.opensearch.core.xcontent.XContentBuilder;
1819
import org.opensearch.core.xcontent.XContentParser;
20+
import org.opensearch.ml.common.CommonValue;
1921

2022
import lombok.Builder;
2123
import lombok.EqualsAndHashCode;
@@ -24,20 +26,31 @@
2426
@EqualsAndHashCode
2527
@Getter
2628
public class MLToolSpec implements ToXContentObject {
29+
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG = CommonValue.VERSION_2_18_0;
30+
2731
public static final String TOOL_TYPE_FIELD = "type";
2832
public static final String TOOL_NAME_FIELD = "name";
2933
public static final String DESCRIPTION_FIELD = "description";
3034
public static final String PARAMETERS_FIELD = "parameters";
3135
public static final String INCLUDE_OUTPUT_IN_AGENT_RESPONSE = "include_output_in_agent_response";
36+
public static final String CONFIG_FIELD = "config";
3237

3338
private String type;
3439
private String name;
3540
private String description;
3641
private Map<String, String> parameters;
3742
private boolean includeOutputInAgentResponse;
43+
private Map<String, String> configMap;
3844

3945
@Builder(toBuilder = true)
40-
public MLToolSpec(String type, String name, String description, Map<String, String> parameters, boolean includeOutputInAgentResponse) {
46+
public MLToolSpec(
47+
String type,
48+
String name,
49+
String description,
50+
Map<String, String> parameters,
51+
boolean includeOutputInAgentResponse,
52+
Map<String, String> configMap
53+
) {
4154
if (type == null) {
4255
throw new IllegalArgumentException("tool type is null");
4356
}
@@ -46,6 +59,7 @@ public MLToolSpec(String type, String name, String description, Map<String, Stri
4659
this.description = description;
4760
this.parameters = parameters;
4861
this.includeOutputInAgentResponse = includeOutputInAgentResponse;
62+
this.configMap = configMap;
4963
}
5064

5165
public MLToolSpec(StreamInput input) throws IOException {
@@ -56,6 +70,9 @@ public MLToolSpec(StreamInput input) throws IOException {
5670
parameters = input.readMap(StreamInput::readString, StreamInput::readOptionalString);
5771
}
5872
includeOutputInAgentResponse = input.readBoolean();
73+
if (input.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG) && input.readBoolean()) {
74+
configMap = input.readMap(StreamInput::readString, StreamInput::readOptionalString);
75+
}
5976
}
6077

6178
public void writeTo(StreamOutput out) throws IOException {
@@ -69,6 +86,14 @@ public void writeTo(StreamOutput out) throws IOException {
6986
out.writeBoolean(false);
7087
}
7188
out.writeBoolean(includeOutputInAgentResponse);
89+
if (out.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG)) {
90+
if (configMap != null) {
91+
out.writeBoolean(true);
92+
out.writeMap(configMap, StreamOutput::writeString, StreamOutput::writeOptionalString);
93+
} else {
94+
out.writeBoolean(false);
95+
}
96+
}
7297
}
7398

7499
@Override
@@ -87,6 +112,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
87112
builder.field(PARAMETERS_FIELD, parameters);
88113
}
89114
builder.field(INCLUDE_OUTPUT_IN_AGENT_RESPONSE, includeOutputInAgentResponse);
115+
if (configMap != null && !configMap.isEmpty()) {
116+
builder.field(CONFIG_FIELD, configMap);
117+
}
90118
builder.endObject();
91119
return builder;
92120
}
@@ -97,6 +125,7 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
97125
String description = null;
98126
Map<String, String> parameters = null;
99127
boolean includeOutputInAgentResponse = false;
128+
Map<String, String> configMap = null;
100129

101130
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
102131
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -119,6 +148,9 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
119148
case INCLUDE_OUTPUT_IN_AGENT_RESPONSE:
120149
includeOutputInAgentResponse = parser.booleanValue();
121150
break;
151+
case CONFIG_FIELD:
152+
configMap = getParameterMap(parser.map());
153+
break;
122154
default:
123155
parser.skipChildren();
124156
break;
@@ -131,6 +163,7 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
131163
.description(description)
132164
.parameters(parameters)
133165
.includeOutputInAgentResponse(includeOutputInAgentResponse)
166+
.configMap(configMap)
134167
.build();
135168
}
136169

common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public void constructor_NullName() {
4646
MLAgentType.CONVERSATIONAL.name(),
4747
"test",
4848
new LLMSpec("test_model", Map.of("test_key", "test_value")),
49-
List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)),
49+
List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())),
5050
null,
5151
null,
5252
Instant.EPOCH,
@@ -66,7 +66,7 @@ public void constructor_NullType() {
6666
null,
6767
"test",
6868
new LLMSpec("test_model", Map.of("test_key", "test_value")),
69-
List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)),
69+
List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())),
7070
null,
7171
null,
7272
Instant.EPOCH,
@@ -86,7 +86,7 @@ public void constructor_NullLLMSpec() {
8686
MLAgentType.CONVERSATIONAL.name(),
8787
"test",
8888
null,
89-
List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)),
89+
List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())),
9090
null,
9191
null,
9292
Instant.EPOCH,
@@ -100,7 +100,14 @@ public void constructor_NullLLMSpec() {
100100
public void constructor_DuplicateTool() {
101101
exceptionRule.expect(IllegalArgumentException.class);
102102
exceptionRule.expectMessage("Duplicate tool defined: test_tool_name");
103-
MLToolSpec mlToolSpec = new MLToolSpec("test_tool_type", "test_tool_name", "test", Collections.EMPTY_MAP, false);
103+
MLToolSpec mlToolSpec = new MLToolSpec(
104+
"test_tool_type",
105+
"test_tool_name",
106+
"test",
107+
Collections.emptyMap(),
108+
false,
109+
Collections.emptyMap()
110+
);
104111
MLAgent agent = new MLAgent(
105112
"test_name",
106113
MLAgentType.CONVERSATIONAL.name(),
@@ -123,7 +130,7 @@ public void writeTo() throws IOException {
123130
"CONVERSATIONAL",
124131
"test",
125132
new LLMSpec("test_model", Map.of("test_key", "test_value")),
126-
List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)),
133+
List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())),
127134
Map.of("test", "test"),
128135
new MLMemorySpec("test", "123", 0),
129136
Instant.EPOCH,
@@ -150,7 +157,7 @@ public void writeTo_NullLLM() throws IOException {
150157
"FLOW",
151158
"test",
152159
null,
153-
List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)),
160+
List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())),
154161
Map.of("test", "test"),
155162
new MLMemorySpec("test", "123", 0),
156163
Instant.EPOCH,
@@ -194,7 +201,7 @@ public void writeTo_NullParameters() throws IOException {
194201
MLAgentType.CONVERSATIONAL.name(),
195202
"test",
196203
new LLMSpec("test_model", Map.of("test_key", "test_value")),
197-
List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)),
204+
List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())),
198205
null,
199206
new MLMemorySpec("test", "123", 0),
200207
Instant.EPOCH,
@@ -216,7 +223,7 @@ public void writeTo_NullMemory() throws IOException {
216223
"CONVERSATIONAL",
217224
"test",
218225
new LLMSpec("test_model", Map.of("test_key", "test_value")),
219-
List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)),
226+
List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())),
220227
Map.of("test", "test"),
221228
null,
222229
Instant.EPOCH,
@@ -238,7 +245,7 @@ public void toXContent() throws IOException {
238245
"CONVERSATIONAL",
239246
"test",
240247
new LLMSpec("test_model", Map.of("test_key", "test_value")),
241-
List.of(new MLToolSpec("test", "test", "test", Map.of("test", "test"), false)),
248+
List.of(new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.emptyMap())),
242249
Map.of("test", "test"),
243250
new MLMemorySpec("test", "123", 0),
244251
Instant.EPOCH,
@@ -294,7 +301,7 @@ public void fromStream() throws IOException {
294301
MLAgentType.CONVERSATIONAL.name(),
295302
"test",
296303
new LLMSpec("test_model", Map.of("test_key", "test_value")),
297-
List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)),
304+
List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())),
298305
Map.of("test", "test"),
299306
new MLMemorySpec("test", "123", 0),
300307
Instant.EPOCH,

0 commit comments

Comments
 (0)