diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java index 9a0d6002fd..fa9dc429ec 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java @@ -8,6 +8,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; +import static org.opensearch.ml.common.CommonValue.VERSION_3_3_0; import java.io.IOException; import java.time.Instant; @@ -50,12 +51,14 @@ public class MLAgentUpdateInput implements ToXContentObject, Writeable { public static final String MEMORY_TYPE_FIELD = "type"; public static final String MEMORY_SESSION_ID_FIELD = "session_id"; public static final String MEMORY_WINDOW_SIZE_FIELD = "window_size"; + public static final String TYPE_FIELD = "type"; public static final String APP_TYPE_FIELD = "app_type"; public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; @Getter private String agentId; private String name; + private String type; private String description; private String llmModelId; private Map llmParameters; @@ -72,6 +75,7 @@ public class MLAgentUpdateInput implements ToXContentObject, Writeable { public MLAgentUpdateInput( String agentId, String name, + String type, String description, String llmModelId, Map llmParameters, @@ -86,6 +90,7 @@ public MLAgentUpdateInput( ) { this.agentId = agentId; this.name = name; + this.type = type; this.description = description; this.llmModelId = llmModelId; this.llmParameters = llmParameters; @@ -104,6 +109,7 @@ public MLAgentUpdateInput(StreamInput in) throws IOException { Version streamInputVersion = in.getVersion(); agentId = in.readString(); name = in.readOptionalString(); + type = streamInputVersion.onOrAfter(VERSION_3_3_0) ? in.readOptionalString() : null; description = in.readOptionalString(); llmModelId = in.readOptionalString(); if (in.readBoolean()) { @@ -134,6 +140,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (name != null) { builder.field(AGENT_NAME_FIELD, name); } + if (type != null) { + builder.field(TYPE_FIELD, type); + } if (description != null) { builder.field(DESCRIPTION_FIELD, description); } @@ -184,6 +193,9 @@ public void writeTo(StreamOutput out) throws IOException { Version streamOutputVersion = out.getVersion(); out.writeString(agentId); out.writeOptionalString(name); + if (streamOutputVersion.onOrAfter(VERSION_3_3_0)) { + out.writeOptionalString(type); + } out.writeOptionalString(description); out.writeOptionalString(llmModelId); if (llmParameters != null && !llmParameters.isEmpty()) { @@ -220,6 +232,7 @@ public void writeTo(StreamOutput out) throws IOException { public static MLAgentUpdateInput parse(XContentParser parser) throws IOException { String agentId = null; String name = null; + String type = null; String description = null; String llmModelId = null; Map llmParameters = null; @@ -243,6 +256,9 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException case AGENT_NAME_FIELD: name = parser.text(); break; + case TYPE_FIELD: + type = parser.text(); + break; case DESCRIPTION_FIELD: description = parser.text(); break; @@ -313,6 +329,7 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException return new MLAgentUpdateInput( agentId, name, + type, description, llmModelId, llmParameters, @@ -328,6 +345,9 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException } public MLAgent toMLAgent(MLAgent originalAgent) { + if (type != null && !type.equals(originalAgent.getType())) { + throw new IllegalArgumentException("Agent type cannot be updated"); + } LLMSpec finalLlm; if (llmModelId == null && (llmParameters == null || llmParameters.isEmpty())) { finalLlm = originalAgent.getLlm(); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java index 72eb035279..4cf2924c84 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java @@ -392,6 +392,7 @@ public void testParseWithAllFields() throws Exception { { "agent_id": "test-agent-id", "name": "test-agent", + "type": "flow", "description": "test description", "llm": { "model_id": "test-model-id", @@ -423,6 +424,7 @@ public void testParseWithAllFields() throws Exception { """; testParseFromJsonString(inputStr, parsedInput -> { assertEquals("test-agent", parsedInput.getName()); + assertEquals("flow", parsedInput.getType()); assertEquals("test description", parsedInput.getDescription()); assertEquals("test-model-id", parsedInput.getLlmModelId()); assertEquals(1, parsedInput.getTools().size()); @@ -959,6 +961,41 @@ public void testCombinedLLMAndMemoryPartialUpdates() { assertEquals(Integer.valueOf(10), updatedAgent.getMemory().getWindowSize()); // Updated } + @Test + public void testAgentTypeValidation() { + MLAgent originalAgent = MLAgent.builder().type(MLAgentType.FLOW.name()).name("Test Agent").build(); + + // Same type should be allowed + MLAgentUpdateInput sameTypeInput = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .type(MLAgentType.FLOW.name()) + .name("Updated Name") + .build(); + + MLAgent updatedAgent = sameTypeInput.toMLAgent(originalAgent); + assertEquals(MLAgentType.FLOW.name(), updatedAgent.getType()); + assertEquals("Updated Name", updatedAgent.getName()); + + // Different type should throw error + MLAgentUpdateInput differentTypeInput = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .type(MLAgentType.CONVERSATIONAL.name()) + .name("Updated Name") + .build(); + + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> { differentTypeInput.toMLAgent(originalAgent); }); + assertEquals("Agent type cannot be updated", e.getMessage()); + + // No type provided should work (original type) + MLAgentUpdateInput noTypeInput = MLAgentUpdateInput.builder().agentId("test-agent-id").name("Updated Name").build(); + + MLAgent originalAgentType = noTypeInput.toMLAgent(originalAgent); + assertEquals(MLAgentType.FLOW.name(), originalAgentType.getType()); + assertEquals("Updated Name", originalAgentType.getName()); + } + @Test public void testStreamInputOutputWithVersion() throws IOException { MLAgentUpdateInput input = MLAgentUpdateInput