diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLDeleteWorkingMemoryRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLDeleteWorkingMemoryRequest.java deleted file mode 100644 index 0cae293a36..0000000000 --- a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLDeleteWorkingMemoryRequest.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.memorycontainer.memory; - -import java.io.IOException; - -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; - -import lombok.Builder; -import lombok.Getter; - -@Getter -public class MLDeleteWorkingMemoryRequest extends ActionRequest { - private String memoryContainerId; - private String workingMemoryId; - - @Builder - public MLDeleteWorkingMemoryRequest(String memoryContainerId, String workingMemoryId) { - this.memoryContainerId = memoryContainerId; - this.workingMemoryId = workingMemoryId; - } - - public MLDeleteWorkingMemoryRequest(StreamInput in) throws IOException { - super(in); - this.memoryContainerId = in.readString(); - this.workingMemoryId = in.readString(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - out.writeString(memoryContainerId); - out.writeString(workingMemoryId); - } - - @Override - public ActionRequestValidationException validate() { - ActionRequestValidationException exception = null; - if (memoryContainerId == null || memoryContainerId.isEmpty()) { - exception = new ActionRequestValidationException(); - exception.addValidationError("Memory container id is required"); - } - if (workingMemoryId == null || workingMemoryId.isEmpty()) { - if (exception == null) { - exception = new ActionRequestValidationException(); - } - exception.addValidationError("Working memory id is required"); - } - return exception; - } -} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetWorkingMemoryRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetWorkingMemoryRequest.java deleted file mode 100644 index 0bf0e9b4c7..0000000000 --- a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetWorkingMemoryRequest.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.memorycontainer.memory; - -import java.io.IOException; - -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; - -import lombok.Builder; -import lombok.Getter; - -@Getter -public class MLGetWorkingMemoryRequest extends ActionRequest { - private String memoryContainerId; - private String workingMemoryId; - - @Builder - public MLGetWorkingMemoryRequest(String memoryContainerId, String workingMemoryId) { - this.memoryContainerId = memoryContainerId; - this.workingMemoryId = workingMemoryId; - } - - public MLGetWorkingMemoryRequest(StreamInput in) throws IOException { - super(in); - this.memoryContainerId = in.readString(); - this.workingMemoryId = in.readString(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - out.writeString(memoryContainerId); - out.writeString(workingMemoryId); - } - - @Override - public ActionRequestValidationException validate() { - ActionRequestValidationException exception = null; - if (memoryContainerId == null || memoryContainerId.isEmpty()) { - exception = new ActionRequestValidationException(); - exception.addValidationError("Memory container id is required"); - } - if (workingMemoryId == null || workingMemoryId.isEmpty()) { - if (exception == null) { - exception = new ActionRequestValidationException(); - } - exception.addValidationError("Working memory id is required"); - } - return exception; - } -} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetWorkingMemoryResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetWorkingMemoryResponse.java deleted file mode 100644 index 887e31f66b..0000000000 --- a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetWorkingMemoryResponse.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.memorycontainer.memory; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import org.opensearch.core.action.ActionResponse; -import org.opensearch.core.common.io.stream.InputStreamStreamInput; -import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.ml.common.memorycontainer.MLWorkingMemory; - -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; - -@Getter -@ToString -public class MLGetWorkingMemoryResponse extends ActionResponse implements ToXContentObject { - - private MLWorkingMemory workingMemory; - - @Builder - public MLGetWorkingMemoryResponse(MLWorkingMemory workingMemory) { - this.workingMemory = workingMemory; - } - - public MLGetWorkingMemoryResponse(StreamInput in) throws IOException { - super(in); - workingMemory = new MLWorkingMemory(in); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - workingMemory.writeTo(out); - } - - @Override - public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { - return workingMemory.toXContent(xContentBuilder, params); - } - - public static MLGetWorkingMemoryResponse fromActionResponse(ActionResponse actionResponse) { - if (actionResponse instanceof MLGetWorkingMemoryResponse) { - return (MLGetWorkingMemoryResponse) actionResponse; - } - - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { - actionResponse.writeTo(osso); - try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new MLGetWorkingMemoryResponse(input); - } - } catch (IOException e) { - throw new UncheckedIOException("failed to parse ActionResponse into MLGetWorkingMemoryResponse", e); - } - } -} diff --git a/common/src/main/java/org/opensearch/ml/common/utils/IndexUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/IndexUtils.java index dbb27940d0..c4f1530417 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/IndexUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/IndexUtils.java @@ -52,17 +52,33 @@ public class IndexUtils { public static final Map MAPPING_PLACEHOLDERS = Map .of(USER_PLACEHOLDER, "index-mappings/placeholders/user.json", CONNECTOR_PLACEHOLDER, "index-mappings/placeholders/connector.json"); - public static String getMappingFromFile(String path) throws IOException { + /** + * Loads a resource file from the classpath as a String. + * This is a utility method for loading JSON or text resources. + * + * @param path The path to the resource file relative to the classpath root + * @param resourceType A descriptive name for the resource type (e.g., "mapping", "schema") for error messages + * @return The resource content as a trimmed String + * @throws IOException if the resource cannot be found or loaded + * @throws IllegalArgumentException if the resource is empty + */ + public static String loadResourceFromFile(String path, String resourceType) throws IOException { URL url = IndexUtils.class.getClassLoader().getResource(path); if (url == null) { - throw new IOException("Resource not found: " + path); + throw new IOException(resourceType + " resource not found: " + path); } - String mapping = Resources.toString(url, Charsets.UTF_8).trim(); - if (mapping.isEmpty()) { - throw new IllegalArgumentException("Empty mapping found at: " + path); + String content = Resources.toString(url, Charsets.UTF_8).trim(); + if (content.isEmpty()) { + throw new IllegalArgumentException("Empty " + resourceType + " found at: " + path); } + return content; + } + + public static String getMappingFromFile(String path) throws IOException { + String mapping = loadResourceFromFile(path, "Mapping"); + mapping = replacePlaceholders(mapping); validateMapping(mapping); diff --git a/common/src/main/java/org/opensearch/ml/common/utils/LlmResultPathGenerator.java b/common/src/main/java/org/opensearch/ml/common/utils/LlmResultPathGenerator.java new file mode 100644 index 0000000000..c23392ec30 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/utils/LlmResultPathGenerator.java @@ -0,0 +1,244 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.utils; + +import java.io.IOException; +import java.util.Iterator; +import java.util.Map; + +import org.opensearch.OpenSearchParseException; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import lombok.extern.log4j.Log4j2; + +/** + * Utility class for auto-generating JSONPath expressions from JSON Schema. + * + * This class analyzes the "output" schema from MLModel's modelInterface field + * and generates a JSONPath expression to extract LLM text responses from + * connector-specific dataAsMap structures. + * + * The generator looks for fields marked with the custom schema property + * "x-llm-output": true to identify the target LLM text field. + * + * Example Usage: + *
+ * String outputSchema = model.getModelInterface().get("output");
+ * String llmResultPath = LlmResultPathGenerator.generate(outputSchema);
+ * // Returns: "$.choices[0].message.content" (for OpenAI)
+ * // or: "$.content[0].text" (for Bedrock Claude)
+ * 
+ */ +@Log4j2 +public class LlmResultPathGenerator { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + // Custom JSON Schema extension marker for LLM output fields + private static final String LLM_OUTPUT_MARKER = "x-llm-output"; + + /** + * Generates a JSONPath expression from the model's output schema. + * + * This method searches for fields marked with "x-llm-output": true in the schema. + * It is designed to work with properly annotated schemas from supported models + * (GPT-4o-mini, GPT-5, Claude 3.7+). + * + * If no marker is found, returns null and the caller should use a default fallback path. + * + * @param outputSchemaJson The JSON Schema string from model.interface.output + * @return JSONPath expression (e.g., "$.choices[0].message.content"), or null if no marker found + * @throws IOException if schema parsing fails + * @throws OpenSearchParseException if schema structure is invalid + */ + public static String generate(String outputSchemaJson) throws IOException { + if (outputSchemaJson == null || outputSchemaJson.trim().isEmpty()) { + log.warn("Output schema is null or empty, cannot generate llm_result_path"); + return null; + } + + log.info("Starting llm_result_path auto-generation (schema size: {} bytes)", outputSchemaJson.length()); + + try { + JsonNode schemaRoot = MAPPER.readTree(outputSchemaJson); + + // Navigate to dataAsMap schema node using hardcoded path; if not found, search from root + JsonNode searchRoot = navigateToDataAsMapSchema(schemaRoot); + if (searchRoot == null) { + log.info("No dataAsMap schema found in standard ModelTensorOutput path, searching from schema root"); + searchRoot = schemaRoot; + } else { + log.info("Found dataAsMap schema at standard path, searching within dataAsMap structure"); + } + + // Search for LLM output field with x-llm-output marker + String jsonPath = findLlmTextField(searchRoot, "$"); + + if (jsonPath == null) { + log.warn("No field with x-llm-output marker found in schema, will use fallback path"); + return null; + } + + log.info("Successfully generated llm_result_path: {}", jsonPath); + return jsonPath; + + } catch (Exception e) { + log.error("Failed to generate llm_result_path from schema", e); + throw new OpenSearchParseException("Schema parsing error: " + e.getMessage(), e); + } + } + + /** + * Navigates to the dataAsMap schema node using the rigid ModelTensorOutput structure. + * + * The path follows the serialization structure defined by: + * - ModelTensorOutput.INFERENCE_RESULT_FIELD = "inference_results" + * - ModelTensors.OUTPUT_FIELD = "output" + * - ModelTensor.DATA_AS_MAP_FIELD = "dataAsMap" + * + * Schema path: properties.inference_results.items.properties.output.items.properties.dataAsMap + * + * @param schemaRoot The root schema node + * @return The dataAsMap schema node if found, null otherwise + */ + private static JsonNode navigateToDataAsMapSchema(JsonNode schemaRoot) { + if (schemaRoot == null || schemaRoot.isMissingNode()) { + log.debug("Schema root is null or missing"); + return null; + } + + log + .debug( + "Navigating to dataAsMap using rigid path: properties.inference_results.items.properties.output.items.properties.dataAsMap" + ); + + // Follow the rigid ModelTensorOutput → ModelTensors → ModelTensor structure + JsonNode dataAsMapSchema = schemaRoot + .path("properties") + .path("inference_results") + .path("items") + .path("properties") + .path("output") + .path("items") + .path("properties") + .path("dataAsMap"); + + if (dataAsMapSchema.isMissingNode()) { + log.debug("dataAsMap not found at standard path"); + return null; + } + + log.debug("Successfully navigated to dataAsMap schema"); + return dataAsMapSchema; + } + + /** + * Recursively searches for the LLM text field marked with "x-llm-output": true. + * + * @param schemaNode The current schema node to search + * @param currentPath The current JSONPath being built + * @return JSONPath expression to the LLM text field, or null if not found + */ + private static String findLlmTextField(JsonNode schemaNode, String currentPath) { + return findLlmTextFieldWithMarker(schemaNode, currentPath); + } + + /** + * Searches ONLY for fields with explicit "x-llm-output": true marker. + * Does NOT use any heuristic field name matching. + * + * @param schemaNode The current schema node to search + * @param currentPath The current JSONPath being built + * @return JSONPath expression if marker found, null otherwise + */ + private static String findLlmTextFieldWithMarker(JsonNode schemaNode, String currentPath) { + if (schemaNode == null || schemaNode.isMissingNode()) { + return null; + } + + // Check if this field has the x-llm-output marker + JsonNode marker = schemaNode.get(LLM_OUTPUT_MARKER); + if (marker != null && marker.isBoolean() && marker.asBoolean()) { + log.info("Found x-llm-output marker at path: {}", currentPath); + return currentPath; + } + + // Get the type of this schema node + JsonNode typeNode = schemaNode.get("type"); + String type = typeNode != null && typeNode.isTextual() ? typeNode.asText() : null; + + // If it's an object, recursively search properties + if ("object".equals(type) || schemaNode.has("properties")) { + JsonNode properties = schemaNode.get("properties"); + if (properties != null && properties.isObject()) { + Iterator> fields = properties.fields(); + while (fields.hasNext()) { + Map.Entry field = fields.next(); + String fieldName = field.getKey(); + JsonNode fieldSchema = field.getValue(); + + String newPath = currentPath + "." + fieldName; + log.debug("Checking field: {}", newPath); + String result = findLlmTextFieldWithMarker(fieldSchema, newPath); + if (result != null) { + return result; + } + } + } + } + + // If it's an array, navigate into items + if ("array".equals(type) || schemaNode.has("items")) { + JsonNode items = schemaNode.get("items"); + if (items != null) { + String newPath = currentPath + "[0]"; + log.debug("Navigating into array: {}", newPath); + String result = findLlmTextFieldWithMarker(items, newPath); + if (result != null) { + return result; + } + } + } + + return null; + } + + /** + * Validates that a generated JSONPath can be parsed and applied. + * + * This is a basic validation that checks if the path syntax is valid. + * It does not validate against actual data. + * + * @param jsonPath The JSONPath expression to validate + * @return true if the path appears valid, false otherwise + */ + public static boolean isValidJsonPath(String jsonPath) { + if (jsonPath == null || jsonPath.trim().isEmpty()) { + return false; + } + + // Basic validation: must start with $ + if (!jsonPath.startsWith("$")) { + return false; + } + + // Check for balanced brackets + int bracketCount = 0; + for (char c : jsonPath.toCharArray()) { + if (c == '[') + bracketCount++; + if (c == ']') + bracketCount--; + if (bracketCount < 0) + return false; + } + + return bracketCount == 0; + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java index a896f6075f..973fd6b979 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java @@ -5,7 +5,9 @@ package org.opensearch.ml.common.utils; +import java.io.IOException; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; @@ -17,869 +19,278 @@ @Log4j2 public class ModelInterfaceUtils { - private static final String GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT = "{\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"parameters\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"inputs\": {\n" - + " \"type\": \"string\"\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"inputs\"\n" - + " ]\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"parameters\"\n" - + " ]\n" - + "}"; + // Schema loading infrastructure + private static final String SCHEMA_BASE_PATH = "model-interface-schemas"; + private static final ConcurrentHashMap SCHEMA_CACHE = new ConcurrentHashMap<>(); - private static final String GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT = "{\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"parameters\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"texts\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"string\"\n" - + " }\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"texts\"\n" - + " ]\n" - + " }\n" - + " }\n" - + "}"; - - private static final String TITAN_TEXT_EMBEDDING_MODEL_INTERFACE_INPUT = "{\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"parameters\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"inputText\": {\n" - + " \"type\": \"string\"\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"inputText\"\n" - + " ]\n" - + " }\n" - + " }\n" - + "}"; - - private static final String TITAN_MULTI_MODAL_EMBEDDING_MODEL_INTERFACE_INPUT = "{\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"parameters\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"inputText\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"inputImage\": {\n" - + " \"type\": \"string\"\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + "}"; - - private static final String AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE_INPUT = "{\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"parameters\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"Text\": {\n" - + " \"type\": \"string\"\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"Text\"\n" - + " ]\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"parameters\"\n" - + " ]\n" - + "}"; - - private static final String AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE_INPUT = "{\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"parameters\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"bytes\": {\n" - + " \"type\": \"string\"\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"bytes\"\n" - + " ]\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"parameters\"\n" - + " ]\n" - + "}"; - - private static final String GENERAL_CONVERSATIONAL_MODEL_INTERFACE_OUTPUT = "{\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"inference_results\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"output\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"name\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"dataAsMap\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"response\": {\n" - + " \"type\": \"string\"\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"response\"\n" - + " ]\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"name\",\n" - + " \"dataAsMap\"\n" - + " ]\n" - + " }\n" - + " },\n" - + " \"status_code\": {\n" - + " \"type\": \"integer\"\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"output\",\n" - + " \"status_code\"\n" - + " ]\n" - + " }\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"inference_results\"\n" - + " ]\n" - + "}"; - - private static final String BEDROCK_AI21_J2_MID_V1_RAW_MODEL_INTERFACE_OUTPUT = "{\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"inference_results\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"output\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"name\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"dataAsMap\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"id\": {\n" - + " \"type\": \"number\"\n" - + " },\n" - + " \"prompt\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"text\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"tokens\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"generatedToken\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"token\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"logprob\": {\n" - + " \"type\": \"number\"\n" - + " },\n" - + " \"raw_logprob\": {\n" - + " \"type\": \"number\"\n" - + " }\n" - + " }\n" - + " },\n" - + " \"textRange\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"start\": {\n" - + " \"type\": \"number\"\n" - + " },\n" - + " \"end\": {\n" - + " \"type\": \"number\"\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " },\n" - + " \"completions\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"data\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"text\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"tokens\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"generatedToken\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"token\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"logprob\": {\n" - + " \"type\": \"number\"\n" - + " },\n" - + " \"raw_logprob\": {\n" - + " \"type\": \"number\"\n" - + " }\n" - + " }\n" - + " },\n" - + " \"textRange\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"start\": {\n" - + " \"type\": \"number\"\n" - + " },\n" - + " \"end\": {\n" - + " \"type\": \"number\"\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " },\n" - + " \"finishReason\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"reason\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"length\": {\n" - + " \"type\": \"number\"\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " },\n" - + " \"status_code\": {\n" - + " \"type\": \"integer\"\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + "}"; - - private static final String BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE_OUTPUT = "{\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"inference_results\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"output\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"name\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"dataAsMap\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"type\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"completion\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"stop_reason\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"stop\": {\n" - + " \"type\": \"string\"\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"type\",\n" - + " \"completion\"\n" - + " ]\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"name\",\n" - + " \"dataAsMap\"\n" - + " ]\n" - + " }\n" - + " },\n" - + " \"status_code\": {\n" - + " \"type\": \"integer\"\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"output\",\n" - + " \"status_code\"\n" - + " ]\n" - + " }\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"inference_results\"\n" - + " ]\n" - + "}"; - - private static final String GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT = "{\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"inference_results\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"output\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"name\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"data_type\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"shape\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"integer\"\n" - + " }\n" - + " },\n" - + " \"data\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"number\"\n" - + " }\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"name\",\n" - + " \"data_type\",\n" - + " \"shape\",\n" - + " \"data\"\n" - + " ]\n" - + " }\n" - + " },\n" - + " \"status_code\": {\n" - + " \"type\": \"integer\"\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"output\",\n" - + " \"status_code\"\n" - + " ]\n" - + " }\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"inference_results\"\n" - + " ]\n" - + "}"; - - private static final String AMAZON_TITAN_EMBEDDING_V1_RAW_INTERFACE_OUTPUT = "{\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"inference_results\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"output\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"name\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"dataAsMap\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"embedding\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"number\"\n" - + " }\n" - + " },\n" - + " \"inputTextTokenCount\": {\n" - + " \"type\": \"integer\"\n" - + " }\n" - + " },\n" - + " \"required\": [\"embedding\", \"inputTextTokenCount\"]\n" - + " }\n" - + " },\n" - + " \"required\": [\"name\", \"dataAsMap\"]\n" - + " }\n" - + " },\n" - + " \"status_code\": {\n" - + " \"type\": \"integer\"\n" - + " }\n" - + " },\n" - + " \"required\": [\"output\", \"status_code\"]\n" - + " }\n" - + " }\n" - + " },\n" - + " \"required\": [\"inference_results\"]\n" - + "}"; - - private static final String COHERE_EMBEDDING_V3_RAW_INTERFACE_OUTPUT = "{\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"inference_results\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"output\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"name\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"dataAsMap\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"id\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"texts\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"string\"\n" - + " }\n" - + " },\n" - + " \"embeddings\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"number\"\n" - + " }\n" - + " }\n" - + " },\n" - + " \"response_type\": {\n" - + " \"type\": \"string\"\n" - + " }\n" - + " },\n" - + " \"required\": [\"id\", \"texts\", \"embeddings\", \"response_type\"]\n" - + " }\n" - + " },\n" - + " \"required\": [\"name\", \"dataAsMap\"]\n" - + " }\n" - + " },\n" - + " \"status_code\": {\n" - + " \"type\": \"integer\"\n" - + " }\n" - + " },\n" - + " \"required\": [\"output\", \"status_code\"]\n" - + " }\n" - + " }\n" - + " },\n" - + " \"required\": [\"inference_results\"]\n" - + "}"; - - private static final String AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE_OUTPUT = "{\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"inference_results\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"output\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"name\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"dataAsMap\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"response\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"Languages\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"LanguageCode\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"Score\": {\n" - + " \"type\": \"number\"\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"LanguageCode\",\n" - + " \"Score\"\n" - + " ]\n" - + " }\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"Languages\"\n" - + " ]\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"response\"\n" - + " ]\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"name\",\n" - + " \"dataAsMap\"\n" - + " ]\n" - + " }\n" - + " },\n" - + " \"status_code\": {\n" - + " \"type\": \"integer\"\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"output\",\n" - + " \"status_code\"\n" - + " ]\n" - + " }\n" - + " }\n" - + " },\n" - + " \"required\": [\n" - + " \"inference_results\"\n" - + " ]\n" - + "}"; - - private static final String AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE_OUTPUT = "{\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"inference_results\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"output\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"name\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"dataAsMap\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"Blocks\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"BlockType\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"Geometry\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"BoundingBox\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"Height\": {\n" - + " \"type\": \"number\"\n" - + " },\n" - + " \"Left\": {\n" - + " \"type\": \"number\"\n" - + " },\n" - + " \"Top\": {\n" - + " \"type\": \"number\"\n" - + " },\n" - + " \"Width\": {\n" - + " \"type\": \"number\"\n" - + " }\n" - + " }\n" - + " },\n" - + " \"Polygon\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"X\": {\n" - + " \"type\": \"number\"\n" - + " },\n" - + " \"Y\": {\n" - + " \"type\": \"number\"\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " },\n" - + " \"Id\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"Relationships\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"Ids\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"string\"\n" - + " }\n" - + " },\n" - + " \"Type\": {\n" - + " \"type\": \"string\"\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " },\n" - + " \"DetectDocumentTextModelVersion\": {\n" - + " \"type\": \"string\"\n" - + " },\n" - + " \"DocumentMetadata\": {\n" - + " \"type\": \"object\",\n" - + " \"properties\": {\n" - + " \"Pages\": {\n" - + " \"type\": \"number\"\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " },\n" - + " \"status_code\": {\n" - + " \"type\": \"number\"\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + "}"; - - public static final Map BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE = Map - .of("input", GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT, "output", GENERAL_CONVERSATIONAL_MODEL_INTERFACE_OUTPUT); - - public static final Map BEDROCK_AI21_LABS_JURASSIC2_MID_V1_RAW_MODEL_INTERFACE = Map - .of("input", GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT, "output", BEDROCK_AI21_J2_MID_V1_RAW_MODEL_INTERFACE_OUTPUT); - - public static final Map BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE = Map - .of("input", GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT, "output", GENERAL_CONVERSATIONAL_MODEL_INTERFACE_OUTPUT); - - public static final Map BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE = Map - .of("input", GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT, "output", BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE_OUTPUT); + /** + * Loads a schema from the resources directory with caching. + * + * @param schemaPath The path to the schema file relative to resources root + * @return The schema content as a String + * @throws IOException if the schema file cannot be loaded + */ + private static String loadSchemaFromFile(String schemaPath) throws IOException { + // Check cache first + String cachedSchema = SCHEMA_CACHE.get(schemaPath); + if (cachedSchema != null) { + return cachedSchema; + } - public static final Map BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE = Map - .of("input", GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT, "output", GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT); + // Load from file using shared utility + String schema = IndexUtils.loadResourceFromFile(schemaPath, "Schema"); - public static final Map BEDROCK_COHERE_EMBED_ENGLISH_V3_RAW_MODEL_INTERFACE = Map - .of("input", GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT, "output", COHERE_EMBEDDING_V3_RAW_INTERFACE_OUTPUT); + // Cache and return + SCHEMA_CACHE.put(schemaPath, schema); + log.debug("Loaded and cached schema from: {}", schemaPath); + return schema; + } - public static final Map BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE = Map - .of("input", GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT, "output", GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT); + /** + * Loads an input schema by name. + * + * @param schemaName The name of the input schema file (without .json extension) + * @return The schema content as a String + * @throws IOException if the schema file cannot be loaded + */ + private static String getInputSchema(String schemaName) throws IOException { + String schemaPath = SCHEMA_BASE_PATH + "/input/" + schemaName + ".json"; + return loadSchemaFromFile(schemaPath); + } - public static final Map BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_RAW_MODEL_INTERFACE = Map - .of("input", GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT, "output", COHERE_EMBEDDING_V3_RAW_INTERFACE_OUTPUT); + /** + * Loads an output schema by name. + * + * @param schemaName The name of the output schema file (without .json extension) + * @return The schema content as a String + * @throws IOException if the schema file cannot be loaded + */ + private static String getOutputSchema(String schemaName) throws IOException { + String schemaPath = SCHEMA_BASE_PATH + "/output/" + schemaName + ".json"; + return loadSchemaFromFile(schemaPath); + } - public static final Map BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE = Map - .of("input", TITAN_TEXT_EMBEDDING_MODEL_INTERFACE_INPUT, "output", GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT); + /** + * Enum representing all supported model interface schema variants. + * Each variant defines input and output schema file names and provides lazy loading with caching. + */ + public enum ModelInterfaceSchema { + BEDROCK_AI21_LABS_JURASSIC2_MID_V1( + "bedrock_ai21_labs_jurassic2_mid_v1", + "general_conversational_single_round_input", + "general_conversational_single_round_output" + ), + BEDROCK_AI21_LABS_JURASSIC2_MID_V1_RAW( + "bedrock_ai21_labs_jurassic2_mid_v1_raw", + "general_conversational_single_round_input", + "bedrock_ai21_j2_raw_output" + ), + BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET( + "bedrock_anthropic_claude_v3_sonnet", + "general_conversational_single_round_input", + "general_conversational_single_round_output" + ), + BEDROCK_ANTHROPIC_CLAUDE_V2( + "bedrock_anthropic_claude_v2", + "general_conversational_single_round_input", + "bedrock_anthropic_claude_v2_output" + ), + BEDROCK_COHERE_EMBED_ENGLISH_V3("bedrock_cohere_embed_english_v3", "general_embedding_input", "general_embedding_output"), + BEDROCK_COHERE_EMBED_ENGLISH_V3_RAW( + "bedrock_cohere_embed_english_v3_raw", + "general_embedding_input", + "cohere_embedding_v3_raw_output" + ), + BEDROCK_COHERE_EMBED_MULTILINGUAL_V3("bedrock_cohere_embed_multilingual_v3", "general_embedding_input", "general_embedding_output"), + BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_RAW( + "bedrock_cohere_embed_multilingual_v3_raw", + "general_embedding_input", + "cohere_embedding_v3_raw_output" + ), + BEDROCK_TITAN_EMBED_TEXT_V1("bedrock_titan_embed_text_v1", "titan_text_embedding_input", "general_embedding_output"), + BEDROCK_TITAN_EMBED_TEXT_V1_RAW( + "bedrock_titan_embed_text_v1_raw", + "titan_text_embedding_input", + "amazon_titan_embedding_v1_raw_output" + ), + BEDROCK_TITAN_EMBED_MULTI_MODAL_V1( + "bedrock_titan_embed_multi_modal_v1", + "titan_multi_modal_embedding_input", + "general_embedding_output" + ), + BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_RAW( + "bedrock_titan_embed_multi_modal_v1_raw", + "titan_multi_modal_embedding_input", + "amazon_titan_embedding_v1_raw_output" + ), + AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE( + "amazon_comprehend_detectdomainantlanguage", + "amazon_comprehend_input", + "amazon_comprehend_output" + ), + AMAZON_TEXTRACT_DETECTDOCUMENTTEXT("amazon_textract_detectdocumenttext", "amazon_textract_input", "amazon_textract_output"), + BEDROCK_ANTHROPIC_CLAUDE_USE_SYSTEM_PROMPT( + "bedrock_anthropic_claude_use_system_prompt", + "bedrock_anthropic_claude_use_system_prompt_input", + "bedrock_anthropic_claude_use_system_prompt_output" + ), + OPENAI_CHAT_COMPLETIONS("openai_chat_completions", "openai_chat_completions_input", "openai_chat_completions_output"); + + private final String name; + private final String inputSchemaFile; + private final String outputSchemaFile; + + ModelInterfaceSchema(String name, String inputSchemaFile, String outputSchemaFile) { + this.name = name; + this.inputSchemaFile = inputSchemaFile; + this.outputSchemaFile = outputSchemaFile; + } - public static final Map BEDROCK_TITAN_EMBED_TEXT_V1_RAW_MODEL_INTERFACE = Map - .of("input", TITAN_TEXT_EMBEDDING_MODEL_INTERFACE_INPUT, "output", AMAZON_TITAN_EMBEDDING_V1_RAW_INTERFACE_OUTPUT); + /** + * Gets the model interface as a Map with "input" and "output" keys. + * Schema files are cached at the I/O level by loadSchemaFromFile(). + * + * @return Map containing input and output schemas + * @throws RuntimeException if schemas cannot be loaded + */ + public Map getInterface() { + try { + String input = getInputSchema(inputSchemaFile); + String output = getOutputSchema(outputSchemaFile); + return Map.of("input", input, "output", output); + } catch (IOException e) { + log.error("Failed to load model interface schema: {}", name, e); + throw new RuntimeException("Failed to load schema: " + name, e); + } + } - public static final Map BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE = Map - .of("input", TITAN_MULTI_MODAL_EMBEDDING_MODEL_INTERFACE_INPUT, "output", GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT); + /** + * Finds a ModelInterfaceSchema by its name string (case-insensitive). + * + * @param name The schema name to look up + * @return The matching ModelInterfaceSchema + * @throws IllegalArgumentException if no matching schema is found + */ + public static ModelInterfaceSchema fromString(String name) { + if (name == null || name.isBlank()) { + throw new IllegalArgumentException("Schema name cannot be null or blank"); + } - public static final Map BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_RAW_MODEL_INTERFACE = Map - .of("input", TITAN_MULTI_MODAL_EMBEDDING_MODEL_INTERFACE_INPUT, "output", AMAZON_TITAN_EMBEDDING_V1_RAW_INTERFACE_OUTPUT); + for (ModelInterfaceSchema schema : values()) { + if (schema.name.equalsIgnoreCase(name)) { + return schema; + } + } - public static final Map AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE = Map - .of( - "input", - AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE_INPUT, - "output", - AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE_OUTPUT - ); + throw new IllegalArgumentException("Unknown model interface schema: " + name); + } - public static final Map AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE = Map - .of( - "input", - AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE_INPUT, - "output", - AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE_OUTPUT - ); + public String getName() { + return name; + } + } - private static Map createPresetModelInterfaceByConnector(Connector connector) { + private static ModelInterfaceSchema createPresetModelInterfaceByConnector(Connector connector) { if (connector.getParameters() != null) { ConnectorAction connectorAction = connector.getActions().get(0); + + // Check for OpenAI Chat Completions models (outside service_name switch) + String model = connector.getParameters().get("model"); + String url = connectorAction.getUrl(); + if (model != null && url != null) { + boolean isOpenAIModel = model.equals("gpt-3.5-turbo") || model.equals("gpt-4o-mini") || model.equals("gpt-5"); + boolean isChatCompletionsEndpoint = url.endsWith("v1/chat/completions"); + + if (isOpenAIModel && isChatCompletionsEndpoint) { + log.debug("Detected OpenAI Chat Completions model: {}", model); + return ModelInterfaceSchema.OPENAI_CHAT_COMPLETIONS; + } + } + switch ((connector.getParameters().get("service_name") != null) ? connector.getParameters().get("service_name") : "null") { case "bedrock": log.debug("Detected Amazon Bedrock model"); - switch ((connector.getParameters().get("model") != null) ? connector.getParameters().get("model") : "null") { + switch ((model != null) ? model : "null") { case "ai21.j2-mid-v1": if (connectorAction.getPostProcessFunction() != null && !connectorAction.getPostProcessFunction().isBlank()) { - log - .debug( - "Creating preset model interface for Amazon Bedrock model with post-process function: {}", - connector.getParameters().get("model") - ); - return BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE; + log.debug("Creating preset model interface for Amazon Bedrock model with post-process function: {}", model); + return ModelInterfaceSchema.BEDROCK_AI21_LABS_JURASSIC2_MID_V1; } else { - log .debug( "Creating preset model interface for Amazon Bedrock model without post-process function: {}", - connector.getParameters().get("model") + model ); - return BEDROCK_AI21_LABS_JURASSIC2_MID_V1_RAW_MODEL_INTERFACE; + return ModelInterfaceSchema.BEDROCK_AI21_LABS_JURASSIC2_MID_V1_RAW; } case "anthropic.claude-3-sonnet-20240229-v1:0": - log - .debug( - "Creating preset model interface for Amazon Bedrock model: {}", - connector.getParameters().get("model") - ); - return BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE; + log.debug("Creating preset model interface for Amazon Bedrock model: {}", model); + return ModelInterfaceSchema.BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET; case "anthropic.claude-v2": - log - .debug( - "Creating preset model interface for Amazon Bedrock model: {}", - connector.getParameters().get("model") - ); - return BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE; + log.debug("Creating preset model interface for Amazon Bedrock model: {}", model); + return ModelInterfaceSchema.BEDROCK_ANTHROPIC_CLAUDE_V2; case "cohere.embed-english-v3": if (connectorAction.getPostProcessFunction() != null && connectorAction.getPostProcessFunction().equalsIgnoreCase(MLPostProcessFunction.COHERE_EMBEDDING)) { - log - .debug( - "Creating preset model interface for Amazon Bedrock model with post-process function: {}", - connector.getParameters().get("model") - ); - return BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE; + log.debug("Creating preset model interface for Amazon Bedrock model with post-process function: {}", model); + return ModelInterfaceSchema.BEDROCK_COHERE_EMBED_ENGLISH_V3; } else { log .debug( "Creating preset model interface for Amazon Bedrock model without post-process function: {}", - connector.getParameters().get("model") + model ); - return BEDROCK_COHERE_EMBED_ENGLISH_V3_RAW_MODEL_INTERFACE; + return ModelInterfaceSchema.BEDROCK_COHERE_EMBED_ENGLISH_V3_RAW; } case "cohere.embed-multilingual-v3": if (connectorAction.getPostProcessFunction() != null && connectorAction.getPostProcessFunction().equalsIgnoreCase(MLPostProcessFunction.COHERE_EMBEDDING)) { - log - .debug( - "Creating preset model interface for Amazon Bedrock model with post-process function: {}", - connector.getParameters().get("model") - ); - return BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE; + log.debug("Creating preset model interface for Amazon Bedrock model with post-process function: {}", model); + return ModelInterfaceSchema.BEDROCK_COHERE_EMBED_MULTILINGUAL_V3; } else { log .debug( "Creating preset model interface for Amazon Bedrock model without post-process function: {}", - connector.getParameters().get("model") + model ); - return BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_RAW_MODEL_INTERFACE; + return ModelInterfaceSchema.BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_RAW; } case "amazon.titan-embed-text-v1": if (connectorAction.getPostProcessFunction() != null && connectorAction.getPostProcessFunction().equalsIgnoreCase(MLPostProcessFunction.BEDROCK_EMBEDDING)) { - log - .debug( - "Creating preset model interface for Amazon Bedrock model with post-process function: {}", - connector.getParameters().get("model") - ); - return BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE; + log.debug("Creating preset model interface for Amazon Bedrock model with post-process function: {}", model); + return ModelInterfaceSchema.BEDROCK_TITAN_EMBED_TEXT_V1; } else { log .debug( "Creating preset model interface for Amazon Bedrock model without post-process function: {}", - connector.getParameters().get("model") + model ); - return BEDROCK_TITAN_EMBED_TEXT_V1_RAW_MODEL_INTERFACE; + return ModelInterfaceSchema.BEDROCK_TITAN_EMBED_TEXT_V1_RAW; } case "amazon.titan-embed-image-v1": if (connectorAction.getPostProcessFunction() != null && connectorAction.getPostProcessFunction().equalsIgnoreCase(MLPostProcessFunction.BEDROCK_EMBEDDING)) { - log - .debug( - "Creating preset model interface for Amazon Bedrock model with post-process function: {}", - connector.getParameters().get("model") - ); - return BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE; + log.debug("Creating preset model interface for Amazon Bedrock model with post-process function: {}", model); + return ModelInterfaceSchema.BEDROCK_TITAN_EMBED_MULTI_MODAL_V1; } else { log .debug( "Creating preset model interface for Amazon Bedrock model without post-process function: {}", - connector.getParameters().get("model") + model ); - return BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_RAW_MODEL_INTERFACE; + return ModelInterfaceSchema.BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_RAW; + } + case "us.anthropic.claude-3-7-sonnet-20250219-v1:0": + case "us.anthropic.claude-sonnet-4-20250514-v1:0": + // Check if use_system_prompt parameter is true + String useSystemPrompt = connector.getParameters().get("use_system_prompt"); + if ("true".equalsIgnoreCase(useSystemPrompt)) { + log.debug("Creating preset model interface for Amazon Bedrock Claude model with system prompt: {}", model); + return ModelInterfaceSchema.BEDROCK_ANTHROPIC_CLAUDE_USE_SYSTEM_PROMPT; } + log.debug("Model {} does not use system prompt parameter, skipping preset interface", model); + return null; default: return null; } @@ -895,14 +306,14 @@ private static Map createPresetModelInterfaceByConnector(Connect "Creating preset model interface for Amazon Comprehend API: {}", connector.getParameters().get("api_name") ); - return AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE; + return ModelInterfaceSchema.AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE; default: return null; } case "textract": log.debug("Detected Amazon Textract model"); log.debug("Creating preset model interface for Amazon Textract DetectDocumentText API"); - return AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE; + return ModelInterfaceSchema.AMAZON_TEXTRACT_DETECTDOCUMENTTEXT; default: return null; } @@ -919,9 +330,9 @@ public static void updateRegisterModelInputModelInterfaceFieldsByConnector( MLRegisterModelInput registerModelInput, Connector connector ) { - Map presetModelInterface = createPresetModelInterfaceByConnector(connector); - if (presetModelInterface != null) { - registerModelInput.setModelInterface(presetModelInterface); + ModelInterfaceSchema schema = createPresetModelInterfaceByConnector(connector); + if (schema != null) { + registerModelInput.setModelInterface(schema.getInterface()); } } @@ -930,9 +341,9 @@ public static void updateRegisterModelInputModelInterfaceFieldsByConnector( * @param registerModelInput the register model input */ public static void updateRegisterModelInputModelInterfaceFieldsByConnector(MLRegisterModelInput registerModelInput) { - Map presetModelInterface = createPresetModelInterfaceByConnector(registerModelInput.getConnector()); - if (presetModelInterface != null) { - registerModelInput.setModelInterface(presetModelInterface); + ModelInterfaceSchema schema = createPresetModelInterfaceByConnector(registerModelInput.getConnector()); + if (schema != null) { + registerModelInput.setModelInterface(schema.getInterface()); } } } diff --git a/common/src/main/java/org/opensearch/ml/common/utils/message/ClaudeMessageFormatter.java b/common/src/main/java/org/opensearch/ml/common/utils/message/ClaudeMessageFormatter.java new file mode 100644 index 0000000000..57643513aa --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/utils/message/ClaudeMessageFormatter.java @@ -0,0 +1,173 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.utils.message; + +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.transport.memorycontainer.memory.MessageInput; +import org.opensearch.ml.common.utils.StringUtils; + +import lombok.extern.log4j.Log4j2; + +/** + * Message formatter for Claude models (and similar models using system_prompt parameter). + * + *

Format characteristics: + *

    + *
  • System prompt: Placed in "system_prompt" parameter
  • + *
  • Messages: Array of user/assistant messages (NO system role)
  • + *
  • Content: Normalized to have "type" field
  • + *
+ * + *

Compatible with: + *

    + *
  • Claude 3.x (Bedrock, Anthropic API)
  • + *
  • Claude 4.x (Sonnet, Opus)
  • + *
  • Any model with system_prompt in input schema
  • + *
+ * + *

Example output: + *

+ * {
+ *   "system_prompt": "You are a helpful assistant",
+ *   "messages": "[{\"role\":\"user\",\"content\":[{\"type\":\"text\",\"text\":\"Hello\"}]}]"
+ * }
+ * 
+ */ +@Log4j2 +public class ClaudeMessageFormatter implements MessageFormatter { + + @Override + public Map formatRequest(String systemPrompt, List messages, Map additionalConfig) { + Map parameters = new HashMap<>(); + + // Claude-style: system_prompt as parameter + if (systemPrompt != null && !systemPrompt.isBlank()) { + parameters.put("system_prompt", systemPrompt); + log.debug("System prompt added as parameter"); + } + + // Build messages array with content processing + try { + String messagesJson = buildMessagesArray(messages, additionalConfig); + parameters.put("messages", messagesJson); + log.debug("Built messages array with {} messages", messages != null ? messages.size() : 0); + } catch (IOException e) { + log.error("Failed to build messages array", e); + throw new RuntimeException("Failed to format Claude request", e); + } + + return parameters; + } + + @Override + public List> processContent(List> content) { + if (content == null || content.isEmpty()) { + return content; + } + + return content.stream().map(this::normalizeContentObject).collect(Collectors.toList()); + } + + /** + * Normalize a single content object to ensure it has "type" field. + * + *

Rules: + *

    + *
  • If object has "type" field → return as-is (standard LLM format)
  • + *
  • If object lacks "type" field → wrap as {"type": "text", "text": JSON_STRING}
  • + *
+ * + * @param obj Content object to normalize + * @return Normalized content object with "type" field + */ + private Map normalizeContentObject(Map obj) { + if (obj == null || obj.isEmpty()) { + return obj; + } + + // Already has type field → standard format + if (obj.containsKey("type")) { + return obj; + } + + // No type field → user-defined object, wrap as text + Map wrapped = new HashMap<>(); + wrapped.put("type", "text"); + String jsonText = StringUtils.toJson(obj); + wrapped.put("text", jsonText); + return wrapped; + } + + /** + * Build messages JSON array from MessageInput list. + * + *

Includes: + *

    + *
  • Optional system_prompt_message from config (added first)
  • + *
  • User/assistant messages with processed content
  • + *
  • Optional user_prompt_message from config (added last)
  • + *
+ * + * @param messages List of messages to include + * @param additionalConfig Optional config with extra messages + * @return JSON string representing messages array + * @throws IOException if JSON building fails + */ + private String buildMessagesArray(List messages, Map additionalConfig) throws IOException { + XContentBuilder builder = jsonXContent.contentBuilder(); + builder.startArray(); + + // Optional system_prompt_message from config + if (additionalConfig != null && additionalConfig.containsKey("system_prompt_message")) { + Object systemPromptMsg = additionalConfig.get("system_prompt_message"); + if (systemPromptMsg instanceof Map) { + @SuppressWarnings("unchecked") + Map msgMap = (Map) systemPromptMsg; + builder.map(msgMap); + } + } + + // User messages (with content processing) + if (messages != null) { + for (MessageInput message : messages) { + builder.startObject(); + + if (message.getRole() != null) { + builder.field("role", message.getRole()); + } + + // Process content to ensure type fields + if (message.getContent() != null) { + List> processedContent = processContent(message.getContent()); + builder.field("content", processedContent); + } + + builder.endObject(); + } + } + + // Optional user_prompt_message from config + if (additionalConfig != null && additionalConfig.containsKey("user_prompt_message")) { + Object userPromptMsg = additionalConfig.get("user_prompt_message"); + if (userPromptMsg instanceof Map) { + @SuppressWarnings("unchecked") + Map msgMap = (Map) userPromptMsg; + builder.map(msgMap); + } + } + + builder.endArray(); + return builder.toString(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/utils/message/MessageFormatter.java b/common/src/main/java/org/opensearch/ml/common/utils/message/MessageFormatter.java new file mode 100644 index 0000000000..712167660d --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/utils/message/MessageFormatter.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.utils.message; + +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.transport.memorycontainer.memory.MessageInput; + +/** + * Strategy interface for formatting LLM requests based on model requirements. + * + *

Each formatter implementation handles: + *

    + *
  • System prompt placement (parameter vs message)
  • + *
  • Content object normalization (type field enforcement)
  • + *
  • Message array construction
  • + *
+ * + *

Implementations: + *

    + *
  • {@link ClaudeMessageFormatter}: Uses system_prompt parameter (Claude, Bedrock models)
  • + *
  • {@link OpenAIMessageFormatter}: Injects system message in array (OpenAI, GPT models)
  • + *
+ * + *

Usage: + *

+ * MessageFormatter formatter = MessageFormatterFactory.getFormatterForModel(modelId, cache);
+ * Map<String, String> params = formatter.formatRequest(systemPrompt, messages, config);
+ * 
+ */ +public interface MessageFormatter { + + /** + * Format a complete LLM request with proper system prompt placement. + * + *

This method handles: + *

    + *
  • Placing system prompt in correct location (parameter or message)
  • + *
  • Building messages array with content processing
  • + *
  • Incorporating additional config messages (system_prompt_message, user_prompt_message)
  • + *
+ * + * @param systemPrompt The system prompt to inject (may be null or blank) + * @param messages User/assistant messages to include in the request + * @param additionalConfig Optional configuration containing: + *
    + *
  • system_prompt_message: Additional system-level message
  • + *
  • user_prompt_message: Additional user message
  • + *
+ * @return Map of request parameters ready for MLInput, typically containing: + *
    + *
  • "messages": JSON array of messages
  • + *
  • "system_prompt": System prompt (Claude-style only)
  • + *
+ * @throws RuntimeException if message array building fails + */ + Map formatRequest(String systemPrompt, List messages, Map additionalConfig); + + /** + * Process message content objects to ensure LLM compatibility. + * + *

Content processing rules: + *

    + *
  • Objects WITH "type" field → keep as-is (standard LLM format like + * {"type": "text", "text": "..."} or {"type": "image_url", "image_url": {...}})
  • + *
  • Objects WITHOUT "type" field → wrap as {"type": "text", "text": JSON_STRING} + * where JSON_STRING is the serialized user-defined object
  • + *
+ * + *

This ensures that user-defined content objects are properly formatted + * for LLM consumption without breaking standard multimodal content. + * + * @param content List of content objects from message (may be null or empty) + * @return Processed content list with all objects having "type" field + */ + List> processContent(List> content); +} diff --git a/common/src/main/java/org/opensearch/ml/common/utils/message/MessageFormatterFactory.java b/common/src/main/java/org/opensearch/ml/common/utils/message/MessageFormatterFactory.java new file mode 100644 index 0000000000..4b9c4c783f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/utils/message/MessageFormatterFactory.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.utils.message; + +import lombok.extern.log4j.Log4j2; + +/** + * Factory for creating appropriate MessageFormatter based on model's input schema. + * + *

Decision Algorithm: + *

    + *
  1. Get model's input schema (from cache or other source)
  2. + *
  3. Pass schema to {@link #getFormatter(String)}
  4. + *
  5. Check if schema contains "system_prompt" field
  6. + *
  7. If YES → Claude formatter
  8. + *
  9. If NO → OpenAI formatter
  10. + *
  11. On any error → Claude formatter (safe default)
  12. + *
+ * + *

The factory uses singleton formatters for performance (stateless, thread-safe). + * + *

Usage: + *

+ * // Get input schema from model cache
+ * Map<String, String> modelInterface = modelCacheHelper.getModelInterface(modelId);
+ * String inputSchema = modelInterface.get("input");
+ *
+ * // Get appropriate formatter
+ * MessageFormatter formatter = MessageFormatterFactory.getFormatter(inputSchema);
+ *
+ * // Or use explicit formatters for testing
+ * MessageFormatter claude = MessageFormatterFactory.getClaudeFormatter();
+ * 
+ */ +@Log4j2 +public class MessageFormatterFactory { + + // Singleton formatter instances (thread-safe, stateless) + private static final MessageFormatter CLAUDE_FORMATTER = new ClaudeMessageFormatter(); + private static final MessageFormatter OPENAI_FORMATTER = new OpenAIMessageFormatter(); + + /** + * Get appropriate formatter based on input schema JSON. + * + *

This is the core decision logic: + *

    + *
  • Presence of "system_prompt" → Claude-style formatter
  • + *
  • Absence of "system_prompt" → OpenAI-style formatter
  • + *
+ * + *

The detection uses simple string matching for performance and reliability. + * + * @param inputSchemaJson The JSON schema string from model interface + * @return Appropriate formatter (never null, defaults to Claude) + */ + public static MessageFormatter getFormatter(String inputSchemaJson) { + if (inputSchemaJson == null || inputSchemaJson.isBlank()) { + log.debug("No input schema provided, defaulting to Claude formatter"); + return CLAUDE_FORMATTER; + } + + try { + // Simple and fast: check if schema contains system_prompt field + // This is more reliable than parsing JSON and navigating paths + boolean hasSystemPromptParam = inputSchemaJson.contains("\"system_prompt\""); + + if (hasSystemPromptParam) { + log.debug("Schema contains system_prompt parameter, using Claude formatter"); + return CLAUDE_FORMATTER; + } else { + log.debug("Schema lacks system_prompt parameter, using OpenAI formatter"); + return OPENAI_FORMATTER; + } + } catch (Exception e) { + log.warn("Failed to analyze input schema, defaulting to Claude formatter", e); + return CLAUDE_FORMATTER; + } + } + + /** + * Get Claude formatter explicitly (for testing or explicit usage). + * + * @return Claude message formatter instance + */ + public static MessageFormatter getClaudeFormatter() { + return CLAUDE_FORMATTER; + } + + /** + * Get OpenAI formatter explicitly (for testing or explicit usage). + * + * @return OpenAI message formatter instance + */ + public static MessageFormatter getOpenAIFormatter() { + return OPENAI_FORMATTER; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/utils/message/OpenAIMessageFormatter.java b/common/src/main/java/org/opensearch/ml/common/utils/message/OpenAIMessageFormatter.java new file mode 100644 index 0000000000..2bbbab1f24 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/utils/message/OpenAIMessageFormatter.java @@ -0,0 +1,207 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.utils.message; + +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.transport.memorycontainer.memory.MessageInput; +import org.opensearch.ml.common.utils.StringUtils; + +import lombok.extern.log4j.Log4j2; + +/** + * Message formatter for OpenAI models (and similar models without system_prompt parameter). + * + *

Format characteristics: + *

    + *
  • System prompt: Injected as first message with role="system"
  • + *
  • Messages: Array including system + user/assistant messages
  • + *
  • Content: Normalized to have "type" field
  • + *
+ * + *

Compatible with: + *

    + *
  • GPT-4, GPT-4o, GPT-4o-mini
  • + *
  • GPT-3.5-turbo
  • + *
  • Any Chat Completions API model
  • + *
  • Models without system_prompt in schema
  • + *
+ * + *

Example output: + *

+ * {
+ *   "messages": "[{\"role\":\"system\",\"content\":[{\"type\":\"text\",\"text\":\"You are helpful\"}]},
+ *                 {\"role\":\"user\",\"content\":[{\"type\":\"text\",\"text\":\"Hello\"}]}]"
+ * }
+ * 
+ */ +@Log4j2 +public class OpenAIMessageFormatter implements MessageFormatter { + + @Override + public Map formatRequest(String systemPrompt, List messages, Map additionalConfig) { + Map parameters = new HashMap<>(); + + // OpenAI-style: NO system_prompt parameter + // System prompt goes as first message in array + + List allMessages = new ArrayList<>(); + + // Inject system prompt as first message + if (systemPrompt != null && !systemPrompt.isBlank()) { + allMessages.add(createSystemMessage(systemPrompt)); + log.debug("System prompt injected as first message"); + } + + // Add user messages + if (messages != null) { + allMessages.addAll(messages); + } + + // Build messages array with content processing + try { + String messagesJson = buildMessagesArray(allMessages, additionalConfig); + parameters.put("messages", messagesJson); + log.debug("Built messages array with {} messages", allMessages.size()); + } catch (IOException e) { + log.error("Failed to build messages array", e); + throw new RuntimeException("Failed to format OpenAI request", e); + } + + return parameters; + } + + @Override + public List> processContent(List> content) { + if (content == null || content.isEmpty()) { + return content; + } + + return content.stream().map(this::normalizeContentObject).collect(Collectors.toList()); + } + + /** + * Create a system message from prompt text. + * + *

Creates a MessageInput with: + *

    + *
  • role: "system"
  • + *
  • content: [{"type": "text", "text": prompt}]
  • + *
+ * + * @param prompt The system prompt text + * @return MessageInput configured as system message + */ + private MessageInput createSystemMessage(String prompt) { + List> content = new ArrayList<>(); + Map textContent = new HashMap<>(); + textContent.put("type", "text"); + textContent.put("text", prompt); + content.add(textContent); + + return MessageInput.builder().role("system").content(content).build(); + } + + /** + * Normalize a single content object to ensure it has "type" field. + * + *

Rules: + *

    + *
  • If object has "type" field → return as-is (standard LLM format)
  • + *
  • If object lacks "type" field → wrap as {"type": "text", "text": JSON_STRING}
  • + *
+ * + * @param obj Content object to normalize + * @return Normalized content object with "type" field + */ + private Map normalizeContentObject(Map obj) { + if (obj == null || obj.isEmpty()) { + return obj; + } + + // Already has type field → standard format + if (obj.containsKey("type")) { + return obj; + } + + // No type field → user-defined object, wrap as text + Map wrapped = new HashMap<>(); + wrapped.put("type", "text"); + String jsonText = StringUtils.toJson(obj); + wrapped.put("text", jsonText); + return wrapped; + } + + /** + * Build messages JSON array from MessageInput list. + * + *

Includes: + *

    + *
  • Optional system_prompt_message from config (added first)
  • + *
  • All messages (including injected system message) with processed content
  • + *
  • Optional user_prompt_message from config (added last)
  • + *
+ * + * @param messages List of ALL messages to include (including system) + * @param additionalConfig Optional config with extra messages + * @return JSON string representing messages array + * @throws IOException if JSON building fails + */ + private String buildMessagesArray(List messages, Map additionalConfig) throws IOException { + XContentBuilder builder = jsonXContent.contentBuilder(); + builder.startArray(); + + // Optional system_prompt_message from config (rare for OpenAI) + if (additionalConfig != null && additionalConfig.containsKey("system_prompt_message")) { + Object systemPromptMsg = additionalConfig.get("system_prompt_message"); + if (systemPromptMsg instanceof Map) { + @SuppressWarnings("unchecked") + Map msgMap = (Map) systemPromptMsg; + builder.map(msgMap); + } + } + + // All messages (including injected system message) + if (messages != null) { + for (MessageInput message : messages) { + builder.startObject(); + + if (message.getRole() != null) { + builder.field("role", message.getRole()); + } + + // Process content to ensure type fields + if (message.getContent() != null) { + List> processedContent = processContent(message.getContent()); + builder.field("content", processedContent); + } + + builder.endObject(); + } + } + + // Optional user_prompt_message from config + if (additionalConfig != null && additionalConfig.containsKey("user_prompt_message")) { + Object userPromptMsg = additionalConfig.get("user_prompt_message"); + if (userPromptMsg instanceof Map) { + @SuppressWarnings("unchecked") + Map msgMap = (Map) userPromptMsg; + builder.map(msgMap); + } + } + + builder.endArray(); + return builder.toString(); + } +} diff --git a/common/src/main/resources/model-interface-schemas/input/amazon_comprehend_input.json b/common/src/main/resources/model-interface-schemas/input/amazon_comprehend_input.json new file mode 100644 index 0000000000..cea590af42 --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/input/amazon_comprehend_input.json @@ -0,0 +1,19 @@ +{ + "type": "object", + "properties": { + "parameters": { + "type": "object", + "properties": { + "Text": { + "type": "string" + } + }, + "required": [ + "Text" + ] + } + }, + "required": [ + "parameters" + ] +} diff --git a/common/src/main/resources/model-interface-schemas/input/amazon_textract_input.json b/common/src/main/resources/model-interface-schemas/input/amazon_textract_input.json new file mode 100644 index 0000000000..c53ef15719 --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/input/amazon_textract_input.json @@ -0,0 +1,19 @@ +{ + "type": "object", + "properties": { + "parameters": { + "type": "object", + "properties": { + "bytes": { + "type": "string" + } + }, + "required": [ + "bytes" + ] + } + }, + "required": [ + "parameters" + ] +} diff --git a/common/src/main/resources/model-interface-schemas/input/bedrock_anthropic_claude_use_system_prompt_input.json b/common/src/main/resources/model-interface-schemas/input/bedrock_anthropic_claude_use_system_prompt_input.json new file mode 100644 index 0000000000..c2b2b1134f --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/input/bedrock_anthropic_claude_use_system_prompt_input.json @@ -0,0 +1,77 @@ +{ + "type": "object", + "properties": { + "parameters": { + "type": "object", + "properties": { + "system_prompt": { + "type": "string", + "description": "System prompt to guide the model's behavior" + }, + "messages": { + "type": "array", + "description": "Array of message objects in conversational format", + "items": { + "type": "object", + "properties": { + "role": { + "type": "string", + "description": "Role of the message sender (e.g., user, assistant, system)" + }, + "content": { + "oneOf": [ + { + "type": "string", + "description": "Simple text content" + }, + { + "type": "array", + "description": "Array of content blocks for multi-part messages", + "items": { + "type": "object", + "description": "Content block with type-specific properties (e.g., text, image, video, tool_use, tool_result)", + "properties": { + "type": { + "type": "string", + "description": "Type of content block" + } + }, + "required": ["type"], + "additionalProperties": true + } + } + ] + } + }, + "required": ["role", "content"] + } + }, + "max_tokens": { + "type": "integer", + "description": "Maximum number of tokens to generate" + }, + "temperature": { + "type": "number", + "description": "Sampling temperature (0.0 to 1.0)" + }, + "top_p": { + "type": "number", + "description": "Nucleus sampling parameter" + }, + "top_k": { + "type": "integer", + "description": "Top-k sampling parameter" + }, + "stop_sequences": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Sequences that will stop generation" + } + }, + "required": ["messages"] + } + }, + "required": ["parameters"] +} diff --git a/common/src/main/resources/model-interface-schemas/input/general_conversational_single_round_input.json b/common/src/main/resources/model-interface-schemas/input/general_conversational_single_round_input.json new file mode 100644 index 0000000000..200fc990e0 --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/input/general_conversational_single_round_input.json @@ -0,0 +1,19 @@ +{ + "type": "object", + "properties": { + "parameters": { + "type": "object", + "properties": { + "inputs": { + "type": "string" + } + }, + "required": [ + "inputs" + ] + } + }, + "required": [ + "parameters" + ] +} diff --git a/common/src/main/resources/model-interface-schemas/input/general_embedding_input.json b/common/src/main/resources/model-interface-schemas/input/general_embedding_input.json new file mode 100644 index 0000000000..1e8ffc4d31 --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/input/general_embedding_input.json @@ -0,0 +1,19 @@ +{ + "type": "object", + "properties": { + "parameters": { + "type": "object", + "properties": { + "texts": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "required": [ + "texts" + ] + } + } +} diff --git a/common/src/main/resources/model-interface-schemas/input/openai_chat_completions_input.json b/common/src/main/resources/model-interface-schemas/input/openai_chat_completions_input.json new file mode 100644 index 0000000000..a11ad75311 --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/input/openai_chat_completions_input.json @@ -0,0 +1,109 @@ +{ + "type": "object", + "properties": { + "parameters": { + "type": "object", + "properties": { + "messages": { + "type": "array", + "description": "Array of message objects in conversational format", + "items": { + "type": "object", + "properties": { + "role": { + "type": "string", + "description": "Role of the message sender (e.g., system, user, assistant, tool)" + }, + "content": { + "oneOf": [ + { + "type": "string", + "description": "Simple text content" + }, + { + "type": "array", + "description": "Array of content blocks for multi-part messages", + "items": { + "type": "object", + "description": "Content block with type-specific properties (e.g., text, image_url, video, audio, file)", + "properties": { + "type": { + "type": "string", + "description": "Type of content block" + } + }, + "required": ["type"], + "additionalProperties": true + } + } + ] + }, + "name": { + "type": "string", + "description": "Optional name for the message sender" + } + }, + "required": ["role", "content"] + } + }, + "model": { + "type": "string", + "description": "ID of the model to use" + }, + "max_tokens": { + "type": "integer", + "description": "Maximum number of tokens to generate" + }, + "temperature": { + "type": "number", + "description": "Sampling temperature (0.0 to 2.0)" + }, + "top_p": { + "type": "number", + "description": "Nucleus sampling parameter" + }, + "n": { + "type": "integer", + "description": "Number of chat completion choices to generate" + }, + "stream": { + "type": "boolean", + "description": "Whether to stream partial message deltas" + }, + "stop": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + } + } + ], + "description": "Sequences where the API will stop generating tokens" + }, + "presence_penalty": { + "type": "number", + "description": "Penalty for new tokens based on presence in text so far" + }, + "frequency_penalty": { + "type": "number", + "description": "Penalty for new tokens based on frequency in text so far" + }, + "logit_bias": { + "type": "object", + "description": "Modify likelihood of specified tokens appearing", + "additionalProperties": true + }, + "user": { + "type": "string", + "description": "Unique identifier for the end-user" + } + }, + "required": ["messages"] + } + }, + "required": ["parameters"] +} diff --git a/common/src/main/resources/model-interface-schemas/input/titan_multi_modal_embedding_input.json b/common/src/main/resources/model-interface-schemas/input/titan_multi_modal_embedding_input.json new file mode 100644 index 0000000000..6e54632184 --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/input/titan_multi_modal_embedding_input.json @@ -0,0 +1,16 @@ +{ + "type": "object", + "properties": { + "parameters": { + "type": "object", + "properties": { + "inputText": { + "type": "string" + }, + "inputImage": { + "type": "string" + } + } + } + } +} diff --git a/common/src/main/resources/model-interface-schemas/input/titan_text_embedding_input.json b/common/src/main/resources/model-interface-schemas/input/titan_text_embedding_input.json new file mode 100644 index 0000000000..efbc7ef073 --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/input/titan_text_embedding_input.json @@ -0,0 +1,16 @@ +{ + "type": "object", + "properties": { + "parameters": { + "type": "object", + "properties": { + "inputText": { + "type": "string" + } + }, + "required": [ + "inputText" + ] + } + } +} diff --git a/common/src/main/resources/model-interface-schemas/output/amazon_comprehend_output.json b/common/src/main/resources/model-interface-schemas/output/amazon_comprehend_output.json new file mode 100644 index 0000000000..e5a019ce3d --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/output/amazon_comprehend_output.json @@ -0,0 +1,72 @@ +{ + "type": "object", + "properties": { + "inference_results": { + "type": "array", + "items": { + "type": "object", + "properties": { + "output": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "dataAsMap": { + "type": "object", + "properties": { + "response": { + "type": "object", + "properties": { + "Languages": { + "type": "array", + "items": { + "type": "object", + "properties": { + "LanguageCode": { + "type": "string" + }, + "Score": { + "type": "number" + } + }, + "required": [ + "LanguageCode", + "Score" + ] + } + } + }, + "required": [ + "Languages" + ] + } + }, + "required": [ + "response" + ] + } + }, + "required": [ + "name", + "dataAsMap" + ] + } + }, + "status_code": { + "type": "integer" + } + }, + "required": [ + "output", + "status_code" + ] + } + } + }, + "required": [ + "inference_results" + ] +} diff --git a/common/src/main/resources/model-interface-schemas/output/amazon_textract_output.json b/common/src/main/resources/model-interface-schemas/output/amazon_textract_output.json new file mode 100644 index 0000000000..3c8d83b74e --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/output/amazon_textract_output.json @@ -0,0 +1,110 @@ +{ + "type": "object", + "properties": { + "inference_results": { + "type": "array", + "items": { + "type": "object", + "properties": { + "output": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "dataAsMap": { + "type": "object", + "properties": { + "Blocks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "BlockType": { + "type": "string" + }, + "Geometry": { + "type": "object", + "properties": { + "BoundingBox": { + "type": "object", + "properties": { + "Height": { + "type": "number" + }, + "Left": { + "type": "number" + }, + "Top": { + "type": "number" + }, + "Width": { + "type": "number" + } + } + }, + "Polygon": { + "type": "array", + "items": { + "type": "object", + "properties": { + "X": { + "type": "number" + }, + "Y": { + "type": "number" + } + } + } + } + } + }, + "Id": { + "type": "string" + }, + "Relationships": { + "type": "array", + "items": { + "type": "object", + "properties": { + "Ids": { + "type": "array", + "items": { + "type": "string" + } + }, + "Type": { + "type": "string" + } + } + } + } + } + } + }, + "DetectDocumentTextModelVersion": { + "type": "string" + }, + "DocumentMetadata": { + "type": "object", + "properties": { + "Pages": { + "type": "number" + } + } + } + } + } + } + } + }, + "status_code": { + "type": "number" + } + } + } + } + } +} diff --git a/common/src/main/resources/model-interface-schemas/output/amazon_titan_embedding_v1_raw_output.json b/common/src/main/resources/model-interface-schemas/output/amazon_titan_embedding_v1_raw_output.json new file mode 100644 index 0000000000..8ef2ec7a3c --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/output/amazon_titan_embedding_v1_raw_output.json @@ -0,0 +1,45 @@ +{ + "type": "object", + "properties": { + "inference_results": { + "type": "array", + "items": { + "type": "object", + "properties": { + "output": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "dataAsMap": { + "type": "object", + "properties": { + "embedding": { + "type": "array", + "items": { + "type": "number" + } + }, + "inputTextTokenCount": { + "type": "integer" + } + }, + "required": ["embedding", "inputTextTokenCount"] + } + }, + "required": ["name", "dataAsMap"] + } + }, + "status_code": { + "type": "integer" + } + }, + "required": ["output", "status_code"] + } + } + }, + "required": ["inference_results"] +} diff --git a/common/src/main/resources/model-interface-schemas/output/bedrock_ai21_j2_raw_output.json b/common/src/main/resources/model-interface-schemas/output/bedrock_ai21_j2_raw_output.json new file mode 100644 index 0000000000..f7af0d7817 --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/output/bedrock_ai21_j2_raw_output.json @@ -0,0 +1,136 @@ +{ + "type": "object", + "properties": { + "inference_results": { + "type": "array", + "items": { + "type": "object", + "properties": { + "output": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "dataAsMap": { + "type": "object", + "properties": { + "id": { + "type": "number" + }, + "prompt": { + "type": "object", + "properties": { + "text": { + "type": "string" + }, + "tokens": { + "type": "array", + "items": { + "type": "object", + "properties": { + "generatedToken": { + "type": "object", + "properties": { + "token": { + "type": "string" + }, + "logprob": { + "type": "number" + }, + "raw_logprob": { + "type": "number" + } + } + }, + "textRange": { + "type": "object", + "properties": { + "start": { + "type": "number" + }, + "end": { + "type": "number" + } + } + } + } + } + } + } + }, + "completions": { + "type": "array", + "items": { + "type": "object", + "properties": { + "data": { + "type": "object", + "properties": { + "text": { + "type": "string" + }, + "tokens": { + "type": "array", + "items": { + "type": "object", + "properties": { + "generatedToken": { + "type": "object", + "properties": { + "token": { + "type": "string" + }, + "logprob": { + "type": "number" + }, + "raw_logprob": { + "type": "number" + } + } + }, + "textRange": { + "type": "object", + "properties": { + "start": { + "type": "number" + }, + "end": { + "type": "number" + } + } + } + } + } + } + } + }, + "finishReason": { + "type": "object", + "properties": { + "reason": { + "type": "string" + }, + "length": { + "type": "number" + } + } + } + } + } + } + } + } + } + } + }, + "status_code": { + "type": "integer" + } + } + } + } + } +} diff --git a/common/src/main/resources/model-interface-schemas/output/bedrock_anthropic_claude_use_system_prompt_output.json b/common/src/main/resources/model-interface-schemas/output/bedrock_anthropic_claude_use_system_prompt_output.json new file mode 100644 index 0000000000..8dce39077b --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/output/bedrock_anthropic_claude_use_system_prompt_output.json @@ -0,0 +1,80 @@ +{ + "type": "object", + "properties": { + "inference_results": { + "type": "array", + "description": "Array of inference results from model prediction", + "items": { + "type": "object", + "properties": { + "output": { + "type": "array", + "description": "Array of model tensor outputs", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Name of the output tensor" + }, + "dataAsMap": { + "type": "object", + "description": "Claude Messages API response", + "properties": { + "id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "role": { + "type": "string" + }, + "content": { + "type": "array", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "text": { + "type": "string", + "x-llm-output": true + } + }, + "required": ["type"], + "additionalProperties": true + } + }, + "model": { + "type": "string" + }, + "stop_reason": { + "type": "string" + }, + "stop_sequence": { + "type": ["string", "null"] + }, + "usage": { + "type": "object", + "additionalProperties": true + } + }, + "required": ["id", "type", "role", "content", "model"] + } + }, + "additionalProperties": true + } + }, + "status_code": { + "type": "integer", + "description": "HTTP status code of the response" + } + }, + "additionalProperties": true + } + } + }, + "required": ["inference_results"] +} diff --git a/common/src/main/resources/model-interface-schemas/output/bedrock_anthropic_claude_v2_output.json b/common/src/main/resources/model-interface-schemas/output/bedrock_anthropic_claude_v2_output.json new file mode 100644 index 0000000000..f9cf540dcb --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/output/bedrock_anthropic_claude_v2_output.json @@ -0,0 +1,59 @@ +{ + "type": "object", + "properties": { + "inference_results": { + "type": "array", + "items": { + "type": "object", + "properties": { + "output": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "dataAsMap": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "completion": { + "type": "string" + }, + "stop_reason": { + "type": "string" + }, + "stop": { + "type": "string" + } + }, + "required": [ + "type", + "completion" + ] + } + }, + "required": [ + "name", + "dataAsMap" + ] + } + }, + "status_code": { + "type": "integer" + } + }, + "required": [ + "output", + "status_code" + ] + } + } + }, + "required": [ + "inference_results" + ] +} diff --git a/common/src/main/resources/model-interface-schemas/output/cohere_embedding_v3_raw_output.json b/common/src/main/resources/model-interface-schemas/output/cohere_embedding_v3_raw_output.json new file mode 100644 index 0000000000..b4401fb5fe --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/output/cohere_embedding_v3_raw_output.json @@ -0,0 +1,57 @@ +{ + "type": "object", + "properties": { + "inference_results": { + "type": "array", + "items": { + "type": "object", + "properties": { + "output": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "dataAsMap": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "texts": { + "type": "array", + "items": { + "type": "string" + } + }, + "embeddings": { + "type": "array", + "items": { + "type": "array", + "items": { + "type": "number" + } + } + }, + "response_type": { + "type": "string" + } + }, + "required": ["id", "texts", "embeddings", "response_type"] + } + }, + "required": ["name", "dataAsMap"] + } + }, + "status_code": { + "type": "integer" + } + }, + "required": ["output", "status_code"] + } + } + }, + "required": ["inference_results"] +} diff --git a/common/src/main/resources/model-interface-schemas/output/general_conversational_single_round_output.json b/common/src/main/resources/model-interface-schemas/output/general_conversational_single_round_output.json new file mode 100644 index 0000000000..416149ceca --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/output/general_conversational_single_round_output.json @@ -0,0 +1,49 @@ +{ + "type": "object", + "properties": { + "inference_results": { + "type": "array", + "items": { + "type": "object", + "properties": { + "output": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "dataAsMap": { + "type": "object", + "properties": { + "response": { + "type": "string" + } + }, + "required": [ + "response" + ] + } + }, + "required": [ + "name", + "dataAsMap" + ] + } + }, + "status_code": { + "type": "integer" + } + }, + "required": [ + "output", + "status_code" + ] + } + } + }, + "required": [ + "inference_results" + ] +} diff --git a/common/src/main/resources/model-interface-schemas/output/general_embedding_output.json b/common/src/main/resources/model-interface-schemas/output/general_embedding_output.json new file mode 100644 index 0000000000..8b3cfff380 --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/output/general_embedding_output.json @@ -0,0 +1,55 @@ +{ + "type": "object", + "properties": { + "inference_results": { + "type": "array", + "items": { + "type": "object", + "properties": { + "output": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "data_type": { + "type": "string" + }, + "shape": { + "type": "array", + "items": { + "type": "integer" + } + }, + "data": { + "type": "array", + "items": { + "type": "number" + } + } + }, + "required": [ + "name", + "data_type", + "shape", + "data" + ] + } + }, + "status_code": { + "type": "integer" + } + }, + "required": [ + "output", + "status_code" + ] + } + } + }, + "required": [ + "inference_results" + ] +} diff --git a/common/src/main/resources/model-interface-schemas/output/openai_chat_completions_output.json b/common/src/main/resources/model-interface-schemas/output/openai_chat_completions_output.json new file mode 100644 index 0000000000..78abd8a9ea --- /dev/null +++ b/common/src/main/resources/model-interface-schemas/output/openai_chat_completions_output.json @@ -0,0 +1,130 @@ +{ + "type": "object", + "properties": { + "inference_results": { + "type": "array", + "description": "Array of inference results from model prediction", + "items": { + "type": "object", + "properties": { + "output": { + "type": "array", + "description": "Array of model tensor outputs", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Name of the output tensor" + }, + "dataAsMap": { + "type": "object", + "description": "OpenAI Chat Completions API response", + "properties": { + "id": { + "type": "string", + "description": "Unique identifier for the chat completion" + }, + "object": { + "type": "string", + "description": "Object type (e.g., chat.completion)" + }, + "created": { + "type": ["integer", "number"], + "description": "Unix timestamp of when the completion was created" + }, + "model": { + "type": "string", + "description": "The model used for completion" + }, + "choices": { + "type": "array", + "description": "Array of chat completion choices", + "items": { + "type": "object", + "properties": { + "index": { + "type": "integer", + "description": "Index of this choice" + }, + "message": { + "type": "object", + "description": "The generated message", + "properties": { + "role": { + "type": "string", + "description": "Role of the message sender" + }, + "content": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string" + } + }, + "required": ["type"], + "additionalProperties": true + } + }, + { + "type": "null" + } + ], + "x-llm-output": true + } + }, + "additionalProperties": true + }, + "finish_reason": { + "type": ["string", "null"], + "description": "Reason why the completion finished" + } + }, + "additionalProperties": true + } + }, + "usage": { + "type": "object", + "description": "Token usage information", + "properties": { + "prompt_tokens": { + "type": ["integer", "number"] + }, + "completion_tokens": { + "type": ["integer", "number"] + }, + "total_tokens": { + "type": ["integer", "number"] + } + }, + "additionalProperties": true + }, + "system_fingerprint": { + "type": ["string", "null"], + "description": "Backend configuration fingerprint" + } + }, + "required": ["id", "object", "created", "model", "choices"] + } + }, + "additionalProperties": true + } + }, + "status_code": { + "type": "integer", + "description": "HTTP status code of the response" + } + }, + "additionalProperties": true + } + } + }, + "required": ["inference_results"] +} diff --git a/common/src/test/java/org/opensearch/ml/common/utils/LlmResultPathGeneratorTest.java b/common/src/test/java/org/opensearch/ml/common/utils/LlmResultPathGeneratorTest.java new file mode 100644 index 0000000000..22a44d6113 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/utils/LlmResultPathGeneratorTest.java @@ -0,0 +1,554 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.utils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.OpenSearchParseException; + +/** + * Unit tests for LlmResultPathGenerator. + */ +public class LlmResultPathGeneratorTest { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + /** + * Test schema with explicit x-llm-output marker. + */ + @Test + public void testGenerate_WithExplicitMarker() throws IOException { + String schema = "{\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"choices\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"message\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"content\": {\n" + + " \"type\": \"string\",\n" + + " \"x-llm-output\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + String result = LlmResultPathGenerator.generate(schema); + + assertNotNull(result); + assertEquals("$.choices[0].message.content", result); + } + + /** + * Test schema without x-llm-output marker returns null. + */ + @Test + public void testGenerate_WithHeuristicMatching_Content() throws IOException { + String schema = "{\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"content\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"text\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + String result = LlmResultPathGenerator.generate(schema); + + // Without x-llm-output marker, should return null + assertNull(result); + } + + /** + * Test Bedrock Claude v3 style response structure. + */ + @Test + public void testGenerate_BedrockClaudeV3Style() throws IOException { + String schema = "{\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"content\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"text\": {\n" + + " \"type\": \"string\",\n" + + " \"x-llm-output\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + String result = LlmResultPathGenerator.generate(schema); + + assertNotNull(result); + assertEquals("$.content[0].text", result); + } + + /** + * Test Bedrock Claude v2 style (simple completion field). + */ + @Test + public void testGenerate_BedrockClaudeV2Style() throws IOException { + String schema = "{\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"completion\": {\n" + + " \"type\": \"string\",\n" + + " \"x-llm-output\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + String result = LlmResultPathGenerator.generate(schema); + + assertNotNull(result); + assertEquals("$.completion", result); + } + + /** + * Test with null schema input. + */ + @Test + public void testGenerate_NullSchema() throws IOException { + String result = LlmResultPathGenerator.generate(null); + assertNull(result); + } + + /** + * Test with empty schema input. + */ + @Test + public void testGenerate_EmptySchema() throws IOException { + String result = LlmResultPathGenerator.generate(""); + assertNull(result); + } + + /** + * Test with malformed JSON schema. + */ + @Test + public void testGenerate_MalformedJson() { + exceptionRule.expect(OpenSearchParseException.class); + exceptionRule.expectMessage("Schema parsing error"); + + try { + LlmResultPathGenerator.generate("{invalid json}"); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** + * Test schema without dataAsMap structure (should attempt fallback). + */ + @Test + public void testGenerate_NoDataAsMapStructure() throws IOException { + String schema = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"response\": {\n" + + " \"type\": \"string\",\n" + + " \"x-llm-output\": true\n" + + " }\n" + + " }\n" + + "}"; + + String result = LlmResultPathGenerator.generate(schema); + + assertNotNull(result); + assertEquals("$.response", result); + } + + /** + * Test schema with no LLM text field markers or common names. + */ + @Test + public void testGenerate_NoLlmTextField() throws IOException { + String schema = "{\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"unknownField\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + String result = LlmResultPathGenerator.generate(schema); + + assertNull(result); + } + + /** + * Test deeply nested structure with marker. + */ + @Test + public void testGenerate_DeeplyNested() throws IOException { + String schema = "{\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"data\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"results\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"text\": {\n" + + " \"type\": \"string\",\n" + + " \"x-llm-output\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + String result = LlmResultPathGenerator.generate(schema); + + assertNotNull(result); + assertEquals("$.data.results[0].output.text", result); + } + + /** + * Test isValidJsonPath with valid paths. + */ + @Test + public void testIsValidJsonPath_ValidPaths() { + assertTrue(LlmResultPathGenerator.isValidJsonPath("$.content")); + assertTrue(LlmResultPathGenerator.isValidJsonPath("$.choices[0].message.content")); + assertTrue(LlmResultPathGenerator.isValidJsonPath("$.data.results[0].output.text")); + assertTrue(LlmResultPathGenerator.isValidJsonPath("$")); + } + + /** + * Test isValidJsonPath with invalid paths. + */ + @Test + public void testIsValidJsonPath_InvalidPaths() { + assertFalse(LlmResultPathGenerator.isValidJsonPath(null)); + assertFalse(LlmResultPathGenerator.isValidJsonPath("")); + assertFalse(LlmResultPathGenerator.isValidJsonPath("content")); // Missing $ + assertFalse(LlmResultPathGenerator.isValidJsonPath("$.choices[0.message.content")); // Unbalanced brackets + assertFalse(LlmResultPathGenerator.isValidJsonPath("$.choices]0[.message")); // Unbalanced brackets + } + + /** + * Test schema with multiple string fields, marker takes precedence. + */ + @Test + public void testGenerate_MarkerTakesPrecedence() throws IOException { + String schema = "{\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"text\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"content\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"actualOutput\": {\n" + + " \"type\": \"string\",\n" + + " \"x-llm-output\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + String result = LlmResultPathGenerator.generate(schema); + + assertNotNull(result); + // Should return actualOutput since it has explicit marker + assertEquals("$.actualOutput", result); + } + + /** + * Test schema with array at root level of dataAsMap. + */ + @Test + public void testGenerate_ArrayAtRootLevel() throws IOException { + String schema = "{\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"message\": {\n" + + " \"type\": \"string\",\n" + + " \"x-llm-output\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + String result = LlmResultPathGenerator.generate(schema); + + assertNotNull(result); + assertEquals("$[0].message", result); + } + + /** + * Test schema without x-llm-output marker returns null. + */ + @Test + public void testGenerate_HeuristicResponseField() throws IOException { + String schema = "{\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"response\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"otherField\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + String result = LlmResultPathGenerator.generate(schema); + + // Without x-llm-output marker, should return null + assertNull(result); + } + + /** + * Test schema with x-llm-output set to false (should be ignored and return null). + */ + @Test + public void testGenerate_MarkerSetToFalse() throws IOException { + String schema = "{\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"wrongField\": {\n" + + " \"type\": \"string\",\n" + + " \"x-llm-output\": false\n" + + " },\n" + + " \"content\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + String result = LlmResultPathGenerator.generate(schema); + + // Without x-llm-output marker set to true, should return null + assertNull(result); + } + + /** + * Test minimal valid schema structure without marker returns null. + */ + @Test + public void testGenerate_MinimalValidSchema() throws IOException { + String schema = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"text\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " }\n" + + "}"; + + String result = LlmResultPathGenerator.generate(schema); + + // Without x-llm-output marker, should return null + assertNull(result); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceSchemaTest.java b/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceSchemaTest.java new file mode 100644 index 0000000000..2fd03653cd --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceSchemaTest.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.utils; + +import static org.junit.Assert.*; + +import java.util.Map; + +import org.junit.Test; +import org.opensearch.ml.common.utils.ModelInterfaceUtils.ModelInterfaceSchema; + +/** + * Unit tests for ModelInterfaceSchema enum. + * Tests validation, schema loading, and backward compatibility. + */ +public class ModelInterfaceSchemaTest { + + @Test + public void testFromString_ValidNames() { + // Test case-insensitive lookup + ModelInterfaceSchema schema1 = ModelInterfaceSchema.fromString("bedrock_ai21_labs_jurassic2_mid_v1"); + assertEquals(ModelInterfaceSchema.BEDROCK_AI21_LABS_JURASSIC2_MID_V1, schema1); + + ModelInterfaceSchema schema2 = ModelInterfaceSchema.fromString("BEDROCK_AI21_LABS_JURASSIC2_MID_V1"); + assertEquals(ModelInterfaceSchema.BEDROCK_AI21_LABS_JURASSIC2_MID_V1, schema2); + + ModelInterfaceSchema schema3 = ModelInterfaceSchema.fromString("bedrock_anthropic_claude_v2"); + assertEquals(ModelInterfaceSchema.BEDROCK_ANTHROPIC_CLAUDE_V2, schema3); + } + + @Test(expected = IllegalArgumentException.class) + public void testFromString_InvalidName() { + ModelInterfaceSchema.fromString("invalid_schema_name"); + } + + @Test(expected = IllegalArgumentException.class) + public void testFromString_NullName() { + ModelInterfaceSchema.fromString(null); + } + + @Test(expected = IllegalArgumentException.class) + public void testFromString_BlankName() { + ModelInterfaceSchema.fromString(" "); + } + + @Test + public void testGetInterface_ReturnsValidMap() { + Map interface1 = ModelInterfaceSchema.BEDROCK_AI21_LABS_JURASSIC2_MID_V1.getInterface(); + + assertNotNull("Interface should not be null", interface1); + assertTrue("Interface should contain 'input' key", interface1.containsKey("input")); + assertTrue("Interface should contain 'output' key", interface1.containsKey("output")); + assertFalse("Input schema should not be empty", interface1.get("input").isEmpty()); + assertFalse("Output schema should not be empty", interface1.get("output").isEmpty()); + } + + @Test + public void testGetInterface_ConsistentResults() { + // Multiple calls should return equivalent maps (file I/O is cached at loadSchemaFromFile level) + Map interface1 = ModelInterfaceSchema.BEDROCK_ANTHROPIC_CLAUDE_V2.getInterface(); + Map interface2 = ModelInterfaceSchema.BEDROCK_ANTHROPIC_CLAUDE_V2.getInterface(); + + assertEquals("Results should be equal", interface1, interface2); + assertEquals("Input schemas should match", interface1.get("input"), interface2.get("input")); + assertEquals("Output schemas should match", interface1.get("output"), interface2.get("output")); + } + + @Test + public void testAllSchemas_CanLoadSuccessfully() { + // Verify all enum values can load their schemas without errors + for (ModelInterfaceSchema schema : ModelInterfaceSchema.values()) { + Map schemaInterface = schema.getInterface(); + assertNotNull("Schema should not be null: " + schema.name(), schemaInterface); + assertEquals("Schema should have exactly 2 keys", 2, schemaInterface.size()); + } + } + + @Test + public void testGetName() { + String name = ModelInterfaceSchema.BEDROCK_TITAN_EMBED_TEXT_V1.getName(); + assertEquals("bedrock_titan_embed_text_v1", name); + } + + @Test + public void testEnumValues_AllPresent() { + ModelInterfaceSchema[] values = ModelInterfaceSchema.values(); + + // Verify we have all expected schemas + assertEquals("Should have 16 schema variants", 16, values.length); + + // Verify specific schemas exist + assertNotNull(ModelInterfaceSchema.valueOf("BEDROCK_AI21_LABS_JURASSIC2_MID_V1")); + assertNotNull(ModelInterfaceSchema.valueOf("AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE")); + assertNotNull(ModelInterfaceSchema.valueOf("AMAZON_TEXTRACT_DETECTDOCUMENTTEXT")); + assertNotNull(ModelInterfaceSchema.valueOf("BEDROCK_ANTHROPIC_CLAUDE_USE_SYSTEM_PROMPT")); + assertNotNull(ModelInterfaceSchema.valueOf("OPENAI_CHAT_COMPLETIONS")); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceUtilsTest.java index 2bab51f2f6..9e912046a9 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceUtilsTest.java @@ -8,20 +8,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_AI21_LABS_JURASSIC2_MID_V1_RAW_MODEL_INTERFACE; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_COHERE_EMBED_ENGLISH_V3_RAW_MODEL_INTERFACE; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_RAW_MODEL_INTERFACE; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_RAW_MODEL_INTERFACE; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_TEXT_V1_RAW_MODEL_INTERFACE; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.ModelInterfaceSchema; import static org.opensearch.ml.common.utils.ModelInterfaceUtils.updateRegisterModelInputModelInterfaceFieldsByConnector; import java.util.HashMap; @@ -100,7 +87,10 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_A .actions(List.of(connectorActionWithPostProcessFunction)) .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); - assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.BEDROCK_AI21_LABS_JURASSIC2_MID_V1.getInterface() + ); } @Test @@ -115,7 +105,10 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_A .actions(List.of(connectorActionWithoutPostProcessFunction)) .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); - assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_AI21_LABS_JURASSIC2_MID_V1_RAW_MODEL_INTERFACE); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.BEDROCK_AI21_LABS_JURASSIC2_MID_V1_RAW.getInterface() + ); } @Test @@ -131,7 +124,10 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_A .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); - assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET.getInterface() + ); } @Test @@ -147,7 +143,10 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_A .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); - assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.BEDROCK_ANTHROPIC_CLAUDE_V2.getInterface() + ); } @Test @@ -173,7 +172,10 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_C .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); - assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.BEDROCK_COHERE_EMBED_ENGLISH_V3.getInterface() + ); } @Test @@ -188,7 +190,10 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_C .actions(List.of(connectorActionWithoutPostProcessFunction)) .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); - assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_COHERE_EMBED_ENGLISH_V3_RAW_MODEL_INTERFACE); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.BEDROCK_COHERE_EMBED_ENGLISH_V3_RAW.getInterface() + ); } @Test @@ -214,7 +219,10 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_C .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); - assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.BEDROCK_COHERE_EMBED_MULTILINGUAL_V3.getInterface() + ); } @Test @@ -231,7 +239,7 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_C updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); assertEquals( registerModelInputWithStandaloneConnector.getModelInterface(), - BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_RAW_MODEL_INTERFACE + ModelInterfaceSchema.BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_RAW.getInterface() ); } @@ -258,7 +266,10 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_T .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); - assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.BEDROCK_TITAN_EMBED_TEXT_V1.getInterface() + ); } @Test @@ -273,7 +284,10 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_T .actions(List.of(connectorActionWithoutPostProcessFunction)) .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); - assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_TITAN_EMBED_TEXT_V1_RAW_MODEL_INTERFACE); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.BEDROCK_TITAN_EMBED_TEXT_V1_RAW.getInterface() + ); } @Test @@ -299,7 +313,10 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_T .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); - assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.BEDROCK_TITAN_EMBED_MULTI_MODAL_V1.getInterface() + ); } @Test @@ -314,7 +331,222 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_T .actions(List.of(connectorActionWithoutPostProcessFunction)) .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); - assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_RAW_MODEL_INTERFACE); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_RAW.getInterface() + ); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_ANTHROPIC_CLAUDE_3_7_SONNET_WITH_SYSTEM_PROMPT() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "us.anthropic.claude-3-7-sonnet-20250219-v1:0"); + parameters.put("use_system_prompt", "true"); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.BEDROCK_ANTHROPIC_CLAUDE_USE_SYSTEM_PROMPT.getInterface() + ); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_ANTHROPIC_CLAUDE_3_7_SONNET_WITHOUT_SYSTEM_PROMPT() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "us.anthropic.claude-3-7-sonnet-20250219-v1:0"); + parameters.put("use_system_prompt", "false"); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertNull(registerModelInputWithStandaloneConnector.getModelInterface()); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_ANTHROPIC_CLAUDE_3_7_SONNET_MISSING_SYSTEM_PROMPT() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "us.anthropic.claude-3-7-sonnet-20250219-v1:0"); + // use_system_prompt parameter not set + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertNull(registerModelInputWithStandaloneConnector.getModelInterface()); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_ANTHROPIC_CLAUDE_SONNET_4_WITH_SYSTEM_PROMPT() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "us.anthropic.claude-sonnet-4-20250514-v1:0"); + parameters.put("use_system_prompt", "true"); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.BEDROCK_ANTHROPIC_CLAUDE_USE_SYSTEM_PROMPT.getInterface() + ); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_ANTHROPIC_CLAUDE_SONNET_4_WITHOUT_SYSTEM_PROMPT() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "us.anthropic.claude-sonnet-4-20250514-v1:0"); + parameters.put("use_system_prompt", "false"); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertNull(registerModelInputWithStandaloneConnector.getModelInterface()); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_ANTHROPIC_CLAUDE_SONNET_4_MISSING_SYSTEM_PROMPT() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "us.anthropic.claude-sonnet-4-20250514-v1:0"); + // use_system_prompt parameter not set + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertNull(registerModelInputWithStandaloneConnector.getModelInterface()); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorOPENAI_GPT_3_5_TURBO() { + Map parameters = new HashMap<>(); + parameters.put("model", "gpt-3.5-turbo"); + + ConnectorAction openaiAction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("https://api.openai.com/v1/chat/completions") + .requestBody("{\"model\": \"${parameters.model}\", \"messages\": \"${parameters.messages}\"}") + .build(); + + connector = HttpConnector.builder().protocol("http").parameters(parameters).actions(List.of(openaiAction)).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.OPENAI_CHAT_COMPLETIONS.getInterface() + ); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorOPENAI_GPT_4O_MINI() { + Map parameters = new HashMap<>(); + parameters.put("model", "gpt-4o-mini"); + + ConnectorAction openaiAction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("https://api.openai.com/v1/chat/completions") + .requestBody("{\"model\": \"${parameters.model}\", \"messages\": \"${parameters.messages}\"}") + .build(); + + connector = HttpConnector.builder().protocol("http").parameters(parameters).actions(List.of(openaiAction)).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.OPENAI_CHAT_COMPLETIONS.getInterface() + ); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorOPENAI_GPT_5() { + Map parameters = new HashMap<>(); + parameters.put("model", "gpt-5"); + + ConnectorAction openaiAction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("https://api.openai.com/v1/chat/completions") + .requestBody("{\"model\": \"${parameters.model}\", \"messages\": \"${parameters.messages}\"}") + .build(); + + connector = HttpConnector.builder().protocol("http").parameters(parameters).actions(List.of(openaiAction)).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.OPENAI_CHAT_COMPLETIONS.getInterface() + ); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorOPENAI_WRONG_ENDPOINT() { + Map parameters = new HashMap<>(); + parameters.put("model", "gpt-3.5-turbo"); + + ConnectorAction openaiAction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("https://api.openai.com/v1/completions") + .requestBody("{\"model\": \"${parameters.model}\", \"prompt\": \"${parameters.prompt}\"}") + .build(); + + connector = HttpConnector.builder().protocol("http").parameters(parameters).actions(List.of(openaiAction)).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertNull(registerModelInputWithStandaloneConnector.getModelInterface()); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorOPENAI_WRONG_MODEL() { + Map parameters = new HashMap<>(); + parameters.put("model", "gpt-4"); + + ConnectorAction openaiAction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("https://api.openai.com/v1/chat/completions") + .requestBody("{\"model\": \"${parameters.model}\", \"messages\": \"${parameters.messages}\"}") + .build(); + + connector = HttpConnector.builder().protocol("http").parameters(parameters).actions(List.of(openaiAction)).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertNull(registerModelInputWithStandaloneConnector.getModelInterface()); } @Test @@ -332,7 +564,7 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorAMAZON_CO updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); assertEquals( registerModelInputWithStandaloneConnector.getModelInterface(), - AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE + ModelInterfaceSchema.AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE.getInterface() ); } @@ -349,7 +581,10 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorAMAZON_TE .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); - assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + ModelInterfaceSchema.AMAZON_TEXTRACT_DETECTDOCUMENTTEXT.getInterface() + ); } @Test @@ -419,7 +654,10 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorNullParam .build(); registerModelInputWithInnerConnector.setConnector(connector); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithInnerConnector); - assertEquals(registerModelInputWithInnerConnector.getModelInterface(), BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE); + assertEquals( + registerModelInputWithInnerConnector.getModelInterface(), + ModelInterfaceSchema.BEDROCK_AI21_LABS_JURASSIC2_MID_V1.getInterface() + ); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/utils/message/MessageFormatterFactoryTests.java b/common/src/test/java/org/opensearch/ml/common/utils/message/MessageFormatterFactoryTests.java new file mode 100644 index 0000000000..db92631a7a --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/utils/message/MessageFormatterFactoryTests.java @@ -0,0 +1,304 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.utils.message; + +import static org.junit.Assert.*; + +import org.junit.Test; + +/** + * Unit tests for MessageFormatterFactory. + * Tests the factory's ability to select the correct formatter based on input schema. + */ +public class MessageFormatterFactoryTests { + + // Claude schema example - contains "system_prompt" parameter + private static final String CLAUDE_SCHEMA = "{" + + "\"type\":\"object\"," + + "\"properties\":{" + + "\"system_prompt\":{\"type\":\"string\"}," + + "\"messages\":{\"type\":\"array\"}" + + "}" + + "}"; + + // OpenAI schema example - NO "system_prompt" parameter + private static final String OPENAI_SCHEMA = "{" + + "\"type\":\"object\"," + + "\"properties\":{" + + "\"messages\":{\"type\":\"array\"}," + + "\"temperature\":{\"type\":\"number\"}" + + "}" + + "}"; + + @Test + public void testFactoryWithClaudeSchema() { + MessageFormatter formatter = MessageFormatterFactory.getFormatter(CLAUDE_SCHEMA); + + assertNotNull("Formatter should not be null", formatter); + assertTrue("Should return ClaudeMessageFormatter", formatter instanceof ClaudeMessageFormatter); + } + + @Test + public void testFactoryWithOpenAISchema() { + MessageFormatter formatter = MessageFormatterFactory.getFormatter(OPENAI_SCHEMA); + + assertNotNull("Formatter should not be null", formatter); + assertTrue("Should return OpenAIMessageFormatter", formatter instanceof OpenAIMessageFormatter); + } + + @Test + public void testFactoryWithNullSchema() { + MessageFormatter formatter = MessageFormatterFactory.getFormatter(null); + + assertNotNull("Formatter should not be null", formatter); + // Default to Claude formatter + assertTrue("Should default to ClaudeMessageFormatter", formatter instanceof ClaudeMessageFormatter); + } + + @Test + public void testFactoryWithBlankSchema() { + MessageFormatter formatter = MessageFormatterFactory.getFormatter(""); + + assertNotNull("Formatter should not be null", formatter); + // Default to Claude formatter + assertTrue("Should default to ClaudeMessageFormatter", formatter instanceof ClaudeMessageFormatter); + } + + @Test + public void testFactoryWithWhitespaceSchema() { + MessageFormatter formatter = MessageFormatterFactory.getFormatter(" "); + + assertNotNull("Formatter should not be null", formatter); + // Default to Claude formatter + assertTrue("Should default to ClaudeMessageFormatter", formatter instanceof ClaudeMessageFormatter); + } + + @Test + public void testFactoryWithMalformedSchema() { + // Invalid JSON - missing closing brace + String malformedSchema = "{\"properties\":{\"system_prompt\":{\"type\":\"string\"}"; + + MessageFormatter formatter = MessageFormatterFactory.getFormatter(malformedSchema); + + assertNotNull("Formatter should not be null", formatter); + // Even with malformed JSON, should return formatter (checks for string presence) + assertTrue("Should return ClaudeMessageFormatter (contains system_prompt)", formatter instanceof ClaudeMessageFormatter); + } + + @Test + public void testFactoryWithPartialClaudeSchema() { + // Schema that contains system_prompt in a string value but not as a field name + String partialSchema = "{\"description\":\"This schema has system_prompt field\"}"; + + MessageFormatter formatter = MessageFormatterFactory.getFormatter(partialSchema); + + assertNotNull("Formatter should not be null", formatter); + // Factory checks for "\"system_prompt\"" (with quotes), not just "system_prompt" + // This schema contains system_prompt only in a description, not as a field name + assertTrue("Should return OpenAIMessageFormatter (no quoted field name)", formatter instanceof OpenAIMessageFormatter); + } + + @Test + public void testFactoryWithComplexClaudeSchema() { + String complexSchema = "{" + + "\"$schema\":\"http://json-schema.org/draft-07/schema#\"," + + "\"type\":\"object\"," + + "\"properties\":{" + + "\"model\":{\"type\":\"string\"}," + + "\"system_prompt\":{\"type\":\"string\",\"description\":\"System prompt\"}," + + "\"messages\":{\"type\":\"array\"}," + + "\"temperature\":{\"type\":\"number\"}" + + "}," + + "\"required\":[\"messages\"]" + + "}"; + + MessageFormatter formatter = MessageFormatterFactory.getFormatter(complexSchema); + + assertNotNull("Formatter should not be null", formatter); + assertTrue("Should return ClaudeMessageFormatter", formatter instanceof ClaudeMessageFormatter); + } + + @Test + public void testFactoryWithComplexOpenAISchema() { + String complexSchema = "{" + + "\"$schema\":\"http://json-schema.org/draft-07/schema#\"," + + "\"type\":\"object\"," + + "\"properties\":{" + + "\"model\":{\"type\":\"string\"}," + + "\"messages\":{\"type\":\"array\",\"items\":{\"type\":\"object\"}}," + + "\"temperature\":{\"type\":\"number\"}," + + "\"max_tokens\":{\"type\":\"integer\"}" + + "}," + + "\"required\":[\"model\",\"messages\"]" + + "}"; + + MessageFormatter formatter = MessageFormatterFactory.getFormatter(complexSchema); + + assertNotNull("Formatter should not be null", formatter); + assertTrue("Should return OpenAIMessageFormatter", formatter instanceof OpenAIMessageFormatter); + } + + @Test + public void testFactoryReturnsClaudeFormatter() { + MessageFormatter formatter = MessageFormatterFactory.getClaudeFormatter(); + + assertNotNull("Formatter should not be null", formatter); + assertTrue("Should return ClaudeMessageFormatter", formatter instanceof ClaudeMessageFormatter); + } + + @Test + public void testFactoryReturnsOpenAIFormatter() { + MessageFormatter formatter = MessageFormatterFactory.getOpenAIFormatter(); + + assertNotNull("Formatter should not be null", formatter); + assertTrue("Should return OpenAIMessageFormatter", formatter instanceof OpenAIMessageFormatter); + } + + @Test + public void testFactorySingletonPattern() { + // Get Claude formatter multiple times + MessageFormatter claude1 = MessageFormatterFactory.getClaudeFormatter(); + MessageFormatter claude2 = MessageFormatterFactory.getClaudeFormatter(); + + // Should be the same instance (singleton) + assertSame("Should return same Claude formatter instance", claude1, claude2); + + // Get OpenAI formatter multiple times + MessageFormatter openai1 = MessageFormatterFactory.getOpenAIFormatter(); + MessageFormatter openai2 = MessageFormatterFactory.getOpenAIFormatter(); + + // Should be the same instance (singleton) + assertSame("Should return same OpenAI formatter instance", openai1, openai2); + + // Claude and OpenAI should be different instances + assertNotSame("Claude and OpenAI formatters should be different", claude1, openai1); + } + + @Test + public void testFactoryNeverReturnsNull() { + // Test various edge cases to ensure factory never returns null + + String[] testSchemas = { null, "", " ", "invalid json{{{", "{}", "{\"random\":\"schema\"}", CLAUDE_SCHEMA, OPENAI_SCHEMA }; + + for (String schema : testSchemas) { + MessageFormatter formatter = MessageFormatterFactory.getFormatter(schema); + assertNotNull("Factory should never return null for schema: " + schema, formatter); + } + } + + @Test + public void testFactorySchemaDetection_CaseSensitive() { + // Test with different casings - should be case-sensitive + String upperCaseSchema = "{\"SYSTEM_PROMPT\":\"value\"}"; + MessageFormatter formatter1 = MessageFormatterFactory.getFormatter(upperCaseSchema); + // Should NOT match (case sensitive) - defaults to Claude anyway + assertNotNull("Should return formatter", formatter1); + + String mixedCaseSchema = "{\"System_Prompt\":\"value\"}"; + MessageFormatter formatter2 = MessageFormatterFactory.getFormatter(mixedCaseSchema); + // Should NOT match exact string "system_prompt" + assertNotNull("Should return formatter", formatter2); + + String correctCaseSchema = "{\"system_prompt\":\"value\"}"; + MessageFormatter formatter3 = MessageFormatterFactory.getFormatter(correctCaseSchema); + assertTrue("Should match exact case", formatter3 instanceof ClaudeMessageFormatter); + } + + @Test + public void testFactoryWithMinimalClaudeSchema() { + String minimalSchema = "{\"system_prompt\":\"\"}"; + + MessageFormatter formatter = MessageFormatterFactory.getFormatter(minimalSchema); + + assertNotNull("Formatter should not be null", formatter); + assertTrue("Should return ClaudeMessageFormatter", formatter instanceof ClaudeMessageFormatter); + } + + @Test + public void testFactoryWithMinimalOpenAISchema() { + String minimalSchema = "{\"messages\":[]}"; + + MessageFormatter formatter = MessageFormatterFactory.getFormatter(minimalSchema); + + assertNotNull("Formatter should not be null", formatter); + assertTrue("Should return OpenAIMessageFormatter", formatter instanceof OpenAIMessageFormatter); + } + + @Test + public void testFactorySchemaWithSystemPromptInDescription() { + // Edge case: "system_prompt" appears in description but not as a field name + String schema = "{" + + "\"properties\":{" + + "\"messages\":{\"type\":\"array\",\"description\":\"Messages array, not system_prompt\"}" + + "}" + + "}"; + + MessageFormatter formatter = MessageFormatterFactory.getFormatter(schema); + + assertNotNull("Formatter should not be null", formatter); + // Factory checks for "\"system_prompt\"" (with quotes as a field name) + // This schema only has system_prompt in a description string, not as a field name + assertTrue("Should return OpenAIMessageFormatter (system_prompt not a field name)", formatter instanceof OpenAIMessageFormatter); + } + + @Test + public void testFactoryRealWorldClaudeSchema() { + // Real-world Claude schema example + String realClaudeSchema = "{" + + "\"type\":\"object\"," + + "\"properties\":{" + + "\"anthropic_version\":{\"type\":\"string\"}," + + "\"max_tokens\":{\"type\":\"integer\"}," + + "\"system_prompt\":{\"type\":\"string\"}," + + "\"messages\":{" + + "\"type\":\"array\"," + + "\"items\":{" + + "\"type\":\"object\"," + + "\"properties\":{" + + "\"role\":{\"type\":\"string\"}," + + "\"content\":{\"type\":\"array\"}" + + "}" + + "}" + + "}" + + "}" + + "}"; + + MessageFormatter formatter = MessageFormatterFactory.getFormatter(realClaudeSchema); + + assertNotNull("Formatter should not be null", formatter); + assertTrue("Should return ClaudeMessageFormatter", formatter instanceof ClaudeMessageFormatter); + } + + @Test + public void testFactoryRealWorldOpenAISchema() { + // Real-world OpenAI schema example (Chat Completions API) + String realOpenAISchema = "{" + + "\"type\":\"object\"," + + "\"properties\":{" + + "\"model\":{\"type\":\"string\"}," + + "\"messages\":{" + + "\"type\":\"array\"," + + "\"items\":{" + + "\"type\":\"object\"," + + "\"properties\":{" + + "\"role\":{\"type\":\"string\",\"enum\":[\"system\",\"user\",\"assistant\"]}," + + "\"content\":{\"type\":\"string\"}" + + "}," + + "\"required\":[\"role\",\"content\"]" + + "}" + + "}," + + "\"temperature\":{\"type\":\"number\",\"minimum\":0,\"maximum\":2}," + + "\"max_tokens\":{\"type\":\"integer\"}" + + "}," + + "\"required\":[\"model\",\"messages\"]" + + "}"; + + MessageFormatter formatter = MessageFormatterFactory.getFormatter(realOpenAISchema); + + assertNotNull("Formatter should not be null", formatter); + assertTrue("Should return OpenAIMessageFormatter", formatter instanceof OpenAIMessageFormatter); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/utils/message/MessageFormatterTests.java b/common/src/test/java/org/opensearch/ml/common/utils/message/MessageFormatterTests.java new file mode 100644 index 0000000000..8bbec171da --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/utils/message/MessageFormatterTests.java @@ -0,0 +1,445 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.utils.message; + +import static org.junit.Assert.*; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.json.JSONArray; +import org.json.JSONObject; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.ml.common.transport.memorycontainer.memory.MessageInput; + +/** + * Comprehensive unit tests for MessageFormatter implementations. + * Tests both ClaudeMessageFormatter and OpenAIMessageFormatter. + */ +public class MessageFormatterTests { + + private ClaudeMessageFormatter claudeFormatter; + private OpenAIMessageFormatter openaiFormatter; + + @Before + public void setUp() { + claudeFormatter = new ClaudeMessageFormatter(); + openaiFormatter = new OpenAIMessageFormatter(); + } + + // ============================================================ + // ClaudeMessageFormatter Tests + // ============================================================ + + @Test + public void testClaudeFormatter_SystemPromptInParameters() { + String systemPrompt = "You are a helpful assistant."; + List messages = createSimpleUserMessage("Hello"); + + Map result = claudeFormatter.formatRequest(systemPrompt, messages, null); + + // Verify system_prompt parameter exists + assertTrue("Should contain system_prompt parameter", result.containsKey("system_prompt")); + assertEquals("System prompt should match", systemPrompt, result.get("system_prompt")); + + // Verify messages parameter exists + assertTrue("Should contain messages parameter", result.containsKey("messages")); + } + + @Test + public void testClaudeFormatter_MessagesArrayStructure() { + String systemPrompt = "You are a helpful assistant."; + List messages = createSimpleUserMessage("Hello"); + + Map result = claudeFormatter.formatRequest(systemPrompt, messages, null); + + // Parse messages JSON + String messagesJson = result.get("messages"); + assertNotNull("Messages should not be null", messagesJson); + + JSONArray messagesArray = new JSONArray(messagesJson); + assertEquals("Should have 1 message", 1, messagesArray.length()); + + JSONObject firstMessage = messagesArray.getJSONObject(0); + assertEquals("Role should be user", "user", firstMessage.getString("role")); + + // Verify NO system role in messages array + for (int i = 0; i < messagesArray.length(); i++) { + JSONObject msg = messagesArray.getJSONObject(i); + assertNotEquals("Should NOT have system role in messages", "system", msg.getString("role")); + } + } + + @Test + public void testClaudeFormatter_BlankSystemPrompt() { + List messages = createSimpleUserMessage("Hello"); + + Map result = claudeFormatter.formatRequest("", messages, null); + + // Should NOT contain system_prompt if blank + assertFalse("Should not contain empty system_prompt", result.containsKey("system_prompt")); + + // Should still have messages + assertTrue("Should contain messages", result.containsKey("messages")); + } + + @Test + public void testClaudeFormatter_NullSystemPrompt() { + List messages = createSimpleUserMessage("Hello"); + + Map result = claudeFormatter.formatRequest(null, messages, null); + + // Should NOT contain system_prompt if null + assertFalse("Should not contain null system_prompt", result.containsKey("system_prompt")); + } + + @Test + public void testClaudeFormatter_ContentNormalization_StandardFormat() { + // Content object WITH "type" field - should remain unchanged + Map contentWithType = new HashMap<>(); + contentWithType.put("type", "text"); + contentWithType.put("text", "Hello world"); + + List> content = new ArrayList<>(); + content.add(contentWithType); + + List> processed = claudeFormatter.processContent(content); + + assertEquals("Should have 1 content object", 1, processed.size()); + assertTrue("Should contain type field", processed.get(0).containsKey("type")); + assertEquals("Type should be text", "text", processed.get(0).get("type")); + assertEquals("Text should match", "Hello world", processed.get(0).get("text")); + } + + @Test + public void testClaudeFormatter_ContentNormalization_UserDefinedObject() { + // Content object WITHOUT "type" field - should be wrapped + Map contentWithoutType = new HashMap<>(); + contentWithoutType.put("custom_field", "custom_value"); + contentWithoutType.put("another_field", 123); + + List> content = new ArrayList<>(); + content.add(contentWithoutType); + + List> processed = claudeFormatter.processContent(content); + + assertEquals("Should have 1 wrapped content object", 1, processed.size()); + assertTrue("Should contain type field", processed.get(0).containsKey("type")); + assertEquals("Type should be text", "text", processed.get(0).get("type")); + assertTrue("Should contain text field", processed.get(0).containsKey("text")); + + // The "text" field should contain JSON string of original object + String jsonText = (String) processed.get(0).get("text"); + assertNotNull("Text should not be null", jsonText); + assertTrue("Should contain custom_field", jsonText.contains("custom_field")); + } + + @Test + public void testClaudeFormatter_ContentNormalization_MixedFormat() { + List> content = new ArrayList<>(); + + // Standard format object + Map standard = new HashMap<>(); + standard.put("type", "text"); + standard.put("text", "Hello"); + content.add(standard); + + // User-defined object + Map custom = new HashMap<>(); + custom.put("custom_key", "custom_value"); + content.add(custom); + + List> processed = claudeFormatter.processContent(content); + + assertEquals("Should have 2 content objects", 2, processed.size()); + + // First should be unchanged + assertEquals("First should be unchanged", standard, processed.get(0)); + + // Second should be wrapped + assertTrue("Second should have type field", processed.get(1).containsKey("type")); + assertEquals("Second type should be text", "text", processed.get(1).get("type")); + } + + @Test + public void testClaudeFormatter_EmptyMessagesList() { + String systemPrompt = "You are helpful."; + List emptyMessages = new ArrayList<>(); + + Map result = claudeFormatter.formatRequest(systemPrompt, emptyMessages, null); + + assertTrue("Should contain system_prompt", result.containsKey("system_prompt")); + assertTrue("Should contain messages", result.containsKey("messages")); + + JSONArray messagesArray = new JSONArray(result.get("messages")); + assertEquals("Messages array should be empty", 0, messagesArray.length()); + } + + @Test + public void testClaudeFormatter_MultipleMessages() { + String systemPrompt = "You are helpful."; + List messages = new ArrayList<>(); + messages.add(createMessageInput("user", "Hello")); + messages.add(createMessageInput("assistant", "Hi there!")); + messages.add(createMessageInput("user", "How are you?")); + + Map result = claudeFormatter.formatRequest(systemPrompt, messages, null); + + JSONArray messagesArray = new JSONArray(result.get("messages")); + assertEquals("Should have 3 messages", 3, messagesArray.length()); + + assertEquals("First role should be user", "user", messagesArray.getJSONObject(0).getString("role")); + assertEquals("Second role should be assistant", "assistant", messagesArray.getJSONObject(1).getString("role")); + assertEquals("Third role should be user", "user", messagesArray.getJSONObject(2).getString("role")); + } + + // ============================================================ + // OpenAIMessageFormatter Tests + // ============================================================ + + @Test + public void testOpenAIFormatter_SystemPromptAsMessage() { + String systemPrompt = "You are a helpful assistant."; + List messages = createSimpleUserMessage("Hello"); + + Map result = openaiFormatter.formatRequest(systemPrompt, messages, null); + + // Parse messages JSON + String messagesJson = result.get("messages"); + assertNotNull("Messages should not be null", messagesJson); + + JSONArray messagesArray = new JSONArray(messagesJson); + + // Should have 2 messages: system + user + assertEquals("Should have 2 messages (system + user)", 2, messagesArray.length()); + + // First message should be system + JSONObject firstMessage = messagesArray.getJSONObject(0); + assertEquals("First message should have system role", "system", firstMessage.getString("role")); + + // Verify system message content + JSONArray systemContent = firstMessage.getJSONArray("content"); + assertEquals("System content should have 1 item", 1, systemContent.length()); + JSONObject systemContentObj = systemContent.getJSONObject(0); + assertEquals("System content type should be text", "text", systemContentObj.getString("type")); + assertEquals("System content text should match", systemPrompt, systemContentObj.getString("text")); + + // Second message should be user + JSONObject secondMessage = messagesArray.getJSONObject(1); + assertEquals("Second message should have user role", "user", secondMessage.getString("role")); + } + + @Test + public void testOpenAIFormatter_NoSystemPromptParameter() { + String systemPrompt = "You are helpful."; + List messages = createSimpleUserMessage("Hello"); + + Map result = openaiFormatter.formatRequest(systemPrompt, messages, null); + + // Should NOT have system_prompt parameter + assertFalse("Should NOT contain system_prompt parameter", result.containsKey("system_prompt")); + + // Should only have messages parameter + assertTrue("Should contain messages parameter", result.containsKey("messages")); + assertEquals("Should only have 1 parameter", 1, result.size()); + } + + @Test + public void testOpenAIFormatter_BlankSystemPrompt() { + List messages = createSimpleUserMessage("Hello"); + + Map result = openaiFormatter.formatRequest("", messages, null); + + // Parse messages + JSONArray messagesArray = new JSONArray(result.get("messages")); + + // Should only have user message, NO system message + assertEquals("Should only have 1 message (user)", 1, messagesArray.length()); + assertEquals("First message should be user", "user", messagesArray.getJSONObject(0).getString("role")); + } + + @Test + public void testOpenAIFormatter_NullSystemPrompt() { + List messages = createSimpleUserMessage("Hello"); + + Map result = openaiFormatter.formatRequest(null, messages, null); + + // Parse messages + JSONArray messagesArray = new JSONArray(result.get("messages")); + + // Should only have user message + assertEquals("Should only have 1 message (user)", 1, messagesArray.length()); + assertEquals("First message should be user", "user", messagesArray.getJSONObject(0).getString("role")); + } + + @Test + public void testOpenAIFormatter_ContentNormalization_StandardFormat() { + // Same logic as Claude formatter + Map contentWithType = new HashMap<>(); + contentWithType.put("type", "text"); + contentWithType.put("text", "Hello world"); + + List> content = new ArrayList<>(); + content.add(contentWithType); + + List> processed = openaiFormatter.processContent(content); + + assertEquals("Should have 1 content object", 1, processed.size()); + assertEquals("Should be unchanged", contentWithType, processed.get(0)); + } + + @Test + public void testOpenAIFormatter_ContentNormalization_UserDefinedObject() { + // Same logic as Claude formatter + Map contentWithoutType = new HashMap<>(); + contentWithoutType.put("custom_field", "custom_value"); + + List> content = new ArrayList<>(); + content.add(contentWithoutType); + + List> processed = openaiFormatter.processContent(content); + + assertEquals("Should have 1 wrapped object", 1, processed.size()); + assertTrue("Should have type field", processed.get(0).containsKey("type")); + assertEquals("Type should be text", "text", processed.get(0).get("type")); + assertTrue("Should have text field", processed.get(0).containsKey("text")); + } + + @Test + public void testOpenAIFormatter_MultipleMessages() { + String systemPrompt = "You are helpful."; + List messages = new ArrayList<>(); + messages.add(createMessageInput("user", "Hello")); + messages.add(createMessageInput("assistant", "Hi!")); + + Map result = openaiFormatter.formatRequest(systemPrompt, messages, null); + + JSONArray messagesArray = new JSONArray(result.get("messages")); + + // Should have system + user + assistant = 3 messages + assertEquals("Should have 3 messages", 3, messagesArray.length()); + assertEquals("First should be system", "system", messagesArray.getJSONObject(0).getString("role")); + assertEquals("Second should be user", "user", messagesArray.getJSONObject(1).getString("role")); + assertEquals("Third should be assistant", "assistant", messagesArray.getJSONObject(2).getString("role")); + } + + @Test + public void testOpenAIFormatter_EmptyMessagesList() { + String systemPrompt = "You are helpful."; + List emptyMessages = new ArrayList<>(); + + Map result = openaiFormatter.formatRequest(systemPrompt, emptyMessages, null); + + JSONArray messagesArray = new JSONArray(result.get("messages")); + + // Should only have system message + assertEquals("Should have 1 message (system)", 1, messagesArray.length()); + assertEquals("Should be system message", "system", messagesArray.getJSONObject(0).getString("role")); + } + + // ============================================================ + // Edge Cases and Error Handling + // ============================================================ + + @Test + public void testClaudeFormatter_NullContent() { + // Test with null content list - implementation returns null for null input + List> processed = claudeFormatter.processContent(null); + assertNull("Should return null for null input", processed); + } + + @Test + public void testOpenAIFormatter_NullContent() { + // Test with null content list - implementation returns null for null input + List> processed = openaiFormatter.processContent(null); + assertNull("Should return null for null input", processed); + } + + @Test + public void testClaudeFormatter_EmptyContent() { + List> emptyContent = new ArrayList<>(); + List> processed = claudeFormatter.processContent(emptyContent); + + assertNotNull("Should return non-null result", processed); + assertEquals("Should return empty list", 0, processed.size()); + } + + @Test + public void testClaudeFormatter_WithAdditionalConfig() { + String systemPrompt = "You are helpful."; + List messages = createSimpleUserMessage("Hello"); + Map additionalConfig = new HashMap<>(); + additionalConfig.put("temperature", 0.7); + additionalConfig.put("max_tokens", 100); + + Map result = claudeFormatter.formatRequest(systemPrompt, messages, additionalConfig); + + // Should still have basic parameters + assertTrue("Should contain system_prompt", result.containsKey("system_prompt")); + assertTrue("Should contain messages", result.containsKey("messages")); + } + + @Test + public void testOpenAIFormatter_WithAdditionalConfig() { + String systemPrompt = "You are helpful."; + List messages = createSimpleUserMessage("Hello"); + Map additionalConfig = new HashMap<>(); + additionalConfig.put("temperature", 0.7); + + Map result = openaiFormatter.formatRequest(systemPrompt, messages, additionalConfig); + + // Should have messages parameter + assertTrue("Should contain messages", result.containsKey("messages")); + assertFalse("Should NOT contain system_prompt", result.containsKey("system_prompt")); + } + + @Test + public void testClaudeFormatter_ComplexContent() { + // Test with complex nested content + Map complexContent = new HashMap<>(); + Map nestedMap = new HashMap<>(); + nestedMap.put("nested_key", "nested_value"); + complexContent.put("data", nestedMap); + complexContent.put("array", List.of("item1", "item2")); + + List> content = new ArrayList<>(); + content.add(complexContent); + + List> processed = claudeFormatter.processContent(content); + + assertEquals("Should have 1 wrapped object", 1, processed.size()); + assertTrue("Should have type field", processed.get(0).containsKey("type")); + assertEquals("Type should be text", "text", processed.get(0).get("type")); + assertTrue("Should have text field", processed.get(0).containsKey("text")); + + // Verify the wrapped JSON contains the original data + String wrappedJson = (String) processed.get(0).get("text"); + assertTrue("Should contain nested data", wrappedJson.contains("nested_key") || wrappedJson.contains("data")); + } + + // ============================================================ + // Helper Methods + // ============================================================ + + private List createSimpleUserMessage(String text) { + List messages = new ArrayList<>(); + messages.add(createMessageInput("user", text)); + return messages; + } + + private MessageInput createMessageInput(String role, String text) { + List> content = new ArrayList<>(); + Map textContent = new HashMap<>(); + textContent.put("type", "text"); + textContent.put("text", text); + content.add(textContent); + + return MessageInput.builder().role(role).content(content).build(); + } +} diff --git a/docs/remote_inference_blueprints/bedrock_connector_anthropic_claude3.7_use_system_prompt_blueprint.md b/docs/remote_inference_blueprints/bedrock_connector_anthropic_claude3.7_use_system_prompt_blueprint.md new file mode 100644 index 0000000000..03eed8e4cc --- /dev/null +++ b/docs/remote_inference_blueprints/bedrock_connector_anthropic_claude3.7_use_system_prompt_blueprint.md @@ -0,0 +1,355 @@ +# Bedrock connector blueprint example for Claude 3.7 model + +Anthropic's Claude 3.7 Sonnet model is now available on Amazon Bedrock. For more details, check out this [blog](https://aws.amazon.com/blogs/aws/anthropics-claude-3-7-sonnet-the-first-hybrid-reasoning-model-is-now-available-in-amazon-bedrock/). + +Claude 3.7 is Anthropic's first hybrid reasoning model, supporting two modes: standard and extended thinking. This doc covers both modes. + +## 1. Add connector endpoint to trusted URLs: + +Note: This step is only necessary for OpenSearch versions prior to 2.11.0. + +```json +PUT /_cluster/settings +{ + "persistent": { + "plugins.ml_commons.trusted_connector_endpoints_regex": [ + "^https://bedrock-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$" + ] + } +} +``` + +## 2. Standard mode +### 2.1 Create connector + +If you are using self-managed Opensearch, you should supply AWS credentials: + +Note: +1. User needs to use [inference profile](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html) for invocation of this model. We can see the profile ID for Claude 3.7 is `us.anthropic.claude-3-7-sonnet-20250219-v1:0` for three available US regions `us-east-1`, `us-east-2`, `us-west-2`. + +```json +POST /_plugins/_ml/connectors/_create +{ + "name": "Amazon Bedrock claude v3.7", + "description": "Test connector for Amazon Bedrock claude v3.7", + "version": 1, + "protocol": "aws_sigv4", + "credential": { + "access_key": "", + "secret_key": "", + "session_token": "" + }, + "parameters": { + "region": "", + "service_name": "bedrock", + "max_tokens": 8000, + "temperature": 1, + "anthropic_version": "bedrock-2023-05-31", + "model": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "use_system_prompt": true + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "headers": { + "content-type": "application/json" + }, + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "request_body": "{ \"system\": \"${parameters.system_prompt}\", \"anthropic_version\": \"${parameters.anthropic_version}\", \"max_tokens\": ${parameters.max_tokens}, \"temperature\": ${parameters.temperature}, \"messages\": ${parameters.messages} }" + } + ] +} +``` + +If using the AWS Opensearch Service, you can provide an IAM role arn that allows access to the bedrock service. +Refer to this [AWS doc](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/ml-amazon-connector.html) + +```json +POST /_plugins/_ml/connectors/_create +{ + "name": "Amazon Bedrock claude v3.7", + "description": "Test connector for Amazon Bedrock claude v3.7", + "version": 1, + "protocol": "aws_sigv4", + "credential": { + "roleArn": "" + }, + "parameters": { + "region": "", + "service_name": "bedrock", + "max_tokens": 8000, + "temperature": 1, + "anthropic_version": "bedrock-2023-05-31", + "model": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "use_system_prompt": true + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "headers": { + "content-type": "application/json" + }, + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "request_body": "{ \"system\": \"${parameters.system_prompt}\", \"anthropic_version\": \"${parameters.anthropic_version}\", \"max_tokens\": ${parameters.max_tokens}, \"temperature\": ${parameters.temperature}, \"messages\": ${parameters.messages} }" + } + ] +} +``` + +Sample response: +```json +{ + "connector_id": "fa5tP5UBX2k07okSp89B" +} +``` + +### 2.2 Register model + +```json +POST /_plugins/_ml/models/_register?deploy=true +{ + "name": "anthropic.claude-v3.7", + "function_name": "remote", + "description": "claude v3.7 model", + "connector_id": "fa5tP5UBX2k07okSp89B" +} +``` + +Sample response: +```json +{ + "task_id": "fq5uP5UBX2k07okSFM__", + "status": "CREATED", + "model_id": "f65uP5UBX2k07okSFc8P" +} +``` + +### 2.3 Test model inference + +```json +POST /_plugins/_ml/models/f65uP5UBX2k07okSFc8P/_predict +{ + "parameters": { + "system_prompt": "You are a helpful assistant.", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "hello world" + } + ] + } + ] + } +} +``` + +Sample response: +```json +{ + "inference_results": [ + { + "output": [ + { + "name": "response", + "dataAsMap": { + "id": "msg_bdrk_012spGFGr4CcD1PWb2TSfYut", + "type": "message", + "role": "assistant", + "model": "claude-3-7-sonnet-20250219", + "content": [ + { + "type": "text", + "text": "Hello! It's nice to meet you. How can I help you today?" + } + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 9.0, + "cache_creation_input_tokens": 0.0, + "cache_read_input_tokens": 0.0, + "output_tokens": 19.0 + } + } + } + ], + "status_code": 200 + } + ] +} +``` + +## 3. Extended thinking mode + +Extended thinking mode allows Claude 3.7 to perform more in-depth reasoning before providing a response. + +### 3.1 Create connector + +```json +POST /_plugins/_ml/connectors/_create +{ + "name": "Amazon Bedrock claude v3.7", + "description": "Test connector for Amazon Bedrock claude v3.7", + "version": 1, + "protocol": "aws_sigv4", + "credential": { + "access_key": "", + "secret_key": "", + "session_token": "" + }, + "parameters": { + "region": "", + "service_name": "bedrock", + "max_tokens": 8000, + "temperature": 1, + "anthropic_version": "bedrock-2023-05-31", + "model": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "budget_tokens": 1024, + "use_system_prompt": true + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "headers": { + "content-type": "application/json" + }, + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "request_body": "{ \"system\": \"${parameters.system_prompt}\", \"anthropic_version\": \"${parameters.anthropic_version}\", \"max_tokens\": ${parameters.max_tokens}, \"temperature\": ${parameters.temperature}, \"messages\": ${parameters.messages}, \"thinking\": {\"type\": \"enabled\", \"budget_tokens\": ${parameters.budget_tokens} } }" + } + ] +} +``` + +If using the AWS Opensearch Service, you can provide an IAM role arn that allows access to the bedrock service. +Refer to this [AWS doc](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/ml-amazon-connector.html) + +```json +POST /_plugins/_ml/connectors/_create +{ + "name": "Amazon Bedrock claude v3.7", + "description": "Test connector for Amazon Bedrock claude v3.7", + "version": 1, + "protocol": "aws_sigv4", + "credential": { + "roleArn": "" + }, + "parameters": { + "region": "", + "service_name": "bedrock", + "max_tokens": 8000, + "temperature": 1, + "anthropic_version": "bedrock-2023-05-31", + "model": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "budget_tokens": 1024, + "use_system_prompt": true + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "headers": { + "content-type": "application/json" + }, + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "request_body": "{ \"system\": \"${parameters.system_prompt}\", \"anthropic_version\": \"${parameters.anthropic_version}\", \"max_tokens\": ${parameters.max_tokens}, \"temperature\": ${parameters.temperature}, \"messages\": ${parameters.messages}, \"thinking\": {\"type\": \"enabled\", \"budget_tokens\": ${parameters.budget_tokens} } }" + } + ] +} +``` + +Sample response: +```json +{ + "connector_id": "1652P5UBX2k07okSys_J" +} +``` + +### 3.2 Register model: + +```json +POST /_plugins/_ml/models/_register?deploy=true +{ + "name": "anthropic.claude-v3.7", + "function_name": "remote", + "description": "claude v3.7 model", + "connector_id": "1652P5UBX2k07okSys_J" +} +``` + +Sample response: +```json +{ + "task_id": "5K53P5UBX2k07okSXc-7", + "status": "CREATED", + "model_id": "5a53P5UBX2k07okSXc_M" +} +``` + +### 3.3 Test model inference + +```json +POST /_plugins/_ml/models/5a53P5UBX2k07okSXc_M/_predict +{ + "parameters": { + "system_prompt": "You are a helpful assistant.", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "hello world" + } + ] + } + ] + } +} +``` + +Sample response: +```json +{ + "inference_results": [ + { + "output": [ + { + "name": "response", + "dataAsMap": { + "id": "msg_bdrk_01TqgZsyqsxhNGAGVjRjCP6N", + "type": "message", + "role": "assistant", + "model": "claude-3-7-sonnet-20250219", + "content": [ + { + "type": "thinking", + "thinking": "This is a simple greeting phrase \"hello world\" which is often the first program someone writes when learning a new programming language. The person could be:\n1. Simply greeting me casually\n2. Making a reference to programming\n3. Testing if I'm working\n\nI'll respond with a friendly greeting that acknowledges the \"hello world\" phrase and its connection to programming culture, while being conversational.", + "signature": "" + }, + { + "type": "text", + "text": "Hello! It's nice to meet you. \"Hello world\" is such a classic phrase - it's often the first program many people write when learning to code! How are you doing today? Is there something I can help you with?" + } + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 37.0, + "cache_creation_input_tokens": 0.0, + "cache_read_input_tokens": 0.0, + "output_tokens": 143.0 + } + } + } + ], + "status_code": 200 + } + ] +} +``` diff --git a/docs/remote_inference_blueprints/bedrock_connector_anthropic_claude4_use_system_prompt_blueprint.md b/docs/remote_inference_blueprints/bedrock_connector_anthropic_claude4_use_system_prompt_blueprint.md new file mode 100644 index 0000000000..163b6cefac --- /dev/null +++ b/docs/remote_inference_blueprints/bedrock_connector_anthropic_claude4_use_system_prompt_blueprint.md @@ -0,0 +1,346 @@ +# Bedrock connector blueprint example for Claude 4 models + +Anthropic's Claude 4 models are now available on Amazon Bedrock. For more details, check out this [blog](https://www.aboutamazon.com/news/aws/anthropic-claude-4-opus-sonnet-amazon-bedrock). + +Similar to Claude 3.7 Sonnet, Claude 4 offers both standard mode and [extended thinking mode](https://www.anthropic.com/news/visible-extended-thinking). Extended thinking mode directs the model to think more deeply about trickier questions by creating `thinking` content blocks for its internal reasoning. This also provides transparency into Claude's thought process before it delivers a final answer. + +This blueprint will cover both the standard mode and the extended thinking mode. + +## 1. Add connector endpoint to trusted URLs: + +Note: This step is only necessary for OpenSearch versions prior to 2.11.0. + +```json +PUT /_cluster/settings +{ + "persistent": { + "plugins.ml_commons.trusted_connector_endpoints_regex": [ + "^https://bedrock-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$" + ] + } +} +``` + +## 2. Standard mode + +If you would like to use the extended thinking mode, skip to [section 3](#section3). +### 2.1 Create connector + +If you are using self-managed Opensearch, you should supply AWS credentials: + +Note: +1. Users need to use an [inference profile](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html) to invoke this model. The profile IDs for Claude Sonnet 4 and Claude Opus 4 are `us.anthropic.claude-sonnet-4-20250514-v1:0` and `us.anthropic.claude-opus-4-20250514-v1:0` respectively, for three available US regions `us-east-1`, `us-east-2` and `us-west-2`. + +```json +POST /_plugins/_ml/connectors/_create +{ + "name": "Amazon Bedrock claude v4", + "description": "Test connector for Amazon Bedrock claude v4", + "version": 1, + "protocol": "aws_sigv4", + "credential": { + "access_key": "", + "secret_key": "", + "session_token": "" + }, + "parameters": { + "region": "", + "service_name": "bedrock", + "max_tokens": 8000, + "temperature": 1, + "anthropic_version": "bedrock-2023-05-31", + "model": "us.anthropic.claude-sonnet-4-20250514-v1:0", + "use_system_prompt": true + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "headers": { + "content-type": "application/json" + }, + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "request_body": "{ \"system\": \"${parameters.system_prompt}\", \"anthropic_version\": \"${parameters.anthropic_version}\", \"max_tokens\": ${parameters.max_tokens}, \"temperature\": ${parameters.temperature}, \"messages\": ${parameters.messages} }" + } + ] +} +``` + +If using AWS Opensearch Service, you can provide an IAM role arn that allows access to the bedrock service. +Refer to this [AWS doc](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/ml-amazon-connector.html) + +```json +POST /_plugins/_ml/connectors/_create +{ + "name": "Amazon Bedrock claude v4", + "description": "Test connector for Amazon Bedrock claude v4", + "version": 1, + "protocol": "aws_sigv4", + "credential": { + "roleArn": "" + }, + "parameters": { + "region": "", + "service_name": "bedrock", + "max_tokens": 8000, + "temperature": 1, + "anthropic_version": "bedrock-2023-05-31", + "model": "us.anthropic.claude-sonnet-4-20250514-v1:0", + "use_system_prompt": true + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "headers": { + "content-type": "application/json" + }, + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "request_body": "{ \"system\": \"${parameters.system_prompt}\", \"anthropic_version\": \"${parameters.anthropic_version}\", \"max_tokens\": ${parameters.max_tokens}, \"temperature\": ${parameters.temperature}, \"messages\": ${parameters.messages} }" + } + ] +} +``` + +Sample response: +```json +{ + "connector_id":"5kxp_5YBIvu8EdWRQuez" +} +``` + +### 2.2 Register model + +```json +POST /_plugins/_ml/models/_register?deploy=true +{ + "name": "anthropic.claude-v4", + "function_name": "remote", + "description": "claude v4 model", + "connector_id": "5kxp_5YBIvu8EdWRQuez" +} +``` + +Sample response: +```json +{ + "task_id":"6kxq_5YBIvu8EdWRwedJ", + "status":"CREATED", + "model_id":"7Exq_5YBIvu8EdWRwefI" +} +``` + +### 2.3 Test model inference + +```json +POST /_plugins/_ml/models/7Exq_5YBIvu8EdWRwefI/_predict +{ + "parameters": { + "system_prompt": "You are a helpful assistant.", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "hello world" + } + ] + } + ] + } +} +``` + +Sample response: +```json +{ + "inference_results": [{ + "output": [{ + "name": "response", + "dataAsMap": { + "id": "msg_bdrk_017wv2bnUmKroe7C48MHdu32", + "type": "message", + "role": "assistant", + "model": "claude-sonnet-4-20250514", + "content": [{ + "type": "text", + "text": "Hello! Nice to meet you. How are you doing today? Is there anything I can help you with?" + }], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 9.0, + "cache_creation_input_tokens": 0.0, + "cache_read_input_tokens": 0.0, + "output_tokens": 25.0 + } + } + }], + "status_code": 200 + }] +} +``` + +## 3. Extended thinking mode + +Extended thinking mode allows Claude 4 to perform more in-depth reasoning before providing a response. Note that `budget_tokens` can be specified in parameters, which determines the number of tokens Claude can use for its internal reasoning process. See Claude [documentation](https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#how-to-use-extended-thinking) for more details. + +### 3.1 Create connector + +```json +POST /_plugins/_ml/connectors/_create +{ + "name": "Amazon Bedrock claude v4", + "description": "Test connector for Amazon Bedrock claude v4", + "version": 1, + "protocol": "aws_sigv4", + "credential": { + "access_key": "", + "secret_key": "", + "session_token": "" + }, + "parameters": { + "region": "", + "service_name": "bedrock", + "max_tokens": 8000, + "temperature": 1, + "anthropic_version": "bedrock-2023-05-31", + "model": "us.anthropic.claude-sonnet-4-20250514-v1:0", + "budget_tokens": 1024, + "use_system_prompt": true + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "headers": { + "content-type": "application/json" + }, + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "request_body": "{ \"system\": \"${parameters.system_prompt}\", \"anthropic_version\": \"${parameters.anthropic_version}\", \"max_tokens\": ${parameters.max_tokens}, \"temperature\": ${parameters.temperature}, \"messages\": ${parameters.messages}, \"thinking\": {\"type\": \"enabled\", \"budget_tokens\": ${parameters.budget_tokens} } }" + } + ] +} +``` + +If using the AWS Opensearch Service, you can provide an IAM role arn that allows access to the bedrock service. +Refer to this [AWS doc](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/ml-amazon-connector.html) + +```json +POST /_plugins/_ml/connectors/_create +{ + "name": "Amazon Bedrock claude v4", + "description": "Test connector for Amazon Bedrock claude v4", + "version": 1, + "protocol": "aws_sigv4", + "credential": { + "roleArn": "" + }, + "parameters": { + "region": "", + "service_name": "bedrock", + "max_tokens": 8000, + "temperature": 1, + "anthropic_version": "bedrock-2023-05-31", + "model": "us.anthropic.claude-sonnet-4-20250514-v1:0", + "budget_tokens": 1024, + "use_system_prompt": true + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "headers": { + "content-type": "application/json" + }, + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "request_body": "{ \"system\": \"${parameters.system_prompt}\", \"anthropic_version\": \"${parameters.anthropic_version}\", \"max_tokens\": ${parameters.max_tokens}, \"temperature\": ${parameters.temperature}, \"messages\": ${parameters.messages}, \"thinking\": {\"type\": \"enabled\", \"budget_tokens\": ${parameters.budget_tokens} } }" + } + ] +} +``` + +Sample response: +```json +{ + "connector_id":"DEx5_5YBIvu8EdWRTOiq" +} +``` + +### 3.2 Register model: + +```json +POST /_plugins/_ml/models/_register?deploy=true +{ + "name": "anthropic.claude-v4", + "function_name": "remote", + "description": "claude v4 model with extended thinking", + "connector_id": "DEx5_5YBIvu8EdWRTOiq" +} +``` + +Sample response: +```json +{ + "task_id":"DUx6_5YBIvu8EdWRLuj1", + "status":"CREATED", + "model_id":"Dkx6_5YBIvu8EdWRL-gO" +} +``` + +### 3.3 Test model inference + +```json +POST /_plugins/_ml/models/Dkx6_5YBIvu8EdWRL-gO/_predict +{ + "parameters": { + "system_prompt": "You are a helpful assistant.", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "hello world" + } + ] + } + ] + } +} +``` + +Sample response: +```json +{ + "inference_results": [{ + "output": [{ + "name": "response", + "dataAsMap": { + "id": "msg_bdrk_0117MNj2HVP7dXmeDGeCSQaL", + "type": "message", + "role": "assistant", + "model": "claude-sonnet-4-20250514", + "content": [{ + "type": "thinking", + "thinking": "The user has sent me a simple \"hello world\" message. This is a classic, friendly greeting that's often used as a first program or test message in programming and casual conversation. I should respond in a warm, welcoming way.", + "signature": "" + }, { + "type": "text", + "text": "Hello! It's nice to meet you. How are you doing today? Is there anything I can help you with?" + }], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 37.0, + "cache_creation_input_tokens": 0.0, + "cache_read_input_tokens": 0.0, + "output_tokens": 84.0 + } + } + }], + "status_code": 200 + }] +} +``` diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java index 592d4dba2f..2e43260381 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java @@ -15,7 +15,6 @@ import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SESSION_SUMMARY_PROMPT; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SUMMARY_FACTS_EXTRACTION_PROMPT; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.USER_PREFERENCE_FACTS_EXTRACTION_PROMPT; -import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.USER_PREFERENCE_JSON_ENFORCEMENT_MESSAGE; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import java.io.IOException; @@ -23,14 +22,11 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import org.opensearch.OpenSearchException; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; @@ -47,9 +43,13 @@ import org.opensearch.ml.common.transport.memorycontainer.memory.MessageInput; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.LlmResultPathGenerator; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.common.utils.message.MessageFormatter; +import org.opensearch.ml.common.utils.message.MessageFormatterFactory; import org.opensearch.ml.engine.processor.MLProcessorType; import org.opensearch.ml.engine.processor.ProcessorChain; +import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.transport.client.Client; import com.jayway.jsonpath.JsonPath; @@ -59,14 +59,21 @@ @Log4j2 public class MemoryProcessingService { - public static final String DEFAULT_LLM_RESULT_PATH = "$.content[0].text"; + /** + * Smart fallback path for Claude models when auto-generation fails. + * This path works with all Claude models (v2, v3, Sonnet 4, etc.) that use the Messages API. + */ + public static final String CLAUDE_SYSTEM_PROMPT_PATH = "$.content[0].text"; + private final Client client; private final NamedXContentRegistry xContentRegistry; private final ProcessorChain extractJsonProcessorChain; + private final MLModelCacheHelper modelCacheHelper; - public MemoryProcessingService(Client client, NamedXContentRegistry xContentRegistry) { + public MemoryProcessingService(Client client, NamedXContentRegistry xContentRegistry, MLModelCacheHelper modelCacheHelper) { this.client = client; this.xContentRegistry = xContentRegistry; + this.modelCacheHelper = modelCacheHelper; List> processorConfigs = new ArrayList<>(); processorConfigs.add(Map.of("type", MLProcessorType.EXTRACT_JSON.getValue(), "extract_type", "object")); this.extractJsonProcessorChain = new ProcessorChain(processorConfigs); @@ -111,75 +118,24 @@ public void extractFactsFromConversation( isOverride ? "strategy-override" : "container-config" ); - Map stringParameters = new HashMap<>(); - - // Determine default prompt based on strategy type - String defaultPrompt; - MemoryStrategyType type = strategy.getType(); - if (type == MemoryStrategyType.USER_PREFERENCE) { - defaultPrompt = USER_PREFERENCE_FACTS_EXTRACTION_PROMPT; - } else if (type == MemoryStrategyType.SUMMARY) { - defaultPrompt = SUMMARY_FACTS_EXTRACTION_PROMPT; - } else { - defaultPrompt = SEMANTIC_FACTS_EXTRACTION_PROMPT; - } + // Get appropriate formatter for the model + MessageFormatter formatter = getFormatterForModel(llmModelId); + log.debug("Using {} for model {}", formatter.getClass().getSimpleName(), llmModelId); - if (strategy.getStrategyConfig() == null || strategy.getStrategyConfig().isEmpty()) { - stringParameters.put("system_prompt", defaultPrompt); - } else { - Object customPrompt = strategy.getStrategyConfig().get("system_prompt"); - if (customPrompt == null || customPrompt.toString().isBlank()) { - stringParameters.put("system_prompt", defaultPrompt); - } else if (!validatePromptFormat(customPrompt.toString())) { - log.error("Invalid custom prompt format - must specify JSON response format with 'facts' array"); - listener.onFailure(new IllegalArgumentException("Custom prompt must specify JSON response format with 'facts' array")); - return; - } else { - stringParameters.put("system_prompt", customPrompt.toString()); // use custom strategy - } + // Determine system prompt (default or custom) + String systemPrompt = determineSystemPrompt(strategy, listener); + if (systemPrompt == null) { + // Validation failed, listener already called + return; } - try { - XContentBuilder messagesBuilder = jsonXContent.contentBuilder(); - messagesBuilder.startArray(); - Map strategyConfig = strategy.getStrategyConfig(); - if (strategyConfig != null && strategyConfig.containsKey("system_prompt_message")) { - Object systemPromptMsg = strategyConfig.get("system_prompt_message"); - if (systemPromptMsg != null && systemPromptMsg instanceof Map) { - messagesBuilder.map((Map) systemPromptMsg); - } - } - for (MessageInput message : messages) { - message.toXContent(messagesBuilder, ToXContent.EMPTY_PARAMS); - } - if (strategyConfig != null && strategyConfig.containsKey("user_prompt_message")) { - Object userPromptMsg = strategyConfig.get("user_prompt_message"); - if (userPromptMsg != null && userPromptMsg instanceof Map) { - messagesBuilder.map((Map) userPromptMsg); - } - } else { // Add default user prompt (when strategyConfig is null or doesn't have user_prompt_message) - MessageInput message = getMessageInput("Please extract information from our conversation so far"); - message.toXContent(messagesBuilder, ToXContent.EMPTY_PARAMS); - } - - // Always add JSON enforcement message for fact extraction - String enforcementMsg = (strategy.getType() == MemoryStrategyType.USER_PREFERENCE) - ? USER_PREFERENCE_JSON_ENFORCEMENT_MESSAGE - : JSON_ENFORCEMENT_MESSAGE; - MessageInput enforcementMessage = getMessageInput(enforcementMsg); - enforcementMessage.toXContent(messagesBuilder, ToXContent.EMPTY_PARAMS); + // Build full message list with additional messages from config + List allMessages = buildFullMessageList(messages, strategy); - messagesBuilder.endArray(); - String messagesJson = messagesBuilder.toString(); - stringParameters.put("messages", messagesJson); - - log.debug("LLM request - processing {} messages", messages.size()); - } catch (Exception e) { - log.error("Failed to build messages JSON", e); - listener.onResponse(new ArrayList<>()); - return; - } + // Format the request using the appropriate formatter + Map stringParameters = formatter.formatRequest(systemPrompt, allMessages, strategy.getStrategyConfig()); + // Continue with existing MLInput building and execution MLInput mlInput = MLInput .builder() .algorithm(FunctionName.REMOTE) @@ -192,7 +148,7 @@ public void extractFactsFromConversation( try { log.debug("Received LLM response, parsing facts..."); MLOutput mlOutput = response.getOutput(); - List facts = parseFactsFromLLMResponse(strategy, mlOutput); + List facts = parseFactsFromLLMResponse(strategy, mlOutput, llmModelId); log.debug("Extracted {} facts from LLM response", facts.size()); listener.onResponse(facts); } catch (Exception e) { @@ -238,6 +194,10 @@ public void makeMemoryDecisions( allSearchResults.size() ); + // Get appropriate formatter for the model + MessageFormatter formatter = getFormatterForModel(llmModelId); + log.debug("Using {} for model {}", formatter.getClass().getSimpleName(), llmModelId); + List oldMemories = new ArrayList<>(); for (FactSearchResult result : allSearchResults) { oldMemories @@ -250,60 +210,45 @@ public void makeMemoryDecisions( .retrievedFacts(extractedFacts) .build(); - Map stringParameters = new HashMap<>(); - stringParameters.put("system_prompt", DEFAULT_UPDATE_MEMORY_PROMPT); - String decisionRequestJson = decisionRequest.toJsonString(); - try { - XContentBuilder messagesBuilder = jsonXContent.contentBuilder(); - messagesBuilder.startArray(); - messagesBuilder.startObject(); - messagesBuilder.field("role", "user"); - messagesBuilder.startArray("content"); - messagesBuilder.startObject(); - messagesBuilder.field("type", "text"); - messagesBuilder.field("text", decisionRequestJson); - messagesBuilder.endObject(); - messagesBuilder.endArray(); - messagesBuilder.endObject(); - messagesBuilder.endArray(); - - String messagesJson = messagesBuilder.toString(); - stringParameters.put("messages", messagesJson); + // Create user message with decision request JSON + List messages = new ArrayList<>(); + List> content = new ArrayList<>(); + content.add(Map.of("type", "text", "text", decisionRequestJson)); + messages.add(MessageInput.builder().role("user").content(content).build()); - log - .debug( - "Making memory decisions for {} extracted facts and {} existing memories", - extractedFacts.size(), - allSearchResults.size() - ); + // Format the request using the appropriate formatter + Map stringParameters = formatter.formatRequest(DEFAULT_UPDATE_MEMORY_PROMPT, messages, null); - RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(stringParameters).build(); - MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(); + log + .debug( + "Making memory decisions for {} extracted facts and {} existing memories", + extractedFacts.size(), + allSearchResults.size() + ); - MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder().modelId(llmModelId).mlInput(mlInput).build(); + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(stringParameters).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(); - client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, ActionListener.wrap(response -> { - try { - List decisions = parseMemoryDecisions(response); - log.debug("LLM made {} memory decisions", decisions.size()); - listener.onResponse(decisions); - } catch (Exception e) { - log.error("Failed to parse memory decisions from LLM response", e); - listener.onFailure(e); - } - }, e -> { - log.error("Failed to get memory decisions from LLM", e); + MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder().modelId(llmModelId).mlInput(mlInput).build(); + + client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, ActionListener.wrap(response -> { + try { + List decisions = parseMemoryDecisions(response, llmModelId); + log.debug("LLM made {} memory decisions", decisions.size()); + listener.onResponse(decisions); + } catch (Exception e) { + log.error("Failed to parse memory decisions from LLM response", e); listener.onFailure(e); - })); - } catch (Exception e) { - log.error("Failed to build memory decision request", e); + } + }, e -> { + log.error("Failed to get memory decisions from LLM", e); listener.onFailure(e); - } + })); } - private List parseFactsFromLLMResponse(MemoryStrategy strategy, MLOutput mlOutput) { + private List parseFactsFromLLMResponse(MemoryStrategy strategy, MLOutput mlOutput, String modelId) { List facts = new ArrayList<>(); if (!(mlOutput instanceof ModelTensorOutput)) { @@ -325,12 +270,20 @@ private List parseFactsFromLLMResponse(MemoryStrategy strategy, MLOutput for (int i = 0; i < modelTensors.getMlModelTensors().size(); i++) { Map dataMap = modelTensors.getMlModelTensors().get(i).getDataAsMap(); - String llmResultPath = Optional - .ofNullable(strategy.getStrategyConfig()) - .map(config -> config.get("llm_result_path")) - .map(Object::toString) - .orElse(DEFAULT_LLM_RESULT_PATH); - Object filterdResult = JsonPath.read(dataMap, llmResultPath); + + // Auto-generate the result path based on model's output schema + String llmResultPath = getAutoGeneratedLlmResultPath(modelId); + log.debug("Using llm_result_path: {}", llmResultPath); + + // Gracefully handle malformed responses + Object filterdResult = null; + try { + filterdResult = JsonPath.read(dataMap, llmResultPath); + } catch (Exception e) { + log.warn("Failed to extract LLM result from path '{}': {}. Skipping tensor {}", llmResultPath, e.getMessage(), i); + continue; + } + String llmResult = null; if (filterdResult != null) { llmResult = StringUtils.toJson(filterdResult); @@ -362,7 +315,7 @@ private List parseFactsFromLLMResponse(MemoryStrategy strategy, MLOutput return facts; } - private List parseMemoryDecisions(MLTaskResponse response) { + private List parseMemoryDecisions(MLTaskResponse response, String modelId) { try { MLOutput mlOutput = response.getOutput(); if (!(mlOutput instanceof ModelTensorOutput)) { @@ -377,21 +330,32 @@ private List parseMemoryDecisions(MLTaskResponse response) { Map dataMap = tensors.get(0).getMlModelTensors().get(0).getDataAsMap(); - String responseContent = null; - if (dataMap.containsKey("response")) { - responseContent = (String) dataMap.get("response"); - } else if (dataMap.containsKey("content")) { - List> contentList = (List>) dataMap.get("content"); - if (contentList != null && !contentList.isEmpty()) { - Map firstContent = contentList.get(0); - responseContent = (String) firstContent.get("text"); - } + // Auto-generate the result path based on model's output schema (same as fact extraction) + String llmResultPath = getAutoGeneratedLlmResultPath(modelId); + log.debug("Using llm_result_path: {}", llmResultPath); + + // Extract response content using the auto-generated path + Object filteredResult = null; + try { + filteredResult = JsonPath.read(dataMap, llmResultPath); + } catch (Exception e) { + log + .error( + "Failed to extract response from path '{}': {}. Available keys: {}", + llmResultPath, + e.getMessage(), + dataMap.keySet() + ); + throw new IllegalStateException("Failed to extract response content from LLM output using path: " + llmResultPath, e); } - if (responseContent == null) { - throw new IllegalStateException("No response content found in LLM output"); + if (filteredResult == null) { + throw new IllegalStateException("Extracted null content from LLM output using path: " + llmResultPath); } + // Convert to JSON string + String responseContent = StringUtils.toJson(filteredResult); + // Clean response content responseContent = cleanMarkdownFromJson(responseContent); @@ -423,64 +387,79 @@ private List parseMemoryDecisions(MLTaskResponse response) { public void summarizeMessages(MemoryConfiguration configuration, List messages, ActionListener listener) { if (messages == null || messages.isEmpty()) { listener.onResponse(""); - } else { - Map stringParameters = new HashMap<>(); - Map memoryParams = configuration.getParameters(); - String llmResultPath = (String) memoryParams.getOrDefault("llm_result_path", DEFAULT_LLM_RESULT_PATH); - Map sessionParams = (Map) memoryParams.getOrDefault("session", new HashMap<>()); - stringParameters.put("system_prompt", SESSION_SUMMARY_PROMPT); - stringParameters.putAll(getParameterMap(sessionParams)); - stringParameters.putIfAbsent("max_summary_size", "10"); + return; + } + + String llmModelId = configuration.getLlmId(); + if (llmModelId == null) { + log.error("LLM model is required for summarization but not configured"); + listener.onFailure(new IllegalStateException("LLM model is required for summarization")); + return; + } + + // Get appropriate formatter for the model + MessageFormatter formatter = getFormatterForModel(llmModelId); + log.debug("Using {} for model {}", formatter.getClass().getSimpleName(), llmModelId); + + Map memoryParams = configuration.getParameters(); + Map sessionParams = (Map) memoryParams.getOrDefault("session", new HashMap<>()); + + // Prepare additional parameters (non-formatter params) + Map additionalParams = new HashMap<>(); + additionalParams.putAll(getParameterMap(sessionParams)); + additionalParams.putIfAbsent("max_summary_size", "10"); + // Build message list + List allMessages = new ArrayList<>(messages); + + // Add user prompt if not in config + if (!sessionParams.containsKey("user_prompt_message")) { + MessageInput summaryPrompt = getMessageInput( + "Please summarize our conversation, not exceed " + additionalParams.get("max_summary_size") + " words" + ); + allMessages.add(summaryPrompt); + } + + // Format the request using the appropriate formatter + Map formatterParams = formatter.formatRequest(SESSION_SUMMARY_PROMPT, allMessages, null); + + // Combine formatter params with additional params + Map stringParameters = new HashMap<>(); + stringParameters.putAll(formatterParams); // system_prompt and messages from formatter + stringParameters.putAll(additionalParams); // max_summary_size, etc. + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(stringParameters).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(); + MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder().modelId(llmModelId).mlInput(mlInput).build(); + + client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, ActionListener.wrap(response -> { try { - XContentBuilder messagesBuilder = jsonXContent.contentBuilder(); - messagesBuilder.startArray(); - for (MessageInput message : messages) { - message.toXContent(messagesBuilder, ToXContent.EMPTY_PARAMS); - } - if (sessionParams.containsKey("user_prompt_message")) { - Object userPromptMsg = sessionParams.get("user_prompt_message"); - if (userPromptMsg != null && userPromptMsg instanceof Map) { - messagesBuilder.map((Map) userPromptMsg); - } - } else { - MessageInput message = getMessageInput( - "Please summarize our conversation, not exceed " + stringParameters.get("max_summary_size") + " words" - ); - message.toXContent(messagesBuilder, ToXContent.EMPTY_PARAMS); - } - messagesBuilder.endArray(); - stringParameters.put("messages", messagesBuilder.toString()); - - RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(stringParameters).build(); - MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(); - MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest - .builder() - .modelId(configuration.getLlmId()) - .mlInput(mlInput) - .build(); - - client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, ActionListener.wrap(response -> { - try { - String summary = parseSessionSummary((ModelTensorOutput) response.getOutput(), llmResultPath); - listener.onResponse(summary); - } catch (Exception e) { - log.error("Failed to parse memory decisions from LLM response", e); - listener.onFailure(e); - } - }, e -> { - log.error("Failed to get memory decisions from LLM", e); - listener.onFailure(e); - })); + String summary = parseSessionSummary((ModelTensorOutput) response.getOutput(), llmModelId); + listener.onResponse(summary); } catch (Exception e) { + log.error("Failed to parse session summary from LLM response", e); listener.onFailure(e); } - } + }, e -> { + log.error("Failed to get session summary from LLM", e); + listener.onFailure(e); + })); } - private String parseSessionSummary(ModelTensorOutput output, String llmResultPath) { + private String parseSessionSummary(ModelTensorOutput output, String modelId) { Map dataAsMap = output.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap(); - Object filterdResult = JsonPath.read(dataAsMap, llmResultPath); + // Auto-generate the result path based on model's output schema + String llmResultPath = getAutoGeneratedLlmResultPath(modelId); + + // Gracefully handle malformed responses + Object filterdResult = null; + try { + filterdResult = JsonPath.read(dataAsMap, llmResultPath); + } catch (Exception e) { + log.warn("Failed to extract session summary from path '{}': {}. Returning empty summary", llmResultPath, e.getMessage()); + return ""; + } + String sessionSummary = null; if (filterdResult != null) { sessionSummary = StringUtils.toJson(filterdResult); @@ -518,6 +497,99 @@ else if (response.startsWith("```") && response.endsWith("```")) { return response; } + /** + * Get appropriate message formatter for a model based on its input schema. + * + * @param modelId The model ID to get formatter for + * @return MessageFormatter instance (never null, defaults to Claude) + */ + private MessageFormatter getFormatterForModel(String modelId) { + if (modelId == null) { + log.debug("No model ID provided, using default Claude formatter"); + return MessageFormatterFactory.getClaudeFormatter(); + } + + try { + Map modelInterface = modelCacheHelper.getModelInterface(modelId); + + if (modelInterface != null && modelInterface.containsKey("input")) { + String inputSchema = modelInterface.get("input"); + log.debug("Retrieved input schema for model {}", modelId); + return MessageFormatterFactory.getFormatter(inputSchema); + } + + log.debug("No input schema found for model {}, using default Claude formatter", modelId); + } catch (Exception e) { + log.warn("Failed to get formatter for model {}, using default: {}", modelId, e.getMessage()); + } + + return MessageFormatterFactory.getClaudeFormatter(); + } + + /** + * Determine the system prompt to use (default or custom). + * + * @param strategy The memory strategy that may contain custom prompt + * @param listener The listener to call on validation failure + * @return The system prompt to use, or null if validation failed (listener already called) + */ + private String determineSystemPrompt(MemoryStrategy strategy, ActionListener> listener) { + // Determine default prompt based on strategy type + String defaultPrompt; + MemoryStrategyType type = strategy.getType(); + if (type == MemoryStrategyType.USER_PREFERENCE) { + defaultPrompt = USER_PREFERENCE_FACTS_EXTRACTION_PROMPT; + } else if (type == MemoryStrategyType.SUMMARY) { + defaultPrompt = SUMMARY_FACTS_EXTRACTION_PROMPT; + } else { + defaultPrompt = SEMANTIC_FACTS_EXTRACTION_PROMPT; + } + + // Check for custom prompt in strategy config + if (strategy.getStrategyConfig() == null || strategy.getStrategyConfig().isEmpty()) { + return defaultPrompt; + } + + Object customPrompt = strategy.getStrategyConfig().get("system_prompt"); + if (customPrompt == null || customPrompt.toString().isBlank()) { + return defaultPrompt; + } + + // Validate custom prompt format + if (!validatePromptFormat(customPrompt.toString())) { + log.error("Invalid custom prompt format - must specify JSON response format with 'facts' array"); + listener.onFailure(new IllegalArgumentException("Custom prompt must specify JSON response format with 'facts' array")); + return null; + } + + return customPrompt.toString(); + } + + /** + * Build full message list including additional messages from config. + * + * @param messages Original messages from conversation + * @param strategy Memory strategy that may contain additional messages + * @return Full list of messages including any additional config messages + */ + private List buildFullMessageList(List messages, MemoryStrategy strategy) { + List allMessages = new ArrayList<>(messages); + + Map strategyConfig = strategy.getStrategyConfig(); + + // Add user prompt message if not in config + if (strategyConfig == null || !strategyConfig.containsKey("user_prompt_message")) { + MessageInput userPrompt = getMessageInput("Please extract information from our conversation so far"); + allMessages.add(userPrompt); + } + + // Always add JSON enforcement message + MessageInput enforcementMessage = getMessageInput(JSON_ENFORCEMENT_MESSAGE); + allMessages.add(enforcementMessage); + + return allMessages; + } + /** * Get the effective LLM model ID, checking strategy override first, then falling back to memory config. * @@ -537,4 +609,65 @@ private String getEffectiveLlmId(MemoryStrategy strategy, MemoryConfiguration me // Fall back to memory config return memoryConfig != null ? memoryConfig.getLlmId() : null; } + + /** + * Auto-generates the JSONPath for extracting LLM output from model responses. + * Uses the model's output schema to intelligently locate the text field. + * + * Algorithm: + * 1. If modelId is null → return Claude default path + * 2. Retrieve model's output schema from cache + * 3. Use LlmResultPathGenerator to find x-llm-output marked field + * 4. On any failure → return Claude default path (smart fallback) + * + * Examples: + * - OpenAI models: $.choices[0].message.content + * - Claude models: $.content[0].text + * - Unknown/missing schema: $.content[0].text (safe default) + * + * @param modelId The model ID to generate path for (can be null) + * @return JSONPath string for extracting LLM text output + */ + private String getAutoGeneratedLlmResultPath(String modelId) { + // No model specified → use Claude default + if (modelId == null) { + log.debug("No modelId provided, using Claude default path: {}", CLAUDE_SYSTEM_PROMPT_PATH); + return CLAUDE_SYSTEM_PROMPT_PATH; + } + + try { + // Try to get model's output schema from cache + Map modelInterface = modelCacheHelper.getModelInterface(modelId); + + if (modelInterface != null && modelInterface.containsKey("output")) { + String outputSchema = modelInterface.get("output"); + + // Use LlmResultPathGenerator to find the LLM output field + String generatedPath = LlmResultPathGenerator.generate(outputSchema); + + if (generatedPath != null && !generatedPath.isBlank()) { + log.debug("Auto-generated llm_result_path for model {}: {}", modelId, generatedPath); + return generatedPath; + } + } + + log + .debug( + "Could not auto-generate path for model {} (schema not found or generation failed), using Claude default: {}", + modelId, + CLAUDE_SYSTEM_PROMPT_PATH + ); + } catch (Exception e) { + log + .warn( + "Failed to auto-generate llm_result_path for model {}: {}. Using Claude default: {}", + modelId, + e.getMessage(), + CLAUDE_SYSTEM_PROMPT_PATH + ); + } + + // Smart fallback: Claude path works for most models + return CLAUDE_SYSTEM_PROMPT_PATH; + } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesAction.java index 380e97f524..b6829f5607 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesAction.java @@ -52,6 +52,7 @@ import org.opensearch.ml.common.transport.memorycontainer.memory.MemoryResult; import org.opensearch.ml.common.transport.memorycontainer.memory.MessageInput; import org.opensearch.ml.helper.MemoryContainerHelper; +import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; @@ -86,6 +87,7 @@ public TransportAddMemoriesAction( NamedXContentRegistry xContentRegistry, MLFeatureEnabledSetting mlFeatureEnabledSetting, MemoryContainerHelper memoryContainerHelper, + MLModelCacheHelper mlModelCacheHelper, ThreadPool threadPool ) { super(MLAddMemoriesAction.NAME, transportService, actionFilters, MLAddMemoriesRequest::new); @@ -94,7 +96,7 @@ public TransportAddMemoriesAction( this.memoryContainerHelper = memoryContainerHelper; // Initialize services - this.memoryProcessingService = new MemoryProcessingService(client, xContentRegistry); + this.memoryProcessingService = new MemoryProcessingService(client, xContentRegistry, mlModelCacheHelper); this.memorySearchService = new MemorySearchService(memoryContainerHelper); this.memoryOperationsService = new MemoryOperationsService(memoryContainerHelper); this.threadPool = threadPool; @@ -261,7 +263,7 @@ private void extractLongTermMemory( ActionListener actionListener ) { List messages = input.getMessages(); - log.debug("Processing {} messages for fact extraction", messages.size()); + log.debug("Processing {} messages for fact extraction", messages != null ? messages.size() : 0); List strategies = container.getConfiguration().getStrategies(); MemoryConfiguration memoryConfig = container.getConfiguration(); diff --git a/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingServiceAdditionalTests.java b/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingServiceAdditionalTests.java index 16c34f691d..1c89cf6ac6 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingServiceAdditionalTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingServiceAdditionalTests.java @@ -34,6 +34,7 @@ import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.memorycontainer.memory.MessageInput; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.transport.client.Client; public class MemoryProcessingServiceAdditionalTests { @@ -44,6 +45,9 @@ public class MemoryProcessingServiceAdditionalTests { @Mock private NamedXContentRegistry xContentRegistry; + @Mock + private MLModelCacheHelper mlModelCacheHelper; + @Mock private ActionListener> factsListener; @@ -58,7 +62,7 @@ public class MemoryProcessingServiceAdditionalTests { public void setup() { MockitoAnnotations.openMocks(this); memoryStrategy = new MemoryStrategy("id", true, MemoryStrategyType.SEMANTIC, Arrays.asList("user_id"), new HashMap<>()); - memoryProcessingService = new MemoryProcessingService(client, xContentRegistry); + memoryProcessingService = new MemoryProcessingService(client, xContentRegistry, mlModelCacheHelper); } @Test @@ -182,7 +186,6 @@ public void testExtractFactsFromConversation_WithProcessorChainExtractingJson() when(storageConfig.getLlmId()).thenReturn("llm-model-123"); MemoryStrategy strategy = new MemoryStrategy("id", true, MemoryStrategyType.SEMANTIC, Arrays.asList("user_id"), new HashMap<>()); - strategy.getStrategyConfig().put("llm_result_path", "$.content[0].text"); MLTaskResponse mockResponse = mock(MLTaskResponse.class); ModelTensorOutput mockOutput = mock(ModelTensorOutput.class); @@ -219,7 +222,6 @@ public void testExtractFactsFromConversation_WithProcessorChainExtractingNestedJ when(storageConfig.getLlmId()).thenReturn("llm-model-123"); MemoryStrategy strategy = new MemoryStrategy("id", true, MemoryStrategyType.SEMANTIC, Arrays.asList("user_id"), new HashMap<>()); - strategy.getStrategyConfig().put("llm_result_path", "$.content[0].text"); MLTaskResponse mockResponse = mock(MLTaskResponse.class); ModelTensorOutput mockOutput = mock(ModelTensorOutput.class); @@ -256,7 +258,6 @@ public void testExtractFactsFromConversation_WithProcessorChainMultipleTensors() when(storageConfig.getLlmId()).thenReturn("llm-model-123"); MemoryStrategy strategy = new MemoryStrategy("id", true, MemoryStrategyType.SEMANTIC, Arrays.asList("user_id"), new HashMap<>()); - strategy.getStrategyConfig().put("llm_result_path", "$.content[0].text"); MLTaskResponse mockResponse = mock(MLTaskResponse.class); ModelTensorOutput mockOutput = mock(ModelTensorOutput.class); @@ -299,7 +300,6 @@ public void testExtractFactsFromConversation_WithProcessorChainCleaningMarkdown( when(storageConfig.getLlmId()).thenReturn("llm-model-123"); MemoryStrategy strategy = new MemoryStrategy("id", true, MemoryStrategyType.SEMANTIC, Arrays.asList("user_id"), new HashMap<>()); - strategy.getStrategyConfig().put("llm_result_path", "$.content[0].text"); MLTaskResponse mockResponse = mock(MLTaskResponse.class); ModelTensorOutput mockOutput = mock(ModelTensorOutput.class); @@ -337,7 +337,6 @@ public void testExtractFactsFromConversation_WithProcessorChainEmptyFacts() { when(storageConfig.getLlmId()).thenReturn("llm-model-123"); MemoryStrategy strategy = new MemoryStrategy("id", true, MemoryStrategyType.SEMANTIC, Arrays.asList("user_id"), new HashMap<>()); - strategy.getStrategyConfig().put("llm_result_path", "$.content[0].text"); MLTaskResponse mockResponse = mock(MLTaskResponse.class); ModelTensorOutput mockOutput = mock(ModelTensorOutput.class); @@ -374,7 +373,6 @@ public void testExtractFactsFromConversation_WithProcessorChainComplexJsonStruct when(storageConfig.getLlmId()).thenReturn("llm-model-123"); MemoryStrategy strategy = new MemoryStrategy("id", true, MemoryStrategyType.SEMANTIC, Arrays.asList("user_id"), new HashMap<>()); - strategy.getStrategyConfig().put("llm_result_path", "$.content[0].text"); MLTaskResponse mockResponse = mock(MLTaskResponse.class); ModelTensorOutput mockOutput = mock(ModelTensorOutput.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingServiceTests.java b/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingServiceTests.java index 0ee9ad4351..62d00c9722 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingServiceTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingServiceTests.java @@ -44,6 +44,7 @@ import org.opensearch.ml.common.transport.memorycontainer.memory.MessageInput; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.transport.client.Client; public class MemoryProcessingServiceTests { @@ -54,6 +55,9 @@ public class MemoryProcessingServiceTests { @Mock private NamedXContentRegistry xContentRegistry; + @Mock + private MLModelCacheHelper mlModelCacheHelper; + @Mock private ActionListener> factsListener; @@ -72,8 +76,7 @@ public class MemoryProcessingServiceTests { public void setup() { MockitoAnnotations.openMocks(this); memoryStrategy = new MemoryStrategy("id", true, MemoryStrategyType.SEMANTIC, Arrays.asList("user_id"), new HashMap<>()); - memoryStrategy.getStrategyConfig().put("llm_result_path", "$"); - memoryProcessingService = new MemoryProcessingService(client, xContentRegistry); + memoryProcessingService = new MemoryProcessingService(client, xContentRegistry, mlModelCacheHelper); testContent = createTestContent("Hello"); when(memoryConfig.getParameters()).thenReturn(new HashMap<>()); } @@ -290,7 +293,6 @@ public void testExtractFactsFromConversation_ParseException() { Arrays.asList("user_id"), new HashMap<>() ); - memoryStrategy.getStrategyConfig().put("llm_result_path", "$.content[0].text"); memoryProcessingService.extractFactsFromConversation(messages, memoryStrategy, storageConfig, factsListener); @@ -547,7 +549,10 @@ public void testMakeMemoryDecisions_JsonCodeBlock() { ModelTensor mockTensor = mock(ModelTensor.class); Map dataMap = new HashMap<>(); - dataMap.put("response", "```json\n{\"memory_decisions\": []}\n```"); + // Use Claude format with content array + Map contentItem = new HashMap<>(); + contentItem.put("text", "```json\n{\"memory_decisions\": []}\n```"); + dataMap.put("content", Arrays.asList(contentItem)); when(mockResponse.getOutput()).thenReturn(mockOutput); when(mockOutput.getMlModelOutputs()).thenReturn(Arrays.asList(mockTensors)); @@ -578,7 +583,10 @@ public void testMakeMemoryDecisions_PlainCodeBlock() { ModelTensor mockTensor = mock(ModelTensor.class); Map dataMap = new HashMap<>(); - dataMap.put("response", "```\n{\"memory_decisions\": []}\n```"); + // Use Claude format with content array + Map contentItem = new HashMap<>(); + contentItem.put("text", "```\n{\"memory_decisions\": []}\n```"); + dataMap.put("content", Arrays.asList(contentItem)); when(mockResponse.getOutput()).thenReturn(mockOutput); when(mockOutput.getMlModelOutputs()).thenReturn(Arrays.asList(mockTensors)); @@ -907,6 +915,9 @@ public void testSummarizeMessages_WithMessages() { ActionListener stringListener = mock(ActionListener.class); List messages = Arrays.asList(MessageInput.builder().content(testContent).role("user").build()); + // Configure LLM ID for summarization + when(memoryConfig.getLlmId()).thenReturn("llm-model-123"); + doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); List mlModelOutputs = new ArrayList<>(); diff --git a/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesActionTests.java index 7aaaae14c6..74a1a5020c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesActionTests.java @@ -52,6 +52,7 @@ import org.opensearch.ml.common.transport.memorycontainer.memory.MessageInput; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.MemoryContainerHelper; +import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; @@ -91,6 +92,9 @@ public class TransportAddMemoriesActionTests { @Mock private MemoryContainerHelper memoryContainerHelper; + @Mock + private MLModelCacheHelper mlModelCacheHelper; + @Mock private MemoryProcessingService memoryProcessingService; @@ -137,6 +141,7 @@ public void setup() throws Exception { xContentRegistry, mlFeatureEnabledSetting, memoryContainerHelper, + mlModelCacheHelper, threadPool ); diff --git a/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java b/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java index e082571ad6..16e6f1f9cd 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java @@ -6,9 +6,7 @@ package org.opensearch.ml.utils; import static java.util.Collections.emptyMap; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE; -import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.ModelInterfaceSchema; import static org.opensearch.ml.utils.TestHelper.ML_ROLE; import java.io.IOException; @@ -71,42 +69,42 @@ public void testValidateSchema() throws IOException { @Test public void testValidateEmbeddingInputWithGeneralEmbeddingRemoteSchema() throws IOException { - String schema = BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE.get("input"); + String schema = ModelInterfaceSchema.BEDROCK_COHERE_EMBED_ENGLISH_V3.getInterface().get("input"); String json = "{\"text_docs\":[ \"today is sunny\", \"today is sunny\"]}"; MLNodeUtils.validateSchema(schema, json); } @Test public void testValidateRemoteInputWithGeneralEmbeddingRemoteSchema() throws IOException { - String schema = BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE.get("input"); + String schema = ModelInterfaceSchema.BEDROCK_COHERE_EMBED_ENGLISH_V3.getInterface().get("input"); String json = "{\"parameters\": {\"texts\": [\"Hello\",\"world\"]}}"; MLNodeUtils.validateSchema(schema, json); } @Test public void testValidateEmbeddingInputWithTitanTextRemoteSchema() throws IOException { - String schema = BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE.get("input"); + String schema = ModelInterfaceSchema.BEDROCK_TITAN_EMBED_TEXT_V1.getInterface().get("input"); String json = "{\"text_docs\":[ \"today is sunny\", \"today is sunny\"]}"; MLNodeUtils.validateSchema(schema, json); } @Test public void testValidateRemoteInputWithTitanTextRemoteSchema() throws IOException { - String schema = BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE.get("input"); + String schema = ModelInterfaceSchema.BEDROCK_TITAN_EMBED_TEXT_V1.getInterface().get("input"); String json = "{\"parameters\": {\"inputText\": \"Say this is a test\"}}"; MLNodeUtils.validateSchema(schema, json); } @Test public void testValidateEmbeddingInputWithTitanMultiModalRemoteSchema() throws IOException { - String schema = BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE.get("input"); + String schema = ModelInterfaceSchema.BEDROCK_TITAN_EMBED_MULTI_MODAL_V1.getInterface().get("input"); String json = "{\"text_docs\":[ \"today is sunny\", \"today is sunny\"]}"; MLNodeUtils.validateSchema(schema, json); } @Test public void testValidateRemoteInputWithTitanMultiModalRemoteSchema() throws IOException { - String schema = BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE.get("input"); + String schema = ModelInterfaceSchema.BEDROCK_TITAN_EMBED_MULTI_MODAL_V1.getInterface().get("input"); String json = "{\n" + " \"parameters\": {\n" + " \"inputText\": \"Say this is a test\",\n"