diff --git a/docs/changelog/136624.yaml b/docs/changelog/136624.yaml new file mode 100644 index 0000000000000..e9e58bd91d953 --- /dev/null +++ b/docs/changelog/136624.yaml @@ -0,0 +1,5 @@ +pr: 136624 +summary: Added OpenShift AI text_embedding, completion, chat_completion and rerank support to the Inference Plugin +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv b/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv new file mode 100644 index 0000000000000..023d490c87211 --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv @@ -0,0 +1 @@ +9218000 diff --git a/server/src/main/resources/transport/upper_bounds/9.3.csv b/server/src/main/resources/transport/upper_bounds/9.3.csv index b29f7625613b5..69abc63d4761f 100644 --- a/server/src/main/resources/transport/upper_bounds/9.3.csv +++ b/server/src/main/resources/transport/upper_bounds/9.3.csv @@ -1 +1 @@ -resharding_shard_summary_in_esql,9217000 +ml_inference_openshift_ai_added,9218000 diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index dd1012daffc03..2194c5b08122a 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -70,6 +70,7 @@ public void testGetServicesWithoutTaskType() throws IOException { "llama", "mistral", "openai", + "openshift_ai", "streaming_completion_test_service", "completion_test_service", "test_reranking_service", @@ -116,6 +117,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { "llama", "mistral", "openai", + "openshift_ai", "text_embedding_test_service", "voyageai", "watsonxai" @@ -140,6 +142,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { "elasticsearch", "googlevertexai", "jinaai", + "openshift_ai", "test_reranking_service", "voyageai", "hugging_face", @@ -167,6 +170,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException { "googleaistudio", "googlevertexai", "openai", + "openshift_ai", "streaming_completion_test_service", "completion_test_service", "hugging_face", @@ -188,6 +192,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { "deepseek", "elastic", "openai", + "openshift_ai", "streaming_completion_test_service", "hugging_face", "amazon_sagemaker", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 3c4e48b95eccc..f39ca9f66f3ca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -115,6 +115,10 @@ import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankTaskSettings; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -172,6 +176,7 @@ public static List getNamedWriteables() { addCustomNamedWriteables(namedWriteables); addLlamaNamedWriteables(namedWriteables); addAi21NamedWriteables(namedWriteables); + addOpenShiftAiNamedWriteables(namedWriteables); addUnifiedNamedWriteables(namedWriteables); @@ -446,6 +451,33 @@ private static void addOpenAiNamedWriteables(List ); } + private static void addOpenShiftAiNamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + OpenShiftAiEmbeddingsServiceSettings.NAME, + OpenShiftAiEmbeddingsServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + OpenShiftAiChatCompletionServiceSettings.NAME, + OpenShiftAiChatCompletionServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + OpenShiftAiRerankServiceSettings.NAME, + OpenShiftAiRerankServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(TaskSettings.class, OpenShiftAiRerankTaskSettings.NAME, OpenShiftAiRerankTaskSettings::new) + ); + } + private static void addHuggingFaceNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 9e1f17643ac73..c4025d0e6c4b0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -164,6 +164,7 @@ import org.elasticsearch.xpack.inference.services.llama.LlamaService; import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiService; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerService; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerConfiguration; @@ -492,6 +493,7 @@ public List getInferenceServiceFactories() { context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context), context -> new LlamaService(httpFactory.get(), serviceComponents.get(), context), context -> new Ai21Service(httpFactory.get(), serviceComponents.get(), context), + context -> new OpenShiftAiService(httpFactory.get(), serviceComponents.get(), context), ElasticsearchInternalService::new, context -> new CustomService(httpFactory.get(), serviceComponents.get(), context) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java index 1af79a69839ac..062ef0f67059c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java @@ -14,6 +14,7 @@ public final class ServiceFields { public static final String SIMILARITY = "similarity"; public static final String DIMENSIONS = "dimensions"; + public static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; // Typically we use this to define the maximum tokens for the input text (text being sent to an integration) public static final String MAX_INPUT_TOKENS = "max_input_tokens"; public static final String URL = "url"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java new file mode 100644 index 0000000000000..80f5ea0bdcd6b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai; + +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.openshiftai.action.OpenShiftAiActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.Map; +import java.util.Objects; + +/** + * Represents an OpenShift AI model that can be used for inference tasks. + * This class extends RateLimitGroupingModel to handle rate limiting based on modelId and API key. + */ +public abstract class OpenShiftAiModel extends RateLimitGroupingModel { + + protected OpenShiftAiModel(ModelConfigurations configurations, ModelSecrets secrets) { + super(configurations, secrets); + } + + protected OpenShiftAiModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + protected OpenShiftAiModel(RateLimitGroupingModel model, TaskSettings taskSettings) { + super(model, taskSettings); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return getServiceSettings().rateLimitSettings(); + } + + @Override + public int rateLimitGroupingHash() { + return Objects.hash(getServiceSettings().uri(), getServiceSettings().modelId()); + } + + @Override + public OpenShiftAiServiceSettings getServiceSettings() { + return (OpenShiftAiServiceSettings) super.getServiceSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + /** + * Accepts a visitor to create an executable action for this OpenShift AI model. + * + * @param creator the visitor that creates the executable action + * @param taskSettings the task settings to be used for the executable action + * @return an {@link ExecutableAction} specific to this OpenShift AI model + */ + public abstract ExecutableAction accept(OpenShiftAiActionVisitor creator, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java new file mode 100644 index 0000000000000..d0b4f71017736 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java @@ -0,0 +1,423 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.RerankingInferenceService; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.core.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.openshiftai.action.OpenShiftAiActionCreator; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModel; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.openshiftai.request.completion.OpenShiftAiChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +/** + * OpenShiftAiService is an implementation of the {@link SenderService} and {@link RerankingInferenceService} that handles inference tasks + * using models deployed to OpenShift AI environment. + * The service uses {@link OpenShiftAiActionCreator} to create actions for executing inference requests. + */ +public class OpenShiftAiService extends SenderService implements RerankingInferenceService { + public static final String NAME = "openshift_ai"; + /** + * The optimal batch size depends on the model deployed in OpenShift AI. + * For OpenShift AI use a conservatively small max batch size as it is unknown what model is deployed. + */ + static final int EMBEDDING_MAX_BATCH_SIZE = 20; + private static final String SERVICE_NAME = "OpenShift AI"; + private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of( + TaskType.TEXT_EMBEDDING, + TaskType.COMPLETION, + TaskType.CHAT_COMPLETION, + TaskType.RERANK + ); + private static final ResponseHandler CHAT_COMPLETION_HANDLER = new OpenShiftAiChatCompletionResponseHandler( + "OpenShift AI chat completions", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + public OpenShiftAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public OpenShiftAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); + } + + @Override + protected void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof OpenShiftAiModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + var openShiftAiModel = (OpenShiftAiModel) model; + var actionCreator = new OpenShiftAiActionCreator(getSender(), getServiceComponents()); + openShiftAiModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener); + } + + @Override + protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { + ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException); + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof OpenShiftAiChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + OpenShiftAiChatCompletionModel chatCompletionModel = (OpenShiftAiChatCompletionModel) model; + var overriddenModel = OpenShiftAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest().model()); + var manager = new GenericRequestManager<>( + getServiceComponents().threadPool(), + overriddenModel, + CHAT_COMPLETION_HANDLER, + unifiedChatInput -> new OpenShiftAiChatCompletionRequest(unifiedChatInput, overriddenModel), + UnifiedChatInput.class + ); + var errorMessage = OpenShiftAiActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId()); + var action = new SenderExecutableAction(getSender(), manager, errorMessage); + + action.execute(inputs, timeout, listener); + } + + @Override + protected void doChunkedInfer( + Model model, + List inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + if (model instanceof OpenShiftAiEmbeddingsModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + var openShiftAiEmbeddingsModel = (OpenShiftAiEmbeddingsModel) model; + var actionCreator = new OpenShiftAiActionCreator(getSender(), getServiceComponents()); + List batchedRequests = new EmbeddingRequestChunker<>( + inputs, + EMBEDDING_MAX_BATCH_SIZE, + openShiftAiEmbeddingsModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = openShiftAiEmbeddingsModel.accept(actionCreator, taskSettings); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + } + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return SUPPORTED_TASK_TYPES; + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + + OpenShiftAiModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + @Override + public OpenShiftAiModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + secretSettingsMap, + taskSettingsMap, + chunkingSettings + ); + } + + @Override + public OpenShiftAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + ChunkingSettings chunkingSettingsMap = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettingsMap = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, null, taskSettingsMap, chunkingSettingsMap); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return OpenShiftAiUtils.ML_INFERENCE_OPENSHIFT_AI_ADDED; + } + + @Override + public Set supportedStreamingTasks() { + return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); + } + + private static OpenShiftAiModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + @Nullable Map secretSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + ConfigurationParseContext context + ) { + return switch (taskType) { + case CHAT_COMPLETION, COMPLETION -> new OpenShiftAiChatCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + secretSettings, + context + ); + case TEXT_EMBEDDING -> new OpenShiftAiEmbeddingsModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + chunkingSettings, + secretSettings, + context + ); + case RERANK -> new OpenShiftAiRerankModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); + }; + } + + private OpenShiftAiModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map secretSettings, + Map taskSettings, + ChunkingSettings chunkingSettings + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + secretSettings, + taskSettings, + chunkingSettings, + ConfigurationParseContext.PERSISTENT + ); + } + + @Override + public int rerankerWindowSize(String modelId) { + // OpenShift AI uses Cohere and JinaAI rerank protocols for reranking + // JinaAI rerank model has 131K tokens limit https://jina.ai/models/jina-reranker-v3/ + // Cohere rerank model truncates at 4096 tokens https://docs.cohere.com/reference/rerank + // We choose a conservative limit based on these two models + // Using 1 token = 0.75 words as a rough estimate, we get 3072 words allowing for some headroom, we set the window size below 3072 + return 2800; + } + + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof OpenShiftAiEmbeddingsModel embeddingsModel) { + var serviceSettings = embeddingsModel.getServiceSettings(); + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; + + var updatedServiceSettings = new OpenShiftAiEmbeddingsServiceSettings( + serviceSettings.modelId(), + serviceSettings.uri(), + embeddingSize, + similarityToUse, + serviceSettings.maxInputTokens(), + serviceSettings.rateLimitSettings(), + serviceSettings.dimensionsSetByUser() + ); + + return new OpenShiftAiEmbeddingsModel(embeddingsModel, updatedServiceSettings); + } else { + throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass()); + } + } + + /** + * Configuration class for the OpenShift AI inference service. + * It provides the settings and configurations required for the service. + */ + public static class Configuration { + public static InferenceServiceConfiguration get() { + return CONFIGURATION.getOrCompute(); + } + + private static final LazyInitializable CONFIGURATION = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.put( + URL, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.") + .setLabel("URL") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + + configurationMap.put( + MODEL_ID, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription( + "The name of the model to use for the inference task." + ) + .setLabel("Model ID") + .setRequired(false) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(SUPPORTED_TASK_TYPES) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java new file mode 100644 index 0000000000000..03184031eff27 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java @@ -0,0 +1,178 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; + +/** + * Represents the settings for an OpenShift AI service. + * This class encapsulates the model ID, URI, and rate limit settings for the OpenShift AI service. + */ +public abstract class OpenShiftAiServiceSettings extends FilteredXContentObject implements ServiceSettings { + // There is no default rate limit for OpenShift AI, so we set a reasonable default of 3000 requests per minute + protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + + protected final String modelId; + protected final URI uri; + protected final RateLimitSettings rateLimitSettings; + + /** + * Constructs a new OpenShiftAiServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + protected OpenShiftAiServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readOptionalString(); + this.uri = createUri(in.readString()); + this.rateLimitSettings = new RateLimitSettings(in); + } + + /** + * Constructs a new OpenShiftAiServiceSettings. + * + * @param modelId the ID of the model + * @param uri the URI of the service + * @param rateLimitSettings the rate limit settings for the service + */ + protected OpenShiftAiServiceSettings(@Nullable String modelId, URI uri, @Nullable RateLimitSettings rateLimitSettings) { + this.modelId = modelId; + this.uri = Objects.requireNonNull(uri); + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + assert false : "should never be called when supportsVersion is used"; + return OpenShiftAiUtils.ML_INFERENCE_OPENSHIFT_AI_ADDED; + } + + @Override + public boolean supportsVersion(TransportVersion version) { + return OpenShiftAiUtils.supportsOpenShiftAi(version); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(modelId); + out.writeString(uri.toString()); + rateLimitSettings.writeTo(out); + } + + @Override + public String modelId() { + return this.modelId; + } + + /** + * Returns the URI of the OpenShift AI service. + * + * @return the URI of the service + */ + public URI uri() { + return this.uri; + } + + /** + * Returns the rate limit settings for the OpenShift AI service. + * + * @return the rate limit settings + */ + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + if (modelId != null) { + builder.field(MODEL_ID, modelId); + } + builder.field(URL, uri.toString()); + rateLimitSettings.toXContent(builder, params); + return builder; + } + + /** + * Creates an instance of T from the provided map using the given factory function. + * @param map the map containing the service settings + * @param context the context for parsing configuration settings + * @param factory the factory function to create an instance of T + * @return an instance of T + * @param the type of {@link OpenShiftAiServiceSettings} to create + */ + protected static T fromMap( + Map map, + ConfigurationParseContext context, + Function factory + ) { + var validationException = new ValidationException(); + var commonServiceSettings = extractOpenShiftAiCommonServiceSettings(map, context, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return factory.apply(commonServiceSettings); + } + + /** + * Extracts common OpenShift AI service settings from the provided map. + * @param map the map containing the service settings + * @param context the context for parsing configuration settings + * @param validationException the validation exception to collect validation errors + * @return an instance of {@link OpenShiftAiCommonServiceSettings} + */ + protected static OpenShiftAiCommonServiceSettings extractOpenShiftAiCommonServiceSettings( + Map map, + ConfigurationParseContext context, + ValidationException validationException + ) { + var model = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = extractUri(map, URL, validationException); + var rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + OpenShiftAiService.NAME, + context + ); + return new OpenShiftAiCommonServiceSettings(model, uri, rateLimitSettings); + } + + protected record OpenShiftAiCommonServiceSettings(String model, URI uri, RateLimitSettings rateLimitSettings) {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiUtils.java new file mode 100644 index 0000000000000..d339c47b52c8b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiUtils.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai; + +import org.elasticsearch.TransportVersion; + +/** + * Utility class for OpenShift AI related version checks. + */ +public final class OpenShiftAiUtils { + + /** + * TransportVersion indicating when OpenShift AI features were added. + */ + public static final TransportVersion ML_INFERENCE_OPENSHIFT_AI_ADDED = TransportVersion.fromName("ml_inference_openshift_ai_added"); + + /** + * Checks if the given TransportVersion supports OpenShift AI features. + * + * @param version the TransportVersion to check + * @return true if OpenShift AI features are supported, false otherwise + */ + public static boolean supportsOpenShiftAi(TransportVersion version) { + return version.supports(ML_INFERENCE_OPENSHIFT_AI_ADDED); + } + + /** + * Private constructor to prevent instantiation. + */ + private OpenShiftAiUtils() {} + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreator.java new file mode 100644 index 0000000000000..52319d12f1112 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreator.java @@ -0,0 +1,142 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.action; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIResponseHandler; +import org.elasticsearch.xpack.inference.services.jinaai.response.JinaAIRerankResponseEntity; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModel; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsResponseHandler; +import org.elasticsearch.xpack.inference.services.openshiftai.request.completion.OpenShiftAiChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.openshiftai.request.embeddings.OpenShiftAiEmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.openshiftai.request.rarank.OpenShiftAiRerankRequest; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankModel; + +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +/** + * Creates executable actions for OpenShift AI models. + * This class implements the {@link OpenShiftAiActionVisitor} interface to provide specific action creation methods. + */ +public class OpenShiftAiActionCreator implements OpenShiftAiActionVisitor { + + private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = + "Failed to send OpenShift AI %s request from inference entity id [%s]"; + private static final String COMPLETION_ERROR_PREFIX = "OpenShift AI completions"; + private static final String USER_ROLE = "user"; + + private static final ResponseHandler EMBEDDINGS_HANDLER = new OpenShiftAiEmbeddingsResponseHandler( + "OpenShift AI text embedding", + OpenAiEmbeddingsResponseEntity::fromResponse + ); + private static final ResponseHandler COMPLETION_HANDLER = new OpenShiftAiCompletionResponseHandler( + "OpenShift AI completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + // OpenShift AI Rerank task uses the same response format as JinaAI, therefore we can reuse the JinaAIResponseHandler + private static final ResponseHandler RERANK_HANDLER = new JinaAIResponseHandler( + "OpenShift AI rerank", + (request, response) -> JinaAIRerankResponseEntity.fromResponse(response) + ); + + private final Sender sender; + private final ServiceComponents serviceComponents; + + /** + * Constructs a new OpenShiftAiActionCreator. + * + * @param sender the sender to use for executing actions + * @param serviceComponents the service components providing necessary services + */ + public OpenShiftAiActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(OpenShiftAiEmbeddingsModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + EMBEDDINGS_HANDLER, + embeddingsInput -> new OpenShiftAiEmbeddingsRequest( + serviceComponents.truncator(), + truncate(embeddingsInput.getTextInputs(), model.getServiceSettings().maxInputTokens()), + model + ), + EmbeddingsInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId()); + return new SenderExecutableAction(sender, manager, errorMessage); + } + + @Override + public ExecutableAction create(OpenShiftAiChatCompletionModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + COMPLETION_HANDLER, + inputs -> new OpenShiftAiChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + ChatCompletionInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.COMPLETION, model.getInferenceEntityId()); + return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX); + } + + @Override + public ExecutableAction create(OpenShiftAiRerankModel model, Map taskSettings) { + var overriddenModel = OpenShiftAiRerankModel.of(model, taskSettings); + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + overriddenModel, + RERANK_HANDLER, + inputs -> new OpenShiftAiRerankRequest( + inputs.getQuery(), + inputs.getChunks(), + inputs.getReturnDocuments(), + inputs.getTopN(), + overriddenModel + ), + QueryAndDocsInputs.class + ); + var errorMessage = buildErrorMessage(TaskType.RERANK, overriddenModel.getInferenceEntityId()); + return new SenderExecutableAction(sender, manager, errorMessage); + } + + /** + * Builds an error message for failed requests. + * + * @param requestType the type of request that failed + * @param inferenceId the inference entity ID associated with the request + * @return a formatted error message + */ + public static String buildErrorMessage(TaskType requestType, String inferenceId) { + return format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, requestType.toString(), inferenceId); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionVisitor.java new file mode 100644 index 0000000000000..fd06807bcb174 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionVisitor.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.action; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModel; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankModel; + +import java.util.Map; + +/** + * Visitor interface for creating executable actions for OpenShift AI inference models. + * This interface defines methods to create actions for embeddings, reranking and completion models. + */ +public interface OpenShiftAiActionVisitor { + + /** + * Creates an executable action for the given OpenShift AI embeddings model. + * + * @param model The OpenShift AI embeddings model. + * @return An executable action for the embeddings model. + */ + ExecutableAction create(OpenShiftAiEmbeddingsModel model); + + /** + * Creates an executable action for the given OpenShift AI chat completion model. + * + * @param model The OpenShift AI chat completion model. + * @return An executable action for the chat completion model. + */ + ExecutableAction create(OpenShiftAiChatCompletionModel model); + + /** + * Creates an executable action for the given OpenShift AI rerank model. + * + * @param model The OpenShift AI rerank model. + * @param taskSettings The task settings for the rerank action. + * @return An executable action for the rerank model. + */ + ExecutableAction create(OpenShiftAiRerankModel model, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java new file mode 100644 index 0000000000000..181c25c5c04ec --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.completion; + +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiModel; +import org.elasticsearch.xpack.inference.services.openshiftai.action.OpenShiftAiActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; +import java.util.Objects; + +/** + * Represents an OpenShift AI chat completion model. + * This class extends the {@link OpenShiftAiModel} and provides specific configurations for chat completion tasks. + */ +public class OpenShiftAiChatCompletionModel extends OpenShiftAiModel { + + /** + * Constructor for creating an OpenShiftAiChatCompletionModel with specified parameters. + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to chat completion + * @param secrets the secret settings for the model, such as API keys or tokens + * @param context the context for parsing configuration settings + */ + public OpenShiftAiChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + OpenShiftAiChatCompletionServiceSettings.fromMap(serviceSettings, context), + DefaultSecretSettings.fromMap(secrets) + ); + } + + /** + * Constructor for creating an OpenShiftAiChatCompletionModel with specified parameters. + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to chat completion + * @param secrets the secret settings for the model + */ + public OpenShiftAiChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + OpenShiftAiChatCompletionServiceSettings serviceSettings, + SecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE), + new ModelSecrets(secrets) + ); + } + + /** + * Factory method to create an OpenShiftAiChatCompletionModel with potential overrides from a UnifiedCompletionRequest. + * If the request does not specify a model ID, the original model is returned. + * + * @param model the original OpenShiftAiChatCompletionModel + * @param modelId the model ID specified in the request, which may override the original model's ID + * @return a new OpenShiftAiChatCompletionModel with overridden settings or the original model ID if no overrides are specified + */ + public static OpenShiftAiChatCompletionModel of(OpenShiftAiChatCompletionModel model, String modelId) { + if (modelId == null || Objects.equals(model.getServiceSettings().modelId(), modelId)) { + // If no model ID is specified in the request, or if it matches the original model's ID, return the original model. + return model; + } + + var originalModelServiceSettings = model.getServiceSettings(); + var overriddenServiceSettings = new OpenShiftAiChatCompletionServiceSettings( + modelId, + originalModelServiceSettings.uri(), + originalModelServiceSettings.rateLimitSettings() + ); + + return new OpenShiftAiChatCompletionModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + overriddenServiceSettings, + model.getSecretSettings() + ); + } + + /** + * Returns the service settings specific to OpenShift AI chat completion. + * + * @return the OpenShiftAiChatCompletionServiceSettings associated with this model + */ + @Override + public OpenShiftAiChatCompletionServiceSettings getServiceSettings() { + return (OpenShiftAiChatCompletionServiceSettings) super.getServiceSettings(); + } + + @Override + public ExecutableAction accept(OpenShiftAiActionVisitor creator, Map taskSettings) { + // Chat completion models do not have task settings, so we ignore the taskSettings parameter. + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..00d313709035b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandler.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.completion; + +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParserContract; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponseUtils; +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; + +/** + * Handles streaming chat completion responses and error parsing for OpenShift AI inference endpoints. + * Adapts the OpenAI handler to support OpenShift AI's error schema. + */ +public class OpenShiftAiChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { + + private static final String OPENSHIFT_AI_ERROR = "openshift_ai_error"; + private static final UnifiedChatCompletionErrorParserContract OPENSHIFT_AI_ERROR_PARSER = UnifiedChatCompletionErrorResponseUtils + .createErrorParserWithStringify(OPENSHIFT_AI_ERROR); + + public OpenShiftAiChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, OPENSHIFT_AI_ERROR_PARSER::parse, OPENSHIFT_AI_ERROR_PARSER); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java new file mode 100644 index 0000000000000..9ab56ea9ec208 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.completion; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; + +/** + * Represents the settings for an OpenShift AI chat completion service. + * This class encapsulates the model ID, URI, and rate limit settings for the OpenShift AI chat completion service. + */ +public class OpenShiftAiChatCompletionServiceSettings extends OpenShiftAiServiceSettings { + public static final String NAME = "openshift_ai_completion_service_settings"; + + /** + * Creates a new instance of OpenShiftAiChatCompletionServiceSettings from a map of settings. + * + * @param map the map containing the service settings + * @param context the context for parsing configuration settings + * @return a new instance of OpenShiftAiChatCompletionServiceSettings + * @throws ValidationException if required fields are missing or invalid + */ + public static OpenShiftAiChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { + return fromMap( + map, + context, + commonServiceSettings -> new OpenShiftAiChatCompletionServiceSettings( + commonServiceSettings.model(), + commonServiceSettings.uri(), + commonServiceSettings.rateLimitSettings() + ) + ); + } + + /** + * Constructs a new OpenShiftAiChatCompletionServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public OpenShiftAiChatCompletionServiceSettings(StreamInput in) throws IOException { + super(in); + } + + /** + * Constructs a new OpenShiftAiChatCompletionServiceSettings. + * + * @param modelId the ID of the model + * @param uri the URI of the service + * @param rateLimitSettings the rate limit settings for the service + */ + public OpenShiftAiChatCompletionServiceSettings(@Nullable String modelId, URI uri, @Nullable RateLimitSettings rateLimitSettings) { + super(modelId, uri, rateLimitSettings); + } + + /** + * Constructs a new OpenShiftAiChatCompletionServiceSettings. + * + * @param modelId the ID of the model + * @param url the URL of the OpenShift AI service + * @param rateLimitSettings the rate limit settings for the service + */ + public OpenShiftAiChatCompletionServiceSettings(@Nullable String modelId, String url, @Nullable RateLimitSettings rateLimitSettings) { + super(modelId, createUri(url), rateLimitSettings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OpenShiftAiChatCompletionServiceSettings that = (OpenShiftAiChatCompletionServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiCompletionResponseHandler.java new file mode 100644 index 0000000000000..2d32d146aec43 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiCompletionResponseHandler.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.completion; + +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler; + +/** + * Handles non-streaming completion responses for OpenShift AI models, extending the OpenAI completion response handler. + */ +public class OpenShiftAiCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { + + /** + * Constructs an OpenShiftAiCompletionResponseHandler with the specified request type and response parser. + * + * @param requestType The type of request being handled (e.g., "Openshift AI completions"). + * @param parseFunction The function to parse the response. + */ + public OpenShiftAiCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, ErrorResponse::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java new file mode 100644 index 0000000000000..796cd60a932a5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java @@ -0,0 +1,87 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.embeddings; + +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiModel; +import org.elasticsearch.xpack.inference.services.openshiftai.action.OpenShiftAiActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +/** + * Represents an OpenShift AI embeddings model for inference. + * This class extends the {@link OpenShiftAiModel} and provides specific configurations and settings for embeddings tasks. + */ +public class OpenShiftAiEmbeddingsModel extends OpenShiftAiModel { + + public OpenShiftAiEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + ChunkingSettings chunkingSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + OpenShiftAiEmbeddingsServiceSettings.fromMap(serviceSettings, context), + chunkingSettings, + DefaultSecretSettings.fromMap(secrets) + ); + } + + public OpenShiftAiEmbeddingsModel(OpenShiftAiEmbeddingsModel model, OpenShiftAiEmbeddingsServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + /** + * Constructor for creating an OpenShiftAiEmbeddingsModel with specified parameters. + * + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to embeddings + * @param chunkingSettings the chunking settings for processing input data + * @param secrets the secret settings for the model, such as API keys or tokens + */ + public OpenShiftAiEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + OpenShiftAiEmbeddingsServiceSettings serviceSettings, + ChunkingSettings chunkingSettings, + SecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings), + new ModelSecrets(secrets) + ); + } + + @Override + public OpenShiftAiEmbeddingsServiceSettings getServiceSettings() { + return (OpenShiftAiEmbeddingsServiceSettings) super.getServiceSettings(); + } + + @Override + public ExecutableAction accept(OpenShiftAiActionVisitor creator, Map taskSettings) { + // Embeddings models do not have task settings, so we ignore the taskSettings parameter. + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsResponseHandler.java new file mode 100644 index 0000000000000..ad2db4571a47a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsResponseHandler.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.embeddings; + +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.services.openai.OpenAiResponseHandler; + +/** + * Handles responses for OpenShift AI embeddings requests, parsing the response and handling errors. + * This class extends OpenAiResponseHandler to provide specific functionality for OpenShift AI embeddings. + */ +public class OpenShiftAiEmbeddingsResponseHandler extends OpenAiResponseHandler { + + /** + * Constructs a new OpenShiftAiEmbeddingsResponseHandler with the specified request type and response parser. + * + * @param requestType the type of request this handler will process + * @param parseFunction the function to parse the response + */ + public OpenShiftAiEmbeddingsResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, ErrorResponse::fromResponse, false); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..b9343179db185 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java @@ -0,0 +1,245 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.embeddings; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.InferenceUtils; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS_SET_BY_USER; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; + +/** + * Settings for the OpenShift AI embeddings service. + * This class encapsulates the configuration settings required to use OpenShift AI for generating embeddings. + */ +public class OpenShiftAiEmbeddingsServiceSettings extends OpenShiftAiServiceSettings { + public static final String NAME = "openshift_ai_embeddings_service_settings"; + + private final Integer dimensions; + private final SimilarityMeasure similarity; + private final Integer maxInputTokens; + private final Boolean dimensionsSetByUser; + + /** + * Creates a new instance of OpenShiftAiEmbeddingsServiceSettings from a map of settings. + * + * @param map the map containing the settings + * @param context the context for parsing configuration settings + * @return a new instance of OpenShiftAiEmbeddingsServiceSettings + * @throws ValidationException if any required fields are missing or invalid + */ + public static OpenShiftAiEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + var validationException = new ValidationException(); + var commonServiceSettings = extractOpenShiftAiCommonServiceSettings(map, context, validationException); + var dimensions = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); + var similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + var maxInputTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + var dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException); + switch (context) { + case REQUEST -> { + if (dimensionsSetByUser != null) { + validationException.addValidationError( + ServiceUtils.invalidSettingError(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS) + ); + } + dimensionsSetByUser = dimensions != null; + } + case PERSISTENT -> { + if (dimensionsSetByUser == null) { + validationException.addValidationError( + InferenceUtils.missingSettingErrorMsg(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS) + ); + } + } + } + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new OpenShiftAiEmbeddingsServiceSettings( + commonServiceSettings.model(), + commonServiceSettings.uri(), + dimensions, + similarity, + maxInputTokens, + commonServiceSettings.rateLimitSettings(), + dimensionsSetByUser + ); + } + + /** + * Constructs a new OpenShiftAiEmbeddingsServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public OpenShiftAiEmbeddingsServiceSettings(StreamInput in) throws IOException { + super(in); + this.dimensions = in.readOptionalVInt(); + this.similarity = in.readOptionalEnum(SimilarityMeasure.class); + this.maxInputTokens = in.readOptionalVInt(); + this.dimensionsSetByUser = in.readBoolean(); + } + + /** + * Constructs a new OpenShiftAiEmbeddingsServiceSettings with the specified parameters. + * + * @param modelId the identifier for the model + * @param uri the URI of the OpenShift AI service + * @param dimensions the number of dimensions for the embeddings, can be null + * @param similarity the similarity measure to use, can be null + * @param maxInputTokens the maximum number of input tokens, can be null + * @param rateLimitSettings the rate limit settings for the service, can be null + * @param dimensionsSetByUser indicates if dimensions were set by the user + */ + public OpenShiftAiEmbeddingsServiceSettings( + @Nullable String modelId, + URI uri, + @Nullable Integer dimensions, + @Nullable SimilarityMeasure similarity, + @Nullable Integer maxInputTokens, + @Nullable RateLimitSettings rateLimitSettings, + Boolean dimensionsSetByUser + ) { + super(modelId, uri, rateLimitSettings); + this.dimensions = dimensions; + this.similarity = similarity; + this.maxInputTokens = maxInputTokens; + this.dimensionsSetByUser = Objects.requireNonNull(dimensionsSetByUser); + } + + /** + * Constructs a new OpenShiftAiEmbeddingsServiceSettings with the specified parameters. + * + * @param modelId the identifier for the model + * @param url the URL of the OpenShift AI service + * @param dimensions the number of dimensions for the embeddings, can be null + * @param similarity the similarity measure to use, can be null + * @param maxInputTokens the maximum number of input tokens, can be null + * @param rateLimitSettings the rate limit settings for the service, can be null + * @param dimensionsSetByUser indicates if dimensions were set by the user + */ + public OpenShiftAiEmbeddingsServiceSettings( + String modelId, + String url, + @Nullable Integer dimensions, + @Nullable SimilarityMeasure similarity, + @Nullable Integer maxInputTokens, + @Nullable RateLimitSettings rateLimitSettings, + Boolean dimensionsSetByUser + ) { + this(modelId, createUri(url), dimensions, similarity, maxInputTokens, rateLimitSettings, dimensionsSetByUser); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public Integer dimensions() { + return this.dimensions; + } + + @Override + public Boolean dimensionsSetByUser() { + return this.dimensionsSetByUser; + } + + @Override + public SimilarityMeasure similarity() { + return this.similarity; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + /** + * Returns the maximum number of input tokens allowed for this service. + * + * @return the maximum input tokens, or null if not specified + */ + public Integer maxInputTokens() { + return this.maxInputTokens; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalVInt(dimensions); + out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); + out.writeOptionalVInt(maxInputTokens); + out.writeBoolean(dimensionsSetByUser); + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + super.toXContentFragmentOfExposedFields(builder, params); + + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OpenShiftAiEmbeddingsServiceSettings that = (OpenShiftAiEmbeddingsServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(similarity, that.similarity) + && Objects.equals(rateLimitSettings, that.rateLimitSettings) + && Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, dimensions, maxInputTokens, similarity, rateLimitSettings, dimensionsSetByUser); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequest.java new file mode 100644 index 0000000000000..9e4796e28ac3d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequest.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.completion; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * OpenShift AI Chat Completion Request + * This class is responsible for creating a request to the OpenShift AI chat completion model. + * It constructs an HTTP POST request with the necessary headers and body content. + */ +public class OpenShiftAiChatCompletionRequest implements Request { + + private final OpenShiftAiChatCompletionModel model; + private final UnifiedChatInput chatInput; + + /** + * Constructs a new OpenShiftAiChatCompletionRequest with the specified chat input and model. + * + * @param chatInput the chat input containing the messages and parameters for the completion request + * @param model the OpenShift AI chat completion model to be used for the request + */ + public OpenShiftAiChatCompletionRequest(UnifiedChatInput chatInput, OpenShiftAiChatCompletionModel model) { + this.chatInput = Objects.requireNonNull(chatInput); + this.model = Objects.requireNonNull(model); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.getServiceSettings().uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new OpenShiftAiChatCompletionRequestEntity(chatInput, model.getServiceSettings().modelId())) + .getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return model.getServiceSettings().uri(); + } + + @Override + public Request truncate() { + // No truncation for OpenShift AI chat completions + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation for OpenShift AI chat completions + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public boolean isStreaming() { + return chatInput.stream(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..e8194116c403c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntity.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.completion; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; + +import java.io.IOException; + +/** + * OpenShiftAiChatCompletionRequestEntity is responsible for creating the request entity for OpenShift AI chat completion. + * It implements ToXContentObject to allow serialization to XContent format. + */ +public class OpenShiftAiChatCompletionRequestEntity implements ToXContentObject { + + private final String modelId; + private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; + + public OpenShiftAiChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, @Nullable String modelId) { + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); + this.modelId = modelId; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokensAndSkipStreamOptionsField(modelId, params)); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequest.java new file mode 100644 index 0000000000000..b1a31f60cfc98 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequest.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.embeddings; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * OpenShift AI Embeddings Request + * This class is responsible for creating a request to the OpenShift AI embeddings endpoint. + * It constructs an HTTP POST request with the necessary headers and body content. + */ +public class OpenShiftAiEmbeddingsRequest implements Request { + private final OpenShiftAiEmbeddingsModel model; + private final Truncator.TruncationResult truncationResult; + private final Truncator truncator; + + /** + * Constructs a new OpenShiftAiEmbeddingsRequest with the specified truncator, input, and model. + * + * @param truncator the truncator to handle input truncation + * @param input the input to be truncated + * @param model the OpenShift AI embeddings model to be used for the request + */ + public OpenShiftAiEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, OpenShiftAiEmbeddingsModel model) { + this.model = model; + this.truncator = truncator; + this.truncationResult = input; + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.getServiceSettings().uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString( + new OpenShiftAiEmbeddingsRequestEntity( + truncationResult.input(), + model.getServiceSettings().modelId(), + model.getServiceSettings().dimensions(), + model.getServiceSettings().dimensionsSetByUser() + ) + ).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return model.getServiceSettings().uri(); + } + + @Override + public Request truncate() { + var truncatedInput = truncator.truncate(truncationResult.input()); + return new OpenShiftAiEmbeddingsRequest(truncator, truncatedInput, model); + } + + @Override + public boolean[] getTruncationInfo() { + return truncationResult.truncated().clone(); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..7aa045f7fc5ab --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntity.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.embeddings; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +/** + * OpenShiftAiEmbeddingsRequestEntity is responsible for creating the request entity for OpenShift AI embeddings. + * It implements ToXContentObject to allow serialization to XContent format. + */ +public record OpenShiftAiEmbeddingsRequestEntity( + List input, + @Nullable String modelId, + @Nullable Integer dimensions, + boolean dimensionsSetByUser +) implements ToXContentObject { + + private static final String INPUT_FIELD = "input"; + private static final String MODEL_FIELD = "model"; + private static final String DIMENSIONS_FIELD = "dimensions"; + + public OpenShiftAiEmbeddingsRequestEntity { + Objects.requireNonNull(input); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INPUT_FIELD, input); + if (modelId != null) { + builder.field(MODEL_FIELD, modelId); + } + if (dimensionsSetByUser && dimensions != null) { + builder.field(DIMENSIONS_FIELD, dimensions); + } + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntity.java new file mode 100644 index 0000000000000..e4df85705d83f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntity.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.rarank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +/** + * Entity representing the request body for OpenShift AI rerank requests. + * @param modelId the model identifier (optional) + * @param query the query string + * @param documents the list of documents to be reranked + * @param returnDocuments whether to return the documents in the response (optional) + * @param topN the number of top results to return (optional) + */ +public record OpenShiftAIRerankRequestEntity( + @Nullable String modelId, + String query, + List documents, + @Nullable Boolean returnDocuments, + @Nullable Integer topN +) implements ToXContentObject { + + private static final String MODEL_FIELD = "model"; + private static final String DOCUMENTS_FIELD = "documents"; + private static final String QUERY_FIELD = "query"; + + public OpenShiftAIRerankRequestEntity { + Objects.requireNonNull(query); + Objects.requireNonNull(documents); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + // model field is optional for OpenShift AI + if (modelId != null) { + builder.field(MODEL_FIELD, modelId); + } + builder.field(QUERY_FIELD, query); + builder.field(DOCUMENTS_FIELD, documents); + + if (topN != null) { + builder.field(OpenShiftAiRerankTaskSettings.TOP_N, topN); + } + + if (returnDocuments != null) { + builder.field(OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments); + } + + builder.endObject(); + return builder; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequest.java new file mode 100644 index 0000000000000..56ef806982121 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequest.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.rarank; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * Represents a request to the OpenShift AI rerank service. + * This class constructs the HTTP request with the necessary headers and body content. + * @param query the query string to rerank against + * @param input the list of input documents to be reranked + * @param returnDocuments whether to return the documents in the response (optional) + * @param topN the number of top results to return (optional) + * @param model the OpenShift AI rerank model configuration + */ +public record OpenShiftAiRerankRequest( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + OpenShiftAiRerankModel model +) implements Request { + + public OpenShiftAiRerankRequest { + Objects.requireNonNull(input); + Objects.requireNonNull(query); + Objects.requireNonNull(model); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(getURI()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString( + new OpenShiftAIRerankRequestEntity(model.getServiceSettings().modelId(), query, input, returnDocuments(), topN()) + ).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + + httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public URI getURI() { + return model.getServiceSettings().uri(); + } + + public Integer topN() { + return topN != null ? topN : model.getTaskSettings().getTopN(); + } + + public Boolean returnDocuments() { + return returnDocuments != null ? returnDocuments : model.getTaskSettings().getReturnDocuments(); + } + + @Override + public Request truncate() { + // Not applicable for rerank, only used in text embedding requests + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // Not applicable for rerank, only used in text embedding requests + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModel.java new file mode 100644 index 0000000000000..94d9047e5d39a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModel.java @@ -0,0 +1,94 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.rerank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiModel; +import org.elasticsearch.xpack.inference.services.openshiftai.action.OpenShiftAiActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +/** + * Represents an OpenShift AI rerank model. + * This class extends the {@link OpenShiftAiModel} and provides specific configurations for rerank tasks. + */ +public class OpenShiftAiRerankModel extends OpenShiftAiModel { + + /** + * Creates a new {@link OpenShiftAiRerankModel} with updated task settings if they differ from the existing ones. + * @param model the existing OpenShift AI rerank model + * @param taskSettings the new task settings to apply + * @return a new {@link OpenShiftAiRerankModel} with updated task settings, or the original model if settings are unchanged + */ + public static OpenShiftAiRerankModel of(OpenShiftAiRerankModel model, Map taskSettings) { + var requestTaskSettings = OpenShiftAiRerankTaskSettings.fromMap(taskSettings); + if (requestTaskSettings.isEmpty() || requestTaskSettings.equals(model.getTaskSettings())) { + return model; + } + return new OpenShiftAiRerankModel(model, OpenShiftAiRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + } + + public OpenShiftAiRerankModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + OpenShiftAiRerankServiceSettings.fromMap(serviceSettings, context), + OpenShiftAiRerankTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + OpenShiftAiRerankModel( + String inferenceEntityId, + TaskType taskType, + String service, + OpenShiftAiRerankServiceSettings serviceSettings, + OpenShiftAiRerankTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings) + ); + } + + private OpenShiftAiRerankModel(OpenShiftAiRerankModel model, OpenShiftAiRerankTaskSettings taskSettings) { + super(model, taskSettings); + } + + @Override + public OpenShiftAiRerankServiceSettings getServiceSettings() { + return (OpenShiftAiRerankServiceSettings) super.getServiceSettings(); + } + + @Override + public OpenShiftAiRerankTaskSettings getTaskSettings() { + return (OpenShiftAiRerankTaskSettings) super.getTaskSettings(); + } + + @Override + public ExecutableAction accept(OpenShiftAiActionVisitor visitor, Map taskSettings) { + return visitor.create(this, taskSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java new file mode 100644 index 0000000000000..f6cf673da33ce --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.rerank; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; + +/** + * Represents the settings for an OpenShift AI rerank service. + * This class encapsulates the model ID, URI, and rate limit settings for the OpenShift AI rerank service. + */ +public class OpenShiftAiRerankServiceSettings extends OpenShiftAiServiceSettings { + public static final String NAME = "openshift_ai_rerank_service_settings"; + + /** + * Creates a new instance of OpenShiftAiRerankServiceSettings from a map of settings. + * + * @param map the map containing the service settings + * @param context the context for parsing configuration settings + * @return a new instance of OpenShiftAiRerankServiceSettings + * @throws ValidationException if required fields are missing or invalid + */ + public static OpenShiftAiRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { + return fromMap( + map, + context, + commonServiceSettings -> new OpenShiftAiRerankServiceSettings( + commonServiceSettings.model(), + commonServiceSettings.uri(), + commonServiceSettings.rateLimitSettings() + ) + ); + } + + /** + * Constructs a new OpenShiftAiRerankServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public OpenShiftAiRerankServiceSettings(StreamInput in) throws IOException { + super(in); + } + + /** + * Constructs a new OpenShiftAiRerankServiceSettings with the specified model ID, URI, and rate limit settings. + * + * @param modelId the ID of the model + * @param uri the URI of the service + * @param rateLimitSettings the rate limit settings for the service + */ + public OpenShiftAiRerankServiceSettings(@Nullable String modelId, URI uri, @Nullable RateLimitSettings rateLimitSettings) { + super(modelId, uri, rateLimitSettings); + } + + /** + * Constructs a new OpenShiftAiRerankServiceSettings with the specified model ID and URL. + * The rate limit settings will be set to the default value. + * + * @param modelId the ID of the model + * @param url the URL of the service + */ + public OpenShiftAiRerankServiceSettings(@Nullable String modelId, String url, @Nullable RateLimitSettings rateLimitSettings) { + this(modelId, createUri(url), rateLimitSettings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OpenShiftAiRerankServiceSettings that = (OpenShiftAiRerankServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettings.java new file mode 100644 index 0000000000000..d0c5fbbf3a0f1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettings.java @@ -0,0 +1,188 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; + +/** + * Defines the task settings for the OpenShift AI rerank service. + */ +public class OpenShiftAiRerankTaskSettings implements TaskSettings { + + public static final String NAME = "openshift_ai_rerank_task_settings"; + public static final String RETURN_DOCUMENTS = "return_documents"; + public static final String TOP_N = "top_n"; + + private static final OpenShiftAiRerankTaskSettings EMPTY_SETTINGS = new OpenShiftAiRerankTaskSettings(null, null); + + /** + * Creates a new {@link OpenShiftAiRerankTaskSettings} from a map of settings. + * @param map the map of settings + * @return a constructed {@link OpenShiftAiRerankTaskSettings} + * @throws ValidationException if any of the settings are invalid + */ + public static OpenShiftAiRerankTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException); + Integer topN = extractOptionalPositiveInteger(map, TOP_N, ModelConfigurations.TASK_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return of(topN, returnDocuments); + } + + /** + * Creates a new {@link OpenShiftAiRerankTaskSettings} by using non-null fields from the request settings over the original settings. + * + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @return a constructed {@link OpenShiftAiRerankTaskSettings} + */ + public static OpenShiftAiRerankTaskSettings of( + OpenShiftAiRerankTaskSettings originalSettings, + OpenShiftAiRerankTaskSettings requestTaskSettings + ) { + if (requestTaskSettings.isEmpty() || originalSettings.equals(requestTaskSettings)) { + return originalSettings; + } + return new OpenShiftAiRerankTaskSettings( + requestTaskSettings.getTopN() != null ? requestTaskSettings.getTopN() : originalSettings.getTopN(), + requestTaskSettings.getReturnDocuments() != null + ? requestTaskSettings.getReturnDocuments() + : originalSettings.getReturnDocuments() + ); + } + + /** + * Creates a new {@link OpenShiftAiRerankTaskSettings} with the specified settings. + * + * @param topN the number of top documents to return + * @param returnDocuments whether to return the documents + * @return a constructed {@link OpenShiftAiRerankTaskSettings} + */ + public static OpenShiftAiRerankTaskSettings of(@Nullable Integer topN, @Nullable Boolean returnDocuments) { + if (topN == null && returnDocuments == null) { + return EMPTY_SETTINGS; + } + return new OpenShiftAiRerankTaskSettings(topN, returnDocuments); + } + + private final Integer topN; + private final Boolean returnDocuments; + + /** + * Constructs a new {@link OpenShiftAiRerankTaskSettings} by reading from a {@link StreamInput}. + * + * @param in the stream input to read from + * @throws IOException if an I/O error occurs + */ + public OpenShiftAiRerankTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalInt(), in.readOptionalBoolean()); + } + + /** + * Constructs a new {@link OpenShiftAiRerankTaskSettings} with the specified settings. + * + * @param topN the number of top documents to return + * @param doReturnDocuments whether to return the documents + */ + public OpenShiftAiRerankTaskSettings(@Nullable Integer topN, @Nullable Boolean doReturnDocuments) { + this.topN = topN; + this.returnDocuments = doReturnDocuments; + } + + @Override + public boolean isEmpty() { + return topN == null && returnDocuments == null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (topN != null) { + builder.field(TOP_N, topN); + } + if (returnDocuments != null) { + builder.field(RETURN_DOCUMENTS, returnDocuments); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + assert false : "should never be called when supportsVersion is used"; + return OpenShiftAiUtils.ML_INFERENCE_OPENSHIFT_AI_ADDED; + } + + @Override + public boolean supportsVersion(TransportVersion version) { + return OpenShiftAiUtils.supportsOpenShiftAi(version); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalInt(topN); + out.writeOptionalBoolean(returnDocuments); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OpenShiftAiRerankTaskSettings that = (OpenShiftAiRerankTaskSettings) o; + return Objects.equals(returnDocuments, that.returnDocuments) && Objects.equals(topN, that.topN); + } + + @Override + public int hashCode() { + return Objects.hash(returnDocuments, topN); + } + + public Integer getTopN() { + return topN; + } + + public Boolean getReturnDocuments() { + return returnDocuments; + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + OpenShiftAiRerankTaskSettings updatedSettings = OpenShiftAiRerankTaskSettings.fromMap(new HashMap<>(newSettings)); + return OpenShiftAiRerankTaskSettings.of(this, updatedSettings); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java new file mode 100644 index 0000000000000..131fa98f97563 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -0,0 +1,888 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionTestUtils; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.RerankingInferenceService; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModel; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION; +import static org.elasticsearch.inference.TaskType.COMPLETION; +import static org.elasticsearch.inference.TaskType.RERANK; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettingsTests.buildServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModelTests.createChatCompletionModel; +import static org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionServiceSettingsTests.getServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; + +public class OpenShiftAiServiceTests extends AbstractInferenceServiceTests { + private static final String API_KEY_FIELD_NAME = "api_key"; + private static final String INPUT_FIELD_NAME = "input"; + private static final String MODEL_FIELD_NAME = "model"; + private static final String URL_VALUE = "http://www.abc.com"; + private static final String MODEL_VALUE = "some_model"; + private static final String ROLE_VALUE = "user"; + private static final String API_KEY_VALUE = "test_api_key"; + private static final String INFERENCE_ID_VALUE = "id"; + private static final int DIMENSIONS_VALUE = 1536; + private static final int MAX_INPUT_TOKENS_VALUE = 512; + private static final String FIRST_PART_OF_INPUT_VALUE = "abc"; + private static final String SECOND_PART_OF_INPUT_VALUE = "def"; + + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + public OpenShiftAiServiceTests() { + super(createTestConfiguration()); + } + + public static TestConfiguration createTestConfiguration() { + return new TestConfiguration.Builder( + new CommonConfig( + TaskType.TEXT_EMBEDDING, + TaskType.SPARSE_EMBEDDING, + EnumSet.of(TEXT_EMBEDDING, COMPLETION, CHAT_COMPLETION, RERANK) + ) { + + @Override + protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + return OpenShiftAiServiceTests.createService(threadPool, clientManager); + } + + @Override + protected Map createServiceSettingsMap(TaskType taskType) { + return OpenShiftAiServiceTests.createServiceSettingsMap(taskType); + } + + @Override + protected Map createTaskSettingsMap() { + return new HashMap<>(); + } + + @Override + protected Map createSecretSettingsMap() { + return OpenShiftAiServiceTests.createSecretSettingsMap(); + } + + @Override + protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + OpenShiftAiServiceTests.assertModel(model, taskType, modelIncludesSecrets); + } + + @Override + protected EnumSet supportedStreamingTasks() { + return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); + } + } + ).enableUpdateModelTests(new UpdateModelConfiguration() { + @Override + protected OpenShiftAiEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { + return createInternalEmbeddingModel(similarityMeasure); + } + }).build(); + } + + private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + switch (taskType) { + case TEXT_EMBEDDING -> assertTextEmbeddingModel(model, modelIncludesSecrets); + case COMPLETION -> assertCompletionModel(model, modelIncludesSecrets); + case CHAT_COMPLETION -> assertChatCompletionModel(model, modelIncludesSecrets); + default -> fail(Strings.format("unexpected task type [%s]", taskType)); + } + } + + private static void assertTextEmbeddingModel(Model model, boolean modelIncludesSecrets) { + var openShiftAiModel = assertCommonModelFields(model, modelIncludesSecrets); + + assertThat(openShiftAiModel.getTaskType(), is(TaskType.TEXT_EMBEDDING)); + assertThat(model, instanceOf(OpenShiftAiEmbeddingsModel.class)); + var embeddingsModel = (OpenShiftAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(DIMENSIONS_VALUE)); + assertThat(embeddingsModel.getServiceSettings().similarity(), is(SimilarityMeasure.COSINE)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(MAX_INPUT_TOKENS_VALUE)); + } + + private static OpenShiftAiModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) { + assertThat(model, instanceOf(OpenShiftAiModel.class)); + + var openShiftAiModel = (OpenShiftAiModel) model; + assertThat(openShiftAiModel.getServiceSettings().modelId(), is(MODEL_VALUE)); + assertThat(openShiftAiModel.getServiceSettings().uri.toString(), is(URL_VALUE)); + assertThat(openShiftAiModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + + if (modelIncludesSecrets) { + assertThat(openShiftAiModel.getSecretSettings().apiKey(), is(new SecureString(API_KEY_VALUE.toCharArray()))); + } + + return openShiftAiModel; + } + + private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) { + var openShiftAiModel = assertCommonModelFields(model, modelIncludesSecrets); + assertThat(openShiftAiModel.getTaskType(), is(TaskType.COMPLETION)); + } + + private static void assertChatCompletionModel(Model model, boolean modelIncludesSecrets) { + var openShiftAiModel = assertCommonModelFields(model, modelIncludesSecrets); + assertThat(openShiftAiModel.getTaskType(), is(TaskType.CHAT_COMPLETION)); + } + + public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + return new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); + } + + private static Map createServiceSettingsMap(TaskType taskType) { + Map settingsMap = new HashMap<>(Map.of(ServiceFields.URL, URL_VALUE, ServiceFields.MODEL_ID, MODEL_VALUE)); + + if (taskType == TaskType.TEXT_EMBEDDING) { + settingsMap.putAll( + Map.of( + ServiceFields.SIMILARITY, + SimilarityMeasure.COSINE.toString(), + ServiceFields.DIMENSIONS, + DIMENSIONS_VALUE, + ServiceFields.MAX_INPUT_TOKENS, + MAX_INPUT_TOKENS_VALUE + ) + ); + } + + return settingsMap; + } + + private static Map createSecretSettingsMap() { + return new HashMap<>(Map.of(API_KEY_FIELD_NAME, API_KEY_VALUE)); + } + + private static OpenShiftAiEmbeddingsModel createInternalEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure) { + return new OpenShiftAiEmbeddingsModel( + INFERENCE_ID_VALUE, + TaskType.TEXT_EMBEDDING, + OpenShiftAiService.NAME, + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_VALUE, + URL_VALUE, + DIMENSIONS_VALUE, + similarityMeasure, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(10_000), + true + ), + createRandomChunkingSettings(), + new DefaultSecretSettings(new SecureString(API_KEY_VALUE.toCharArray())) + ); + } + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityExecutors()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + var chunkingSettings = createRandomChunkingSettings(); + try (var service = createService()) { + ActionListener modelVerificationActionListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(OpenShiftAiEmbeddingsModel.class)); + + var embeddingsModel = (OpenShiftAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is(URL_VALUE)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings().asMap(), is(chunkingSettings.asMap())); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); + }, e -> fail("parse request should not fail " + e.getMessage())); + + service.parseRequestConfig( + INFERENCE_ID_VALUE, + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getServiceSettingsMap(MODEL_VALUE, URL_VALUE), + chunkingSettings.asMap(), + getSecretSettingsMap(API_KEY_VALUE) + ), + modelVerificationActionListener + ); + } + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + try (var service = createService()) { + ActionListener modelVerificationActionListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(OpenShiftAiEmbeddingsModel.class)); + + var embeddingsModel = (OpenShiftAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is(URL_VALUE)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), is(ChunkingSettingsBuilder.DEFAULT_SETTINGS)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); + }, e -> fail("parse request should not fail " + e.getMessage())); + + service.parseRequestConfig( + INFERENCE_ID_VALUE, + TaskType.TEXT_EMBEDDING, + getRequestConfigMap(getServiceSettingsMap(MODEL_VALUE, URL_VALUE), getSecretSettingsMap(API_KEY_VALUE)), + modelVerificationActionListener + ); + } + } + + public void testParseRequestConfig_WithoutModelId_Success() throws IOException { + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(m -> { + assertThat(m, instanceOf(OpenShiftAiChatCompletionModel.class)); + + var chatCompletionModel = (OpenShiftAiChatCompletionModel) m; + + assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is(URL_VALUE)); + assertThat(chatCompletionModel.getServiceSettings().modelId(), is(nullValue())); + assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); + + }, e -> fail("parse request should not fail " + e.getMessage())); + + service.parseRequestConfig( + INFERENCE_ID_VALUE, + TaskType.CHAT_COMPLETION, + getRequestConfigMap(getServiceSettingsMap(null, URL_VALUE), getSecretSettingsMap(API_KEY_VALUE)), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_WithoutUrl_ThrowsException() throws IOException { + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + m -> fail("Expected exception, but got model: " + m), + exception -> { + assertThat(exception, instanceOf(ValidationException.class)); + assertThat( + exception.getMessage(), + is("Validation Failed: 1: [service_settings] does not contain the required setting [url];") + ); + } + ); + + service.parseRequestConfig( + INFERENCE_ID_VALUE, + TaskType.CHAT_COMPLETION, + getRequestConfigMap(getServiceSettingsMap(MODEL_VALUE, null), getSecretSettingsMap(API_KEY_VALUE)), + modelVerificationListener + ); + } + } + + public void testUnifiedCompletionInfer() throws Exception { + // The escapes are because the streaming response must be on a single line + String responseJson = """ + data: {\ + "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26",\ + "choices": [{\ + "delta": {\ + "content": "Deep",\ + "function_call": null,\ + "refusal": null,\ + "role": "assistant",\ + "tool_calls": null\ + },\ + "finish_reason": null,\ + "index": 0,\ + "logprobs": null\ + }\ + ],\ + "created": 1750158492,\ + "model": "llama3.2:3b",\ + "object": "chat.completion.chunk",\ + "service_tier": null,\ + "system_fingerprint": "fp_ollama",\ + "usage": null\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = createChatCompletionModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of( + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), ROLE_VALUE, null, null) + ) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(XContentHelper.stripWhitespace(""" + { + "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26", + "choices": [{ + "delta": { + "content": "Deep", + "role": "assistant" + }, + "index": 0 + } + ], + "model": "llama3.2:3b", + "object": "chat.completion.chunk" + } + """)); + } + } + + public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception { + String responseJson = """ + { + "detail": "Not Found" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); + var latch = new CountDownLatch(1); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of( + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), ROLE_VALUE, null, null) + ) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> { + try (var builder = XContentFactory.jsonBuilder()) { + var t = unwrapCause(e); + assertThat(t, isA(UnifiedChatCompletionException.class)); + ((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + assertThat(json, is(Strings.format(XContentHelper.stripWhitespace(""" + { + "error" : { + "code" : "not_found", + "message" : "Resource not found at [%s] for request from inference entity id [inferenceEntityId] status \ + [404]. Error message: [{\\n \\"detail\\": \\"Not Found\\"\\n}\\n]", + "type" : "openshift_ai_error" + } + }"""), getUrl(webServer)))); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }), latch::countDown) + ); + assertThat(latch.await(30, TimeUnit.SECONDS), is(true)); + } + } + + public void testMidStreamUnifiedCompletionError() throws Exception { + String responseJson = """ + data: {"error": {"message": "400: Invalid value: Model 'llama3.12:3b' not found"}} + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(XContentHelper.stripWhitespace(""" + { + "error": { + "message": "Received an error response for request from inference entity id [inferenceEntityId].\ + Error message: [{\\"error\\": {\\"message\\": \\"400: Invalid value: Model 'llama3.12:3b' not found\\"}}]", + "type": "openshift_ai_error" + } + } + """)); + } + + public void testInfer_StreamRequest() throws Exception { + String responseJson = """ + data: {\ + "id": "chatcmpl-2c57e3888b1a4e80a0c708889546288e",\ + "object": "chat.completion.chunk",\ + "created": 1760082951,\ + "model": "llama-31-8b-instruct",\ + "choices": [{\ + "index": 0,\ + "delta": {\ + "role": "assistant",\ + "content": "Deep"\ + },\ + "logprobs": null,\ + "finish_reason": null\ + }\ + ]\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + streamCompletion().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"Deep"}]}"""); + } + + private void testStreamError(String expectedResponse) throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of( + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), ROLE_VALUE, null, null) + ) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> { + e = unwrapCause(e); + assertThat(e, isA(UnifiedChatCompletionException.class)); + try (var builder = XContentFactory.jsonBuilder()) { + ((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + + assertThat(json, is(expectedResponse)); + } + }); + } + } + + public void testInfer_StreamRequest_ErrorResponse() { + String responseJson = """ + { + "detail": "Not Found" + }"""; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + + var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion); + assertThat(e.status(), equalTo(RestStatus.NOT_FOUND)); + assertThat(e.getMessage(), equalTo(Strings.format(""" + Resource not found at [%s] for request from inference entity id [inferenceEntityId] status [404]. Error message: [{ + "detail": "Not Found" + }]""", getUrl(webServer)))); + } + + public void testInfer_StreamRequestRetry() throws Exception { + webServer.enqueue(new MockResponse().setResponseCode(503).setBody(""" + { + "error": { + "message": "server busy" + } + }""")); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(""" + data: {\ + "id": "chatcmpl-2c57e3888b1a4e80a0c708889546288e",\ + "object": "chat.completion.chunk",\ + "created": 1760082951,\ + "model": "llama-31-8b-instruct",\ + "choices": [{\ + "index": 0,\ + "delta": {\ + "role": "assistant",\ + "content": "Deep"\ + },\ + "logprobs": null,\ + "finish_reason": null\ + }\ + ]\ + } + + """)); + + streamCompletion().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"Deep"}]}"""); + } + + public void testSupportsStreaming() throws IOException { + try (var service = new OpenShiftAiService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); + assertThat(service.canStream(TaskType.ANY), is(false)); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException { + try (var service = createService()) { + var secretSettings = getSecretSettingsMap(API_KEY_VALUE); + secretSettings.put("extra_key", "value"); + + var config = getRequestConfigMap(getEmbeddingsServiceSettingsMap(), secretSettings); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Configuration contains settings [{extra_key=value}] unknown to the [openshift_ai] service") + ); + } + ); + + service.parseRequestConfig(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, config, modelVerificationListener); + } + } + + public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { + var model = OpenShiftAiEmbeddingsModelTests.createModel( + getUrl(webServer), + API_KEY_VALUE, + MODEL_VALUE, + 1234, + false, + DIMENSIONS_VALUE, + null + ); + + testChunkedInfer(model); + } + + public void testChunkedInfer_ChunkingSettingsSet() throws IOException { + var model = OpenShiftAiEmbeddingsModelTests.createModel( + getUrl(webServer), + API_KEY_VALUE, + MODEL_VALUE, + 1234, + false, + DIMENSIONS_VALUE, + createRandomChunkingSettings() + ); + + testChunkedInfer(model); + } + + public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + + String responseJson = """ + { + "id": "embd-45e6d99b97a645c0af96653598069cd9", + "object": "list", + "created": 1760085467, + "model": "gritlm-7b", + "data": [ + { + "index": 0, + "object": "embedding", + "embedding": [ + 0.0089111328125, + -0.007049560546875 + ] + }, + { + "index": 1, + "object": "embedding", + "embedding": [ + -0.008544921875, + -0.0230712890625 + ] + } + ], + "usage": { + "prompt_tokens": 7, + "total_tokens": 7, + "completion_tokens": 0, + "prompt_tokens_details": null + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(new ChunkInferenceInput(FIRST_PART_OF_INPUT_VALUE), new ChunkInferenceInput(SECOND_PART_OF_INPUT_VALUE)), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + + assertThat(results, hasSize(2)); + { + assertThat(results.getFirst(), Matchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.getFirst(); + assertThat(floatResult.chunks(), hasSize(1)); + assertThat(floatResult.chunks().getFirst().embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + assertThat( + Arrays.equals( + new float[] { 0.0089111328125f, -0.007049560546875f }, + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().getFirst().embedding()).values() + ), + is(true) + ); + } + { + assertThat(results.get(1), Matchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.get(1); + assertThat(floatResult.chunks(), hasSize(1)); + assertThat(floatResult.chunks().getFirst().embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + assertThat( + Arrays.equals( + new float[] { -0.008544921875f, -0.0230712890625f }, + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().getFirst().embedding()).values() + ), + is(true) + ); + } + + assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests().getFirst().getUri().getQuery(), is(nullValue())); + assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), + is(Strings.format("Bearer %s", API_KEY_VALUE)) + ); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(FIRST_PART_OF_INPUT_VALUE, SECOND_PART_OF_INPUT_VALUE))); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); + } + } + + public void testGetConfiguration() throws Exception { + try (var service = createService()) { + String content = XContentHelper.stripWhitespace(""" + { + "service": "openshift_ai", + "name": "OpenShift AI", + "task_types": ["text_embedding", "rerank", "completion", "chat_completion"], + "configurations": { + "api_key": { + "description": "API Key for the provider you're connecting to.", + "label": "API Key", + "required": true, + "sensitive": true, + "updatable": true, + "type": "str", + "supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"] + }, + "model_id": { + "description": "The name of the model to use for the inference task.", + "label": "Model ID", + "required": false, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"] + }, + "rate_limit.requests_per_minute": { + "description": "Minimize the number of rate limit errors.", + "label": "Rate Limit", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"] + }, + "url": { + "description": "The URL endpoint to use for the requests.", + "label": "URL", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"] + } + } + } + """); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + boolean humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = service.getConfiguration(); + assertToXContentEquivalent( + originalBytes, + toXContent(serviceConfiguration, XContentType.JSON, humanReadable), + XContentType.JSON + ); + } + } + + private InferenceEventsAssertion streamCompletion() throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = OpenShiftAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of(FIRST_PART_OF_INPUT_VALUE), + true, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + return InferenceEventsAssertion.assertThat(listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT)).hasFinishedStream(); + } + } + + private OpenShiftAiService createService() { + return new OpenShiftAiService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + + private Map getRequestConfigMap(Map serviceSettings, Map secretSettings) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings)); + } + + private static Map getEmbeddingsServiceSettingsMap() { + return buildServiceSettingsMap(INFERENCE_ID_VALUE, URL_VALUE, SimilarityMeasure.COSINE.toString(), null, null, null); + } + + @Override + public InferenceService createInferenceService() { + return createService(); + } + + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat(rerankingInferenceService.rerankerWindowSize(MODEL_VALUE), is(2800)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java new file mode 100644 index 0000000000000..de8650ca1b9a6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -0,0 +1,1022 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.action; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockRequest; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests; +import org.elasticsearch.xpack.inference.InputTypeTests; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankModelTests; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankTaskSettings; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests.buildExpectationRerank; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModelTests.createCompletionModel; +import static org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModelTests.createModel; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; + +public class OpenShiftAiActionCreatorTests extends ESTestCase { + + // Field names + private static final String N_FIELD_NAME = "n"; + private static final String STREAM_FIELD_NAME = "stream"; + private static final String MESSAGES_FIELD_NAME = "messages"; + private static final String ROLE_FIELD_NAME = "role"; + private static final String CONTENT_FIELD_NAME = "content"; + + private static final String DOCUMENTS_FIELD_NAME = "documents"; + private static final String MODEL_FIELD_NAME = "model"; + private static final String QUERY_FIELD_NAME = "query"; + private static final String TOP_N_FIELD_NAME = "top_n"; + private static final String RETURN_DOCUMENTS_FIELD_NAME = "return_documents"; + + private static final String INPUT_FIELD_NAME = "input"; + + // Test values + private static final String API_KEY_VALUE = "test_api_key"; + private static final String MODEL_VALUE = "some_model"; + private static final String QUERY_VALUE = "popular name"; + private static final String ROLE_VALUE = "user"; + private static final String INPUT_ENTRY_VALUE = "abcd"; + private static final List INPUT_VALUE = List.of(INPUT_ENTRY_VALUE); + private static final String INPUT_TO_TRUNCATE = "super long input"; + private static final List EMBEDDINGS_VALUE = List.of(new float[] { 0.0123F, -0.0123F }); + private static final List DOCUMENTS_VALUE = List.of("Luke"); + private static final int N_VALUE = 1; + private static final boolean RETURN_DOCUMENTS_DEFAULT_VALUE = true; + private static final boolean RETURN_DOCUMENTS_OVERRIDDEN_VALUE = false; + private static final int TOP_N_DEFAULT_VALUE = 2; + private static final int TOP_N_OVERRIDDEN_VALUE = 1; + private static final List RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS = List.of( + new RankedDocsResultsTests.RerankExpectation(Map.of("text", "awgawgawgawg", "index", 1, "relevance_score", 0.9921875f)), + new RankedDocsResultsTests.RerankExpectation(Map.of("text", "awdawdawda", "index", 0, "relevance_score", 0.4921875f)) + ); + private static final List RERANK_EXPECTATIONS_NO_TEXT_SINGLE_RESULT = List.of( + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", 0.4921875f)) + ); + + // Settings with no retries + private static final Settings NO_RETRY_SETTINGS = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + + // Mock server and client manager + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityExecutors()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "embd-45e6d99b97a645c0af96653598069cd9", + "object": "list", + "created": 1760085467, + "model": "gritlm-7b", + "data": [ + { + "index": 0, + "object": "embedding", + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "usage": { + "prompt_tokens": 7, + "total_tokens": 7, + "completion_tokens": 0, + "prompt_tokens_details": null + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); + var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(INPUT_VALUE, InputTypeTests.randomWithNull()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationFloat(EMBEDDINGS_VALUE))); + assertThat(webServer.requests(), hasSize(1)); + + var request = webServer.requests().getFirst(); + assertThat(request.getUri().getQuery(), is(nullValue())); + assertContentTypeAndAuthorization(request); + + var requestMap = entityAsMap(request.getBody()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(INPUT_VALUE)); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); + } + } + + private static void assertContentTypeAndAuthorization(MockRequest request) { + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters())); + assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo(Strings.format("Bearer %s", API_KEY_VALUE))); + } + + public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "embd-45e6d99b97a645c0af96653598069cd9", + "object": "list", + "created": 1760085467, + "model": "gritlm-7b", + "data_does_not_exist": [ + { + "index": 0, + "object": "embedding", + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "usage": { + "prompt_tokens": 7, + "total_tokens": 7, + "completion_tokens": 0, + "prompt_tokens_details": null + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); + var action = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)).create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(INPUT_VALUE, InputTypeTests.randomWithNull()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT) + ); + + var failureCauseMessage = "Required [data]"; + assertThat( + thrownException.getMessage(), + is( + Strings.format( + "Failed to send OpenShift AI text_embedding request from inference entity id [inferenceEntityId]. Cause: %s", + failureCauseMessage + ) + ) + ); + assertThat(thrownException.getCause().getMessage(), is(failureCauseMessage)); + } + } + + public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "chatcmpl-921d2eb8f3bc46dd8f4cb0502a4608a7", + "object": "chat.completion", + "created": 1760082857, + "model": "llama-31-8b-instruct", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "reasoning_content": null, + "content": "Hello there, how may I assist you today?", + "tool_calls": [] + }, + "logprobs": null, + "finish_reason": "length", + "stop_reason": null + } + ], + "usage": { + "prompt_tokens": 40, + "total_tokens": 140, + "completion_tokens": 100, + "prompt_tokens_details": null + }, + "prompt_logprobs": null, + "kv_transfer_params": null + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createCompletionModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); + var action = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)).create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(INPUT_VALUE), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); + assertThat(webServer.requests(), hasSize(1)); + + var request = webServer.requests().getFirst(); + assertContentTypeAndAuthorization(request); + + var requestMap = entityAsMap(request.getBody()); + assertThat( + requestMap.get(MESSAGES_FIELD_NAME), + is(List.of(Map.of(ROLE_FIELD_NAME, ROLE_VALUE, CONTENT_FIELD_NAME, INPUT_ENTRY_VALUE))) + ); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); + assertThat(requestMap.get(N_FIELD_NAME), is(N_VALUE)); + assertThat(requestMap.get(STREAM_FIELD_NAME), is(false)); + } + } + + public void testCreate_OpenShiftAiChatCompletionModel_FailsFromInvalidResponseFormat() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "chatcmpl-921d2eb8f3bc46dd8f4cb0502a4608a7", + "object": "chat.completion", + "created": 1760082857, + "model": "llama-31-8b-instruct", + "not_choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "reasoning_content": null, + "content": "Hello there, how may I assist you today?", + "tool_calls": [] + }, + "logprobs": null, + "finish_reason": "length", + "stop_reason": null + } + ], + "usage": { + "prompt_tokens": 40, + "total_tokens": 140, + "completion_tokens": 100, + "prompt_tokens_details": null + }, + "prompt_logprobs": null, + "kv_transfer_params": null + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createCompletionModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); + var action = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)).create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(INPUT_VALUE), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT) + ); + var failureCauseMessage = "Required [choices]"; + assertThat( + thrownException.getMessage(), + is( + Strings.format( + "Failed to send OpenShift AI completion request from inference entity id [inferenceEntityId]. Cause: %s", + failureCauseMessage + ) + ) + ); + assertThat(thrownException.getCause().getMessage(), is(failureCauseMessage)); + } + } + + public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + var contentTooLargeErrorMessage = """ + This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;\ + 0 for the completion). Please reduce your prompt; or completion length."""; + + String responseJsonContentTooLarge = Strings.format(""" + { + "error": { + "message": "%s", + "type": "content_too_large", + "param": null, + "code": null + } + } + """, contentTooLargeErrorMessage); + + String responseJson = """ + { + "id": "embd-45e6d99b97a645c0af96653598069cd9", + "object": "list", + "created": 1760085467, + "model": "gritlm-7b", + "data": [ + { + "index": 0, + "object": "embedding", + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "usage": { + "prompt_tokens": 7, + "total_tokens": 7, + "completion_tokens": 0, + "prompt_tokens_details": null + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(413).setBody(responseJsonContentTooLarge)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); + var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(INPUT_VALUE, InputTypeTests.randomWithNull()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationFloat(EMBEDDINGS_VALUE))); + assertThat(webServer.requests(), hasSize(2)); + { + assertThat(webServer.requests().getFirst().getUri().getQuery(), is(nullValue())); + assertContentTypeAndAuthorization(webServer.requests().getFirst()); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(INPUT_VALUE)); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); + } + { + assertThat(webServer.requests().get(1).getUri().getQuery(), is(nullValue())); + assertContentTypeAndAuthorization(webServer.requests().get(1)); + + var requestMap = entityAsMap(webServer.requests().get(1).getBody()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_ENTRY_VALUE.substring(0, 2)))); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); + } + } + } + + public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCodeWithContentTooLargeMessage() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + var contentTooLargeErrorMessage = """ + This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;\ + 0 for the completion). Please reduce your prompt; or completion length."""; + + String responseJsonContentTooLarge = Strings.format(""" + { + "error": { + "message": "%s", + "type": "content_too_large", + "param": null, + "code": null + } + } + """, contentTooLargeErrorMessage); + + var responseJson = """ + { + "id": "embd-45e6d99b97a645c0af96653598069cd9", + "object": "list", + "created": 1760085467, + "model": "gritlm-7b", + "data": [ + { + "index": 0, + "object": "embedding", + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "usage": { + "prompt_tokens": 7, + "total_tokens": 7, + "completion_tokens": 0, + "prompt_tokens_details": null + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(400).setBody(responseJsonContentTooLarge)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); + var action = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)).create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(INPUT_VALUE, InputTypeTests.randomWithNull()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationFloat(EMBEDDINGS_VALUE))); + assertThat(webServer.requests(), hasSize(2)); + + { + var firstRequest = webServer.requests().getFirst(); + assertContentTypeAndAuthorization(firstRequest); + var firstRequestMap = entityAsMap(firstRequest.getBody()); + assertThat(firstRequestMap, aMapWithSize(2)); + assertThat(firstRequestMap.get(INPUT_FIELD_NAME), is(INPUT_VALUE)); + assertThat(firstRequestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); + } + { + var secondRequest = webServer.requests().get(1); + assertContentTypeAndAuthorization(secondRequest); + var secondRequestMap = entityAsMap(secondRequest.getBody()); + assertThat(secondRequestMap, aMapWithSize(2)); + assertThat(secondRequestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_ENTRY_VALUE.substring(0, 2)))); + assertThat(secondRequestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); + } + } + } + + public void testExecute_TruncatesInputBeforeSending() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + var responseJson = """ + { + "id": "embd-45e6d99b97a645c0af96653598069cd9", + "object": "list", + "created": 1760085467, + "model": "gritlm-7b", + "data": [ + { + "index": 0, + "object": "embedding", + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "usage": { + "prompt_tokens": 7, + "total_tokens": 7, + "completion_tokens": 0, + "prompt_tokens_details": null + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + // truncated to 1 token = 3 characters + var model = createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE, 1); + var action = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)).create(model); + + var listener = new PlainActionFuture(); + action.execute( + new EmbeddingsInput(List.of(INPUT_TO_TRUNCATE), InputTypeTests.randomWithNull()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationFloat(EMBEDDINGS_VALUE))); + assertThat(webServer.requests(), hasSize(1)); + + var request = webServer.requests().getFirst(); + assertThat(request.getUri().getQuery(), is(nullValue())); + assertContentTypeAndAuthorization(request); + + var requestMap = entityAsMap(request.getBody()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_TO_TRUNCATE.substring(0, 3)))); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); + } + } + + public void testCreate_OpenShiftAiRerankModel_WithTaskSettings() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "rerank-d300256dd02b4c63b8a2bc34dcdad845", + "model": "bge-reranker-v2-m3", + "usage": { + "total_tokens": 30 + }, + "results": [ + { + "index": 1, + "document": { + "text": "awgawgawgawg" + }, + "relevance_score": 0.9921875 + }, + { + "index": 0, + "document": { + "text": "awdawdawda" + }, + "relevance_score": 0.4921875 + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = OpenShiftAiRerankModelTests.createModel( + getUrl(webServer), + API_KEY_VALUE, + MODEL_VALUE, + TOP_N_DEFAULT_VALUE, + RETURN_DOCUMENTS_DEFAULT_VALUE + ); + var actionCreator = new OpenShiftAiActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create(model, null); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new QueryAndDocsInputs(QUERY_VALUE, DOCUMENTS_VALUE, null, null, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS))); + } + assertRerankActionCreator(TOP_N_DEFAULT_VALUE, RETURN_DOCUMENTS_DEFAULT_VALUE, MODEL_VALUE); + } + + public void testCreate_OpenShiftAiRerankModel_WithOverriddenTaskSettings() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "rerank-d300256dd02b4c63b8a2bc34dcdad845", + "model": "bge-reranker-v2-m3", + "usage": { + "total_tokens": 10 + }, + "results": [ + { + "index": 0, + "relevance_score": 0.4921875 + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = OpenShiftAiRerankModelTests.createModel( + getUrl(webServer), + API_KEY_VALUE, + MODEL_VALUE, + TOP_N_DEFAULT_VALUE, + RETURN_DOCUMENTS_DEFAULT_VALUE + ); + var actionCreator = new OpenShiftAiActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create( + model, + new HashMap<>( + Map.of( + OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, + RETURN_DOCUMENTS_OVERRIDDEN_VALUE, + OpenShiftAiRerankTaskSettings.TOP_N, + TOP_N_OVERRIDDEN_VALUE + ) + ) + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new QueryAndDocsInputs(QUERY_VALUE, DOCUMENTS_VALUE, null, null, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_NO_TEXT_SINGLE_RESULT))); + } + assertRerankActionCreator(TOP_N_OVERRIDDEN_VALUE, RETURN_DOCUMENTS_OVERRIDDEN_VALUE, MODEL_VALUE); + } + + public void testCreate_OpenShiftAiRerankModel_NoTaskSettingsWithModelId() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "rerank-d300256dd02b4c63b8a2bc34dcdad845", + "model": "bge-reranker-v2-m3", + "usage": { + "total_tokens": 30 + }, + "results": [ + { + "index": 1, + "document": { + "text": "awgawgawgawg" + }, + "relevance_score": 0.9921875 + }, + { + "index": 0, + "document": { + "text": "awdawdawda" + }, + "relevance_score": 0.4921875 + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE, null, null); + var actionCreator = new OpenShiftAiActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create(model, null); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new QueryAndDocsInputs(QUERY_VALUE, DOCUMENTS_VALUE, null, null, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS))); + } + assertRerankActionCreator(null, null, MODEL_VALUE); + } + + public void testCreate_OpenShiftAiRerankModel_NoTaskSettings_NoModelId() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "rerank-d300256dd02b4c63b8a2bc34dcdad845", + "model": "bge-reranker-v2-m3", + "usage": { + "total_tokens": 30 + }, + "results": [ + { + "index": 1, + "document": { + "text": "awgawgawgawg" + }, + "relevance_score": 0.9921875 + }, + { + "index": 0, + "document": { + "text": "awdawdawda" + }, + "relevance_score": 0.4921875 + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY_VALUE, null, null, null); + var actionCreator = new OpenShiftAiActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create(model, null); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new QueryAndDocsInputs(QUERY_VALUE, DOCUMENTS_VALUE, null, null, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS))); + } + assertRerankActionCreator(null, null, null); + } + + public void testCreate_OpenShiftAiRerankModel_NoTaskSettings_WithRequestParameters() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "rerank-d300256dd02b4c63b8a2bc34dcdad845", + "model": "bge-reranker-v2-m3", + "usage": { + "total_tokens": 30 + }, + "results": [ + { + "index": 1, + "document": { + "text": "awgawgawgawg" + }, + "relevance_score": 0.9921875 + }, + { + "index": 0, + "document": { + "text": "awdawdawda" + }, + "relevance_score": 0.4921875 + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE, null, null); + var actionCreator = new OpenShiftAiActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create(model, null); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new QueryAndDocsInputs(QUERY_VALUE, DOCUMENTS_VALUE, RETURN_DOCUMENTS_DEFAULT_VALUE, TOP_N_DEFAULT_VALUE, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS))); + } + assertRerankActionCreator(TOP_N_DEFAULT_VALUE, RETURN_DOCUMENTS_DEFAULT_VALUE, MODEL_VALUE); + } + + public void testCreate_OpenShiftAiRerankModel_WithTaskSettings_WithRequestParametersPrioritized() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "rerank-d300256dd02b4c63b8a2bc34dcdad845", + "model": "bge-reranker-v2-m3", + "usage": { + "total_tokens": 10 + }, + "results": [ + { + "index": 0, + "relevance_score": 0.4921875 + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = OpenShiftAiRerankModelTests.createModel( + getUrl(webServer), + API_KEY_VALUE, + MODEL_VALUE, + TOP_N_DEFAULT_VALUE, + RETURN_DOCUMENTS_DEFAULT_VALUE + ); + var actionCreator = new OpenShiftAiActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create(model, null); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new QueryAndDocsInputs(QUERY_VALUE, DOCUMENTS_VALUE, RETURN_DOCUMENTS_OVERRIDDEN_VALUE, TOP_N_OVERRIDDEN_VALUE, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_NO_TEXT_SINGLE_RESULT))); + } + assertRerankActionCreator(TOP_N_OVERRIDDEN_VALUE, RETURN_DOCUMENTS_OVERRIDDEN_VALUE, MODEL_VALUE); + } + + public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "rerank-d300256dd02b4c63b8a2bc34dcdad845", + "model": "bge-reranker-v2-m3", + "usage": { + "total_tokens": 30 + }, + "not_results": [ + { + "index": 1, + "document": { + "text": "awgawgawgawg" + }, + "relevance_score": 0.9921875 + }, + { + "index": 0, + "document": { + "text": "awdawdawda" + }, + "relevance_score": 0.4921875 + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = OpenShiftAiRerankModelTests.createModel( + getUrl(webServer), + API_KEY_VALUE, + MODEL_VALUE, + TOP_N_DEFAULT_VALUE, + RETURN_DOCUMENTS_DEFAULT_VALUE + ); + var actionCreator = new OpenShiftAiActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create(model, null); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new QueryAndDocsInputs(QUERY_VALUE, DOCUMENTS_VALUE, null, null, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to send OpenShift AI rerank request from inference entity id [inferenceEntityId]. Cause: Required [results]") + ); + } + } + + private void assertRerankActionCreator( + @Nullable Integer expectedTopN, + @Nullable Boolean expectedReturnDocuments, + @Nullable String expectedModel + ) throws IOException { + assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests().getFirst().getUri().getQuery(), is(nullValue())); + assertContentTypeAndAuthorization(webServer.requests().getFirst()); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + int fieldCount = 2; + assertThat(requestMap.get(DOCUMENTS_FIELD_NAME), is(DOCUMENTS_VALUE)); + assertThat(requestMap.get(QUERY_FIELD_NAME), is(QUERY_VALUE)); + if (expectedTopN != null) { + assertThat(requestMap.get(TOP_N_FIELD_NAME), is(expectedTopN)); + fieldCount++; + } + if (expectedReturnDocuments != null) { + assertThat(requestMap.get(RETURN_DOCUMENTS_FIELD_NAME), is(expectedReturnDocuments)); + fieldCount++; + } + if (expectedModel != null) { + assertThat(requestMap.get(MODEL_FIELD_NAME), is(expectedModel)); + fieldCount++; + } + assertThat(requestMap, aMapWithSize(fieldCount)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java new file mode 100644 index 0000000000000..927d01cc2f6c0 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.completion; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class OpenShiftAiChatCompletionModelTests extends ESTestCase { + + private static final String MODEL_VALUE = "model_name"; + private static final String API_KEY_VALUE = "test_api_key"; + private static final String URL_VALUE = "http://www.abc.com"; + private static final String ALTERNATE_MODEL_VALUE = "different_model"; + + public static OpenShiftAiChatCompletionModel createCompletionModel(String url, String apiKey, String modelName) { + return createModelWithTaskType(url, apiKey, modelName, TaskType.COMPLETION); + } + + public static OpenShiftAiChatCompletionModel createChatCompletionModel(String url, String apiKey, String modelName) { + return createModelWithTaskType(url, apiKey, modelName, TaskType.CHAT_COMPLETION); + } + + public static OpenShiftAiChatCompletionModel createModelWithTaskType(String url, String apiKey, String modelName, TaskType taskType) { + return new OpenShiftAiChatCompletionModel( + "inferenceEntityId", + taskType, + "service", + new OpenShiftAiChatCompletionServiceSettings(modelName, url, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public void testOverrideWith_UnifiedCompletionRequest_KeepsSameModelId() { + var model = createCompletionModel(URL_VALUE, API_KEY_VALUE, MODEL_VALUE); + var overriddenModel = OpenShiftAiChatCompletionModel.of(model, MODEL_VALUE); + + assertThat(overriddenModel, is(sameInstance(model))); + } + + public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() { + var model = createCompletionModel(URL_VALUE, API_KEY_VALUE, MODEL_VALUE); + var overriddenModel = OpenShiftAiChatCompletionModel.of(model, ALTERNATE_MODEL_VALUE); + + assertThat(overriddenModel.getServiceSettings().modelId(), is(ALTERNATE_MODEL_VALUE)); + } + + public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { + var model = createCompletionModel(URL_VALUE, API_KEY_VALUE, MODEL_VALUE); + var overriddenModel = OpenShiftAiChatCompletionModel.of(model, null); + + assertThat(overriddenModel, is(sameInstance(model))); + } + + public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() { + var model = createCompletionModel(URL_VALUE, API_KEY_VALUE, null); + var overriddenModel = OpenShiftAiChatCompletionModel.of(model, null); + + assertThat(overriddenModel, is(sameInstance(model))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java new file mode 100644 index 0000000000000..fe850ae98b945 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java @@ -0,0 +1,160 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.completion; + +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class OpenShiftAiChatCompletionResponseHandlerTests extends ESTestCase { + private static final String URL_VALUE = "http://www.abc.com"; + private static final String INFERENCE_ID = "id"; + private final OpenShiftAiChatCompletionResponseHandler responseHandler = new OpenShiftAiChatCompletionResponseHandler( + "chat completions", + (a, b) -> mock() + ); + + public void testFailNotFound() throws IOException { + var responseJson = XContentHelper.stripWhitespace(""" + { + "detail": "Not Found" + } + """); + + var errorJson = invalidResponseJson(responseJson, 404); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(Strings.format(""" + { + "error" : { + "code" : "not_found", + "message" : "Resource not found at [%s] for request from inference entity id [%s] \ + status [404]. Error message: [{\\"detail\\":\\"Not Found\\"}]", + "type" : "openshift_ai_error" + } + }""", URL_VALUE, INFERENCE_ID)))); + } + + public void testFailBadRequest() throws IOException { + var responseJson = XContentHelper.stripWhitespace(""" + { + "object": "error", + "message": "[{'type': 'missing', 'loc': ('body', 'messages'), 'msg': 'Field required', \ + 'input': {'model': 'llama-31-8b-instruct', 'messages': [{'role': 'user', 'content': 'What is deep learning?'}], \ + 'max_tokens': 2, 'stream': True}}]", + "type": "Bad Request", + "param": null, + "code": 400 + } + """); + + var errorJson = invalidResponseJson(responseJson, 400); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(Strings.format(""" + { + "error": { + "code": "bad_request", + "message": "Received a bad request status code for request from inference entity id [%s] status [400]. Error message: \ + [{\\"object\\":\\"error\\",\\"message\\":\\"[{'type': 'missing', 'loc': ('body', 'messages'), 'msg': 'Field required', \ + 'input': {'model': 'llama-31-8b-instruct', 'messages': [{'role': 'user', 'content': 'What is deep learning?'}], \ + 'max_tokens': 2, 'stream': True}}]\\",\\"type\\":\\"Bad Request\\",\\"param\\":null,\\"code\\":400}]", + "type": "openshift_ai_error" + } + } + """, INFERENCE_ID)))); + } + + public void testFailValidationWithInvalidJson() throws IOException { + var responseJson = """ + what? this isn't a json + """; + + var errorJson = invalidResponseJson(responseJson, 500); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(Strings.format(""" + { + "error": { + "code": "bad_request", + "message": "Received a server error status code for request from inference entity id [%s] status [500]. \ + Error message: [what? this isn't a json\\n]", + "type": "openshift_ai_error" + } + } + """, INFERENCE_ID)))); + } + + private String invalidResponseJson(String responseJson, int statusCode) throws IOException { + var exception = invalidResponse(responseJson, statusCode); + assertThat(exception, isA(RetryException.class)); + assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class)); + return toJson((UnifiedChatCompletionException) unwrapCause(exception)); + } + + private Exception invalidResponse(String responseJson, int statusCode) { + return expectThrows( + RetryException.class, + () -> responseHandler.validateResponse( + mock(), + mock(), + mockRequest(), + new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + } + + private static Request mockRequest() throws URISyntaxException { + var request = mock(Request.class); + when(request.getInferenceEntityId()).thenReturn(INFERENCE_ID); + when(request.isStreaming()).thenReturn(true); + when(request.getURI()).thenReturn(new URI(URL_VALUE)); + return request; + } + + private static HttpResponse mockErrorResponse(int statusCode) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var response = mock(HttpResponse.class); + when(response.getStatusLine()).thenReturn(statusLine); + + return response; + } + + private String toJson(UnifiedChatCompletionException e) throws IOException { + try (var builder = XContentFactory.jsonBuilder()) { + e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + } + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..83d4fb6cb2538 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java @@ -0,0 +1,238 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.net.URI; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class OpenShiftAiChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase< + OpenShiftAiChatCompletionServiceSettings> { + + private static final String MODEL_VALUE = "some_model"; + private static final String URL_VALUE = "http://www.abc.com"; + private static final String INVALID_URL_VALUE = "^^^"; + private static final int RATE_LIMIT = 2; + + public void testFromMap_AllFields_Success() { + var serviceSettings = OpenShiftAiChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_VALUE, + ServiceFields.URL, + URL_VALUE, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is(new OpenShiftAiChatCompletionServiceSettings(MODEL_VALUE, URL_VALUE, new RateLimitSettings(RATE_LIMIT))) + ); + } + + public void testFromMap_MissingModelId_Success() { + var serviceSettings = OpenShiftAiChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + URL_VALUE, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings, is(new OpenShiftAiChatCompletionServiceSettings(null, URL_VALUE, new RateLimitSettings(RATE_LIMIT)))); + } + + public void testFromMap_MissingUrl_ThrowsException() { + testFromMap_InvalidUrl( + Map.of( + ServiceFields.MODEL_ID, + MODEL_VALUE, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + "Validation Failed: 1: [service_settings] does not contain the required setting [url];" + ); + } + + public void testFromMap_InvalidUrl_ThrowsException() { + testFromMap_InvalidUrl( + Map.of( + ServiceFields.URL, + INVALID_URL_VALUE, + ServiceFields.MODEL_ID, + MODEL_VALUE, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + """ + Validation Failed: 1: [service_settings] Invalid url [^^^] received for field [url]. \ + Error: unable to parse url [^^^]. Reason: Illegal character in path;""" + ); + } + + public void testFromMap_EmptyUrl_ThrowsException() { + testFromMap_InvalidUrl( + Map.of( + ServiceFields.URL, + "", + ServiceFields.MODEL_ID, + MODEL_VALUE, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + "Validation Failed: 1: [service_settings] Invalid value empty string. [url] must be a non-empty string;" + ); + } + + private static void testFromMap_InvalidUrl(Map serviceSettingsMap, String expectedErrorMessage) { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiChatCompletionServiceSettings.fromMap(new HashMap<>(serviceSettingsMap), ConfigurationParseContext.PERSISTENT) + ); + + assertThat(thrownException.getMessage(), containsString(expectedErrorMessage)); + } + + public void testFromMap_MissingRateLimit_Success() { + var serviceSettings = OpenShiftAiChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_VALUE, ServiceFields.URL, URL_VALUE)), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings, is(new OpenShiftAiChatCompletionServiceSettings(MODEL_VALUE, URL_VALUE, null))); + } + + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = OpenShiftAiChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_VALUE, + ServiceFields.URL, + URL_VALUE, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + var expected = XContentHelper.stripWhitespace(Strings.format(""" + { + "model_id": "%s", + "url": "%s", + "rate_limit": { + "requests_per_minute": 2 + } + } + """, MODEL_VALUE, URL_VALUE)); + + assertThat(xContentResult, is(expected)); + } + + public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException { + var serviceSettings = OpenShiftAiChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.URL, URL_VALUE)), + ConfigurationParseContext.PERSISTENT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + var expected = XContentHelper.stripWhitespace(Strings.format(""" + { + "url": "%s", + "rate_limit": { + "requests_per_minute": 3000 + } + } + """, URL_VALUE)); + assertThat(xContentResult, is(expected)); + } + + @Override + protected Writeable.Reader instanceReader() { + return OpenShiftAiChatCompletionServiceSettings::new; + } + + @Override + protected OpenShiftAiChatCompletionServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected OpenShiftAiChatCompletionServiceSettings mutateInstance(OpenShiftAiChatCompletionServiceSettings instance) + throws IOException { + String modelId = instance.modelId(); + URI uri = instance.uri(); + RateLimitSettings rateLimitSettings = instance.rateLimitSettings(); + + switch (between(0, 2)) { + case 0 -> modelId = randomValueOtherThan(modelId, () -> randomAlphaOfLengthOrNull(8)); + case 1 -> uri = randomValueOtherThan(uri, () -> ServiceUtils.createUri(randomAlphaOfLength(15))); + case 2 -> rateLimitSettings = randomValueOtherThan(rateLimitSettings, RateLimitSettingsTests::createRandom); + default -> throw new AssertionError("Illegal randomisation branch"); + } + return new OpenShiftAiChatCompletionServiceSettings(modelId, uri, rateLimitSettings); + } + + @Override + protected OpenShiftAiChatCompletionServiceSettings mutateInstanceForVersion( + OpenShiftAiChatCompletionServiceSettings instance, + TransportVersion version + ) { + return instance; + } + + private static OpenShiftAiChatCompletionServiceSettings createRandom() { + var modelId = randomAlphaOfLengthOrNull(8); + var url = randomAlphaOfLength(15); + return new OpenShiftAiChatCompletionServiceSettings(modelId, ServiceUtils.createUri(url), RateLimitSettingsTests.createRandom()); + } + + public static Map getServiceSettingsMap(String model, String url) { + var map = new HashMap(); + + map.put(ServiceFields.MODEL_ID, model); + map.put(ServiceFields.URL, url); + + return map; + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModelTests.java new file mode 100644 index 0000000000000..5a78b35c86231 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModelTests.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.embeddings; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +public class OpenShiftAiEmbeddingsModelTests extends ESTestCase { + + public static OpenShiftAiEmbeddingsModel createModel(String url, String apiKey, @Nullable String modelId) { + return createModel(url, apiKey, modelId, 1234); + } + + public static OpenShiftAiEmbeddingsModel createModel(String url, String apiKey, @Nullable String modelId, int maxInputTokens) { + return createModel(url, apiKey, modelId, maxInputTokens, false, 1536, null); + } + + public static OpenShiftAiEmbeddingsModel createModel( + String url, + String apiKey, + @Nullable String modelId, + @Nullable Integer maxInputTokens, + @Nullable Boolean dimensionsSetByUser, + @Nullable Integer dimensions, + @Nullable ChunkingSettings chunkingSettings + ) { + return new OpenShiftAiEmbeddingsModel( + "inferenceEntityId", + TaskType.TEXT_EMBEDDING, + "service", + new OpenShiftAiEmbeddingsServiceSettings( + modelId, + url, + dimensions, + SimilarityMeasure.DOT_PRODUCT, + maxInputTokens, + null, + dimensionsSetByUser + ), + chunkingSettings, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..a56e6bb135fc8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java @@ -0,0 +1,657 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.net.URI; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class OpenShiftAiEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase { + private static final String MODEL_VALUE = "some_model"; + private static final String CORRECT_URL_VALUE = "http://www.abc.com"; + private static final String INVALID_URL_VALUE = "^^^"; + private static final int DIMENSIONS_VALUE = 384; + private static final SimilarityMeasure SIMILARITY_MEASURE_VALUE = SimilarityMeasure.DOT_PRODUCT; + private static final int MAX_INPUT_TOKENS_VALUE = 128; + private static final int RATE_LIMIT_VALUE = 2; + + public void testFromMap_AllFields_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + true + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_VALUE, + CORRECT_URL_VALUE, + DIMENSIONS_VALUE, + SIMILARITY_MEASURE_VALUE, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(RATE_LIMIT_VALUE), + true + ) + ) + ); + } + + public void testFromMap_NoModelId_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + null, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + false + ), + ConfigurationParseContext.PERSISTENT + ); + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + null, + CORRECT_URL_VALUE, + DIMENSIONS_VALUE, + SIMILARITY_MEASURE_VALUE, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(RATE_LIMIT_VALUE), + false + ) + ) + ); + } + + public void testFromMap_NoUrl_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + null, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + false + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [url];") + ); + } + + public void testFromMap_EmptyUrl_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + "", + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + false + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value empty string. [url] must be a non-empty string;") + ); + } + + public void testFromMap_InvalidUrl_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + INVALID_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + false + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat(thrownException.getMessage(), containsString(Strings.format(""" + Validation Failed: 1: [service_settings] Invalid url [%s] received for field [url]. \ + Error: unable to parse url [%s]. Reason: Illegal character in path;""", INVALID_URL_VALUE, INVALID_URL_VALUE))); + } + + public void testFromMap_NoSimilarity_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + CORRECT_URL_VALUE, + null, + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + false + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_VALUE, + CORRECT_URL_VALUE, + DIMENSIONS_VALUE, + null, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(RATE_LIMIT_VALUE), + false + ) + ) + ); + } + + public void testFromMap_InvalidSimilarity_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + CORRECT_URL_VALUE, + "by_size", + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + false + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat(thrownException.getMessage(), containsString(""" + Validation Failed: 1: [service_settings] Invalid value [by_size] received. \ + [similarity] must be one of [cosine, dot_product, l2_norm];""")); + } + + public void testFromMap_NoDimensions_SetByUserFalse_Persistent_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + null, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + false + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_VALUE, + CORRECT_URL_VALUE, + null, + SIMILARITY_MEASURE_VALUE, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(RATE_LIMIT_VALUE), + false + ) + ) + ); + } + + public void testFromMap_Persistent_WithDimensions_SetByUserFalse_Persistent_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + false + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_VALUE, + CORRECT_URL_VALUE, + DIMENSIONS_VALUE, + SIMILARITY_MEASURE_VALUE, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(RATE_LIMIT_VALUE), + false + ) + ) + ); + } + + public void testFromMap_WithDimensions_SetByUserNull_Persistent_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + null + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [dimensions_set_by_user];") + ); + } + + public void testFromMap_NoDimensions_SetByUserNull_Request_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + null, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + null + ), + ConfigurationParseContext.REQUEST + ); + + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_VALUE, + CORRECT_URL_VALUE, + null, + SIMILARITY_MEASURE_VALUE, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(RATE_LIMIT_VALUE), + false + ) + ) + ); + } + + public void testFromMap_WithDimensions_SetByUserNull_Request_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + null + ), + ConfigurationParseContext.REQUEST + ); + + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_VALUE, + CORRECT_URL_VALUE, + DIMENSIONS_VALUE, + SIMILARITY_MEASURE_VALUE, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(RATE_LIMIT_VALUE), + true + ) + ) + ); + } + + public void testFromMap_WithDimensions_SetByUserTrue_Request_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + true + ), + ConfigurationParseContext.REQUEST + ) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not allow the setting [dimensions_set_by_user];") + ); + } + + public void testFromMap_ZeroDimensions_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + 0, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + false + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_NegativeDimensions_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + -10, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + false + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [-10]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_NoInputTokens_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + null, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + false + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_VALUE, + CORRECT_URL_VALUE, + DIMENSIONS_VALUE, + SIMILARITY_MEASURE_VALUE, + null, + new RateLimitSettings(RATE_LIMIT_VALUE), + false + ) + ) + ); + } + + public void testFromMap_ZeroInputTokens_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + 0, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + false + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_NegativeInputTokens_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + -10, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), + false + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [-10]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_NoRateLimit_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + null, + false + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_VALUE, + CORRECT_URL_VALUE, + DIMENSIONS_VALUE, + SIMILARITY_MEASURE_VALUE, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(3000), + false + ) + ) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new OpenShiftAiEmbeddingsServiceSettings( + MODEL_VALUE, + CORRECT_URL_VALUE, + DIMENSIONS_VALUE, + SIMILARITY_MEASURE_VALUE, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(3), + false + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(XContentHelper.stripWhitespace(Strings.format(""" + { + "model_id": "%s", + "url": "%s", + "rate_limit": { + "requests_per_minute": 3 + }, + "dimensions": 384, + "similarity": "dot_product", + "max_input_tokens": 128, + "dimensions_set_by_user": false + } + """, MODEL_VALUE, CORRECT_URL_VALUE)))); + } + + @Override + protected Writeable.Reader instanceReader() { + return OpenShiftAiEmbeddingsServiceSettings::new; + } + + @Override + protected OpenShiftAiEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected OpenShiftAiEmbeddingsServiceSettings mutateInstance(OpenShiftAiEmbeddingsServiceSettings instance) throws IOException { + String modelId = instance.modelId(); + URI uri = instance.uri(); + Integer dimensions = instance.dimensions(); + SimilarityMeasure similarity = instance.similarity(); + Integer maxInputTokens = instance.maxInputTokens(); + RateLimitSettings rateLimitSettings = instance.rateLimitSettings(); + Boolean dimensionsSetByUser = instance.dimensionsSetByUser(); + + switch (between(0, 6)) { + case 0 -> modelId = randomValueOtherThan(modelId, () -> randomAlphaOfLengthOrNull(8)); + case 1 -> uri = randomValueOtherThan(uri, () -> ServiceUtils.createUri(randomAlphaOfLength(15))); + case 2 -> dimensions = randomValueOtherThan(dimensions, () -> randomBoolean() ? randomIntBetween(32, 256) : null); + case 3 -> similarity = randomValueOtherThan(similarity, () -> randomBoolean() ? randomFrom(SimilarityMeasure.values()) : null); + case 4 -> maxInputTokens = randomValueOtherThan(maxInputTokens, () -> randomBoolean() ? randomIntBetween(128, 256) : null); + case 5 -> rateLimitSettings = randomValueOtherThan(rateLimitSettings, RateLimitSettingsTests::createRandom); + case 6 -> dimensionsSetByUser = randomValueOtherThan(dimensionsSetByUser, ESTestCase::randomBoolean); + default -> throw new AssertionError("Illegal randomisation branch"); + } + return new OpenShiftAiEmbeddingsServiceSettings( + modelId, + uri, + dimensions, + similarity, + maxInputTokens, + rateLimitSettings, + dimensionsSetByUser + ); + } + + private static OpenShiftAiEmbeddingsServiceSettings createRandom() { + var modelId = randomAlphaOfLength(8); + var url = randomAlphaOfLength(15); + var similarityMeasure = randomBoolean() ? randomFrom(SimilarityMeasure.values()) : null; + var dimensions = randomBoolean() ? randomIntBetween(32, 256) : null; + var maxInputTokens = randomBoolean() ? randomIntBetween(128, 256) : null; + boolean dimensionsSetByUser = randomBoolean(); + return new OpenShiftAiEmbeddingsServiceSettings( + modelId, + url, + dimensions, + similarityMeasure, + maxInputTokens, + RateLimitSettingsTests.createRandom(), + dimensionsSetByUser + ); + } + + public static HashMap buildServiceSettingsMap( + @Nullable String modelId, + @Nullable String url, + @Nullable String similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + @Nullable HashMap rateLimitSettings, + @Nullable Boolean dimensionsSetByUser + ) { + HashMap result = new HashMap<>(); + if (modelId != null) { + result.put(ServiceFields.MODEL_ID, modelId); + } + if (url != null) { + result.put(ServiceFields.URL, url); + } + if (similarity != null) { + result.put(ServiceFields.SIMILARITY, similarity); + } + if (dimensions != null) { + result.put(ServiceFields.DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + result.put(ServiceFields.MAX_INPUT_TOKENS, maxInputTokens); + } + if (rateLimitSettings != null) { + result.put(RateLimitSettings.FIELD_NAME, rateLimitSettings); + } + if (dimensionsSetByUser != null) { + result.put(ServiceFields.DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + } + return result; + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..264c3bababb7c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java @@ -0,0 +1,106 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; + +import java.io.IOException; +import java.util.ArrayList; + +import static org.hamcrest.Matchers.is; + +public class OpenShiftAiChatCompletionRequestEntityTests extends ESTestCase { + private static final String ROLE_VALUE = "user"; + + public void testSerializationWithModelIdStreaming() throws IOException { + testSerialization("modelId", true, """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "model": "modelId", + "n": 1, + "stream": true + } + """); + } + + public void testSerializationWithModelIdNonStreaming() throws IOException { + testSerialization("modelId", false, """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "model": "modelId", + "n": 1, + "stream": false + } + """); + } + + public void testSerializationWithoutModelIdStreaming() throws IOException { + testSerialization(null, true, """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "n": 1, + "stream": true + } + """); + } + + public void testSerializationWithoutModelIdNonStreaming() throws IOException { + testSerialization(null, false, """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "n": 1, + "stream": false + } + """); + } + + private static void testSerialization(String modelId, boolean isStreaming, String expectedJson) throws IOException { + var message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE_VALUE, + null, + null + ); + + var messageList = new ArrayList(); + messageList.add(message); + + var unifiedRequest = UnifiedCompletionRequest.of(messageList); + var unifiedChatInput = new UnifiedChatInput(unifiedRequest, isStreaming); + + var entity = new OpenShiftAiChatCompletionRequestEntity(unifiedChatInput, modelId); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertThat(Strings.toString(builder), is(XContentHelper.stripWhitespace(expectedJson))); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java new file mode 100644 index 0000000000000..2f6a78076e21c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.completion; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModelTests; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.sameInstance; + +public class OpenShiftAiChatCompletionRequestTests extends ESTestCase { + + // Completion field names + private static final String N_FIELD_NAME = "n"; + private static final String STREAM_FIELD_NAME = "stream"; + private static final String MESSAGES_FIELD_NAME = "messages"; + private static final String ROLE_FIELD_NAME = "role"; + private static final String CONTENT_FIELD_NAME = "content"; + private static final String MODEL_FIELD_NAME = "model"; + + // Test values + private static final String URL_VALUE = "http://www.abc.com"; + private static final String MODEL_VALUE = "some_model"; + private static final String ROLE_VALUE = "user"; + private static final String API_KEY_VALUE = "test_api_key"; + + public void testCreateRequest_WithStreaming() throws IOException { + String input = randomAlphaOfLength(15); + var request = createRequest(MODEL_VALUE, URL_VALUE, API_KEY_VALUE, input, true); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(request.getURI().toString(), is(URL_VALUE)); + assertThat(requestMap.get(STREAM_FIELD_NAME), is(true)); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); + assertThat(requestMap.get(N_FIELD_NAME), is(1)); + assertThat(requestMap.get(MESSAGES_FIELD_NAME), is(List.of(Map.of(ROLE_FIELD_NAME, ROLE_VALUE, CONTENT_FIELD_NAME, input)))); + assertThat(requestMap, aMapWithSize(4)); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is(Strings.format("Bearer %s", API_KEY_VALUE))); + } + + public void testTruncate_DoesNotReduceInputTextSize() { + String input = randomAlphaOfLength(5); + var request = createRequest(MODEL_VALUE, URL_VALUE, API_KEY_VALUE, input, true); + assertThat(request.truncate(), is(sameInstance(request))); + } + + public void testTruncationInfo_ReturnsNull() { + var request = createRequest(MODEL_VALUE, URL_VALUE, API_KEY_VALUE, randomAlphaOfLength(5), true); + assertThat(request.getTruncationInfo(), is(nullValue())); + } + + public static OpenShiftAiChatCompletionRequest createRequest(String modelId, String url, String apiKey, String input, boolean stream) { + var chatCompletionModel = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(url, apiKey, modelId); + return new OpenShiftAiChatCompletionRequest(new UnifiedChatInput(List.of(input), ROLE_VALUE, stream), chatCompletionModel); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..506dc072937c8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntityTests.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class OpenShiftAiEmbeddingsRequestEntityTests extends ESTestCase { + + private static final String MODEL_VALUE = "some_model"; + private static final List INPUT_VALUE = List.of("some input"); + private static final int DIMENSIONS_VALUE = 100; + + public void testXContent_DoesNotWriteDimensionsWhenNullAndSetByUserIsFalse() throws IOException { + var entity = new OpenShiftAiEmbeddingsRequestEntity(INPUT_VALUE, MODEL_VALUE, null, false); + testXContent_DoesNotWriteDimensions(entity); + } + + public void testXContent_DoesNotWriteDimensionsWhenNotSetByUser() throws IOException { + var entity = new OpenShiftAiEmbeddingsRequestEntity(INPUT_VALUE, MODEL_VALUE, DIMENSIONS_VALUE, false); + testXContent_DoesNotWriteDimensions(entity); + } + + public void testXContent_DoesNotWriteDimensionsWhenNull_EvenIfSetByUserIsTrue() throws IOException { + var entity = new OpenShiftAiEmbeddingsRequestEntity(INPUT_VALUE, MODEL_VALUE, null, true); + testXContent_DoesNotWriteDimensions(entity); + } + + private static void testXContent_DoesNotWriteDimensions(OpenShiftAiEmbeddingsRequestEntity entity) throws IOException { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(XContentHelper.stripWhitespace(""" + { + "input": ["some input"], + "model": "some_model" + } + """))); + } + + public void testXContent_DoesNotWriteModelWhenItIsNull() throws IOException { + var entity = new OpenShiftAiEmbeddingsRequestEntity(INPUT_VALUE, null, null, false); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(XContentHelper.stripWhitespace(""" + { + "input": ["some input"] + } + """))); + } + + public void testXContent_WritesDimensionsWhenNonNull_AndSetByUserIsTrue() throws IOException { + var entity = new OpenShiftAiEmbeddingsRequestEntity(INPUT_VALUE, MODEL_VALUE, DIMENSIONS_VALUE, true); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(XContentHelper.stripWhitespace(""" + { + "input": ["some input"], + "model": "some_model", + "dimensions": 100 + } + """))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..87defb2301f6e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java @@ -0,0 +1,134 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.embeddings; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModelTests; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; + +public class OpenShiftAiEmbeddingsRequestTests extends ESTestCase { + + private static final String INPUT_FIELD_NAME = "input"; + private static final String MODEL_FIELD_NAME = "model"; + + private static final String MODEL_VALUE = "some_model"; + private static final String INPUT_VALUE = "ABCD"; + private static final String URL_VALUE = "http://www.abc.com"; + private static final String API_KEY_VALUE = "test_api_key"; + private static final int DIMENSIONS_VALUE = 384; + + public void testCreateRequest_NoDimensions_DimensionsSetByUserFalse_Success() throws IOException { + testCreateRequest_Success(null, false, null); + } + + public void testCreateRequest_NoDimensions_DimensionsSetByUserTrue_Success() throws IOException { + testCreateRequest_Success(null, true, null); + } + + public void testCreateRequest_WithDimensions_DimensionsSetByUserFalse_Success() throws IOException { + testCreateRequest_Success(DIMENSIONS_VALUE, false, null); + } + + public void testCreateRequest_WithDimensions_DimensionsSetByUserTrue_Success() throws IOException { + testCreateRequest_Success(DIMENSIONS_VALUE, true, DIMENSIONS_VALUE); + } + + private void testCreateRequest_Success(Integer dimensions, boolean dimensionsSetByUser, Integer expectedDimensions) throws IOException { + var request = createRequest(dimensions, dimensionsSetByUser); + var httpRequest = request.createHttpRequest(); + var httpPost = validateRequestUrlAndContentType(httpRequest); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_VALUE))); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); + assertThat(requestMap.get(ServiceFields.DIMENSIONS), is(expectedDimensions)); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is(Strings.format("Bearer %s", API_KEY_VALUE))); + } + + public void testCreateRequest_NoModel_Success() throws IOException { + var request = createRequest(null, false, null); + var httpRequest = request.createHttpRequest(); + var httpPost = validateRequestUrlAndContentType(httpRequest); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_VALUE))); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(nullValue())); + assertThat(requestMap.get(ServiceFields.DIMENSIONS), is(nullValue())); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is(Strings.format("Bearer %s", API_KEY_VALUE))); + + } + + public void testTruncate_ReducesInputTextSizeByHalf() throws IOException { + var request = createRequest(null, false); + var truncatedRequest = request.truncate(); + + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_VALUE.substring(0, INPUT_VALUE.length() / 2)))); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); + + } + + public void testIsTruncated_ReturnsTrue() { + var request = createRequest(null, false); + assertThat(request.getTruncationInfo()[0], is(false)); + + var truncatedRequest = request.truncate(); + assertThat(truncatedRequest.getTruncationInfo()[0], is(true)); + } + + private HttpPost validateRequestUrlAndContentType(HttpRequest request) { + assertThat(request.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) request.httpRequestBase(); + assertThat(httpPost.getURI().toString(), is(URL_VALUE)); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters())); + return httpPost; + } + + private static OpenShiftAiEmbeddingsRequest createRequest(Integer dimensions, Boolean dimensionsSetByUser) { + return createRequest(dimensions, dimensionsSetByUser, MODEL_VALUE); + } + + private static OpenShiftAiEmbeddingsRequest createRequest(Integer dimensions, Boolean dimensionsSetByUser, String modelId) { + var embeddingsModel = OpenShiftAiEmbeddingsModelTests.createModel( + URL_VALUE, + API_KEY_VALUE, + modelId, + null, + dimensionsSetByUser, + dimensions, + null + ); + return new OpenShiftAiEmbeddingsRequest( + TruncatorTests.createTruncator(), + new Truncator.TruncationResult(List.of(INPUT_VALUE), new boolean[] { false }), + embeddingsModel + ); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java new file mode 100644 index 0000000000000..920ceba6696da --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.rarank; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace; +import static org.hamcrest.Matchers.is; + +public class OpenShiftAIRerankRequestEntityTests extends ESTestCase { + private static final List DOCUMENT_VALUE = List.of("some document"); + private static final String QUERY_VALUE = "some query"; + private static final String MODEL_VALUE = "some_model"; + private static final Integer TOP_N_VALUE = 8; + private static final Boolean RETURN_DOCUMENTS_VALUE = true; + + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new OpenShiftAIRerankRequestEntity(MODEL_VALUE, QUERY_VALUE, DOCUMENT_VALUE, RETURN_DOCUMENTS_VALUE, TOP_N_VALUE); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = Strings.toString(builder); + String expected = """ + { + "model": "some_model", + "query": "some query", + "documents": ["some document"], + "top_n": 8, + "return_documents": true + } + """; + assertThat(stripWhitespace(expected), is(result)); + } + + public void testXContent_WritesMinimalFields() throws IOException { + var entity = new OpenShiftAIRerankRequestEntity(null, QUERY_VALUE, DOCUMENT_VALUE, null, null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = Strings.toString(builder); + String expected = """ + { + "query": "some query", + "documents": ["some document"] + } + """; + assertThat(stripWhitespace(expected), is(result)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java new file mode 100644 index 0000000000000..dba3c26eab97f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java @@ -0,0 +1,153 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.rarank; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankModelTests; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class OpenShiftAiRerankRequestTests extends ESTestCase { + + private static final String TOP_N_FIELD_NAME = "top_n"; + private static final String RETURN_DOCUMENTS_FIELD_NAME = "return_documents"; + private static final String MODEL_FIELD_NAME = "model"; + private static final String DOCUMENTS_FIELD_NAME = "documents"; + private static final String QUERY_FIELD_NAME = "query"; + + private static final String URL_VALUE = "http://www.abc.com"; + private static final String DOCUMENT_VALUE = "some document"; + private static final String QUERY_VALUE = "some query"; + private static final String MODEL_VALUE = "some_model"; + private static final Integer TOP_N_VALUE = 8; + private static final Boolean RETURN_DOCUMENTS_VALUE = false; + private static final String API_KEY_VALUE = "test_api_key"; + + public void testCreateRequest_WithMinimalFieldsSet() throws IOException { + testCreateRequest(createRequest(null, null, null, null, null), null, null, null); + } + + public void testCreateRequest_TaskSettingsWithTopN() throws IOException { + testCreateRequest(createRequest(TOP_N_VALUE, null, null, null, null), TOP_N_VALUE, null, null); + } + + public void testCreateRequest_TaskSettingsWithReturnDocuments() throws IOException { + testCreateRequest(createRequest(null, RETURN_DOCUMENTS_VALUE, null, null, null), null, RETURN_DOCUMENTS_VALUE, null); + } + + public void testCreateRequest_TaskSettingsWithModelId() throws IOException { + testCreateRequest(createRequest(null, null, MODEL_VALUE, null, null), null, null, MODEL_VALUE); + } + + public void testCreateRequest_TaskSettingsWithAllFields() throws IOException { + testCreateRequest( + createRequest(TOP_N_VALUE, RETURN_DOCUMENTS_VALUE, MODEL_VALUE, null, null), + TOP_N_VALUE, + RETURN_DOCUMENTS_VALUE, + MODEL_VALUE + ); + } + + public void testCreateRequest_RequestSettingsOverrideTaskSettings() throws IOException { + testCreateRequest( + createRequest(1, true, MODEL_VALUE, TOP_N_VALUE, RETURN_DOCUMENTS_VALUE), + TOP_N_VALUE, + RETURN_DOCUMENTS_VALUE, + MODEL_VALUE + ); + } + + public void testCreateRequest_RequestSettingsOverrideNullTaskSettings() throws IOException { + testCreateRequest( + createRequest(null, null, MODEL_VALUE, TOP_N_VALUE, RETURN_DOCUMENTS_VALUE), + TOP_N_VALUE, + RETURN_DOCUMENTS_VALUE, + MODEL_VALUE + ); + } + + public void testCreateRequest_ReturnDocumentsFromTaskSettings_TopNFromRequest() throws IOException { + testCreateRequest( + createRequest(null, RETURN_DOCUMENTS_VALUE, MODEL_VALUE, TOP_N_VALUE, null), + TOP_N_VALUE, + RETURN_DOCUMENTS_VALUE, + MODEL_VALUE + ); + } + + public void testCreateRequest_TopNFromTaskSettings_ReturnDocumentsFromRequest() throws IOException { + testCreateRequest( + createRequest(TOP_N_VALUE, null, MODEL_VALUE, null, RETURN_DOCUMENTS_VALUE), + TOP_N_VALUE, + RETURN_DOCUMENTS_VALUE, + MODEL_VALUE + ); + } + + private void testCreateRequest( + OpenShiftAiRerankRequest request, + Integer expectedTopN, + Boolean expectedReturnDocuments, + String expectedModelId + ) throws IOException { + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(Strings.format("Bearer %s", API_KEY_VALUE))); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap.get(DOCUMENTS_FIELD_NAME), is(List.of(DOCUMENT_VALUE))); + assertThat(requestMap.get(QUERY_FIELD_NAME), is(QUERY_VALUE)); + int itemsCount = 2; + if (expectedTopN != null) { + assertThat(requestMap.get(TOP_N_FIELD_NAME), is(expectedTopN)); + itemsCount++; + } + if (expectedReturnDocuments != null) { + assertThat(requestMap.get(RETURN_DOCUMENTS_FIELD_NAME), is(expectedReturnDocuments)); + itemsCount++; + } + if (expectedModelId != null) { + assertThat(requestMap.get(MODEL_FIELD_NAME), is(expectedModelId)); + itemsCount++; + } + assertThat(requestMap, aMapWithSize(itemsCount)); + } + + private static OpenShiftAiRerankRequest createRequest( + @Nullable Integer taskSettingsTopN, + @Nullable Boolean taskSettingsReturnDocuments, + @Nullable String modelId, + @Nullable Integer requestTopN, + @Nullable Boolean requestReturnDocuments + ) { + var rerankModel = OpenShiftAiRerankModelTests.createModel( + URL_VALUE, + API_KEY_VALUE, + modelId, + taskSettingsTopN, + taskSettingsReturnDocuments + ); + return new OpenShiftAiRerankRequest(QUERY_VALUE, List.of(DOCUMENT_VALUE), requestReturnDocuments, requestTopN, rerankModel); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java new file mode 100644 index 0000000000000..f553399d90c98 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java @@ -0,0 +1,108 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.rerank; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS; +import static org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankTaskSettings.TOP_N; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class OpenShiftAiRerankModelTests extends ESTestCase { + + private static final String URL_VALUE = "http://www.abc.com"; + private static final String API_KEY_VALUE = "test_api_key"; + private static final String MODEL_VALUE = "some_model"; + private static final int TOP_N_VALUE = 4; + private static final boolean RETURN_DOCUMENTS_VALUE = false; + + public static OpenShiftAiRerankModel createModel( + String url, + String apiKey, + @Nullable String modelId, + @Nullable Integer topN, + @Nullable Boolean doReturnDocuments + ) { + return new OpenShiftAiRerankModel( + "inferenceEntityId", + TaskType.RERANK, + "service", + new OpenShiftAiRerankServiceSettings(modelId, url, null), + new OpenShiftAiRerankTaskSettings(topN, doReturnDocuments), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public void testOverrideWith_SameParams_KeepsSameModel() { + testOverrideWith_KeepsSameModel(buildTaskSettingsMap(2, true)); + } + + public void testOverrideWith_EmptyParams_KeepsSameModel() { + testOverrideWith_KeepsSameModel(buildTaskSettingsMap(null, null)); + } + + private static void testOverrideWith_KeepsSameModel(Map taskSettings) { + var model = createModel(URL_VALUE, API_KEY_VALUE, MODEL_VALUE, 2, true); + var overriddenModel = OpenShiftAiRerankModel.of(model, taskSettings); + assertThat(overriddenModel, is(sameInstance(model))); + } + + public void testOverrideWith_DifferentParams_OverridesAllTaskSettings() { + testOverrideWith_DifferentParams(buildTaskSettingsMap(TOP_N_VALUE, RETURN_DOCUMENTS_VALUE), TOP_N_VALUE, RETURN_DOCUMENTS_VALUE); + } + + public void testOverrideWith_DifferentParams_OverridesOnlyReturnDocuments() { + testOverrideWith_DifferentParams(buildTaskSettingsMap(null, RETURN_DOCUMENTS_VALUE), 2, RETURN_DOCUMENTS_VALUE); + } + + public void testOverrideWith_DifferentParams_OverridesOnlyTopN() { + testOverrideWith_DifferentParams(buildTaskSettingsMap(TOP_N_VALUE, null), TOP_N_VALUE, true); + } + + public void testOverrideWith_DifferentParams_OverridesNullValues() { + var model = createModel(URL_VALUE, API_KEY_VALUE, MODEL_VALUE, null, null); + var overriddenModel = OpenShiftAiRerankModel.of(model, buildTaskSettingsMap(TOP_N_VALUE, RETURN_DOCUMENTS_VALUE)); + + assertThat(overriddenModel.getTaskSettings().getTopN(), is(TOP_N_VALUE)); + assertThat(overriddenModel.getTaskSettings().getReturnDocuments(), is(RETURN_DOCUMENTS_VALUE)); + } + + private static void testOverrideWith_DifferentParams( + Map taskSettings, + int expectedTopN, + boolean expectedReturnDocuments + ) { + var model = createModel(URL_VALUE, API_KEY_VALUE, MODEL_VALUE, 2, true); + var overriddenModel = OpenShiftAiRerankModel.of(model, taskSettings); + + assertThat(overriddenModel.getTaskSettings().getTopN(), is(expectedTopN)); + assertThat(overriddenModel.getTaskSettings().getReturnDocuments(), is(expectedReturnDocuments)); + } + + public static Map buildTaskSettingsMap(@Nullable Integer topN, @Nullable Boolean returnDocuments) { + final var map = new HashMap(); + + if (returnDocuments != null) { + map.put(RETURN_DOCUMENTS, returnDocuments); + } + + if (topN != null) { + map.put(TOP_N, topN); + } + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettingsTests.java new file mode 100644 index 0000000000000..4e7c3ea3605b1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettingsTests.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.net.URI; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class OpenShiftAiRerankServiceSettingsTests extends AbstractBWCWireSerializationTestCase { + + private static OpenShiftAiRerankServiceSettings createRandom() { + var modelId = randomAlphaOfLengthOrNull(8); + var url = randomAlphaOfLength(15); + return new OpenShiftAiRerankServiceSettings(modelId, ServiceUtils.createUri(url), RateLimitSettingsTests.createRandom()); + } + + public void testToXContent_WritesAllFields() throws IOException { + var url = "http://www.abc.com"; + var model = "model"; + var rateLimitSettings = new RateLimitSettings(100); + + assertXContentEquals(new OpenShiftAiRerankServiceSettings(model, url, rateLimitSettings), """ + { + "model_id":"model", + "url":"http://www.abc.com", + "rate_limit": { + "requests_per_minute": 100 + } + } + """); + } + + public void testToXContent_WritesDefaultRateLimitAndOmitsModelIdIfNotSet() throws IOException { + var url = "http://www.abc.com"; + + assertXContentEquals(new OpenShiftAiRerankServiceSettings(null, url, null), """ + { + "url":"http://www.abc.com", + "rate_limit": { + "requests_per_minute": 3000 + } + } + """); + } + + private static void assertXContentEquals(OpenShiftAiRerankServiceSettings serviceSettings, String expectedString) throws IOException { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(expectedString)); + } + + @Override + protected Writeable.Reader instanceReader() { + return OpenShiftAiRerankServiceSettings::new; + } + + @Override + protected OpenShiftAiRerankServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected OpenShiftAiRerankServiceSettings mutateInstance(OpenShiftAiRerankServiceSettings instance) throws IOException { + String modelId = instance.modelId(); + URI uri = instance.uri(); + RateLimitSettings rateLimitSettings = instance.rateLimitSettings(); + + switch (between(0, 2)) { + case 0 -> modelId = randomValueOtherThan(modelId, () -> randomAlphaOfLengthOrNull(8)); + case 1 -> uri = randomValueOtherThan(uri, () -> ServiceUtils.createUri(randomAlphaOfLength(15))); + case 2 -> rateLimitSettings = randomValueOtherThan(rateLimitSettings, RateLimitSettingsTests::createRandom); + default -> throw new AssertionError("Illegal randomisation branch"); + } + return new OpenShiftAiRerankServiceSettings(modelId, uri, rateLimitSettings); + } + + @Override + protected OpenShiftAiRerankServiceSettings mutateInstanceForVersion( + OpenShiftAiRerankServiceSettings instance, + TransportVersion version + ) { + return instance; + } + + public static Map getServiceSettingsMap(@Nullable String url, @Nullable String model) { + return new HashMap<>(OpenShiftAiChatCompletionServiceSettingsTests.getServiceSettingsMap(url, model)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java new file mode 100644 index 0000000000000..f2f2bc0f66e88 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java @@ -0,0 +1,173 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; +import static org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankModelTests.buildTaskSettingsMap; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.sameInstance; + +public class OpenShiftAiRerankTaskSettingsTests extends AbstractBWCWireSerializationTestCase { + public static OpenShiftAiRerankTaskSettings createRandom() { + var returnDocuments = randomOptionalBoolean(); + var topNDocsOnly = randomBoolean() ? randomIntBetween(1, 10) : null; + + return new OpenShiftAiRerankTaskSettings(topNDocsOnly, returnDocuments); + } + + public void testFromMap_WithValidValues_ReturnsSettings() { + var settings = OpenShiftAiRerankTaskSettings.fromMap(buildTaskSettingsMap(5, true)); + assertThat(settings.getReturnDocuments(), is(true)); + assertThat(settings.getTopN(), is(5)); + } + + public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() { + var settings = OpenShiftAiRerankTaskSettings.fromMap(Map.of()); + assertThat(settings.getReturnDocuments(), is(nullValue())); + assertThat(settings.getTopN(), is(nullValue())); + } + + public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() { + Map taskMap = Map.of( + OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, + "invalid", + OpenShiftAiRerankTaskSettings.TOP_N, + 5 + ); + var thrownException = expectThrows(ValidationException.class, () -> OpenShiftAiRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [return_documents] is not of the expected type")); + } + + public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() { + Map taskMap = Map.of( + OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, + true, + OpenShiftAiRerankTaskSettings.TOP_N, + "invalid" + ); + var thrownException = expectThrows(ValidationException.class, () -> OpenShiftAiRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [top_n] is not of the expected type")); + } + + public void testUpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { + var initialSettings = new OpenShiftAiRerankTaskSettings(5, true); + OpenShiftAiRerankTaskSettings updatedSettings = (OpenShiftAiRerankTaskSettings) initialSettings.updatedTaskSettings(Map.of()); + assertThat(initialSettings, is(sameInstance(updatedSettings))); + } + + public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() { + var initialSettings = new OpenShiftAiRerankTaskSettings(5, true); + OpenShiftAiRerankTaskSettings updatedSettings = (OpenShiftAiRerankTaskSettings) initialSettings.updatedTaskSettings( + buildTaskSettingsMap(null, false) + ); + assertThat(updatedSettings.getReturnDocuments(), is(false)); + assertThat(initialSettings.getTopN(), is(updatedSettings.getTopN())); + } + + public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() { + var initialSettings = new OpenShiftAiRerankTaskSettings(5, true); + OpenShiftAiRerankTaskSettings updatedSettings = (OpenShiftAiRerankTaskSettings) initialSettings.updatedTaskSettings( + buildTaskSettingsMap(7, null) + ); + assertThat(updatedSettings.getTopN(), is(7)); + assertThat(updatedSettings.getReturnDocuments(), is(initialSettings.getReturnDocuments())); + } + + public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() { + var initialSettings = new OpenShiftAiRerankTaskSettings(5, true); + OpenShiftAiRerankTaskSettings updatedSettings = (OpenShiftAiRerankTaskSettings) initialSettings.updatedTaskSettings( + buildTaskSettingsMap(7, false) + ); + assertThat(updatedSettings.getReturnDocuments(), is(false)); + assertThat(updatedSettings.getTopN(), is(7)); + } + + public void testToXContent_WritesAllValues() throws IOException { + testToXContent(2, true, """ + { + "top_n":2, + "return_documents":true + } + """); + } + + public void testToXContent_EmptyValues() throws IOException { + testToXContent(null, null, """ + {} + """); + } + + public void testToXContent_OnlyTopN() throws IOException { + testToXContent(2, null, """ + { + "top_n":2 + } + """); + } + + public void testToXContent_OnlyReturnDocuments() throws IOException { + testToXContent(null, true, """ + { + "return_documents":true + } + """); + } + + private static void testToXContent(Integer topN, Boolean doReturnDocuments, String expectedString) throws IOException { + var taskSettings = new OpenShiftAiRerankTaskSettings(topN, doReturnDocuments); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + taskSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(expectedString)); + } + + @Override + protected Writeable.Reader instanceReader() { + return OpenShiftAiRerankTaskSettings::new; + } + + @Override + protected OpenShiftAiRerankTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected OpenShiftAiRerankTaskSettings mutateInstance(OpenShiftAiRerankTaskSettings instance) throws IOException { + Integer topN = instance.getTopN(); + Boolean returnDocuments = instance.getReturnDocuments(); + switch (between(0, 1)) { + case 0 -> topN = randomValueOtherThan(topN, () -> randomBoolean() ? randomIntBetween(1, 10) : null); + case 1 -> returnDocuments = randomValueOtherThan(returnDocuments, ESTestCase::randomOptionalBoolean); + default -> throw new AssertionError("Illegal randomisation branch"); + } + return new OpenShiftAiRerankTaskSettings(topN, returnDocuments); + } + + @Override + protected OpenShiftAiRerankTaskSettings mutateInstanceForVersion(OpenShiftAiRerankTaskSettings instance, TransportVersion version) { + return instance; + } +}