Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
import static org.opensearch.ml.common.CommonValue.VERSION_3_3_0;

import java.io.IOException;
import java.time.Instant;
Expand Down Expand Up @@ -50,12 +51,14 @@ public class MLAgentUpdateInput implements ToXContentObject, Writeable {
public static final String MEMORY_TYPE_FIELD = "type";
public static final String MEMORY_SESSION_ID_FIELD = "session_id";
public static final String MEMORY_WINDOW_SIZE_FIELD = "window_size";
public static final String TYPE_FIELD = "type";
public static final String APP_TYPE_FIELD = "app_type";
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time";

@Getter
private String agentId;
private String name;
private String type;
private String description;
private String llmModelId;
private Map<String, String> llmParameters;
Expand All @@ -72,6 +75,7 @@ public class MLAgentUpdateInput implements ToXContentObject, Writeable {
public MLAgentUpdateInput(
String agentId,
String name,
String type,
String description,
String llmModelId,
Map<String, String> llmParameters,
Expand All @@ -86,6 +90,7 @@ public MLAgentUpdateInput(
) {
this.agentId = agentId;
this.name = name;
this.type = type;
this.description = description;
this.llmModelId = llmModelId;
this.llmParameters = llmParameters;
Expand All @@ -104,6 +109,7 @@ public MLAgentUpdateInput(StreamInput in) throws IOException {
Version streamInputVersion = in.getVersion();
agentId = in.readString();
name = in.readOptionalString();
type = streamInputVersion.onOrAfter(VERSION_3_3_0) ? in.readOptionalString() : null;
description = in.readOptionalString();
llmModelId = in.readOptionalString();
if (in.readBoolean()) {
Expand Down Expand Up @@ -134,6 +140,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (name != null) {
builder.field(AGENT_NAME_FIELD, name);
}
if (type != null) {
builder.field(TYPE_FIELD, type);
}
if (description != null) {
builder.field(DESCRIPTION_FIELD, description);
}
Expand Down Expand Up @@ -184,6 +193,9 @@ public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeString(agentId);
out.writeOptionalString(name);
if (streamOutputVersion.onOrAfter(VERSION_3_3_0)) {
out.writeOptionalString(type);
}
out.writeOptionalString(description);
out.writeOptionalString(llmModelId);
if (llmParameters != null && !llmParameters.isEmpty()) {
Expand Down Expand Up @@ -220,6 +232,7 @@ public void writeTo(StreamOutput out) throws IOException {
public static MLAgentUpdateInput parse(XContentParser parser) throws IOException {
String agentId = null;
String name = null;
String type = null;
String description = null;
String llmModelId = null;
Map<String, String> llmParameters = null;
Expand All @@ -243,6 +256,9 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException
case AGENT_NAME_FIELD:
name = parser.text();
break;
case TYPE_FIELD:
type = parser.text();
break;
case DESCRIPTION_FIELD:
description = parser.text();
break;
Expand Down Expand Up @@ -313,6 +329,7 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException
return new MLAgentUpdateInput(
agentId,
name,
type,
description,
llmModelId,
llmParameters,
Expand All @@ -328,6 +345,9 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException
}

public MLAgent toMLAgent(MLAgent originalAgent) {
if (type != null && !type.equals(originalAgent.getType())) {
throw new IllegalArgumentException("Agent type cannot be updated");
}
LLMSpec finalLlm;
if (llmModelId == null && (llmParameters == null || llmParameters.isEmpty())) {
finalLlm = originalAgent.getLlm();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ public void testParseWithAllFields() throws Exception {
{
"agent_id": "test-agent-id",
"name": "test-agent",
"type": "flow",
"description": "test description",
"llm": {
"model_id": "test-model-id",
Expand Down Expand Up @@ -423,6 +424,7 @@ public void testParseWithAllFields() throws Exception {
""";
testParseFromJsonString(inputStr, parsedInput -> {
assertEquals("test-agent", parsedInput.getName());
assertEquals("flow", parsedInput.getType());
assertEquals("test description", parsedInput.getDescription());
assertEquals("test-model-id", parsedInput.getLlmModelId());
assertEquals(1, parsedInput.getTools().size());
Expand Down Expand Up @@ -959,6 +961,41 @@ public void testCombinedLLMAndMemoryPartialUpdates() {
assertEquals(Integer.valueOf(10), updatedAgent.getMemory().getWindowSize()); // Updated
}

@Test
public void testAgentTypeValidation() {
MLAgent originalAgent = MLAgent.builder().type(MLAgentType.FLOW.name()).name("Test Agent").build();

// Same type should be allowed
MLAgentUpdateInput sameTypeInput = MLAgentUpdateInput
.builder()
.agentId("test-agent-id")
.type(MLAgentType.FLOW.name())
.name("Updated Name")
.build();

MLAgent updatedAgent = sameTypeInput.toMLAgent(originalAgent);
assertEquals(MLAgentType.FLOW.name(), updatedAgent.getType());
assertEquals("Updated Name", updatedAgent.getName());

// Different type should throw error
MLAgentUpdateInput differentTypeInput = MLAgentUpdateInput
.builder()
.agentId("test-agent-id")
.type(MLAgentType.CONVERSATIONAL.name())
.name("Updated Name")
.build();

IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> { differentTypeInput.toMLAgent(originalAgent); });
assertEquals("Agent type cannot be updated", e.getMessage());

// No type provided should work (original type)
MLAgentUpdateInput noTypeInput = MLAgentUpdateInput.builder().agentId("test-agent-id").name("Updated Name").build();

MLAgent originalAgentType = noTypeInput.toMLAgent(originalAgent);
assertEquals(MLAgentType.FLOW.name(), originalAgentType.getType());
assertEquals("Updated Name", originalAgentType.getName());
}

@Test
public void testStreamInputOutputWithVersion() throws IOException {
MLAgentUpdateInput input = MLAgentUpdateInput
Expand Down