Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public static MLAgentType from(String value) {
try {
return MLAgentType.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Wrong Agent type");
throw new IllegalArgumentException(value + " is not a valid Agent Type");
}
}
}
24 changes: 24 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/MLMemoryType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common;

import java.util.Locale;

public enum MLMemoryType {
CONVERSATION_INDEX,
AGENTIC_MEMORY;

public static MLMemoryType from(String value) {
if (value != null) {
try {
return MLMemoryType.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Wrong Memory type");
}
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -113,7 +112,7 @@ private void validate() {
String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH)
);
}
validateMLAgentType(type);
MLAgentType.from(type);
if (type.equalsIgnoreCase(MLAgentType.CONVERSATIONAL.toString()) && llm == null) {
throw new IllegalArgumentException("We need model information for the conversational agent type");
}
Expand All @@ -130,19 +129,6 @@ private void validate() {
}
}

private void validateMLAgentType(String agentType) {
if (type == null) {
throw new IllegalArgumentException("Agent type can't be null");
} else {
try {
MLAgentType.valueOf(agentType.toUpperCase(Locale.ROOT)); // Use toUpperCase() to allow case-insensitive matching
} catch (IllegalArgumentException e) {
// The typeStr does not match any MLAgentType, so throw a new exception with a clearer message.
throw new IllegalArgumentException(agentType + " is not a valid Agent Type");
}
}
}

public MLAgent(StreamInput input) throws IOException {
Version streamInputVersion = input.getVersion();
name = input.readString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.MLMemoryType;
import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLMemorySpec;
Expand Down Expand Up @@ -383,9 +384,7 @@ private void validate() {
String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH)
);
}
if (memoryType != null && !memoryType.equals("conversation_index")) {
throw new IllegalArgumentException(String.format("Invalid memory type: %s", memoryType));
}
MLMemoryType.from(memoryType);
if (tools != null) {
Set<String> toolNames = new HashSet<>();
for (MLToolSpec toolSpec : tools) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ public void testFromWithMixedCase() {
public void testFromWithInvalidType() {
// This should throw an IllegalArgumentException
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Wrong Agent type");
exceptionRule.expectMessage(" is not a valid Agent Type");
MLAgentType.from("INVALID_TYPE");
}

@Test
public void testFromWithEmptyString() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Wrong Agent type");
exceptionRule.expectMessage(" is not a valid Agent Type");
// This should also throw an IllegalArgumentException
MLAgentType.from("");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public void testValidationWithInvalidMemoryType() {
IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> {
MLAgentUpdateInput.builder().agentId("test-agent-id").name("test-agent").memoryType("invalid_type").build();
});
assertEquals("Invalid memory type: invalid_type", e.getMessage());
assertEquals("Wrong Memory type", e.getMessage());
}

@Test
Expand Down
Loading
Loading