From 3c671238bea5b3dbfbd14b40dec714b1fc1424b7 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Wed, 15 Oct 2025 16:45:09 +0300 Subject: [PATCH 01/70] Implement OpenShift AI integration for chat completion, embeddings, and reranking --- .../ml_inference_openshift_ai_added.csv | 1 + .../resources/transport/upper_bounds/9.3.csv | 2 +- .../InferenceNamedWriteablesProvider.java | 32 ++ .../xpack/inference/InferencePlugin.java | 2 + .../openshiftai/OpenShiftAiModel.java | 59 +++ .../openshiftai/OpenShiftAiService.java | 424 ++++++++++++++++++ .../OpenShiftAiServiceSettings.java | 122 +++++ .../openshiftai/OpenShiftAiUtils.java | 21 + .../action/OpenShiftAiActionCreator.java | 143 ++++++ .../action/OpenShiftAiActionVisitor.java | 47 ++ .../OpenShiftAiChatCompletionModel.java | 126 ++++++ ...nShiftAiChatCompletionResponseHandler.java | 28 ++ ...nShiftAiChatCompletionServiceSettings.java | 104 +++++ .../OpenShiftAiCompletionResponseHandler.java | 28 ++ .../OpenShiftAiEmbeddingsModel.java | 91 ++++ .../OpenShiftAiEmbeddingsResponseHandler.java | 29 ++ .../OpenShiftAiEmbeddingsServiceSettings.java | 245 ++++++++++ .../OpenShiftAiChatCompletionRequest.java | 89 ++++ ...penShiftAiChatCompletionRequestEntity.java | 40 ++ .../OpenShiftAiEmbeddingsRequest.java | 90 ++++ .../OpenShiftAiEmbeddingsRequestEntity.java | 47 ++ .../OpenShiftAIRerankRequestEntity.java | 61 +++ .../rarank/OpenShiftAiRerankRequest.java | 91 ++++ .../rerank/OpenShiftAIRerankTaskSettings.java | 182 ++++++++ .../rerank/OpenShiftAiRerankModel.java | 85 ++++ .../OpenShiftAiRerankServiceSettings.java | 116 +++++ 26 files changed, 2304 insertions(+), 1 deletion(-) create mode 100644 server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiUtils.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreator.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionVisitor.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandler.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiCompletionResponseHandler.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsResponseHandler.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAIRerankTaskSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java diff --git a/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv b/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv new file mode 100644 index 0000000000000..f9b9e54ceb668 --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv @@ -0,0 +1 @@ +9197000 diff --git a/server/src/main/resources/transport/upper_bounds/9.3.csv b/server/src/main/resources/transport/upper_bounds/9.3.csv index af0ee6ebf047e..fb9ecdb9630a1 100644 --- a/server/src/main/resources/transport/upper_bounds/9.3.csv +++ b/server/src/main/resources/transport/upper_bounds/9.3.csv @@ -1 +1 @@ -ilm_downsample_force_merge,9196000 +ml_inference_openshift_ai_added,9197000 diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 951bb9e5802c9..0ae024ac6fe88 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -113,6 +113,10 @@ import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAIRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankServiceSettings; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -170,6 +174,7 @@ public static List getNamedWriteables() { addCustomNamedWriteables(namedWriteables); addLlamaNamedWriteables(namedWriteables); addAi21NamedWriteables(namedWriteables); + addOpenShiftAiNamedWriteables(namedWriteables); addUnifiedNamedWriteables(namedWriteables); @@ -440,6 +445,33 @@ private static void addOpenAiNamedWriteables(List ); } + private static void addOpenShiftAiNamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + OpenShiftAiEmbeddingsServiceSettings.NAME, + OpenShiftAiEmbeddingsServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + OpenShiftAiChatCompletionServiceSettings.NAME, + OpenShiftAiChatCompletionServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + OpenShiftAiRerankServiceSettings.NAME, + OpenShiftAiRerankServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(TaskSettings.class, OpenShiftAIRerankTaskSettings.NAME, OpenShiftAIRerankTaskSettings::new) + ); + } + private static void addHuggingFaceNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 60592c5dd1dbd..810915cab47d9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -144,6 +144,7 @@ import org.elasticsearch.xpack.inference.services.llama.LlamaService; import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiService; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerService; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerConfiguration; @@ -426,6 +427,7 @@ public List getInferenceServiceFactories() { context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context), context -> new LlamaService(httpFactory.get(), serviceComponents.get(), context), context -> new Ai21Service(httpFactory.get(), serviceComponents.get(), context), + context -> new OpenShiftAiService(httpFactory.get(), serviceComponents.get(), context), ElasticsearchInternalService::new, context -> new CustomService(httpFactory.get(), serviceComponents.get(), context) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java new file mode 100644 index 0000000000000..97664e3f5fcea --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai; + +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.Objects; + +/** + * Represents an OpenShift AI modelId that can be used for inference tasks. + * This class extends RateLimitGroupingModel to handle rate limiting based on modelId and API key. + */ +public abstract class OpenShiftAiModel extends RateLimitGroupingModel { + protected RateLimitSettings rateLimitSettings; + + protected OpenShiftAiModel(ModelConfigurations configurations, ModelSecrets secrets) { + super(configurations, secrets); + } + + protected OpenShiftAiModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + protected OpenShiftAiModel(RateLimitGroupingModel model, TaskSettings taskSettings) { + super(model, taskSettings); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return getServiceSettings().rateLimitSettings(); + } + + @Override + public int rateLimitGroupingHash() { + return Objects.hash(getServiceSettings().uri, getServiceSettings().modelId(), getSecretSettings().apiKey()); + } + + @Override + public OpenShiftAiServiceSettings getServiceSettings() { + return (OpenShiftAiServiceSettings) super.getServiceSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java new file mode 100644 index 0000000000000..a27ba0d18c095 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java @@ -0,0 +1,424 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.RerankingInferenceService; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.openshiftai.action.OpenShiftAiActionCreator; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModel; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.openshiftai.request.completion.OpenShiftAiChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +/** + * OpenShiftAiService is an implementation of the SenderService that handles inference tasks + * using models deployed to OpenShift AI environment. + * The service uses OpenShiftAiActionCreator to create actions for executing inference requests. + */ +public class OpenShiftAiService extends SenderService implements RerankingInferenceService { + public static final String NAME = "openshiftai"; + /** + * The optimal batch size depends on the hardware the model is deployed on. + * For OpenShift AI use a conservatively small max batch size as it is + * unknown how the model is deployed + */ + static final int EMBEDDING_MAX_BATCH_SIZE = 20; + private static final String SERVICE_NAME = "OpenShift AI"; + private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of( + TaskType.TEXT_EMBEDDING, + TaskType.COMPLETION, + TaskType.CHAT_COMPLETION, + TaskType.RERANK + ); + private static final ResponseHandler CHAT_COMPLETION_HANDLER = new OpenShiftAiChatCompletionResponseHandler( + "OpenShift AI chat completions", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + public OpenShiftAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public OpenShiftAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); + } + + @Override + protected void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + var actionCreator = new OpenShiftAiActionCreator(getSender(), getServiceComponents()); + + switch (Objects.requireNonNull(model)) { + case OpenShiftAiChatCompletionModel chatCompletionModel -> chatCompletionModel.accept(actionCreator) + .execute(inputs, timeout, listener); + case OpenShiftAiEmbeddingsModel embeddingsModel -> embeddingsModel.accept(actionCreator).execute(inputs, timeout, listener); + case OpenShiftAiRerankModel rerankModel -> rerankModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener); + default -> listener.onFailure(createInvalidModelException(model)); + } + } + + @Override + protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { + ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException); + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof OpenShiftAiChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + OpenShiftAiChatCompletionModel chatCompletionModel = (OpenShiftAiChatCompletionModel) model; + var overriddenModel = OpenShiftAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest()); + var manager = new GenericRequestManager<>( + getServiceComponents().threadPool(), + overriddenModel, + CHAT_COMPLETION_HANDLER, + unifiedChatInput -> new OpenShiftAiChatCompletionRequest(unifiedChatInput, overriddenModel), + UnifiedChatInput.class + ); + var errorMessage = OpenShiftAiActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId()); + var action = new SenderExecutableAction(getSender(), manager, errorMessage); + + action.execute(inputs, timeout, listener); + } + + @Override + protected void doChunkedInfer( + Model model, + List inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + if (model instanceof OpenShiftAiEmbeddingsModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + var openShiftAiEmbeddingsModel = (OpenShiftAiEmbeddingsModel) model; + var actionCreator = new OpenShiftAiActionCreator(getSender(), getServiceComponents()); + List batchedRequests = new EmbeddingRequestChunker<>( + inputs, + EMBEDDING_MAX_BATCH_SIZE, + openShiftAiEmbeddingsModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = openShiftAiEmbeddingsModel.accept(actionCreator); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + } + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return SUPPORTED_TASK_TYPES; + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + + OpenShiftAiModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + @Override + public OpenShiftAiModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + secretSettingsMap, + taskSettingsMap, + chunkingSettings + ); + } + + @Override + public OpenShiftAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + ChunkingSettings chunkingSettingsMap = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettingsMap = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, null, taskSettingsMap, chunkingSettingsMap); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return OpenShiftAiUtils.ML_INFERENCE_OPENSHIFT_AI_ADDED; + } + + @Override + public Set supportedStreamingTasks() { + return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); + } + + private static OpenShiftAiModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + @Nullable Map secretSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + ConfigurationParseContext context + ) { + switch (taskType) { + case CHAT_COMPLETION, COMPLETION: + return new OpenShiftAiChatCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context); + case TEXT_EMBEDDING: + return new OpenShiftAiEmbeddingsModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + chunkingSettings, + secretSettings, + context + ); + case RERANK: + return new OpenShiftAiRerankModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + default: + throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); + } + } + + private OpenShiftAiModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map secretSettings, + Map taskSettings, + ChunkingSettings chunkingSettings + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + secretSettings, + taskSettings, + chunkingSettings, + ConfigurationParseContext.PERSISTENT + ); + } + + @Override + public int rerankerWindowSize(String modelId) { + // Cohere rerank model truncates at 4096 tokens https://docs.cohere.com/reference/rerank + // Using 1 token = 0.75 words as a rough estimate, we get 3072 words + // allowing for some headroom, we set the window size below 3072 + return 2800; + } + + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof OpenShiftAiEmbeddingsModel embeddingsModel) { + var serviceSettings = embeddingsModel.getServiceSettings(); + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; + + var updatedServiceSettings = new OpenShiftAiEmbeddingsServiceSettings( + serviceSettings.modelId(), + serviceSettings.uri(), + embeddingSize, + similarityToUse, + serviceSettings.maxInputTokens(), + serviceSettings.rateLimitSettings(), + serviceSettings.dimensionsSetByUser() + ); + + return new OpenShiftAiEmbeddingsModel(embeddingsModel, updatedServiceSettings); + } else { + throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass()); + } + } + + /** + * Configuration class for the OpenShift AI inference service. + * It provides the settings and configurations required for the service. + */ + public static class Configuration { + public static InferenceServiceConfiguration get() { + return CONFIGURATION.getOrCompute(); + } + + private static final LazyInitializable CONFIGURATION = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.put( + URL, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription( + "The absolute URL of the external service to send requests to." + ) + .setLabel("URL") + .setRequired(false) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + + configurationMap.put( + MODEL_ID, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription( + "The name of the model to use for the inference task." + ) + .setLabel("Model ID") + .setRequired(false) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(SUPPORTED_TASK_TYPES) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java new file mode 100644 index 0000000000000..247af8f7bc3dd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; + +/** + * Represents the settings for an OpenShift AI service. + * This class encapsulates the model ID, URI, and rate limit settings for the OpenShift AI service. + */ +public abstract class OpenShiftAiServiceSettings extends FilteredXContentObject implements ServiceSettings { + // There is no default rate limit for OpenShift AI, so we set a reasonable default of 3000 requests per minute + protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + + protected final String modelId; + protected final URI uri; + protected final RateLimitSettings rateLimitSettings; + + /** + * Constructs a new OpenShiftAiServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + protected OpenShiftAiServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readOptionalString(); + this.uri = createUri(in.readString()); + this.rateLimitSettings = new RateLimitSettings(in); + } + + /** + * Constructs a new OpenShiftAiServiceSettings. + * + * @param modelId the ID of the modelId + * @param uri the URI of the service + * @param rateLimitSettings the rate limit settings for the service + */ + protected OpenShiftAiServiceSettings(@Nullable String modelId, URI uri, @Nullable RateLimitSettings rateLimitSettings) { + this.modelId = modelId; + this.uri = uri; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + assert false : "should never be called when supportsVersion is used"; + return OpenShiftAiUtils.ML_INFERENCE_OPENSHIFT_AI_ADDED; + } + + @Override + public boolean supportsVersion(TransportVersion version) { + return OpenShiftAiUtils.supportsOpenShiftAi(version); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(modelId); + out.writeString(uri.toString()); + rateLimitSettings.writeTo(out); + } + + @Override + public String modelId() { + return this.modelId; + } + + /** + * Returns the URI of the OpenShift AI service. + * + * @return the URI of the service + */ + public URI uri() { + return this.uri; + } + + /** + * Returns the rate limit settings for the OpenShift AI service. + * + * @return the rate limit settings + */ + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + if (modelId != null) { + builder.field(MODEL_ID, modelId); + } + builder.field(URL, uri.toString()); + rateLimitSettings.toXContent(builder, params); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiUtils.java new file mode 100644 index 0000000000000..e0227b532e61c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiUtils.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai; + +import org.elasticsearch.TransportVersion; + +public final class OpenShiftAiUtils { + public static final TransportVersion ML_INFERENCE_OPENSHIFT_AI_ADDED = TransportVersion.fromName("ml_inference_openshift_ai_added"); + + public static boolean supportsOpenShiftAi(TransportVersion version) { + return version.supports(ML_INFERENCE_OPENSHIFT_AI_ADDED); + } + + private OpenShiftAiUtils() {} + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreator.java new file mode 100644 index 0000000000000..3913ced6f1a2b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreator.java @@ -0,0 +1,143 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.action; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.cohere.CohereResponseHandler; +import org.elasticsearch.xpack.inference.services.cohere.response.CohereRankedResponseEntity; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModel; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsResponseHandler; +import org.elasticsearch.xpack.inference.services.openshiftai.request.completion.OpenShiftAiChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.openshiftai.request.embeddings.OpenShiftAiEmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.openshiftai.request.rarank.OpenShiftAiRerankRequest; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankModel; + +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +/** + * Creates executable actions for OpenShift AI models. + * This class implements the {@link OpenShiftAiActionVisitor} interface to provide specific action creation methods. + */ +public class OpenShiftAiActionCreator implements OpenShiftAiActionVisitor { + + private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = + "Failed to send OpenShift AI %s request from inference entity id [%s]"; + private static final String COMPLETION_ERROR_PREFIX = "OpenShift AI completions"; + private static final String USER_ROLE = "user"; + + private static final ResponseHandler EMBEDDINGS_HANDLER = new OpenShiftAiEmbeddingsResponseHandler( + "OpenShift AI text embedding", + OpenAiEmbeddingsResponseEntity::fromResponse + ); + private static final ResponseHandler COMPLETION_HANDLER = new OpenShiftAiCompletionResponseHandler( + "OpenShift AI completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + // OpenShift AI Rerank task uses the same response format as Cohere, therefore we can reuse the CohereResponseHandler + private static final ResponseHandler RERANK_HANDLER = new CohereResponseHandler( + "OpenShift AI rerank", + (request, response) -> CohereRankedResponseEntity.fromResponse(response), + false + ); + + private final Sender sender; + private final ServiceComponents serviceComponents; + + /** + * Constructs a new OpenShiftAiActionCreator. + * + * @param sender the sender to use for executing actions + * @param serviceComponents the service components providing necessary services + */ + public OpenShiftAiActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(OpenShiftAiEmbeddingsModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + EMBEDDINGS_HANDLER, + embeddingsInput -> new OpenShiftAiEmbeddingsRequest( + serviceComponents.truncator(), + truncate(embeddingsInput.getInputs(), model.getServiceSettings().maxInputTokens()), + model + ), + EmbeddingsInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId()); + return new SenderExecutableAction(sender, manager, errorMessage); + } + + @Override + public ExecutableAction create(OpenShiftAiChatCompletionModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + COMPLETION_HANDLER, + inputs -> new OpenShiftAiChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + ChatCompletionInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.COMPLETION, model.getInferenceEntityId()); + return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX); + } + + @Override + public ExecutableAction create(OpenShiftAiRerankModel model, Map taskSettings) { + var overriddenModel = OpenShiftAiRerankModel.of(model, taskSettings); + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + overriddenModel, + RERANK_HANDLER, + inputs -> new OpenShiftAiRerankRequest( + inputs.getQuery(), + inputs.getChunks(), + inputs.getReturnDocuments(), + inputs.getTopN(), + model + ), + QueryAndDocsInputs.class + ); + var errorMessage = buildErrorMessage(TaskType.RERANK, model.getInferenceEntityId()); + return new SenderExecutableAction(sender, manager, errorMessage); + } + + /** + * Builds an error message for failed requests. + * + * @param requestType the type of request that failed + * @param inferenceId the inference entity ID associated with the request + * @return a formatted error message + */ + public static String buildErrorMessage(TaskType requestType, String inferenceId) { + return format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, requestType.toString(), inferenceId); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionVisitor.java new file mode 100644 index 0000000000000..fd06807bcb174 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionVisitor.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.action; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModel; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankModel; + +import java.util.Map; + +/** + * Visitor interface for creating executable actions for OpenShift AI inference models. + * This interface defines methods to create actions for embeddings, reranking and completion models. + */ +public interface OpenShiftAiActionVisitor { + + /** + * Creates an executable action for the given OpenShift AI embeddings model. + * + * @param model The OpenShift AI embeddings model. + * @return An executable action for the embeddings model. + */ + ExecutableAction create(OpenShiftAiEmbeddingsModel model); + + /** + * Creates an executable action for the given OpenShift AI chat completion model. + * + * @param model The OpenShift AI chat completion model. + * @return An executable action for the chat completion model. + */ + ExecutableAction create(OpenShiftAiChatCompletionModel model); + + /** + * Creates an executable action for the given OpenShift AI rerank model. + * + * @param model The OpenShift AI rerank model. + * @param taskSettings The task settings for the rerank action. + * @return An executable action for the rerank model. + */ + ExecutableAction create(OpenShiftAiRerankModel model, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java new file mode 100644 index 0000000000000..d029ab2e1bcad --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.completion; + +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiModel; +import org.elasticsearch.xpack.inference.services.openshiftai.action.OpenShiftAiActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +/** + * Represents an OpenShift AI chat completion model. + * This class extends the OpenShiftAiModel and provides specific configurations for chat completion tasks. + */ +public class OpenShiftAiChatCompletionModel extends OpenShiftAiModel { + + /** + * Constructor for creating a OpenShiftAiChatCompletionModel with specified parameters. + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to chat completion + * @param secrets the secret settings for the model, such as API keys or tokens + * @param context the context for parsing configuration settings + */ + public OpenShiftAiChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + OpenShiftAiChatCompletionServiceSettings.fromMap(serviceSettings, context), + DefaultSecretSettings.fromMap(secrets) + ); + } + + /** + * Constructor for creating an OpenShiftAiChatCompletionModel with specified parameters. + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to chat completion + * @param secrets the secret settings for the model + */ + public OpenShiftAiChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + OpenShiftAiChatCompletionServiceSettings serviceSettings, + SecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE), + new ModelSecrets(secrets) + ); + } + + /** + * Factory method to create an OpenShiftAiChatCompletionModel with potential overrides from a UnifiedCompletionRequest. + * If the request does not specify a model ID, the original model is returned. + * + * @param model the original OpenShiftAiChatCompletionModel + * @param request the UnifiedCompletionRequest containing potential overrides + * @return a new OpenShiftAiChatCompletionModel with overridden settings or the original model ID if no overrides are specified + */ + public static OpenShiftAiChatCompletionModel of(OpenShiftAiChatCompletionModel model, UnifiedCompletionRequest request) { + if (request.model() == null) { + // If no model ID is specified in the request, return the original model + return model; + } + + var originalModelServiceSettings = model.getServiceSettings(); + var overriddenServiceSettings = new OpenShiftAiChatCompletionServiceSettings( + request.model(), + originalModelServiceSettings.uri(), + originalModelServiceSettings.rateLimitSettings() + ); + + return new OpenShiftAiChatCompletionModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + overriddenServiceSettings, + model.getSecretSettings() + ); + } + + /** + * Returns the service settings specific to OpenShift AI chat completion. + * + * @return the OpenShiftAiChatCompletionServiceSettings associated with this model + */ + @Override + public OpenShiftAiChatCompletionServiceSettings getServiceSettings() { + return (OpenShiftAiChatCompletionServiceSettings) super.getServiceSettings(); + } + + /** + * Accepts a visitor that creates an executable action for this OpenShift AI chat completion. + * + * @param creator the visitor that creates the executable action + * @return an ExecutableAction representing this model + */ + public ExecutableAction accept(OpenShiftAiActionVisitor creator) { + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..8d9837e8915c1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandler.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.completion; + +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParserContract; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponseUtils; +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; + +/** + * Handles streaming chat completion responses and error parsing for OpenShift AI inference endpoints. + * Adapts the OpenAI handler to support OpenShift AI's error schema. + */ +public class OpenShiftAiChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { + + private static final String OPENSHIFT_AI_ERROR = "openshiftai_error"; + private static final UnifiedChatCompletionErrorParserContract OPENSHIFT_AI_ERROR_PARSER = UnifiedChatCompletionErrorResponseUtils + .createErrorParserWithStringify(OPENSHIFT_AI_ERROR); + + public OpenShiftAiChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, OPENSHIFT_AI_ERROR_PARSER::parse, OPENSHIFT_AI_ERROR_PARSER); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java new file mode 100644 index 0000000000000..c8980843b3969 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.completion; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiService; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; + +/** + * Represents the settings for a OpenShift AI chat completion service. + * This class encapsulates the model ID, URI, and rate limit settings for the OpenShift AI chat completion service. + */ +public class OpenShiftAiChatCompletionServiceSettings extends OpenShiftAiServiceSettings { + public static final String NAME = "openshiftai_completion_service_settings"; + + /** + * Creates a new instance of OpenShiftAiChatCompletionServiceSettings from a map of settings. + * + * @param map the map containing the service settings + * @param context the context for parsing configuration settings + * @return a new instance of OpenShiftAiChatCompletionServiceSettings + * @throws ValidationException if required fields are missing or invalid + */ + public static OpenShiftAiChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var model = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = extractUri(map, URL, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + OpenShiftAiService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new OpenShiftAiChatCompletionServiceSettings(model, uri, rateLimitSettings); + } + + /** + * Constructs a new OpenShiftAiChatCompletionServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public OpenShiftAiChatCompletionServiceSettings(StreamInput in) throws IOException { + super(in); + } + + /** + * Constructs a new OpenShiftAiChatCompletionServiceSettings. + * + * @param modelId the ID of the model ID + * @param uri the URI of the service + * @param rateLimitSettings the rate limit settings for the service + */ + public OpenShiftAiChatCompletionServiceSettings(@Nullable String modelId, URI uri, @Nullable RateLimitSettings rateLimitSettings) { + super(modelId, uri, rateLimitSettings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OpenShiftAiChatCompletionServiceSettings that = (OpenShiftAiChatCompletionServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiCompletionResponseHandler.java new file mode 100644 index 0000000000000..c1703df534daf --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiCompletionResponseHandler.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.completion; + +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler; + +/** + * Handles non-streaming completion responses for OpenShift AI models, extending the OpenAI completion response handler. + */ +public class OpenShiftAiCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { + + /** + * Constructs a OpenShiftAiCompletionResponseHandler with the specified request type and response parser. + * + * @param requestType The type of request being handled (e.g., "llama completions"). + * @param parseFunction The function to parse the response. + */ + public OpenShiftAiCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, ErrorResponse::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java new file mode 100644 index 0000000000000..6311bd82a068e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.embeddings; + +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiModel; +import org.elasticsearch.xpack.inference.services.openshiftai.action.OpenShiftAiActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +/** + * Represents a OpenShift AI embeddings model for inference. + * This class extends the OpenShiftAiModel and provides specific configurations and settings for embeddings tasks. + */ +public class OpenShiftAiEmbeddingsModel extends OpenShiftAiModel { + + public OpenShiftAiEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + ChunkingSettings chunkingSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + OpenShiftAiEmbeddingsServiceSettings.fromMap(serviceSettings, context), + chunkingSettings, + DefaultSecretSettings.fromMap(secrets) + ); + } + + public OpenShiftAiEmbeddingsModel(OpenShiftAiEmbeddingsModel model, OpenShiftAiEmbeddingsServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + /** + * Constructor for creating a OpenShiftAiEmbeddingsModel with specified parameters. + * + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to embeddings + * @param chunkingSettings the chunking settings for processing input data + * @param secrets the secret settings for the model, such as API keys or tokens + */ + public OpenShiftAiEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + OpenShiftAiEmbeddingsServiceSettings serviceSettings, + ChunkingSettings chunkingSettings, + SecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings), + new ModelSecrets(secrets) + ); + } + + @Override + public OpenShiftAiEmbeddingsServiceSettings getServiceSettings() { + return (OpenShiftAiEmbeddingsServiceSettings) super.getServiceSettings(); + } + + /** + * Accepts a visitor to create an executable action for this OpenShift AI embeddings model. + * + * @param creator the visitor that creates the executable action + * @return an ExecutableAction representing the OpenShift AI embeddings model + */ + public ExecutableAction accept(OpenShiftAiActionVisitor creator) { + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsResponseHandler.java new file mode 100644 index 0000000000000..ad2db4571a47a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsResponseHandler.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.embeddings; + +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.services.openai.OpenAiResponseHandler; + +/** + * Handles responses for OpenShift AI embeddings requests, parsing the response and handling errors. + * This class extends OpenAiResponseHandler to provide specific functionality for OpenShift AI embeddings. + */ +public class OpenShiftAiEmbeddingsResponseHandler extends OpenAiResponseHandler { + + /** + * Constructs a new OpenShiftAiEmbeddingsResponseHandler with the specified request type and response parser. + * + * @param requestType the type of request this handler will process + * @param parseFunction the function to parse the response + */ + public OpenShiftAiEmbeddingsResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, ErrorResponse::fromResponse, false); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..08ced7ab0e60a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java @@ -0,0 +1,245 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.embeddings; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiService; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; + +/** + * Settings for the OpenShift AI embeddings service. + * This class encapsulates the configuration settings required to use OpenShift AI for generating embeddings. + */ +public class OpenShiftAiEmbeddingsServiceSettings extends OpenShiftAiServiceSettings { + public static final String NAME = "openshiftai_embeddings_service_settings"; + static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; + + private final Integer dimensions; + private final SimilarityMeasure similarity; + private final Integer maxInputTokens; + private final Boolean dimensionsSetByUser; + + /** + * Creates a new instance of OpenShiftAiEmbeddingsServiceSettings from a map of settings. + * + * @param map the map containing the settings + * @param context the context for parsing configuration settings + * @return a new instance of OpenShiftAiEmbeddingsServiceSettings + * @throws ValidationException if any required fields are missing or invalid + */ + public static OpenShiftAiEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var model = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = extractUri(map, URL, validationException); + var dimensions = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); + var similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + var maxInputTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + var rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + OpenShiftAiService.NAME, + context + ); + Boolean dimensionsSetByUser = switch (context) { + case REQUEST -> dimensions != null; + case PERSISTENT -> extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException); + }; + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new OpenShiftAiEmbeddingsServiceSettings( + model, + uri, + dimensions, + similarity, + maxInputTokens, + rateLimitSettings, + dimensionsSetByUser + ); + } + + /** + * Constructs a new OpenShiftAiEmbeddingsServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public OpenShiftAiEmbeddingsServiceSettings(StreamInput in) throws IOException { + super(in); + this.dimensions = in.readOptionalVInt(); + this.similarity = in.readOptionalEnum(SimilarityMeasure.class); + this.maxInputTokens = in.readOptionalVInt(); + this.dimensionsSetByUser = in.readOptionalBoolean(); + } + + /** + * Constructs a new OpenShiftAiEmbeddingsServiceSettings with the specified parameters. + * + * @param modelId the identifier for the model + * @param uri the URI of the OpenShift AI service + * @param dimensions the number of dimensions for the embeddings, can be null + * @param similarity the similarity measure to use, can be null + * @param maxInputTokens the maximum number of input tokens, can be null + * @param rateLimitSettings the rate limit settings for the service, can be null + * @param dimensionsSetByUser indicates if dimensions were set by the user, can be null + */ + public OpenShiftAiEmbeddingsServiceSettings( + @Nullable String modelId, + URI uri, + @Nullable Integer dimensions, + @Nullable SimilarityMeasure similarity, + @Nullable Integer maxInputTokens, + @Nullable RateLimitSettings rateLimitSettings, + @Nullable Boolean dimensionsSetByUser + ) { + super(modelId, uri, rateLimitSettings); + this.dimensions = dimensions; + this.similarity = similarity; + this.maxInputTokens = maxInputTokens; + this.dimensionsSetByUser = dimensionsSetByUser; + } + + /** + * Constructs a new OpenShiftAiEmbeddingsServiceSettings with the specified parameters. + * + * @param modelId the identifier for the model + * @param url the URL of the OpenShift AI service + * @param dimensions the number of dimensions for the embeddings, can be null + * @param similarity the similarity measure to use, can be null + * @param maxInputTokens the maximum number of input tokens, can be null + * @param rateLimitSettings the rate limit settings for the service, can be null + * @param dimensionsSetByUser indicates if dimensions were set by the user, can be null + */ + public OpenShiftAiEmbeddingsServiceSettings( + String modelId, + String url, + @Nullable Integer dimensions, + @Nullable SimilarityMeasure similarity, + @Nullable Integer maxInputTokens, + @Nullable RateLimitSettings rateLimitSettings, + @Nullable Boolean dimensionsSetByUser + ) { + this(modelId, createUri(url), dimensions, similarity, maxInputTokens, rateLimitSettings, dimensionsSetByUser); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public Integer dimensions() { + return this.dimensions; + } + + @Override + public Boolean dimensionsSetByUser() { + return this.dimensionsSetByUser; + } + + @Override + public SimilarityMeasure similarity() { + return this.similarity; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + /** + * Returns the maximum number of input tokens allowed for this service. + * + * @return the maximum input tokens, or null if not specified + */ + public Integer maxInputTokens() { + return this.maxInputTokens; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalVInt(dimensions); + out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); + out.writeOptionalVInt(maxInputTokens); + out.writeOptionalBoolean(dimensionsSetByUser); + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + super.toXContentFragmentOfExposedFields(builder, params); + + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + if (dimensionsSetByUser != null) { + builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + } + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OpenShiftAiEmbeddingsServiceSettings that = (OpenShiftAiEmbeddingsServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(similarity, that.similarity) + && Objects.equals(rateLimitSettings, that.rateLimitSettings) + && Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, dimensions, maxInputTokens, similarity, rateLimitSettings, dimensionsSetByUser); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequest.java new file mode 100644 index 0000000000000..9e4796e28ac3d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequest.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.completion; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * OpenShift AI Chat Completion Request + * This class is responsible for creating a request to the OpenShift AI chat completion model. + * It constructs an HTTP POST request with the necessary headers and body content. + */ +public class OpenShiftAiChatCompletionRequest implements Request { + + private final OpenShiftAiChatCompletionModel model; + private final UnifiedChatInput chatInput; + + /** + * Constructs a new OpenShiftAiChatCompletionRequest with the specified chat input and model. + * + * @param chatInput the chat input containing the messages and parameters for the completion request + * @param model the OpenShift AI chat completion model to be used for the request + */ + public OpenShiftAiChatCompletionRequest(UnifiedChatInput chatInput, OpenShiftAiChatCompletionModel model) { + this.chatInput = Objects.requireNonNull(chatInput); + this.model = Objects.requireNonNull(model); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.getServiceSettings().uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new OpenShiftAiChatCompletionRequestEntity(chatInput, model.getServiceSettings().modelId())) + .getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return model.getServiceSettings().uri(); + } + + @Override + public Request truncate() { + // No truncation for OpenShift AI chat completions + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation for OpenShift AI chat completions + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public boolean isStreaming() { + return chatInput.stream(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..1a53c4180b97a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntity.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.completion; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; + +import java.io.IOException; + +/** + * OpenShiftAiChatCompletionRequestEntity is responsible for creating the request entity for OpenShift AI chat completion. + * It implements ToXContentObject to allow serialization to XContent format. + */ +public class OpenShiftAiChatCompletionRequestEntity implements ToXContentObject { + + private final String modelId; + private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; + + public OpenShiftAiChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, @Nullable String modelId) { + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); + this.modelId = modelId; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(modelId, params)); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequest.java new file mode 100644 index 0000000000000..d3fe2bc685c3f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequest.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.embeddings; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * OpenShift AI Embeddings Request + * This class is responsible for creating a request to the OpenShift AI embeddings endpoint. + * It constructs an HTTP POST request with the necessary headers and body content. + */ +public class OpenShiftAiEmbeddingsRequest implements Request { + private final OpenShiftAiEmbeddingsModel model; + private final Truncator.TruncationResult truncationResult; + private final Truncator truncator; + + /** + * Constructs a new OpenShiftAiEmbeddingsRequest with the specified truncator, input, and model. + * + * @param truncator the truncator to handle input truncation + * @param input the input to be truncated + * @param model the OpenShiftId embeddings model to be used for the request + */ + public OpenShiftAiEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, OpenShiftAiEmbeddingsModel model) { + this.model = model; + this.truncator = truncator; + this.truncationResult = input; + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.getServiceSettings().uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString( + new OpenShiftAiEmbeddingsRequestEntity( + truncationResult.input(), + model.getServiceSettings().modelId(), + model.getServiceSettings().dimensions(), + model.getServiceSettings().dimensionsSetByUser() + ) + ).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return model.getServiceSettings().uri(); + } + + @Override + public Request truncate() { + var truncatedInput = truncator.truncate(truncationResult.input()); + return new OpenShiftAiEmbeddingsRequest(truncator, truncatedInput, model); + } + + @Override + public boolean[] getTruncationInfo() { + return truncationResult.truncated().clone(); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..3b06fe5592582 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntity.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.embeddings; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record OpenShiftAiEmbeddingsRequestEntity( + List input, + @Nullable String modelId, + @Nullable Integer dimensions, + boolean dimensionsSetByUser +) implements ToXContentObject { + + private static final String INPUT_FIELD = "input"; + private static final String MODEL_FIELD = "model"; + private static final String DIMENSIONS_FIELD = "dimensions"; + + public OpenShiftAiEmbeddingsRequestEntity { + Objects.requireNonNull(input); + Objects.requireNonNull(modelId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INPUT_FIELD, input); + if (modelId != null) { + builder.field(MODEL_FIELD, modelId); + } + if (dimensionsSetByUser && dimensions != null) { + builder.field(DIMENSIONS_FIELD, dimensions); + } + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntity.java new file mode 100644 index 0000000000000..9ddfcd7f1d574 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntity.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.rarank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAIRerankTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record OpenShiftAIRerankRequestEntity( + @Nullable String modelId, + String query, + List documents, + @Nullable Boolean returnDocuments, + @Nullable Integer topN +) implements ToXContentObject { + + private static final String MODEL_FIELD = "model"; + private static final String DOCUMENTS_FIELD = "documents"; + private static final String QUERY_FIELD = "query"; + + public OpenShiftAIRerankRequestEntity { + Objects.requireNonNull(query); + Objects.requireNonNull(documents); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + // model field is optional for OpenShift AI + if (modelId != null) { + builder.field(MODEL_FIELD, modelId); + } + builder.field(QUERY_FIELD, query); + builder.field(DOCUMENTS_FIELD, documents); + + // prefer the root level top_n over task settings + if (topN != null) { + builder.field(OpenShiftAIRerankTaskSettings.TOP_N, topN); + } + + // prefer the root level return_documents over task settings + if (returnDocuments != null) { + builder.field(OpenShiftAIRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments); + } + + builder.endObject(); + return builder; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequest.java new file mode 100644 index 0000000000000..62094b2b627f9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequest.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.rarank; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * Represents a request to the OpenShift AI rerank service. + * This class constructs the HTTP request with the necessary headers and body content. + */ +public record OpenShiftAiRerankRequest( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + OpenShiftAiRerankModel model +) implements Request { + + public OpenShiftAiRerankRequest { + Objects.requireNonNull(input); + Objects.requireNonNull(query); + Objects.requireNonNull(model); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(getURI()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString( + new OpenShiftAIRerankRequestEntity(model.getServiceSettings().modelId(), query, input, returnDocuments(), topN()) + ).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + + httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public URI getURI() { + return model.getServiceSettings().uri(); + } + + public Integer topN() { + return topN != null ? topN : model.getTaskSettings().getTopN(); + } + + public Boolean returnDocuments() { + return returnDocuments != null ? returnDocuments : model.getTaskSettings().getReturnDocuments(); + } + + @Override + public Request truncate() { + // Not applicable for rerank, only used in text embedding requests + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // Not applicable for rerank, only used in text embedding requests + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAIRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAIRerankTaskSettings.java new file mode 100644 index 0000000000000..4f991af1e4514 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAIRerankTaskSettings.java @@ -0,0 +1,182 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; + +/** + * Defines the task settings for the OpenShift AI rerank service. + */ +public class OpenShiftAIRerankTaskSettings implements TaskSettings { + + public static final String NAME = "openshiftai_rerank_task_settings"; + public static final String RETURN_DOCUMENTS = "return_documents"; + public static final String TOP_N = "top_n"; + + private static final OpenShiftAIRerankTaskSettings EMPTY_SETTINGS = new OpenShiftAIRerankTaskSettings(null, null); + + /** + * Creates a new {@link OpenShiftAIRerankTaskSettings} from a map of settings. + * @param map the map of settings + * @return a constructed {@link OpenShiftAIRerankTaskSettings} + * @throws ValidationException if any of the settings are invalid + */ + public static OpenShiftAIRerankTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException); + Integer topN = extractOptionalPositiveInteger(map, TOP_N, ModelConfigurations.TASK_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return of(topN, returnDocuments); + } + + /** + * Creates a new {@link OpenShiftAIRerankTaskSettings} by using non-null fields from the request settings over the original settings. + * + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @return a constructed {@link OpenShiftAIRerankTaskSettings} + */ + public static OpenShiftAIRerankTaskSettings of( + OpenShiftAIRerankTaskSettings originalSettings, + OpenShiftAIRerankTaskSettings requestTaskSettings + ) { + return new OpenShiftAIRerankTaskSettings( + requestTaskSettings.getTopN() != null ? requestTaskSettings.getTopN() : originalSettings.getTopN(), + requestTaskSettings.getReturnDocuments() != null + ? requestTaskSettings.getReturnDocuments() + : originalSettings.getReturnDocuments() + ); + } + + /** + * Creates a new {@link OpenShiftAIRerankTaskSettings} with the specified settings. + * + * @param topN the number of top documents to return + * @param returnDocuments whether to return the documents + * @return a constructed {@link OpenShiftAIRerankTaskSettings} + */ + public static OpenShiftAIRerankTaskSettings of(@Nullable Integer topN, @Nullable Boolean returnDocuments) { + return new OpenShiftAIRerankTaskSettings(topN, returnDocuments); + } + + private final Integer topN; + private final Boolean returnDocuments; + + /** + * Constructs a new {@link OpenShiftAIRerankTaskSettings} by reading from a {@link StreamInput}. + * + * @param in the stream input to read from + * @throws IOException if an I/O error occurs + */ + public OpenShiftAIRerankTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalInt(), in.readOptionalBoolean()); + } + + /** + * Constructs a new {@link OpenShiftAIRerankTaskSettings} with the specified settings. + * + * @param topN the number of top documents to return + * @param doReturnDocuments whether to return the documents + */ + public OpenShiftAIRerankTaskSettings(@Nullable Integer topN, @Nullable Boolean doReturnDocuments) { + this.topN = topN; + this.returnDocuments = doReturnDocuments; + } + + @Override + public boolean isEmpty() { + return topN == null && returnDocuments == null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (topN != null) { + builder.field(TOP_N, topN); + } + if (returnDocuments != null) { + builder.field(RETURN_DOCUMENTS, returnDocuments); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + assert false : "should never be called when supportsVersion is used"; + return OpenShiftAiUtils.ML_INFERENCE_OPENSHIFT_AI_ADDED; + } + + @Override + public boolean supportsVersion(TransportVersion version) { + return OpenShiftAiUtils.supportsOpenShiftAi(version); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalInt(topN); + out.writeOptionalBoolean(returnDocuments); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OpenShiftAIRerankTaskSettings that = (OpenShiftAIRerankTaskSettings) o; + return Objects.equals(returnDocuments, that.returnDocuments) && Objects.equals(topN, that.topN); + } + + @Override + public int hashCode() { + return Objects.hash(returnDocuments, topN); + } + + public Integer getTopN() { + return topN; + } + + public Boolean getReturnDocuments() { + return returnDocuments; + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + OpenShiftAIRerankTaskSettings updatedSettings = OpenShiftAIRerankTaskSettings.fromMap(new HashMap<>(newSettings)); + return OpenShiftAIRerankTaskSettings.of(this, updatedSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModel.java new file mode 100644 index 0000000000000..09ad4f5db2802 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModel.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.rerank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiModel; +import org.elasticsearch.xpack.inference.services.openshiftai.action.OpenShiftAiActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +public class OpenShiftAiRerankModel extends OpenShiftAiModel { + public static OpenShiftAiRerankModel of(OpenShiftAiRerankModel model, Map taskSettings) { + var requestTaskSettings = OpenShiftAIRerankTaskSettings.fromMap(taskSettings); + return new OpenShiftAiRerankModel(model, OpenShiftAIRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + } + + public OpenShiftAiRerankModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + OpenShiftAiRerankServiceSettings.fromMap(serviceSettings, context), + OpenShiftAIRerankTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + OpenShiftAiRerankModel( + String inferenceEntityId, + TaskType taskType, + String service, + OpenShiftAiRerankServiceSettings serviceSettings, + OpenShiftAIRerankTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings) + ); + } + + private OpenShiftAiRerankModel(OpenShiftAiRerankModel model, OpenShiftAIRerankTaskSettings taskSettings) { + super(model, taskSettings); + } + + @Override + public OpenShiftAiRerankServiceSettings getServiceSettings() { + return (OpenShiftAiRerankServiceSettings) super.getServiceSettings(); + } + + @Override + public OpenShiftAIRerankTaskSettings getTaskSettings() { + return (OpenShiftAIRerankTaskSettings) super.getTaskSettings(); + } + + /** + * Accepts a visitor to create an executable action. The returned action will not return documents in the response. + * @param visitor Interface for creating {@link ExecutableAction} instances for IBM watsonx models. + * @param taskSettings Settings in the request to override the model's defaults + * @return the rerank action + */ + public ExecutableAction accept(OpenShiftAiActionVisitor visitor, Map taskSettings) { + return visitor.create(this, taskSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java new file mode 100644 index 0000000000000..773269acb4010 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java @@ -0,0 +1,116 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.rerank; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiService; +import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; + +/** + * Represents the settings for a OpenShift AI chat rerank service. + * This class encapsulates the model ID, URI, and rate limit settings for the OpenShift AI chat rerank service. + */ +public class OpenShiftAiRerankServiceSettings extends OpenShiftAiServiceSettings { + public static final String NAME = "openshiftai_rerank_service_settings"; + + /** + * Creates a new instance of OpenShiftAiRerankServiceSettings from a map of settings. + * + * @param map the map containing the service settings + * @param context the context for parsing configuration settings + * @return a new instance of OpenShiftAiRerankServiceSettings + * @throws ValidationException if required fields are missing or invalid + */ + public static OpenShiftAiRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var model = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = extractUri(map, URL, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + OpenShiftAiService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new OpenShiftAiRerankServiceSettings(model, uri, rateLimitSettings); + } + + /** + * Constructs a new OpenShiftAiRerankServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public OpenShiftAiRerankServiceSettings(StreamInput in) throws IOException { + super(in); + } + + /** + * Constructs a new OpenShiftAiRerankServiceSettings with the specified model ID, URI, and rate limit settings. + * + * @param modelId the ID of the model + * @param uri the URI of the service + * @param rateLimitSettings the rate limit settings for the service + */ + public OpenShiftAiRerankServiceSettings(@Nullable String modelId, URI uri, @Nullable RateLimitSettings rateLimitSettings) { + super(modelId, uri, rateLimitSettings); + } + + /** + * Constructs a new OpenShiftAiRerankServiceSettings with the specified model ID and URL. + * The rate limit settings will be set to the default value. + * + * @param modelId the ID of the modelId + * @param url the URL of the service + */ + public OpenShiftAiRerankServiceSettings(@Nullable String modelId, String url, @Nullable RateLimitSettings rateLimitSettings) { + this(modelId, createUri(url), rateLimitSettings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OpenShiftAiRerankServiceSettings that = (OpenShiftAiRerankServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, rateLimitSettings); + } +} From fdb22ff6f8320a864ec2e2c6ca83ac57a6c40bc0 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Wed, 15 Oct 2025 17:13:58 +0300 Subject: [PATCH 02/70] Refactor OpenShift AI service settings to use underscores in constant names and add changelog --- docs/changelog/136624.yaml | 5 ++ .../openshiftai/OpenShiftAiService.java | 57 ++++++++++--------- ...nShiftAiChatCompletionResponseHandler.java | 2 +- ...nShiftAiChatCompletionServiceSettings.java | 2 +- .../OpenShiftAiEmbeddingsServiceSettings.java | 2 +- .../rerank/OpenShiftAIRerankTaskSettings.java | 2 +- .../OpenShiftAiRerankServiceSettings.java | 2 +- 7 files changed, 40 insertions(+), 32 deletions(-) create mode 100644 docs/changelog/136624.yaml diff --git a/docs/changelog/136624.yaml b/docs/changelog/136624.yaml new file mode 100644 index 0000000000000..e9e58bd91d953 --- /dev/null +++ b/docs/changelog/136624.yaml @@ -0,0 +1,5 @@ +pr: 136624 +summary: Added OpenShift AI text_embedding, completion, chat_completion and rerank support to the Inference Plugin +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java index a27ba0d18c095..1b52afe63403e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java @@ -74,7 +74,7 @@ * The service uses OpenShiftAiActionCreator to create actions for executing inference requests. */ public class OpenShiftAiService extends SenderService implements RerankingInferenceService { - public static final String NAME = "openshiftai"; + public static final String NAME = "openshift_ai"; /** * The optimal batch size depends on the hardware the model is deployed on. * For OpenShift AI use a conservatively small max batch size as it is @@ -293,32 +293,35 @@ private static OpenShiftAiModel createModel( ChunkingSettings chunkingSettings, ConfigurationParseContext context ) { - switch (taskType) { - case CHAT_COMPLETION, COMPLETION: - return new OpenShiftAiChatCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context); - case TEXT_EMBEDDING: - return new OpenShiftAiEmbeddingsModel( - inferenceEntityId, - taskType, - NAME, - serviceSettings, - chunkingSettings, - secretSettings, - context - ); - case RERANK: - return new OpenShiftAiRerankModel( - inferenceEntityId, - taskType, - NAME, - serviceSettings, - taskSettings, - secretSettings, - context - ); - default: - throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); - } + return switch (taskType) { + case CHAT_COMPLETION, COMPLETION -> new OpenShiftAiChatCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + secretSettings, + context + ); + case TEXT_EMBEDDING -> new OpenShiftAiEmbeddingsModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + chunkingSettings, + secretSettings, + context + ); + case RERANK -> new OpenShiftAiRerankModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); + }; } private OpenShiftAiModel createModelFromPersistent( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandler.java index 8d9837e8915c1..00d313709035b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandler.java @@ -18,7 +18,7 @@ */ public class OpenShiftAiChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { - private static final String OPENSHIFT_AI_ERROR = "openshiftai_error"; + private static final String OPENSHIFT_AI_ERROR = "openshift_ai_error"; private static final UnifiedChatCompletionErrorParserContract OPENSHIFT_AI_ERROR_PARSER = UnifiedChatCompletionErrorResponseUtils .createErrorParserWithStringify(OPENSHIFT_AI_ERROR); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java index c8980843b3969..1ac451b59ca0b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java @@ -31,7 +31,7 @@ * This class encapsulates the model ID, URI, and rate limit settings for the OpenShift AI chat completion service. */ public class OpenShiftAiChatCompletionServiceSettings extends OpenShiftAiServiceSettings { - public static final String NAME = "openshiftai_completion_service_settings"; + public static final String NAME = "openshift_ai_completion_service_settings"; /** * Creates a new instance of OpenShiftAiChatCompletionServiceSettings from a map of settings. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java index 08ced7ab0e60a..3db3586bfbc35 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java @@ -42,7 +42,7 @@ * This class encapsulates the configuration settings required to use OpenShift AI for generating embeddings. */ public class OpenShiftAiEmbeddingsServiceSettings extends OpenShiftAiServiceSettings { - public static final String NAME = "openshiftai_embeddings_service_settings"; + public static final String NAME = "openshift_ai_embeddings_service_settings"; static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; private final Integer dimensions; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAIRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAIRerankTaskSettings.java index 4f991af1e4514..aa8396178e4a4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAIRerankTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAIRerankTaskSettings.java @@ -30,7 +30,7 @@ */ public class OpenShiftAIRerankTaskSettings implements TaskSettings { - public static final String NAME = "openshiftai_rerank_task_settings"; + public static final String NAME = "openshift_ai_rerank_task_settings"; public static final String RETURN_DOCUMENTS = "return_documents"; public static final String TOP_N = "top_n"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java index 773269acb4010..bd58a67ab299c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java @@ -32,7 +32,7 @@ * This class encapsulates the model ID, URI, and rate limit settings for the OpenShift AI chat rerank service. */ public class OpenShiftAiRerankServiceSettings extends OpenShiftAiServiceSettings { - public static final String NAME = "openshiftai_rerank_service_settings"; + public static final String NAME = "openshift_ai_rerank_service_settings"; /** * Creates a new instance of OpenShiftAiRerankServiceSettings from a map of settings. From b268e08530f513de1f16b2040bd58ace5230a97a Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Thu, 16 Oct 2025 13:13:53 +0300 Subject: [PATCH 03/70] Add constructor to OpenShiftAiChatCompletionServiceSettings for URL handling --- .../OpenShiftAiChatCompletionServiceSettings.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java index 1ac451b59ca0b..85f8d775d8aba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java @@ -23,6 +23,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; @@ -82,6 +83,17 @@ public OpenShiftAiChatCompletionServiceSettings(@Nullable String modelId, URI ur super(modelId, uri, rateLimitSettings); } + /** + * Constructs a new OpenShiftAiChatCompletionServiceSettings. + * + * @param modelId the ID of the model ID + * @param url the URL of the OpenShift AI service + * @param rateLimitSettings the rate limit settings for the service + */ + public OpenShiftAiChatCompletionServiceSettings(@Nullable String modelId, String url, @Nullable RateLimitSettings rateLimitSettings) { + super(modelId, createUri(url), rateLimitSettings); + } + @Override public String getWriteableName() { return NAME; From 9cae6b1808cf28b3fe0d8dbb75d11b7f5b3a6b75 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Thu, 16 Oct 2025 17:55:25 +0300 Subject: [PATCH 04/70] Add unit tests --- .../action/OpenShiftAiActionCreatorTests.java | 762 ++++++++++++++++++ .../OpenShiftAiChatCompletionModelTests.java | 33 + .../OpenShiftAiEmbeddingsModelTests.java | 33 + .../rerank/OpenShiftAiRerankModelTests.java | 28 + 4 files changed, 856 insertions(+) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java new file mode 100644 index 0000000000000..85fea4f40b149 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -0,0 +1,762 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.action; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests; +import org.elasticsearch.xpack.inference.InputTypeTests; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankModelTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests.buildExpectationRerank; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModelTests.createCompletionModel; +import static org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModelTests.createModel; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class OpenShiftAiActionCreatorTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityExecutors()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "embd-45e6d99b97a645c0af96653598069cd9", + "object": "list", + "created": 1760085467, + "model": "gritlm-7b", + "data": [ + { + "index": 0, + "object": "embedding", + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "usage": { + "prompt_tokens": 7, + "total_tokens": 7, + "completion_tokens": 0, + "prompt_tokens_details": null + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel(getUrl(webServer), "secret", "model"); + var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("input"), is(List.of("abc"))); + assertThat(requestMap.get("model"), is("model")); + } + } + + public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat() throws IOException { + // timeout as zero for no retries + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "object": "list", + "data_does_not_exist": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel(getUrl(webServer), "secret", "model"); + var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var failureCauseMessage = "Required [data]"; + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is( + format( + "Failed to send OpenShift AI text_embedding request from inference entity id [inferenceEntityId]. Cause: %s", + failureCauseMessage + ) + ) + ); + assertThat(thrownException.getCause().getMessage(), is(failureCauseMessage)); + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("input"), is(List.of("abc"))); + assertThat(requestMap.get("model"), is("model")); + } + } + + public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "chatcmpl-921d2eb8f3bc46dd8f4cb0502a4608a7", + "object": "chat.completion", + "created": 1760082857, + "model": "llama-31-8b-instruct", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "reasoning_content": null, + "content": "Hello there, how may I assist you today?", + "tool_calls": [] + }, + "logprobs": null, + "finish_reason": "length", + "stop_reason": null + } + ], + "usage": { + "prompt_tokens": 40, + "total_tokens": 140, + "completion_tokens": 100, + "prompt_tokens_details": null + }, + "prompt_logprobs": null, + "kv_transfer_params": null + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createCompletionModel(getUrl(webServer), "secret", "model"); + var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); + assertThat(webServer.requests(), hasSize(1)); + + var request = webServer.requests().get(0); + + assertNull(request.getUri().getQuery()); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters())); + assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); + } + } + + public void testCreate_OpenShiftAiChatCompletionModel_FailsFromInvalidResponseFormat() throws IOException { + // timeout as zero for no retries + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "chatcmpl-921d2eb8f3bc46dd8f4cb0502a4608a7", + "object": "chat.completion", + "created": 1760082857, + "model": "llama-31-8b-instruct", + "not_choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "reasoning_content": null, + "content": "Hello there, how may I assist you today?", + "tool_calls": [] + }, + "logprobs": null, + "finish_reason": "length", + "stop_reason": null + } + ], + "usage": { + "prompt_tokens": 40, + "total_tokens": 140, + "completion_tokens": 100, + "prompt_tokens_details": null + }, + "prompt_logprobs": null, + "kv_transfer_params": null + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createCompletionModel(getUrl(webServer), "secret", "model"); + var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var failureCauseMessage = "Required [choices]"; + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is( + format( + "Failed to send OpenShift AI completion request from inference entity id [inferenceEntityId]. Cause: %s", + failureCauseMessage + ) + ) + ); + assertThat(thrownException.getCause().getMessage(), is(failureCauseMessage)); + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); + } + } + + public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + var contentTooLargeErrorMessage = + "This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;" + + "0 for the completion). Please reduce your prompt; or completion length."; + + String responseJsonContentTooLarge = Strings.format(""" + { + "error": { + "message": "%s", + "type": "content_too_large", + "param": null, + "code": null + } + } + """, contentTooLargeErrorMessage); + + String responseJson = """ + { + "id": "embd-45e6d99b97a645c0af96653598069cd9", + "object": "list", + "created": 1760085467, + "model": "gritlm-7b", + "data": [ + { + "index": 0, + "object": "embedding", + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "usage": { + "prompt_tokens": 7, + "total_tokens": 7, + "completion_tokens": 0, + "prompt_tokens_details": null + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(413).setBody(responseJsonContentTooLarge)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel(getUrl(webServer), "secret", "model"); + var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); + assertThat(webServer.requests(), hasSize(2)); + { + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("input"), is(List.of("abcd"))); + assertThat(requestMap.get("model"), is("model")); + } + { + assertNull(webServer.requests().get(1).getUri().getQuery()); + assertThat( + webServer.requests().get(1).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(1).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(1).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("input"), is(List.of("ab"))); + assertThat(requestMap.get("model"), is("model")); + } + } + } + + public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + var contentTooLargeErrorMessage = + "This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;" + + "0 for the completion). Please reduce your prompt; or completion length."; + + String responseJsonContentTooLarge = Strings.format(""" + { + "error": { + "message": "%s", + "type": "content_too_large", + "param": null, + "code": null + } + } + """, contentTooLargeErrorMessage); + + String responseJson = """ + { + "id": "embd-45e6d99b97a645c0af96653598069cd9", + "object": "list", + "created": 1760085467, + "model": "gritlm-7b", + "data": [ + { + "index": 0, + "object": "embedding", + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "usage": { + "prompt_tokens": 7, + "total_tokens": 7, + "completion_tokens": 0, + "prompt_tokens_details": null + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(400).setBody(responseJsonContentTooLarge)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel(getUrl(webServer), "secret", "model"); + var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); + assertThat(webServer.requests(), hasSize(2)); + { + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("input"), is(List.of("abcd"))); + assertThat(requestMap.get("model"), is("model")); + } + { + assertNull(webServer.requests().get(1).getUri().getQuery()); + assertThat( + webServer.requests().get(1).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(1).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(1).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("input"), is(List.of("ab"))); + assertThat(requestMap.get("model"), is("model")); + } + } + } + + public void testExecute_TruncatesInputBeforeSending() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "embd-45e6d99b97a645c0af96653598069cd9", + "object": "list", + "created": 1760085467, + "model": "gritlm-7b", + "data": [ + { + "index": 0, + "object": "embedding", + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "usage": { + "prompt_tokens": 7, + "total_tokens": 7, + "completion_tokens": 0, + "prompt_tokens_details": null + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + // truncated to 1 token = 3 characters + var model = createModel(getUrl(webServer), "secret", "model", 1); + var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("super long input"), InputTypeTests.randomWithNull()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("input"), is(List.of("sup"))); + assertThat(requestMap.get("model"), is("model")); + } + } + + public void testCreate_OpenShiftAiRerankModel() throws IOException { + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + List documents = List.of("Luke"); + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "rerank-d300256dd02b4c63b8a2bc34dcdad845", + "model": "bge-reranker-v2-m3", + "usage": { + "total_tokens": 30 + }, + "results": [ + { + "index": 1, + "document": { + "text": "awgawgawgawg" + }, + "relevance_score": 0.9921875 + }, + { + "index": 0, + "document": { + "text": "awdawdawda" + }, + "relevance_score": 0.4921875 + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), "secret", "model"); + var actionCreator = new OpenShiftAiActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create(model, null); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new QueryAndDocsInputs("popular name", documents, null, null, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + assertThat( + result.asMap(), + is( + buildExpectationRerank( + List.of( + new RankedDocsResultsTests.RerankExpectation( + Map.of("text", "awgawgawgawg", "index", 1, "relevance_score", 0.9921875f) + ), + new RankedDocsResultsTests.RerankExpectation( + Map.of("text", "awdawdawda", "index", 0, "relevance_score", 0.4921875f) + ) + ) + ) + ) + ); + } + assertRerankActionCreator(documents); + } + + public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() throws IOException { + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + List documents = List.of("Luke"); + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "rerank-d300256dd02b4c63b8a2bc34dcdad845", + "model": "bge-reranker-v2-m3", + "usage": { + "total_tokens": 30 + }, + "not_results": [ + { + "index": 1, + "document": { + "text": "awgawgawgawg" + }, + "relevance_score": 0.9921875 + }, + { + "index": 0, + "document": { + "text": "awdawdawda" + }, + "relevance_score": 0.4921875 + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), "secret", "model"); + var actionCreator = new OpenShiftAiActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create(model, null); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new QueryAndDocsInputs("popular name", documents, null, null, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is(""" + Failed to send OpenShift AI rerank request from inference entity id [inferenceEntityId]. Cause: Failed to find required\ + field [results] in Cohere rerank response""")); + } + assertRerankActionCreator(documents); + } + + private void assertRerankActionCreator(List documents) + throws IOException { + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(5)); + assertThat(requestMap.get("documents"), is(documents)); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("query"), is("popular name")); + assertThat(requestMap.get("top_n"), is(2)); + assertThat(requestMap.get("return_documents"), is(true)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java new file mode 100644 index 0000000000000..454108e8255a6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.completion; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +public class OpenShiftAiChatCompletionModelTests extends ESTestCase { + public static OpenShiftAiChatCompletionModel createCompletionModel(String url, String apiKey, String modelName) { + return createModelWithTaskType(url, apiKey, modelName, TaskType.COMPLETION); + } + + public static OpenShiftAiChatCompletionModel createChatCompletionModel(String url, String apiKey, String modelName) { + return createModelWithTaskType(url, apiKey, modelName, TaskType.CHAT_COMPLETION); + } + + public static OpenShiftAiChatCompletionModel createModelWithTaskType(String url, String apiKey, String modelName, TaskType taskType) { + return new OpenShiftAiChatCompletionModel( + "inferenceEntityId", + taskType, + "service", + new OpenShiftAiChatCompletionServiceSettings(modelName, url, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModelTests.java new file mode 100644 index 0000000000000..f9ad1206ef477 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModelTests.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.embeddings; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +public class OpenShiftAiEmbeddingsModelTests extends ESTestCase { + + public static OpenShiftAiEmbeddingsModel createModel(String url, String apiKey, @Nullable String modelId) { + return createModel(url, apiKey, modelId, 1234); + } + + public static OpenShiftAiEmbeddingsModel createModel(String url, String apiKey, @Nullable String modelId, int maxInputTokens) { + return new OpenShiftAiEmbeddingsModel( + "inferenceEntityId", + TaskType.TEXT_EMBEDDING, + "service", + new OpenShiftAiEmbeddingsServiceSettings(modelId, url, 1536, SimilarityMeasure.DOT_PRODUCT, maxInputTokens, null, false), + null, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java new file mode 100644 index 0000000000000..73e5156dea270 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.rerank; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +public class OpenShiftAiRerankModelTests extends ESTestCase { + + public static OpenShiftAiRerankModel createModel(String url, String apiKey, @Nullable String modelId) { + return new OpenShiftAiRerankModel( + "inferenceEntityId", + TaskType.RERANK, + "service", + new OpenShiftAiRerankServiceSettings(modelId, url, null), + new OpenShiftAIRerankTaskSettings(2, true), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } +} From f804331d48c551011c09dc022a1293947792ae89 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 16 Oct 2025 15:06:00 +0000 Subject: [PATCH 05/70] [CI] Auto commit changes from spotless --- .../openshiftai/action/OpenShiftAiActionCreatorTests.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index 85fea4f40b149..87f7f993a8f1a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -741,8 +741,7 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t assertRerankActionCreator(documents); } - private void assertRerankActionCreator(List documents) - throws IOException { + private void assertRerankActionCreator(List documents) throws IOException { assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat( From af2fcd6ed16d26170217cfe6a8b455ac6ec1ecd0 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 17 Oct 2025 13:52:11 +0300 Subject: [PATCH 06/70] Add tests for UnifiedCompletionRequest model ID overrides in OpenShift AI chat completion --- .../OpenShiftAiChatCompletionModelTests.java | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java index 454108e8255a6..fb5c0189072c8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java @@ -9,9 +9,14 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import java.util.List; + +import static org.hamcrest.Matchers.is; + public class OpenShiftAiChatCompletionModelTests extends ESTestCase { public static OpenShiftAiChatCompletionModel createCompletionModel(String url, String apiKey, String modelName) { return createModelWithTaskType(url, apiKey, modelName, TaskType.COMPLETION); @@ -30,4 +35,94 @@ public static OpenShiftAiChatCompletionModel createModelWithTaskType(String url, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } + + public void testOverrideWith_UnifiedCompletionRequest_KeepsSameModelId() { + var model = createCompletionModel("url", "api_key", "model_name"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "model_name", // same model + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = OpenShiftAiChatCompletionModel.of(model, request); + + assertThat(overriddenModel, is(model)); + } + + public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() { + var model = createCompletionModel("url", "api_key", "model_name"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "different_model", // overriding model + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = OpenShiftAiChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + } + + public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() { + var model = createCompletionModel("url", "api_key", null); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "different_model", // overriding model + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = OpenShiftAiChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + } + + public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() { + var model = createCompletionModel("url", "api_key", null); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + null, // not overriding model + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = OpenShiftAiChatCompletionModel.of(model, request); + + assertNull(overriddenModel.getServiceSettings().modelId()); + } + + public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { + var model = createCompletionModel("url", "api_key", "model_name"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + null, // not overriding model + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = OpenShiftAiChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name")); + } } From b98d8d65608fec4aa25cd34d9a1ca322418812a2 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 17 Oct 2025 14:43:00 +0300 Subject: [PATCH 07/70] Add unit tests for OpenShiftAiChatCompletionResponseHandler --- ...tAiChatCompletionResponseHandlerTests.java | 156 ++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java new file mode 100644 index 0000000000000..114b6563382fe --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.completion; + +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class OpenShiftAiChatCompletionResponseHandlerTests extends ESTestCase { + private final OpenShiftAiChatCompletionResponseHandler responseHandler = new OpenShiftAiChatCompletionResponseHandler( + "chat completions", + (a, b) -> mock() + ); + + public void testFailNotFound() throws IOException { + var responseJson = XContentHelper.stripWhitespace(""" + { + "detail": "Not Found" + } + """); + + var errorJson = invalidResponseJson(responseJson, 404); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + { + "error" : { + "code" : "not_found", + "message" : "Resource not found at [https://api.llama.ai/v1/chat/completions] for request from inference entity id [id] \ + status [404]. Error message: [{\\"detail\\":\\"Not Found\\"}]", + "type" : "openshift_ai_error" + } + }"""))); + } + + public void testFailBadRequest() throws IOException { + var responseJson = XContentHelper.stripWhitespace(""" + { + "object": "error", + "message": "[{'type': 'missing', 'loc': ('body', 'messages'), 'msg': 'Field required', 'input': {'model': 'llama-31-8b-ins\ + truct', '1messages': [{'role': 'user', 'content': 'What is deep learning?'}], 'max_tokens': 2, 'stream': True}}]", + "type": "Bad Request", + "param": null, + "code": 400 + } + """); + + var errorJson = invalidResponseJson(responseJson, 400); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + { + "error": { + "code": "bad_request", + "message": "Received a bad request status code for request from inference entity id [id] status [400].\ + Error message: [{\\"object\\":\\"error\\",\\"message\\":\\"[{'type': 'missing', 'loc': ('body', 'messages'), 'msg': 'Field r\ + equired', 'input': {'model': 'llama-31-8b-ins truct', '1messages': [{'role': 'user', 'content': 'What is deep learning?'}]\ + , 'max_tokens': 2, 'stream': True}}]\\",\\"type\\":\\"Bad Request\\",\\"param\\":null,\\"code\\":400}]", + "type": "openshift_ai_error" + } + } + """))); + } + + public void testFailValidationWithInvalidJson() throws IOException { + var responseJson = """ + what? this isn't a json + """; + + var errorJson = invalidResponseJson(responseJson, 500); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + { + "error": { + "code": "bad_request", + "message": "Received a server error status code for request from inference entity id [id] status [500]. Error message: \ + [what? this isn't a json\\n]", + "type": "openshift_ai_error" + } + } + """))); + } + + private String invalidResponseJson(String responseJson, int statusCode) throws IOException { + var exception = invalidResponse(responseJson, statusCode); + assertThat(exception, isA(RetryException.class)); + assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class)); + return toJson((UnifiedChatCompletionException) unwrapCause(exception)); + } + + private Exception invalidResponse(String responseJson, int statusCode) { + return expectThrows( + RetryException.class, + () -> responseHandler.validateResponse( + mock(), + mock(), + mockRequest(), + new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + } + + private static Request mockRequest() throws URISyntaxException { + var request = mock(Request.class); + when(request.getInferenceEntityId()).thenReturn("id"); + when(request.isStreaming()).thenReturn(true); + when(request.getURI()).thenReturn(new URI("https://api.llama.ai/v1/chat/completions")); + return request; + } + + private static HttpResponse mockErrorResponse(int statusCode) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var response = mock(HttpResponse.class); + when(response.getStatusLine()).thenReturn(statusLine); + + return response; + } + + private String toJson(UnifiedChatCompletionException e) throws IOException { + try (var builder = XContentFactory.jsonBuilder()) { + e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + } + } + +} From b19342fd53f96e9842b85ca4876dc402cdd5bbd0 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 17 Oct 2025 16:06:45 +0300 Subject: [PATCH 08/70] Add unit tests for OpenShiftAiChatCompletionServiceSettings --- ...tAiChatCompletionServiceSettingsTests.java | 197 ++++++++++++++++++ 1 file changed, 197 insertions(+) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..5109bf69a6281 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java @@ -0,0 +1,197 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class OpenShiftAiChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase< + OpenShiftAiChatCompletionServiceSettings> { + + public static final String MODEL_ID = "some model"; + public static final String CORRECT_URL = "https://www.elastic.co"; + public static final int RATE_LIMIT = 2; + + public void testFromMap_AllFields_Success() { + var serviceSettings = OpenShiftAiChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + CORRECT_URL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is(new OpenShiftAiChatCompletionServiceSettings(MODEL_ID, CORRECT_URL, new RateLimitSettings(RATE_LIMIT))) + ); + } + + public void testFromMap_MissingModelId_Success() { + var serviceSettings = OpenShiftAiChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + CORRECT_URL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings, is(new OpenShiftAiChatCompletionServiceSettings(null, CORRECT_URL, new RateLimitSettings(RATE_LIMIT)))); + } + + public void testFromMap_MissingUrl_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [url];") + ); + } + + public void testFromMap_MissingRateLimit_Success() { + var serviceSettings = OpenShiftAiChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings, is(new OpenShiftAiChatCompletionServiceSettings(MODEL_ID, CORRECT_URL, null))); + } + + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = OpenShiftAiChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + CORRECT_URL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + var expected = XContentHelper.stripWhitespace(""" + { + "model_id": "some model", + "url": "https://www.elastic.co", + "rate_limit": { + "requests_per_minute": 2 + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException { + var serviceSettings = OpenShiftAiChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.URL, CORRECT_URL)), + ConfigurationParseContext.PERSISTENT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + var expected = XContentHelper.stripWhitespace(""" + { + "url": "https://www.elastic.co", + "rate_limit": { + "requests_per_minute": 3000 + } + } + """); + assertThat(xContentResult, is(expected)); + } + + @Override + protected Writeable.Reader instanceReader() { + return OpenShiftAiChatCompletionServiceSettings::new; + } + + @Override + protected OpenShiftAiChatCompletionServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected OpenShiftAiChatCompletionServiceSettings mutateInstance(OpenShiftAiChatCompletionServiceSettings instance) + throws IOException { + return randomValueOtherThan(instance, OpenShiftAiChatCompletionServiceSettingsTests::createRandom); + } + + @Override + protected OpenShiftAiChatCompletionServiceSettings mutateInstanceForVersion( + OpenShiftAiChatCompletionServiceSettings instance, + TransportVersion version + ) { + return instance; + } + + private static OpenShiftAiChatCompletionServiceSettings createRandom() { + var modelId = randomAlphaOfLengthOrNull(8); + var url = randomAlphaOfLength(15); + return new OpenShiftAiChatCompletionServiceSettings(modelId, ServiceUtils.createUri(url), RateLimitSettingsTests.createRandom()); + } + + public static Map getServiceSettingsMap(String model, String url) { + var map = new HashMap(); + + map.put(ServiceFields.MODEL_ID, model); + map.put(ServiceFields.URL, url); + + return map; + } + +} From 6af168cff04991d7e7b7fc4d5efdba16c2b6294f Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 17 Oct 2025 16:33:46 +0300 Subject: [PATCH 09/70] Update request type description in OpenShiftAiCompletionResponseHandler --- .../completion/OpenShiftAiCompletionResponseHandler.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiCompletionResponseHandler.java index c1703df534daf..65522fc4495bf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiCompletionResponseHandler.java @@ -19,7 +19,7 @@ public class OpenShiftAiCompletionResponseHandler extends OpenAiChatCompletionRe /** * Constructs a OpenShiftAiCompletionResponseHandler with the specified request type and response parser. * - * @param requestType The type of request being handled (e.g., "llama completions"). + * @param requestType The type of request being handled (e.g., "Openshift AI completions"). * @param parseFunction The function to parse the response. */ public OpenShiftAiCompletionResponseHandler(String requestType, ResponseParser parseFunction) { From aadbfde0f9b9f8ff57b689bf9b9cccb839723808 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 17 Oct 2025 19:36:14 +0300 Subject: [PATCH 10/70] Refactor OpenShiftAiEmbeddingsServiceSettings to improve validation logic and update dimensionsSetByUser handling --- .../OpenShiftAiEmbeddingsServiceSettings.java | 40 +- ...ShiftAiEmbeddingsServiceSettingsTests.java | 649 ++++++++++++++++++ 2 files changed, 676 insertions(+), 13 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java index 3db3586bfbc35..60fe0ce62f86d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java @@ -15,7 +15,9 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.InferenceUtils; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiService; import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -78,10 +80,24 @@ public static OpenShiftAiEmbeddingsServiceSettings fromMap(Map m OpenShiftAiService.NAME, context ); - Boolean dimensionsSetByUser = switch (context) { - case REQUEST -> dimensions != null; - case PERSISTENT -> extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException); - }; + Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException); + switch (context) { + case REQUEST -> { + if (dimensionsSetByUser != null) { + validationException.addValidationError( + ServiceUtils.invalidSettingError(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS) + ); + } + dimensionsSetByUser = dimensions != null; + } + case PERSISTENT -> { + if (dimensionsSetByUser == null) { + validationException.addValidationError( + InferenceUtils.missingSettingErrorMsg(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS) + ); + } + } + } if (validationException.validationErrors().isEmpty() == false) { throw validationException; } @@ -108,7 +124,7 @@ public OpenShiftAiEmbeddingsServiceSettings(StreamInput in) throws IOException { this.dimensions = in.readOptionalVInt(); this.similarity = in.readOptionalEnum(SimilarityMeasure.class); this.maxInputTokens = in.readOptionalVInt(); - this.dimensionsSetByUser = in.readOptionalBoolean(); + this.dimensionsSetByUser = in.readBoolean(); } /** @@ -120,7 +136,7 @@ public OpenShiftAiEmbeddingsServiceSettings(StreamInput in) throws IOException { * @param similarity the similarity measure to use, can be null * @param maxInputTokens the maximum number of input tokens, can be null * @param rateLimitSettings the rate limit settings for the service, can be null - * @param dimensionsSetByUser indicates if dimensions were set by the user, can be null + * @param dimensionsSetByUser indicates if dimensions were set by the user */ public OpenShiftAiEmbeddingsServiceSettings( @Nullable String modelId, @@ -129,7 +145,7 @@ public OpenShiftAiEmbeddingsServiceSettings( @Nullable SimilarityMeasure similarity, @Nullable Integer maxInputTokens, @Nullable RateLimitSettings rateLimitSettings, - @Nullable Boolean dimensionsSetByUser + Boolean dimensionsSetByUser ) { super(modelId, uri, rateLimitSettings); this.dimensions = dimensions; @@ -147,7 +163,7 @@ public OpenShiftAiEmbeddingsServiceSettings( * @param similarity the similarity measure to use, can be null * @param maxInputTokens the maximum number of input tokens, can be null * @param rateLimitSettings the rate limit settings for the service, can be null - * @param dimensionsSetByUser indicates if dimensions were set by the user, can be null + * @param dimensionsSetByUser indicates if dimensions were set by the user */ public OpenShiftAiEmbeddingsServiceSettings( String modelId, @@ -156,7 +172,7 @@ public OpenShiftAiEmbeddingsServiceSettings( @Nullable SimilarityMeasure similarity, @Nullable Integer maxInputTokens, @Nullable RateLimitSettings rateLimitSettings, - @Nullable Boolean dimensionsSetByUser + Boolean dimensionsSetByUser ) { this(modelId, createUri(url), dimensions, similarity, maxInputTokens, rateLimitSettings, dimensionsSetByUser); } @@ -201,7 +217,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalVInt(dimensions); out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); out.writeOptionalVInt(maxInputTokens); - out.writeOptionalBoolean(dimensionsSetByUser); + out.writeBoolean(dimensionsSetByUser); } @Override @@ -217,9 +233,7 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil if (maxInputTokens != null) { builder.field(MAX_INPUT_TOKENS, maxInputTokens); } - if (dimensionsSetByUser != null) { - builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); - } + builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); return builder; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..dde32cc099e43 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java @@ -0,0 +1,649 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.ByteArrayStreamInput; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class OpenShiftAiEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase { + private static final String MODEL_ID = "some model"; + private static final String CORRECT_URL = "https://www.elastic.co"; + private static final int DIMENSIONS = 384; + private static final SimilarityMeasure SIMILARITY_MEASURE = SimilarityMeasure.DOT_PRODUCT; + private static final int MAX_INPUT_TOKENS = 128; + private static final int RATE_LIMIT = 2; + + public void testFromMap_AllFields_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + true + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT), + true + ) + ) + ); + } + + public void testFromMap_NoModelId_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + null, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + false + ), + ConfigurationParseContext.PERSISTENT + ); + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + null, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT), + false + ) + ) + ); + } + + public void testFromMap_NoUrl_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + null, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + false + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [url];") + ); + } + + public void testFromMap_EmptyUrl_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + "", + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + false + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value empty string. [url] must be a non-empty string;") + ); + } + + public void testFromMap_InvalidUrl_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + "^^^", + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + false + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString( + "Validation Failed: 1: [service_settings] Invalid url [^^^] received for field [url]. " + + "Error: unable to parse url [^^^]. Reason: Illegal character in path;" + ) + ); + } + + public void testFromMap_NoSimilarity_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + null, + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + false + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + null, + MAX_INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT), + false + ) + ) + ); + } + + public void testFromMap_InvalidSimilarity_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + "by_size", + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + false + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString( + "Validation Failed: 1: [service_settings] Invalid value [by_size] received. " + + "[similarity] must be one of [cosine, dot_product, l2_norm];" + ) + ); + } + + public void testFromMap_NoDimensions_SetByUserFalse_Persistent_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + null, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + false + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + null, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT), + false + ) + ) + ); + } + + public void testFromMap_Persistent_WithDimensions_SetByUserFalse_Persistent_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + false + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT), + false + ) + ) + ); + } + + public void testFromMap_WithDimensions_SetByUserNull_Persistent_Success() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + null + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [dimensions_set_by_user];") + ); + } + + public void testFromMap_NoDimensions_SetByUserNull_Request_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + null, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + null + ), + ConfigurationParseContext.REQUEST + ); + + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + null, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT), + false + ) + ) + ); + } + + public void testFromMap_WithDimensions_SetByUserNull_Request_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + null + ), + ConfigurationParseContext.REQUEST + ); + + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT), + true + ) + ) + ); + } + + public void testFromMap_WithDimensions_SetByUserTrue_Request_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + true + ), + ConfigurationParseContext.REQUEST + ) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not allow the setting [dimensions_set_by_user];") + ); + } + + public void testFromMap_ZeroDimensions_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + 0, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + false + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_NegativeDimensions_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + -10, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + false + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [-10]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_NoInputTokens_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + null, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + false + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + null, + new RateLimitSettings(RATE_LIMIT), + false + ) + ) + ); + } + + public void testFromMap_ZeroInputTokens_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + 0, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + false + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_NegativeInputTokens_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + -10, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + false + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [-10]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_NoRateLimit_Success() { + var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap(MODEL_ID, CORRECT_URL, SIMILARITY_MEASURE.toString(), DIMENSIONS, MAX_INPUT_TOKENS, null, false), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(3000), + false + ) + ) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new OpenShiftAiEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(3), + false + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(XContentHelper.stripWhitespace(""" + { + "model_id": "some model", + "url": "https://www.elastic.co", + "rate_limit": { + "requests_per_minute": 3 + }, + "dimensions": 384, + "similarity": "dot_product", + "max_input_tokens": 128, + "dimensions_set_by_user": false + } + """))); + } + + public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException { + var outputBuffer = new BytesStreamOutput(); + var settings = new OpenShiftAiEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(3), + false + ); + settings.writeTo(outputBuffer); + + var outputBufferRef = outputBuffer.bytes(); + var inputBuffer = new ByteArrayStreamInput(outputBufferRef.array()); + + var settingsFromBuffer = new OpenShiftAiEmbeddingsServiceSettings(inputBuffer); + + assertEquals(settings, settingsFromBuffer); + } + + @Override + protected Writeable.Reader instanceReader() { + return OpenShiftAiEmbeddingsServiceSettings::new; + } + + @Override + protected OpenShiftAiEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected OpenShiftAiEmbeddingsServiceSettings mutateInstance(OpenShiftAiEmbeddingsServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, OpenShiftAiEmbeddingsServiceSettingsTests::createRandom); + } + + private static OpenShiftAiEmbeddingsServiceSettings createRandom() { + var modelId = randomAlphaOfLength(8); + var url = randomAlphaOfLength(15); + var similarityMeasure = randomFrom(SimilarityMeasure.values()); + var dimensions = randomIntBetween(32, 256); + var maxInputTokens = randomIntBetween(128, 256); + return new OpenShiftAiEmbeddingsServiceSettings( + modelId, + url, + dimensions, + similarityMeasure, + maxInputTokens, + RateLimitSettingsTests.createRandom(), + randomBoolean() + ); + } + + public static HashMap buildServiceSettingsMap( + @Nullable String modelId, + @Nullable String url, + @Nullable String similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + @Nullable HashMap rateLimitSettings, + @Nullable Boolean dimensionsSetByUser + ) { + HashMap result = new HashMap<>(); + if (modelId != null) { + result.put(ServiceFields.MODEL_ID, modelId); + } + if (url != null) { + result.put(ServiceFields.URL, url); + } + if (similarity != null) { + result.put(ServiceFields.SIMILARITY, similarity); + } + if (dimensions != null) { + result.put(ServiceFields.DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + result.put(ServiceFields.MAX_INPUT_TOKENS, maxInputTokens); + } + if (rateLimitSettings != null) { + result.put(RateLimitSettings.FIELD_NAME, rateLimitSettings); + } + if (dimensionsSetByUser != null) { + result.put("dimensions_set_by_user", dimensionsSetByUser); + } + return result; + } + +} From fc5c182ad021521a54eed47584adb93a8e8e60c5 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 17 Oct 2025 20:05:44 +0300 Subject: [PATCH 11/70] Update OpenShiftAiChatCompletionRequestEntity to use new method for max tokens and add unit tests for request creation and validation --- ...penShiftAiChatCompletionRequestEntity.java | 2 +- ...OpenShiftAiChatCompletionRequestTests.java | 59 +++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntity.java index 1a53c4180b97a..e8194116c403c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntity.java @@ -33,7 +33,7 @@ public OpenShiftAiChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(modelId, params)); + unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokensAndSkipStreamOptionsField(modelId, params)); builder.endObject(); return builder; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java new file mode 100644 index 0000000000000..ce2c8a946e3a6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.completion; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModelTests; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class OpenShiftAiChatCompletionRequestTests extends ESTestCase { + public void testCreateRequest_WithStreaming() throws IOException { + String input = randomAlphaOfLength(15); + var request = createRequest("model", "url", "secret", input, true); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(request.getURI().toString(), is("url")); + assertThat(requestMap.get("stream"), is(true)); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertNull(requestMap.get("stream_options")); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input)))); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + } + + public void testTruncate_DoesNotReduceInputTextSize() { + String input = randomAlphaOfLength(5); + var request = createRequest("model", "url", "secret", input, true); + assertThat(request.truncate(), is(request)); + } + + public void testTruncationInfo_ReturnsNull() { + var request = createRequest("model", "url", "secret", randomAlphaOfLength(5), true); + assertNull(request.getTruncationInfo()); + } + + public static OpenShiftAiChatCompletionRequest createRequest(String modelId, String url, String apiKey, String input, boolean stream) { + var chatCompletionModel = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(url, apiKey, modelId); + return new OpenShiftAiChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); + } + +} From d6646440c2ccfeb0b89a8201f33e5a6c7cbdfcd2 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 17 Oct 2025 20:13:00 +0300 Subject: [PATCH 12/70] Add unit tests for OpenShiftAiChatCompletionRequestEntity serialization --- ...iftAiChatCompletionRequestEntityTests.java | 99 +++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..084d6e986c527 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; + +import java.io.IOException; +import java.util.ArrayList; + +public class OpenShiftAiChatCompletionRequestEntityTests extends ESTestCase { + private static final String ROLE = "user"; + + public void testSerializationWithModelIdStreaming() throws IOException { + testSerialization("modelId", true, """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "model": "modelId", + "n": 1, + "stream": true + } + """); + } + + public void testSerializationWithModelIdNonStreaming() throws IOException { + testSerialization("modelId", false, """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "model": "modelId", + "n": 1, + "stream": false + } + """); + } + + public void testSerializationWithoutModelIdStreaming() throws IOException { + testSerialization(null, true, """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "n": 1, + "stream": true + } + """); + } + + public void testSerializationWithoutModelIdNonStreaming() throws IOException { + testSerialization(null, false, """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "n": 1, + "stream": false + } + """); + } + + private static void testSerialization(String modelId, boolean isStreaming, String expectedJson) throws IOException { + var message = new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("Hello, world!"), ROLE, null, null); + + var messageList = new ArrayList(); + messageList.add(message); + + var unifiedRequest = UnifiedCompletionRequest.of(messageList); + var unifiedChatInput = new UnifiedChatInput(unifiedRequest, isStreaming); + + var entity = new OpenShiftAiChatCompletionRequestEntity(unifiedChatInput, modelId); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder)); + } + +} From e6d407992e501d2d6a4e1b25f6e4df1c10c45725 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Mon, 20 Oct 2025 13:31:45 +0300 Subject: [PATCH 13/70] Add unit tests for OpenShiftAiEmbeddingsRequest and update model creation logic --- .../OpenShiftAiEmbeddingsRequestEntity.java | 1 - .../OpenShiftAiEmbeddingsModelTests.java | 21 ++- .../OpenShiftAiEmbeddingsRequestTests.java | 121 ++++++++++++++++++ 3 files changed, 141 insertions(+), 2 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntity.java index 3b06fe5592582..3268001d54f1d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntity.java @@ -28,7 +28,6 @@ public record OpenShiftAiEmbeddingsRequestEntity( public OpenShiftAiEmbeddingsRequestEntity { Objects.requireNonNull(input); - Objects.requireNonNull(modelId); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModelTests.java index f9ad1206ef477..f815f31848ebf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModelTests.java @@ -21,11 +21,30 @@ public static OpenShiftAiEmbeddingsModel createModel(String url, String apiKey, } public static OpenShiftAiEmbeddingsModel createModel(String url, String apiKey, @Nullable String modelId, int maxInputTokens) { + return createModel(url, apiKey, modelId, maxInputTokens, false, 1536); + } + + public static OpenShiftAiEmbeddingsModel createModel( + String url, + String apiKey, + @Nullable String modelId, + @Nullable Integer maxInputTokens, + @Nullable Boolean dimensionsSetByUser, + @Nullable Integer dimensions + ) { return new OpenShiftAiEmbeddingsModel( "inferenceEntityId", TaskType.TEXT_EMBEDDING, "service", - new OpenShiftAiEmbeddingsServiceSettings(modelId, url, 1536, SimilarityMeasure.DOT_PRODUCT, maxInputTokens, null, false), + new OpenShiftAiEmbeddingsServiceSettings( + modelId, + url, + dimensions, + SimilarityMeasure.DOT_PRODUCT, + maxInputTokens, + null, + dimensionsSetByUser + ), null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..8ec14b2ef0fd5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.embeddings; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModelTests; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class OpenShiftAiEmbeddingsRequestTests extends ESTestCase { + + public void testCreateRequest_NoDimensions_DimensionsSetByUserFalse_Success() throws IOException { + testCreateRequest_Success(null, false, null); + } + + public void testCreateRequest_NoDimensions_DimensionsSetByUserTrue_Success() throws IOException { + testCreateRequest_Success(null, true, null); + } + + public void testCreateRequest_WithDimensions_DimensionsSetByUserFalse_Success() throws IOException { + testCreateRequest_Success(384, false, null); + } + + public void testCreateRequest_WithDimensions_DimensionsSetByUserTrue_Success() throws IOException { + testCreateRequest_Success(384, true, 384); + } + + private void testCreateRequest_Success(Integer dimensions, boolean dimensionsSetByUser, Integer expectedDimensions) throws IOException { + var request = createRequest(dimensions, dimensionsSetByUser); + var httpRequest = request.createHttpRequest(); + var httpPost = validateRequestUrlAndContentType(httpRequest); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.get("input"), is(List.of("ABCD"))); + assertThat(requestMap.get("model"), is("llama-embed")); + assertThat(requestMap.get("dimensions"), is(expectedDimensions)); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer apikey")); + } + + public void testCreateRequest_NoModel_Success() throws IOException { + var request = createRequest(null, false, null); + var httpRequest = request.createHttpRequest(); + var httpPost = validateRequestUrlAndContentType(httpRequest); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.get("input"), is(List.of("ABCD"))); + assertNull(requestMap.get("model")); + assertNull(requestMap.get("dimensions")); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer apikey")); + + } + + public void testTruncate_ReducesInputTextSizeByHalf() throws IOException { + var request = createRequest(null, false); + var truncatedRequest = request.truncate(); + + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get("input"), is(List.of("AB"))); + assertThat(requestMap.get("model"), is("llama-embed")); + + } + + public void testIsTruncated_ReturnsTrue() { + var request = createRequest(null, false); + assertFalse(request.getTruncationInfo()[0]); + + var truncatedRequest = request.truncate(); + assertTrue(truncatedRequest.getTruncationInfo()[0]); + } + + private HttpPost validateRequestUrlAndContentType(HttpRequest request) { + assertThat(request.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) request.httpRequestBase(); + assertThat(httpPost.getURI().toString(), is("url")); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters())); + return httpPost; + } + + private static OpenShiftAiEmbeddingsRequest createRequest(Integer dimensions, Boolean dimensionsSetByUser) { + return createRequest(dimensions, dimensionsSetByUser, "llama-embed"); + } + + private static OpenShiftAiEmbeddingsRequest createRequest(Integer dimensions, Boolean dimensionsSetByUser, String modelId) { + var embeddingsModel = OpenShiftAiEmbeddingsModelTests.createModel( + "url", + "apikey", + modelId, + dimensions, + dimensionsSetByUser, + dimensions + ); + return new OpenShiftAiEmbeddingsRequest( + TruncatorTests.createTruncator(), + new Truncator.TruncationResult(List.of("ABCD"), new boolean[] { false }), + embeddingsModel + ); + } + +} From fb3109463a90b27b5db0a2ed66e68db974341f5e Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Mon, 20 Oct 2025 15:01:49 +0300 Subject: [PATCH 14/70] Add unit tests for OpenShiftAiEmbeddingsRequestEntity --- .../action/OpenShiftAiActionCreatorTests.java | 2 +- ...enShiftAiEmbeddingsRequestEntityTests.java | 81 +++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntityTests.java diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index 87f7f993a8f1a..fe04d818dcc02 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -42,8 +42,8 @@ import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests.buildExpectationRerank; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..c46de709f7eec --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntityTests.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class OpenShiftAiEmbeddingsRequestEntityTests extends ESTestCase { + + public void testXContent_DoesNotWriteDimensionsWhenNullAndSetByUserIsFalse() throws IOException { + var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of("abc"), "model", null, false); + testXContent_DoesNotWriteDimensions(entity); + } + + public void testXContent_DoesNotWriteDimensionsWhenNotSetByUser() throws IOException { + var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of("abc"), "model", 100, false); + testXContent_DoesNotWriteDimensions(entity); + } + + public void testXContent_DoesNotWriteDimensionsWhenNull_EvenIfSetByUserIsTrue() throws IOException { + var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of("abc"), "model", null, true); + testXContent_DoesNotWriteDimensions(entity); + } + + private static void testXContent_DoesNotWriteDimensions(OpenShiftAiEmbeddingsRequestEntity entity) throws IOException { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(XContentHelper.stripWhitespace(""" + { + "input": ["abc"], + "model": "model" + } + """))); + } + + public void testXContent_DoesNotWriteModelWhenItIsNull() throws IOException { + var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of("abc"), null, null, false); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(XContentHelper.stripWhitespace(""" + { + "input": ["abc"] + } + """))); + } + + public void testXContent_WritesDimensionsWhenNonNull_AndSetByUserIsTrue() throws IOException { + var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of("abc"), "model", 100, true); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(XContentHelper.stripWhitespace(""" + { + "input": ["abc"], + "model": "model", + "dimensions": 100 + } + """))); + } +} From 8e7833780aa0bce37f6796d335ed7ad5e43d2b61 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Mon, 20 Oct 2025 18:40:58 +0300 Subject: [PATCH 15/70] Fix Typo in OpenShiftAiRerankTaskSettings, add tests for request models --- .../InferenceNamedWriteablesProvider.java | 4 +- .../OpenShiftAIRerankRequestEntity.java | 16 ++- .../rarank/OpenShiftAiRerankRequest.java | 6 +- .../rerank/OpenShiftAiRerankModel.java | 14 +-- ...ava => OpenShiftAiRerankTaskSettings.java} | 44 +++---- .../OpenShiftAIRerankRequestEntityTests.java | 63 ++++++++++ .../rarank/OpenShiftAiRerankRequestTests.java | 108 ++++++++++++++++++ .../rerank/OpenShiftAiRerankModelTests.java | 12 +- 8 files changed, 229 insertions(+), 38 deletions(-) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/{OpenShiftAIRerankTaskSettings.java => OpenShiftAiRerankTaskSettings.java} (78%) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 8fec44654a820..125b310717984 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -115,8 +115,8 @@ import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsServiceSettings; -import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAIRerankTaskSettings; import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankTaskSettings; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -472,7 +472,7 @@ private static void addOpenShiftAiNamedWriteables(List taskSettings) { - var requestTaskSettings = OpenShiftAIRerankTaskSettings.fromMap(taskSettings); - return new OpenShiftAiRerankModel(model, OpenShiftAIRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + var requestTaskSettings = OpenShiftAiRerankTaskSettings.fromMap(taskSettings); + return new OpenShiftAiRerankModel(model, OpenShiftAiRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); } public OpenShiftAiRerankModel( @@ -39,7 +39,7 @@ public OpenShiftAiRerankModel( taskType, service, OpenShiftAiRerankServiceSettings.fromMap(serviceSettings, context), - OpenShiftAIRerankTaskSettings.fromMap(taskSettings), + OpenShiftAiRerankTaskSettings.fromMap(taskSettings), DefaultSecretSettings.fromMap(secrets) ); } @@ -50,7 +50,7 @@ public OpenShiftAiRerankModel( TaskType taskType, String service, OpenShiftAiRerankServiceSettings serviceSettings, - OpenShiftAIRerankTaskSettings taskSettings, + OpenShiftAiRerankTaskSettings taskSettings, @Nullable DefaultSecretSettings secretSettings ) { super( @@ -59,7 +59,7 @@ public OpenShiftAiRerankModel( ); } - private OpenShiftAiRerankModel(OpenShiftAiRerankModel model, OpenShiftAIRerankTaskSettings taskSettings) { + private OpenShiftAiRerankModel(OpenShiftAiRerankModel model, OpenShiftAiRerankTaskSettings taskSettings) { super(model, taskSettings); } @@ -69,8 +69,8 @@ public OpenShiftAiRerankServiceSettings getServiceSettings() { } @Override - public OpenShiftAIRerankTaskSettings getTaskSettings() { - return (OpenShiftAIRerankTaskSettings) super.getTaskSettings(); + public OpenShiftAiRerankTaskSettings getTaskSettings() { + return (OpenShiftAiRerankTaskSettings) super.getTaskSettings(); } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAIRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettings.java similarity index 78% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAIRerankTaskSettings.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettings.java index aa8396178e4a4..b285a9c0a5072 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAIRerankTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettings.java @@ -28,21 +28,21 @@ /** * Defines the task settings for the OpenShift AI rerank service. */ -public class OpenShiftAIRerankTaskSettings implements TaskSettings { +public class OpenShiftAiRerankTaskSettings implements TaskSettings { public static final String NAME = "openshift_ai_rerank_task_settings"; public static final String RETURN_DOCUMENTS = "return_documents"; public static final String TOP_N = "top_n"; - private static final OpenShiftAIRerankTaskSettings EMPTY_SETTINGS = new OpenShiftAIRerankTaskSettings(null, null); + private static final OpenShiftAiRerankTaskSettings EMPTY_SETTINGS = new OpenShiftAiRerankTaskSettings(null, null); /** - * Creates a new {@link OpenShiftAIRerankTaskSettings} from a map of settings. + * Creates a new {@link OpenShiftAiRerankTaskSettings} from a map of settings. * @param map the map of settings - * @return a constructed {@link OpenShiftAIRerankTaskSettings} + * @return a constructed {@link OpenShiftAiRerankTaskSettings} * @throws ValidationException if any of the settings are invalid */ - public static OpenShiftAIRerankTaskSettings fromMap(Map map) { + public static OpenShiftAiRerankTaskSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); if (map == null || map.isEmpty()) { @@ -60,17 +60,17 @@ public static OpenShiftAIRerankTaskSettings fromMap(Map map) { } /** - * Creates a new {@link OpenShiftAIRerankTaskSettings} by using non-null fields from the request settings over the original settings. + * Creates a new {@link OpenShiftAiRerankTaskSettings} by using non-null fields from the request settings over the original settings. * * @param originalSettings the settings stored as part of the inference entity configuration * @param requestTaskSettings the settings passed in within the task_settings field of the request - * @return a constructed {@link OpenShiftAIRerankTaskSettings} + * @return a constructed {@link OpenShiftAiRerankTaskSettings} */ - public static OpenShiftAIRerankTaskSettings of( - OpenShiftAIRerankTaskSettings originalSettings, - OpenShiftAIRerankTaskSettings requestTaskSettings + public static OpenShiftAiRerankTaskSettings of( + OpenShiftAiRerankTaskSettings originalSettings, + OpenShiftAiRerankTaskSettings requestTaskSettings ) { - return new OpenShiftAIRerankTaskSettings( + return new OpenShiftAiRerankTaskSettings( requestTaskSettings.getTopN() != null ? requestTaskSettings.getTopN() : originalSettings.getTopN(), requestTaskSettings.getReturnDocuments() != null ? requestTaskSettings.getReturnDocuments() @@ -79,36 +79,36 @@ public static OpenShiftAIRerankTaskSettings of( } /** - * Creates a new {@link OpenShiftAIRerankTaskSettings} with the specified settings. + * Creates a new {@link OpenShiftAiRerankTaskSettings} with the specified settings. * * @param topN the number of top documents to return * @param returnDocuments whether to return the documents - * @return a constructed {@link OpenShiftAIRerankTaskSettings} + * @return a constructed {@link OpenShiftAiRerankTaskSettings} */ - public static OpenShiftAIRerankTaskSettings of(@Nullable Integer topN, @Nullable Boolean returnDocuments) { - return new OpenShiftAIRerankTaskSettings(topN, returnDocuments); + public static OpenShiftAiRerankTaskSettings of(@Nullable Integer topN, @Nullable Boolean returnDocuments) { + return new OpenShiftAiRerankTaskSettings(topN, returnDocuments); } private final Integer topN; private final Boolean returnDocuments; /** - * Constructs a new {@link OpenShiftAIRerankTaskSettings} by reading from a {@link StreamInput}. + * Constructs a new {@link OpenShiftAiRerankTaskSettings} by reading from a {@link StreamInput}. * * @param in the stream input to read from * @throws IOException if an I/O error occurs */ - public OpenShiftAIRerankTaskSettings(StreamInput in) throws IOException { + public OpenShiftAiRerankTaskSettings(StreamInput in) throws IOException { this(in.readOptionalInt(), in.readOptionalBoolean()); } /** - * Constructs a new {@link OpenShiftAIRerankTaskSettings} with the specified settings. + * Constructs a new {@link OpenShiftAiRerankTaskSettings} with the specified settings. * * @param topN the number of top documents to return * @param doReturnDocuments whether to return the documents */ - public OpenShiftAIRerankTaskSettings(@Nullable Integer topN, @Nullable Boolean doReturnDocuments) { + public OpenShiftAiRerankTaskSettings(@Nullable Integer topN, @Nullable Boolean doReturnDocuments) { this.topN = topN; this.returnDocuments = doReturnDocuments; } @@ -157,7 +157,7 @@ public void writeTo(StreamOutput out) throws IOException { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - OpenShiftAIRerankTaskSettings that = (OpenShiftAIRerankTaskSettings) o; + OpenShiftAiRerankTaskSettings that = (OpenShiftAiRerankTaskSettings) o; return Objects.equals(returnDocuments, that.returnDocuments) && Objects.equals(topN, that.topN); } @@ -176,7 +176,7 @@ public Boolean getReturnDocuments() { @Override public TaskSettings updatedTaskSettings(Map newSettings) { - OpenShiftAIRerankTaskSettings updatedSettings = OpenShiftAIRerankTaskSettings.fromMap(new HashMap<>(newSettings)); - return OpenShiftAIRerankTaskSettings.of(this, updatedSettings); + OpenShiftAiRerankTaskSettings updatedSettings = OpenShiftAiRerankTaskSettings.fromMap(new HashMap<>(newSettings)); + return OpenShiftAiRerankTaskSettings.of(this, updatedSettings); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java new file mode 100644 index 0000000000000..4ec7d391a08f1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.rarank; + +import junit.framework.TestCase; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace; + +public class OpenShiftAIRerankRequestEntityTests extends TestCase { + private static final String INPUT = "documents"; + private static final String QUERY = "query"; + private static final String MODEL = "model"; + private static final Integer TOP_N = 8; + private static final Boolean RETURN_DOCUMENTS = true; + + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new OpenShiftAIRerankRequestEntity(MODEL, QUERY, List.of(INPUT), RETURN_DOCUMENTS, TOP_N); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = Strings.toString(builder); + String expected = """ + { + "model": "model", + "query": "query", + "documents": ["documents"], + "top_n": 8, + "return_documents": true + } + """; + assertEquals(stripWhitespace(expected), result); + } + + public void testXContent_WritesMinimalFields() throws IOException { + var entity = new OpenShiftAIRerankRequestEntity(null, QUERY, List.of(INPUT), null, null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = Strings.toString(builder); + String expected = """ + { + "query": "query", + "documents": ["documents"] + } + """; + assertEquals(stripWhitespace(expected), result); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java new file mode 100644 index 0000000000000..53246118ac1f8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java @@ -0,0 +1,108 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.request.rarank; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankModelTests; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class OpenShiftAiRerankRequestTests extends ESTestCase { + private static final String INPUT = "documents"; + private static final String QUERY = "query"; + private static final String MODEL_ID = "modelId"; + private static final Integer TOP_N = 8; + private static final Boolean RETURN_TEXT = false; + + private static final String AUTH_HEADER_VALUE = "Bearer secret"; + + public void testCreateRequest_WithMinimalFieldsSet() throws IOException { + testCreateRequest(null, null, null, createRequest(null, null, null)); + } + + public void testCreateRequest_WithTopN() throws IOException { + testCreateRequest(TOP_N, null, null, createRequest(TOP_N, null, null)); + } + + public void testCreateRequest_WithReturnDocuments() throws IOException { + testCreateRequest(null, RETURN_TEXT, null, createRequest(null, RETURN_TEXT, null)); + } + + public void testCreateRequest_WithModelId() throws IOException { + testCreateRequest(null, null, MODEL_ID, createRequest(null, null, MODEL_ID)); + } + + public void testCreateRequest_AllFields() throws IOException { + testCreateRequest(TOP_N, RETURN_TEXT, MODEL_ID, createRequest(TOP_N, RETURN_TEXT, MODEL_ID)); + } + + public void testCreateRequest_AllFields_OverridesTaskSettings() throws IOException { + testCreateRequest(TOP_N, RETURN_TEXT, MODEL_ID, createRequestWithDifferentTaskSettings(TOP_N, RETURN_TEXT)); + } + + public void testCreateRequest_AllFields_KeepsTaskSettings() throws IOException { + testCreateRequest(1, true, MODEL_ID, createRequestWithDifferentTaskSettings(null, null)); + } + + private void testCreateRequest(Integer topN, Boolean returnDocuments, String modelId, OpenShiftAiRerankRequest request) + throws IOException { + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap.get(INPUT), is(List.of(INPUT))); + assertThat(requestMap.get(QUERY), is(QUERY)); + int itemsCount = 2; + if (topN != null) { + assertThat(requestMap.get("top_n"), is(topN)); + itemsCount++; + } + if (returnDocuments != null) { + assertThat(requestMap.get("return_documents"), is(returnDocuments)); + itemsCount++; + } + if (modelId != null) { + assertThat(requestMap.get("model"), is(modelId)); + itemsCount++; + } + assertThat(requestMap, aMapWithSize(itemsCount)); + } + + private static OpenShiftAiRerankRequest createRequest( + @Nullable Integer topN, + @Nullable Boolean returnDocuments, + @Nullable String modelId + ) { + var rerankModel = OpenShiftAiRerankModelTests.createModel(randomAlphaOfLength(10), "secret", modelId, topN, returnDocuments); + return new OpenShiftAiRerankRequest(QUERY, List.of(INPUT), returnDocuments, topN, rerankModel); + } + + private static OpenShiftAiRerankRequest createRequestWithDifferentTaskSettings( + @Nullable Integer topN, + @Nullable Boolean returnDocuments + ) { + var rerankModel = OpenShiftAiRerankModelTests.createModel(randomAlphaOfLength(10), "secret", MODEL_ID, 1, true); + return new OpenShiftAiRerankRequest(QUERY, List.of(INPUT), returnDocuments, topN, rerankModel); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java index 73e5156dea270..c2f692b44e7c7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java @@ -16,12 +16,22 @@ public class OpenShiftAiRerankModelTests extends ESTestCase { public static OpenShiftAiRerankModel createModel(String url, String apiKey, @Nullable String modelId) { + return createModel(url, apiKey, modelId, 2, true); + } + + public static OpenShiftAiRerankModel createModel( + String url, + String apiKey, + @Nullable String modelId, + @Nullable Integer topN, + @Nullable Boolean doReturnDocuments + ) { return new OpenShiftAiRerankModel( "inferenceEntityId", TaskType.RERANK, "service", new OpenShiftAiRerankServiceSettings(modelId, url, null), - new OpenShiftAIRerankTaskSettings(2, true), + new OpenShiftAiRerankTaskSettings(topN, doReturnDocuments), new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } From ce2cf92910cceffd2aafc5a28d082d0d79114107 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Mon, 20 Oct 2025 19:13:50 +0300 Subject: [PATCH 16/70] Add unit tests for OpenShiftAiRerankServiceSettings and OpenShiftAiRerankTaskSettings --- ...OpenShiftAiRerankServiceSettingsTests.java | 84 ++++++++++++ .../OpenShiftAiRerankTaskSettingsTests.java | 123 ++++++++++++++++++ 2 files changed, 207 insertions(+) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettingsTests.java new file mode 100644 index 0000000000000..c3a5ae587d67e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettingsTests.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class OpenShiftAiRerankServiceSettingsTests extends AbstractBWCWireSerializationTestCase { + + private static OpenShiftAiRerankServiceSettings createRandom() { + var modelId = randomAlphaOfLengthOrNull(8); + var url = randomAlphaOfLength(15); + return new OpenShiftAiRerankServiceSettings(modelId, ServiceUtils.createUri(url), RateLimitSettingsTests.createRandom()); + } + + public void testToXContent_WritesAllValues() throws IOException { + var url = "http://www.abc.com"; + var model = "model"; + + var serviceSettings = new OpenShiftAiRerankServiceSettings(model, url, null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model_id":"model", + "url":"http://www.abc.com", + "rate_limit": { + "requests_per_minute": 3000 + } + } + """)); + } + + @Override + protected Writeable.Reader instanceReader() { + return OpenShiftAiRerankServiceSettings::new; + } + + @Override + protected OpenShiftAiRerankServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected OpenShiftAiRerankServiceSettings mutateInstance(OpenShiftAiRerankServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, OpenShiftAiRerankServiceSettingsTests::createRandom); + } + + @Override + protected OpenShiftAiRerankServiceSettings mutateInstanceForVersion( + OpenShiftAiRerankServiceSettings instance, + TransportVersion version + ) { + return instance; + } + + public static Map getServiceSettingsMap(@Nullable String url, @Nullable String model) { + return new HashMap<>(OpenShiftAiChatCompletionServiceSettingsTests.getServiceSettingsMap(url, model)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java new file mode 100644 index 0000000000000..f4b40580fc936 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; + +public class OpenShiftAiRerankTaskSettingsTests extends AbstractBWCWireSerializationTestCase { + public static OpenShiftAiRerankTaskSettings createRandom() { + var returnDocuments = randomOptionalBoolean(); + var topNDocsOnly = randomBoolean() ? randomIntBetween(1, 10) : null; + + return new OpenShiftAiRerankTaskSettings(topNDocsOnly, returnDocuments); + } + + public void testFromMap_WithValidValues_ReturnsSettings() { + Map taskMap = Map.of( + OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, + true, + OpenShiftAiRerankTaskSettings.TOP_N, + 5 + ); + var settings = OpenShiftAiRerankTaskSettings.fromMap(new HashMap<>(taskMap)); + assertTrue(settings.getReturnDocuments()); + assertEquals(5, settings.getTopN().intValue()); + } + + public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() { + var settings = OpenShiftAiRerankTaskSettings.fromMap(Map.of()); + assertNull(settings.getReturnDocuments()); + assertNull(settings.getTopN()); + } + + public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() { + Map taskMap = Map.of( + OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, + "invalid", + OpenShiftAiRerankTaskSettings.TOP_N, + 5 + ); + var thrownException = expectThrows(ValidationException.class, () -> OpenShiftAiRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [return_documents] is not of the expected type")); + } + + public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() { + Map taskMap = Map.of( + OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, + true, + OpenShiftAiRerankTaskSettings.TOP_N, + "invalid" + ); + var thrownException = expectThrows(ValidationException.class, () -> OpenShiftAiRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [top_n] is not of the expected type")); + } + + public void UpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { + var initialSettings = new OpenShiftAiRerankTaskSettings(5, true); + OpenShiftAiRerankTaskSettings updatedSettings = (OpenShiftAiRerankTaskSettings) initialSettings.updatedTaskSettings(Map.of()); + assertEquals(initialSettings, updatedSettings); + } + + public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() { + var initialSettings = new OpenShiftAiRerankTaskSettings(5, true); + Map newSettings = Map.of(OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, false); + OpenShiftAiRerankTaskSettings updatedSettings = (OpenShiftAiRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertFalse(updatedSettings.getReturnDocuments()); + assertEquals(initialSettings.getTopN(), updatedSettings.getTopN()); + } + + public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() { + var initialSettings = new OpenShiftAiRerankTaskSettings(5, true); + Map newSettings = Map.of(OpenShiftAiRerankTaskSettings.TOP_N, 7); + OpenShiftAiRerankTaskSettings updatedSettings = (OpenShiftAiRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertEquals(7, updatedSettings.getTopN().intValue()); + assertEquals(initialSettings.getReturnDocuments(), updatedSettings.getReturnDocuments()); + } + + public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() { + var initialSettings = new OpenShiftAiRerankTaskSettings(5, true); + Map newSettings = Map.of( + OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, + false, + OpenShiftAiRerankTaskSettings.TOP_N, + 7 + ); + OpenShiftAiRerankTaskSettings updatedSettings = (OpenShiftAiRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertFalse(updatedSettings.getReturnDocuments()); + assertEquals(7, updatedSettings.getTopN().intValue()); + } + + @Override + protected Writeable.Reader instanceReader() { + return OpenShiftAiRerankTaskSettings::new; + } + + @Override + protected OpenShiftAiRerankTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected OpenShiftAiRerankTaskSettings mutateInstance(OpenShiftAiRerankTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, OpenShiftAiRerankTaskSettingsTests::createRandom); + } + + @Override + protected OpenShiftAiRerankTaskSettings mutateInstanceForVersion(OpenShiftAiRerankTaskSettings instance, TransportVersion version) { + return instance; + } +} From 6c6dfe5745f014fc33079e48f51f6d830d6e0a60 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Mon, 20 Oct 2025 16:38:28 +0000 Subject: [PATCH 17/70] [CI] Auto commit changes from spotless --- .../rerank/OpenShiftAiRerankTaskSettingsTests.java | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java index f4b40580fc936..991b2ad0ddd65 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java @@ -27,12 +27,7 @@ public static OpenShiftAiRerankTaskSettings createRandom() { } public void testFromMap_WithValidValues_ReturnsSettings() { - Map taskMap = Map.of( - OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, - true, - OpenShiftAiRerankTaskSettings.TOP_N, - 5 - ); + Map taskMap = Map.of(OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, true, OpenShiftAiRerankTaskSettings.TOP_N, 5); var settings = OpenShiftAiRerankTaskSettings.fromMap(new HashMap<>(taskMap)); assertTrue(settings.getReturnDocuments()); assertEquals(5, settings.getTopN().intValue()); From 52d439fe231b08c8b01901cf466f88b941acbb83 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Tue, 21 Oct 2025 00:11:09 +0300 Subject: [PATCH 18/70] Add unit tests for OpenShiftAiRerankServiceSettings and OpenShiftAiRerankTaskSettings --- .../openshiftai/OpenShiftAiService.java | 6 +- .../openshiftai/OpenShiftAiServiceTests.java | 862 ++++++++++++++++++ .../OpenShiftAiRerankTaskSettingsTests.java | 2 +- 3 files changed, 865 insertions(+), 5 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java index 1b52afe63403e..a1edd4d6a9940 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java @@ -389,11 +389,9 @@ public static InferenceServiceConfiguration get() { configurationMap.put( URL, - new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription( - "The absolute URL of the external service to send requests to." - ) + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.") .setLabel("URL") - .setRequired(false) + .setRequired(true) .setSensitive(false) .setUpdatable(false) .setType(SettingsConfigurationFieldType.STRING) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java new file mode 100644 index 0000000000000..9f3d0f09701b3 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -0,0 +1,862 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openshiftai; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionTestUtils; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.RerankingInferenceService; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModel; +import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.hamcrest.CoreMatchers; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION; +import static org.elasticsearch.inference.TaskType.COMPLETION; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettingsTests.buildServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModelTests.createChatCompletionModel; +import static org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionServiceSettingsTests.getServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.mock; + +public class OpenShiftAiServiceTests extends AbstractInferenceServiceTests { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + public OpenShiftAiServiceTests() { + super(createTestConfiguration()); + } + + public static TestConfiguration createTestConfiguration() { + return new TestConfiguration.Builder( + new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, EnumSet.of(TEXT_EMBEDDING, COMPLETION, CHAT_COMPLETION)) { + + @Override + protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + return OpenShiftAiServiceTests.createService(threadPool, clientManager); + } + + @Override + protected Map createServiceSettingsMap(TaskType taskType) { + return OpenShiftAiServiceTests.createServiceSettingsMap(taskType); + } + + @Override + protected Map createTaskSettingsMap() { + return new HashMap<>(); + } + + @Override + protected Map createSecretSettingsMap() { + return OpenShiftAiServiceTests.createSecretSettingsMap(); + } + + @Override + protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + OpenShiftAiServiceTests.assertModel(model, taskType, modelIncludesSecrets); + } + + @Override + protected EnumSet supportedStreamingTasks() { + return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); + } + } + ).enableUpdateModelTests(new UpdateModelConfiguration() { + @Override + protected OpenShiftAiEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { + return createInternalEmbeddingModel(similarityMeasure); + } + }).build(); + } + + private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + switch (taskType) { + case TEXT_EMBEDDING -> assertTextEmbeddingModel(model, modelIncludesSecrets); + case COMPLETION -> assertCompletionModel(model, modelIncludesSecrets); + case CHAT_COMPLETION -> assertChatCompletionModel(model, modelIncludesSecrets); + default -> fail("unexpected task type [" + taskType + "]"); + } + } + + private static void assertTextEmbeddingModel(Model model, boolean modelIncludesSecrets) { + var openShiftAiModel = assertCommonModelFields(model, modelIncludesSecrets); + + assertThat(openShiftAiModel.getTaskType(), Matchers.is(TaskType.TEXT_EMBEDDING)); + } + + private static OpenShiftAiModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) { + assertThat(model, instanceOf(OpenShiftAiModel.class)); + + var openShiftAiModel = (OpenShiftAiModel) model; + assertThat(openShiftAiModel.getServiceSettings().modelId(), is("model_id")); + assertThat(openShiftAiModel.getServiceSettings().uri.toString(), Matchers.is("http://www.abc.com")); + assertThat(openShiftAiModel.getTaskSettings(), Matchers.is(EmptyTaskSettings.INSTANCE)); + + if (modelIncludesSecrets) { + assertThat(openShiftAiModel.getSecretSettings().apiKey(), Matchers.is(new SecureString("secret".toCharArray()))); + } + + return openShiftAiModel; + } + + private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) { + var openShiftAiModel = assertCommonModelFields(model, modelIncludesSecrets); + assertThat(openShiftAiModel.getTaskType(), Matchers.is(TaskType.COMPLETION)); + } + + private static void assertChatCompletionModel(Model model, boolean modelIncludesSecrets) { + var openShiftAiModel = assertCommonModelFields(model, modelIncludesSecrets); + assertThat(openShiftAiModel.getTaskType(), Matchers.is(TaskType.CHAT_COMPLETION)); + } + + public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + return new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); + } + + private static Map createServiceSettingsMap(TaskType taskType) { + Map settingsMap = new HashMap<>( + Map.of(ServiceFields.URL, "http://www.abc.com", ServiceFields.MODEL_ID, "model_id") + ); + + if (taskType == TaskType.TEXT_EMBEDDING) { + settingsMap.putAll( + Map.of( + ServiceFields.SIMILARITY, + SimilarityMeasure.COSINE.toString(), + ServiceFields.DIMENSIONS, + 1536, + ServiceFields.MAX_INPUT_TOKENS, + 512 + ) + ); + } + + return settingsMap; + } + + private static Map createSecretSettingsMap() { + return new HashMap<>(Map.of("api_key", "secret")); + } + + private static OpenShiftAiEmbeddingsModel createInternalEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure) { + var inferenceId = "inference_id"; + + return new OpenShiftAiEmbeddingsModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + OpenShiftAiService.NAME, + new OpenShiftAiEmbeddingsServiceSettings( + "model_id", + "http://www.abc.com", + 1536, + similarityMeasure, + 512, + new RateLimitSettings(10_000), + true + ), + ChunkingSettingsTests.createRandomChunkingSettings(), + new DefaultSecretSettings(new SecureString("secret".toCharArray())) + ); + } + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityExecutors()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + try (var service = createService()) { + ActionListener modelVerificationActionListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(OpenShiftAiEmbeddingsModel.class)); + + var embeddingsModel = (OpenShiftAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("parse request should not fail " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getServiceSettingsMap("model", "url"), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationActionListener + ); + } + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + try (var service = createService()) { + ActionListener modelVerificationActionListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(OpenShiftAiEmbeddingsModel.class)); + + var embeddingsModel = (OpenShiftAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("parse request should not fail " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getServiceSettingsMap("model", "url"), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationActionListener + ); + } + } + + public void testParseRequestConfig_ThrowsException_WithoutModelId() throws IOException { + var url = "url"; + var secret = "secret"; + + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(m -> { + assertThat(m, instanceOf(OpenShiftAiChatCompletionModel.class)); + + var chatCompletionModel = (OpenShiftAiChatCompletionModel) m; + + assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is(url)); + assertNull(chatCompletionModel.getServiceSettings().modelId()); + assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is("secret")); + + }, exception -> { + assertThat(exception, instanceOf(ValidationException.class)); + assertThat( + exception.getMessage(), + is("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];") + ); + }); + + service.parseRequestConfig( + "id", + TaskType.CHAT_COMPLETION, + getRequestConfigMap(getServiceSettingsMap(null, url), getSecretSettingsMap(secret)), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_ThrowsException_WithoutUrl() throws IOException { + var model = "model"; + var secret = "secret"; + + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(m -> { + assertThat(m, instanceOf(OpenShiftAiChatCompletionModel.class)); + + var chatCompletionModel = (OpenShiftAiChatCompletionModel) m; + + assertThat(chatCompletionModel.getServiceSettings().modelId(), is(model)); + assertNull(chatCompletionModel.getServiceSettings().modelId()); + assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is("secret")); + + }, exception -> { + assertThat(exception, instanceOf(ValidationException.class)); + assertThat( + exception.getMessage(), + is("Validation Failed: 1: [service_settings] does not contain the required setting [url];") + ); + }); + + service.parseRequestConfig( + "id", + TaskType.CHAT_COMPLETION, + getRequestConfigMap(getServiceSettingsMap(model, null), getSecretSettingsMap(secret)), + modelVerificationListener + ); + } + } + + public void testUnifiedCompletionInfer() throws Exception { + // The escapes are because the streaming response must be on a single line + String responseJson = """ + data: {\ + "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26",\ + "choices": [{\ + "delta": {\ + "content": "Deep",\ + "function_call": null,\ + "refusal": null,\ + "role": "assistant",\ + "tool_calls": null\ + },\ + "finish_reason": null,\ + "index": 0,\ + "logprobs": null\ + }\ + ],\ + "created": 1750158492,\ + "model": "llama3.2:3b",\ + "object": "chat.completion.chunk",\ + "service_tier": null,\ + "system_fingerprint": "fp_ollama",\ + "usage": null\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = createChatCompletionModel(getUrl(webServer), "secret", "model"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(XContentHelper.stripWhitespace(""" + { + "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26", + "choices": [{ + "delta": { + "content": "Deep", + "role": "assistant" + }, + "index": 0 + } + ], + "model": "llama3.2:3b", + "object": "chat.completion.chunk" + } + """)); + } + } + + public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception { + String responseJson = """ + { + "detail": "Not Found" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); + var latch = new CountDownLatch(1); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> { + try (var builder = XContentFactory.jsonBuilder()) { + var t = unwrapCause(e); + assertThat(t, isA(UnifiedChatCompletionException.class)); + ((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + assertThat(json, is(String.format(Locale.ROOT, XContentHelper.stripWhitespace(""" + { + "error" : { + "code" : "not_found", + "message" : "Resource not found at [%s] for request from inference entity id [inferenceEntityId] status \ + [404]. Error message: [{\\n \\"detail\\": \\"Not Found\\"\\n}\\n]", + "type" : "openshift_ai_error" + } + }"""), getUrl(webServer)))); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }), latch::countDown) + ); + assertTrue(latch.await(30, TimeUnit.SECONDS)); + } + } + + public void testMidStreamUnifiedCompletionError() throws Exception { + String responseJson = """ + data: {"error": {"message": "400: Invalid value: Model 'llama3.12:3b' not found"}} + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(XContentHelper.stripWhitespace(""" + { + "error": { + "message": "Received an error response for request from inference entity id [inferenceEntityId].\ + Error message: [{\\"error\\": {\\"message\\": \\"400: Invalid value: Model 'llama3.12:3b' not found\\"}}]", + "type": "openshift_ai_error" + } + } + """)); + } + + public void testInfer_StreamRequest() throws Exception { + String responseJson = """ + data: {\ + "id": "chatcmpl-2c57e3888b1a4e80a0c708889546288e",\ + "object": "chat.completion.chunk",\ + "created": 1760082951,\ + "model": "llama-31-8b-instruct",\ + "choices": [{\ + "index": 0,\ + "delta": {\ + "role": "assistant",\ + "content": "Deep"\ + },\ + "logprobs": null,\ + "finish_reason": null\ + }\ + ]\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + streamCompletion().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"Deep"}]}"""); + } + + private void testStreamError(String expectedResponse) throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> { + e = unwrapCause(e); + assertThat(e, isA(UnifiedChatCompletionException.class)); + try (var builder = XContentFactory.jsonBuilder()) { + ((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + + assertThat(json, is(expectedResponse)); + } + }); + } + } + + public void testInfer_StreamRequest_ErrorResponse() { + String responseJson = """ + { + "detail": "Not Found" + }"""; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + + var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion); + assertThat(e.status(), equalTo(RestStatus.NOT_FOUND)); + assertThat(e.getMessage(), equalTo(String.format(Locale.ROOT, """ + Resource not found at [%s] for request from inference entity id [inferenceEntityId] status [404]. Error message: [{ + "detail": "Not Found" + }]""", getUrl(webServer)))); + } + + public void testInfer_StreamRequestRetry() throws Exception { + webServer.enqueue(new MockResponse().setResponseCode(503).setBody(""" + { + "error": { + "message": "server busy" + } + }""")); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(""" + data: {\ + "id": "chatcmpl-2c57e3888b1a4e80a0c708889546288e",\ + "object": "chat.completion.chunk",\ + "created": 1760082951,\ + "model": "llama-31-8b-instruct",\ + "choices": [{\ + "index": 0,\ + "delta": {\ + "role": "assistant",\ + "content": "Deep"\ + },\ + "logprobs": null,\ + "finish_reason": null\ + }\ + ]\ + } + + """)); + + streamCompletion().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"Deep"}]}"""); + } + + public void testSupportsStreaming() throws IOException { + try (var service = new OpenShiftAiService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException { + try (var service = createService()) { + var secretSettings = getSecretSettingsMap("secret"); + secretSettings.put("extra_key", "value"); + + var config = getRequestConfigMap(getEmbeddingsServiceSettingsMap(), secretSettings); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Configuration contains settings [{extra_key=value}] unknown to the [openshift_ai] service") + ); + } + ); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); + } + } + + public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { + var model = OpenShiftAiEmbeddingsModelTests.createModel(getUrl(webServer), "api_key", "model"); + + testChunkedInfer(model); + } + + public void testChunkedInfer_ChunkingSettingsSet() throws IOException { + var model = OpenShiftAiEmbeddingsModelTests.createModel(getUrl(webServer), "api_key", "model"); + + testChunkedInfer(model); + } + + public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + + String responseJson = """ + { + "id": "embd-45e6d99b97a645c0af96653598069cd9", + "object": "list", + "created": 1760085467, + "model": "gritlm-7b", + "data": [ + { + "index": 0, + "object": "embedding", + "embedding": [ + 0.0089111328125, + -0.007049560546875 + ] + }, + { + "index": 1, + "object": "embedding", + "embedding": [ + -0.008544921875, + -0.0230712890625 + ] + } + ], + "usage": { + "prompt_tokens": 7, + "total_tokens": 7, + "completion_tokens": 0, + "prompt_tokens_details": null + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(new ChunkInferenceInput("abc"), new ChunkInferenceInput("def")), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + + assertThat(results, hasSize(2)); + { + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.get(0); + assertThat(floatResult.chunks(), hasSize(1)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + assertTrue( + Arrays.equals( + new float[] { 0.0089111328125f, -0.007049560546875f }, + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ) + ); + } + { + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.get(1); + assertThat(floatResult.chunks(), hasSize(1)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + assertTrue( + Arrays.equals( + new float[] { -0.008544921875f, -0.0230712890625f }, + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ) + ); + } + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer api_key")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), Matchers.is(2)); + assertThat(requestMap.get("input"), Matchers.is(List.of("abc", "def"))); + assertThat(requestMap.get("model"), Matchers.is("model")); + } + } + + public void testGetConfiguration() throws Exception { + try (var service = createService()) { + String content = XContentHelper.stripWhitespace(""" + { + "service": "openshift_ai", + "name": "OpenShift AI", + "task_types": ["text_embedding", "rerank", "completion", "chat_completion"], + "configurations": { + "api_key": { + "description": "API Key for the provider you're connecting to.", + "label": "API Key", + "required": true, + "sensitive": true, + "updatable": true, + "type": "str", + "supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"] + }, + "model_id": { + "description": "The name of the model to use for the inference task.", + "label": "Model ID", + "required": false, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"] + }, + "rate_limit.requests_per_minute": { + "description": "Minimize the number of rate limit errors.", + "label": "Rate Limit", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"] + }, + "url": { + "description": "The URL endpoint to use for the requests.", + "label": "URL", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"] + } + } + } + """); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + boolean humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = service.getConfiguration(); + assertToXContentEquivalent( + originalBytes, + toXContent(serviceConfiguration, XContentType.JSON, humanReadable), + XContentType.JSON + ); + } + } + + private InferenceEventsAssertion streamCompletion() throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = OpenShiftAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("abc"), + true, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); + } + } + + private OpenShiftAiService createService() { + return new OpenShiftAiService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + + private Map getRequestConfigMap(Map serviceSettings, Map secretSettings) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings)); + } + + private static Map getEmbeddingsServiceSettingsMap() { + return buildServiceSettingsMap("id", "url", SimilarityMeasure.COSINE.toString(), null, null, null); + } + + @Override + public InferenceService createInferenceService() { + return createService(); + } + + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat(rerankingInferenceService.rerankerWindowSize("any model"), is(2800)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java index 991b2ad0ddd65..6111a71d503c2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java @@ -61,7 +61,7 @@ public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() { assertThat(thrownException.getMessage(), containsString("field [top_n] is not of the expected type")); } - public void UpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { + public void testUpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { var initialSettings = new OpenShiftAiRerankTaskSettings(5, true); OpenShiftAiRerankTaskSettings updatedSettings = (OpenShiftAiRerankTaskSettings) initialSettings.updatedTaskSettings(Map.of()); assertEquals(initialSettings, updatedSettings); From d63f84fde63f8308ace00a88b8b56c813794bbb4 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Tue, 21 Oct 2025 11:09:10 +0300 Subject: [PATCH 19/70] Refactor tests in OpenShiftAIRerankRequestEntityTests and OpenShiftAiServiceTests for improved readability and accuracy --- .../services/openshiftai/OpenShiftAiServiceTests.java | 10 ++-------- .../rarank/OpenShiftAIRerankRequestEntityTests.java | 10 +++++----- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index 9f3d0f09701b3..0adf54a8d38f4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -305,7 +305,7 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsN } } - public void testParseRequestConfig_ThrowsException_WithoutModelId() throws IOException { + public void testParseRequestConfig_Success_WithoutModelId() throws IOException { var url = "url"; var secret = "secret"; @@ -319,13 +319,7 @@ public void testParseRequestConfig_ThrowsException_WithoutModelId() throws IOExc assertNull(chatCompletionModel.getServiceSettings().modelId()); assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is("secret")); - }, exception -> { - assertThat(exception, instanceOf(ValidationException.class)); - assertThat( - exception.getMessage(), - is("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];") - ); - }); + }, e -> fail("parse request should not fail " + e.getMessage())); service.parseRequestConfig( "id", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java index 4ec7d391a08f1..35637228a2776 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java @@ -7,9 +7,8 @@ package org.elasticsearch.xpack.inference.services.openshiftai.request.rarank; -import junit.framework.TestCase; - import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; @@ -19,8 +18,9 @@ import java.util.List; import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace; +import static org.hamcrest.Matchers.is; -public class OpenShiftAIRerankRequestEntityTests extends TestCase { +public class OpenShiftAIRerankRequestEntityTests extends ESTestCase { private static final String INPUT = "documents"; private static final String QUERY = "query"; private static final String MODEL = "model"; @@ -42,7 +42,7 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException "return_documents": true } """; - assertEquals(stripWhitespace(expected), result); + assertThat(stripWhitespace(expected), is(result)); } public void testXContent_WritesMinimalFields() throws IOException { @@ -57,7 +57,7 @@ public void testXContent_WritesMinimalFields() throws IOException { "documents": ["documents"] } """; - assertEquals(stripWhitespace(expected), result); + assertThat(stripWhitespace(expected), is(result)); } } From 0a6da54e53b1727d564a849e6bca147465da70cb Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Tue, 21 Oct 2025 11:37:12 +0300 Subject: [PATCH 20/70] Enhance OpenShift AI service with detailed comments and utility class documentation --- .../services/openshiftai/OpenShiftAiService.java | 6 ++++-- .../services/openshiftai/OpenShiftAiUtils.java | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java index a1edd4d6a9940..1d8960e0d480a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java @@ -345,9 +345,11 @@ private OpenShiftAiModel createModelFromPersistent( @Override public int rerankerWindowSize(String modelId) { + // OpenShift AI uses Cohere and JinaAI rerank protocols for reranking + // JinaAI rerank model has 8000 tokens limit length https://jina.ai/models/jina-reranker-v2-base-multilingual // Cohere rerank model truncates at 4096 tokens https://docs.cohere.com/reference/rerank - // Using 1 token = 0.75 words as a rough estimate, we get 3072 words - // allowing for some headroom, we set the window size below 3072 + // We choose a conservative limit based on these two models + // Using 1 token = 0.75 words as a rough estimate, we get 3072 words allowing for some headroom, we set the window size below 3072 return 2800; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiUtils.java index e0227b532e61c..d339c47b52c8b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiUtils.java @@ -9,13 +9,29 @@ import org.elasticsearch.TransportVersion; +/** + * Utility class for OpenShift AI related version checks. + */ public final class OpenShiftAiUtils { + + /** + * TransportVersion indicating when OpenShift AI features were added. + */ public static final TransportVersion ML_INFERENCE_OPENSHIFT_AI_ADDED = TransportVersion.fromName("ml_inference_openshift_ai_added"); + /** + * Checks if the given TransportVersion supports OpenShift AI features. + * + * @param version the TransportVersion to check + * @return true if OpenShift AI features are supported, false otherwise + */ public static boolean supportsOpenShiftAi(TransportVersion version) { return version.supports(ML_INFERENCE_OPENSHIFT_AI_ADDED); } + /** + * Private constructor to prevent instantiation. + */ private OpenShiftAiUtils() {} } From 55bf99d7a3e5ff1fb1cefb7515919fe778bcc2b8 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Thu, 23 Oct 2025 15:48:04 +0300 Subject: [PATCH 21/70] Refactor OpenShiftAiModel --- .../xpack/inference/services/openshiftai/OpenShiftAiModel.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java index 97664e3f5fcea..f49e44d4973ff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java @@ -43,7 +43,7 @@ public RateLimitSettings rateLimitSettings() { @Override public int rateLimitGroupingHash() { - return Objects.hash(getServiceSettings().uri, getServiceSettings().modelId(), getSecretSettings().apiKey()); + return Objects.hash(getServiceSettings().uri(), getServiceSettings().modelId()); } @Override From d73f5dab71f573d6a7624808c663dc56c7a13c8d Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Thu, 23 Oct 2025 18:41:03 +0300 Subject: [PATCH 22/70] Remove unused rateLimitSettings field from OpenShiftAiModel --- .../xpack/inference/services/openshiftai/OpenShiftAiModel.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java index f49e44d4973ff..1a5fb0f803e81 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java @@ -22,7 +22,6 @@ * This class extends RateLimitGroupingModel to handle rate limiting based on modelId and API key. */ public abstract class OpenShiftAiModel extends RateLimitGroupingModel { - protected RateLimitSettings rateLimitSettings; protected OpenShiftAiModel(ModelConfigurations configurations, ModelSecrets secrets) { super(configurations, secrets); From 636d72a367bfcb3bb1f1fe8cb3b4cfa882912b36 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 24 Oct 2025 17:54:31 +0300 Subject: [PATCH 23/70] Fix JavaDoc and update transport version --- .../definitions/referable/ml_inference_openshift_ai_added.csv | 2 +- server/src/main/resources/transport/upper_bounds/9.3.csv | 2 +- .../request/embeddings/OpenShiftAiEmbeddingsRequest.java | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv b/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv index d8997da1b2882..9f1d7fea747d7 100644 --- a/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv +++ b/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv @@ -1 +1 @@ -9201000 +9203000 diff --git a/server/src/main/resources/transport/upper_bounds/9.3.csv b/server/src/main/resources/transport/upper_bounds/9.3.csv index a3e52d8099898..98436d7a6f735 100644 --- a/server/src/main/resources/transport/upper_bounds/9.3.csv +++ b/server/src/main/resources/transport/upper_bounds/9.3.csv @@ -1 +1 @@ -min_transport_version,9202000 +ml_inference_openshift_ai_added,9203000 diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequest.java index d3fe2bc685c3f..b1a31f60cfc98 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequest.java @@ -37,7 +37,7 @@ public class OpenShiftAiEmbeddingsRequest implements Request { * * @param truncator the truncator to handle input truncation * @param input the input to be truncated - * @param model the OpenShiftId embeddings model to be used for the request + * @param model the OpenShift AI embeddings model to be used for the request */ public OpenShiftAiEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, OpenShiftAiEmbeddingsModel model) { this.model = model; From 95e7ef5b664b3df828e11ac0ad1e5033c1ed7cf4 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 24 Oct 2025 18:23:15 +0300 Subject: [PATCH 24/70] Add JavaDoc comments and null checks for OpenShift AI request entity --- .../embeddings/OpenShiftAiEmbeddingsRequestEntity.java | 4 ++++ .../openshiftai/request/rarank/OpenShiftAiRerankRequest.java | 1 + 2 files changed, 5 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntity.java index 3268001d54f1d..7aa045f7fc5ab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntity.java @@ -15,6 +15,10 @@ import java.util.List; import java.util.Objects; +/** + * OpenShiftAiEmbeddingsRequestEntity is responsible for creating the request entity for OpenShift AI embeddings. + * It implements ToXContentObject to allow serialization to XContent format. + */ public record OpenShiftAiEmbeddingsRequestEntity( List input, @Nullable String modelId, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequest.java index 414e72d7dd06c..56ef806982121 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequest.java @@ -44,6 +44,7 @@ public record OpenShiftAiRerankRequest( public OpenShiftAiRerankRequest { Objects.requireNonNull(input); Objects.requireNonNull(query); + Objects.requireNonNull(model); } @Override From 291f40a8cb5e888a1a722d20bfc1b0d35efbf989 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Wed, 29 Oct 2025 16:32:47 +0200 Subject: [PATCH 25/70] Enhance mutation logic in OpenShift AI service settings tests --- ...tAiChatCompletionServiceSettingsTests.java | 13 +++++++- ...ShiftAiEmbeddingsServiceSettingsTests.java | 31 ++++++++++++++++++- ...OpenShiftAiRerankServiceSettingsTests.java | 14 ++++++++- .../OpenShiftAiRerankTaskSettingsTests.java | 10 +++++- 4 files changed, 64 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java index 5109bf69a6281..5df17de6d190e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import java.io.IOException; +import java.net.URI; import java.util.HashMap; import java.util.Map; @@ -168,7 +169,17 @@ protected OpenShiftAiChatCompletionServiceSettings createTestInstance() { @Override protected OpenShiftAiChatCompletionServiceSettings mutateInstance(OpenShiftAiChatCompletionServiceSettings instance) throws IOException { - return randomValueOtherThan(instance, OpenShiftAiChatCompletionServiceSettingsTests::createRandom); + String modelId = instance.modelId(); + URI uri = instance.uri(); + RateLimitSettings rateLimitSettings = instance.rateLimitSettings(); + + switch (between(0, 2)) { + case 0 -> modelId = randomValueOtherThan(modelId, () -> randomAlphaOfLengthOrNull(8)); + case 1 -> uri = randomValueOtherThan(uri, () -> ServiceUtils.createUri(randomAlphaOfLength(15))); + case 2 -> rateLimitSettings = randomValueOtherThan(rateLimitSettings, RateLimitSettingsTests::createRandom); + default -> throw new AssertionError("Illegal randomisation branch"); + } + return new OpenShiftAiChatCompletionServiceSettings(modelId, uri, rateLimitSettings); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java index dde32cc099e43..779d127a707a0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java @@ -16,16 +16,19 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.hamcrest.CoreMatchers; import java.io.IOException; +import java.net.URI; import java.util.HashMap; import java.util.Map; @@ -592,7 +595,33 @@ protected OpenShiftAiEmbeddingsServiceSettings createTestInstance() { @Override protected OpenShiftAiEmbeddingsServiceSettings mutateInstance(OpenShiftAiEmbeddingsServiceSettings instance) throws IOException { - return randomValueOtherThan(instance, OpenShiftAiEmbeddingsServiceSettingsTests::createRandom); + String modelId = instance.modelId(); + URI uri = instance.uri(); + Integer dimensions = instance.dimensions(); + SimilarityMeasure similarity = instance.similarity(); + Integer maxInputTokens = instance.maxInputTokens(); + RateLimitSettings rateLimitSettings = instance.rateLimitSettings(); + Boolean dimensionsSetByUser = instance.dimensionsSetByUser(); + + switch (between(0, 6)) { + case 0 -> modelId = randomValueOtherThan(modelId, () -> randomAlphaOfLengthOrNull(8)); + case 1 -> uri = randomValueOtherThan(uri, () -> ServiceUtils.createUri(randomAlphaOfLength(15))); + case 2 -> dimensions = randomValueOtherThan(dimensions, () -> randomIntBetween(32, 256)); + case 3 -> similarity = randomValueOtherThan(similarity, () -> randomFrom(SimilarityMeasure.values())); + case 4 -> maxInputTokens = randomValueOtherThan(maxInputTokens, () -> randomIntBetween(128, 256)); + case 5 -> rateLimitSettings = randomValueOtherThan(rateLimitSettings, RateLimitSettingsTests::createRandom); + case 6 -> dimensionsSetByUser = randomValueOtherThan(dimensionsSetByUser, ESTestCase::randomBoolean); + default -> throw new AssertionError("Illegal randomisation branch"); + } + return new OpenShiftAiEmbeddingsServiceSettings( + modelId, + uri, + dimensions, + similarity, + maxInputTokens, + rateLimitSettings, + dimensionsSetByUser + ); } private static OpenShiftAiEmbeddingsServiceSettings createRandom() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettingsTests.java index c3a5ae587d67e..c45940c084322 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettingsTests.java @@ -17,9 +17,11 @@ import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import java.io.IOException; +import java.net.URI; import java.util.HashMap; import java.util.Map; @@ -66,7 +68,17 @@ protected OpenShiftAiRerankServiceSettings createTestInstance() { @Override protected OpenShiftAiRerankServiceSettings mutateInstance(OpenShiftAiRerankServiceSettings instance) throws IOException { - return randomValueOtherThan(instance, OpenShiftAiRerankServiceSettingsTests::createRandom); + String modelId = instance.modelId(); + URI uri = instance.uri(); + RateLimitSettings rateLimitSettings = instance.rateLimitSettings(); + + switch (between(0, 2)) { + case 0 -> modelId = randomValueOtherThan(modelId, () -> randomAlphaOfLengthOrNull(8)); + case 1 -> uri = randomValueOtherThan(uri, () -> ServiceUtils.createUri(randomAlphaOfLength(15))); + case 2 -> rateLimitSettings = randomValueOtherThan(rateLimitSettings, RateLimitSettingsTests::createRandom); + default -> throw new AssertionError("Illegal randomisation branch"); + } + return new OpenShiftAiRerankServiceSettings(modelId, uri, rateLimitSettings); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java index 6111a71d503c2..2cff47df2588d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import java.io.IOException; @@ -108,7 +109,14 @@ protected OpenShiftAiRerankTaskSettings createTestInstance() { @Override protected OpenShiftAiRerankTaskSettings mutateInstance(OpenShiftAiRerankTaskSettings instance) throws IOException { - return randomValueOtherThan(instance, OpenShiftAiRerankTaskSettingsTests::createRandom); + Integer topN = instance.getTopN(); + Boolean returnDocuments = instance.getReturnDocuments(); + switch (between(0, 1)) { + case 0 -> topN = randomValueOtherThan(topN, () -> randomBoolean() ? randomIntBetween(1, 10) : null); + case 1 -> returnDocuments = randomValueOtherThan(returnDocuments, ESTestCase::randomOptionalBoolean); + default -> throw new AssertionError("Illegal randomisation branch"); + } + return new OpenShiftAiRerankTaskSettings(topN, returnDocuments); } @Override From e55f9fd191def2c991c924f0882a8f3e66b2a1d3 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Thu, 30 Oct 2025 15:48:53 +0200 Subject: [PATCH 26/70] Update Transport Versions and refactor unit tests --- .../ml_inference_openshift_ai_added.csv | 2 +- .../resources/transport/upper_bounds/9.3.csv | 2 +- ...ShiftAiEmbeddingsServiceSettingsTests.java | 35 ++++++++++--------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv b/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv index a79975c00dea2..44b76d9df26ec 100644 --- a/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv +++ b/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv @@ -1 +1 @@ -9204000 +9206000 diff --git a/server/src/main/resources/transport/upper_bounds/9.3.csv b/server/src/main/resources/transport/upper_bounds/9.3.csv index 1acd7ceede226..9d5e32b68b2a5 100644 --- a/server/src/main/resources/transport/upper_bounds/9.3.csv +++ b/server/src/main/resources/transport/upper_bounds/9.3.csv @@ -1 +1 @@ -text_similarity_rank_docs_explain_chunks,9205000 +ml_inference_openshift_ai_added,9206000 diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java index 779d127a707a0..67dc7d95a5a0e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java @@ -102,7 +102,7 @@ public void testFromMap_NoModelId_Success() { ); } - public void testFromMap_NoUrl_Failure() { + public void testFromMap_NoUrl_ThrowsException() { var thrownException = expectThrows( ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( @@ -124,7 +124,7 @@ public void testFromMap_NoUrl_Failure() { ); } - public void testFromMap_EmptyUrl_Failure() { + public void testFromMap_EmptyUrl_ThrowsException() { var thrownException = expectThrows( ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( @@ -146,7 +146,7 @@ public void testFromMap_EmptyUrl_Failure() { ); } - public void testFromMap_InvalidUrl_Failure() { + public void testFromMap_InvalidUrl_ThrowsException() { var thrownException = expectThrows( ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( @@ -201,7 +201,7 @@ public void testFromMap_NoSimilarity_Success() { ); } - public void testFromMap_InvalidSimilarity_Failure() { + public void testFromMap_InvalidSimilarity_ThrowsException() { var thrownException = expectThrows( ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( @@ -286,7 +286,7 @@ public void testFromMap_Persistent_WithDimensions_SetByUserFalse_Persistent_Succ ); } - public void testFromMap_WithDimensions_SetByUserNull_Persistent_Success() { + public void testFromMap_WithDimensions_SetByUserNull_Persistent_ThrowsException() { var thrownException = expectThrows( ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( @@ -369,7 +369,7 @@ public void testFromMap_WithDimensions_SetByUserNull_Request_Success() { ); } - public void testFromMap_WithDimensions_SetByUserTrue_Request_Failure() { + public void testFromMap_WithDimensions_SetByUserTrue_Request_ThrowsException() { var thrownException = expectThrows( ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( @@ -392,7 +392,7 @@ public void testFromMap_WithDimensions_SetByUserTrue_Request_Failure() { ); } - public void testFromMap_ZeroDimensions_Failure() { + public void testFromMap_ZeroDimensions_ThrowsException() { var thrownException = expectThrows( ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( @@ -414,7 +414,7 @@ public void testFromMap_ZeroDimensions_Failure() { ); } - public void testFromMap_NegativeDimensions_Failure() { + public void testFromMap_NegativeDimensions_ThrowsException() { var thrownException = expectThrows( ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( @@ -466,7 +466,7 @@ public void testFromMap_NoInputTokens_Success() { ); } - public void testFromMap_ZeroInputTokens_Failure() { + public void testFromMap_ZeroInputTokens_ThrowsException() { var thrownException = expectThrows( ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( @@ -488,7 +488,7 @@ public void testFromMap_ZeroInputTokens_Failure() { ); } - public void testFromMap_NegativeInputTokens_Failure() { + public void testFromMap_NegativeInputTokens_ThrowsException() { var thrownException = expectThrows( ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( @@ -606,9 +606,9 @@ protected OpenShiftAiEmbeddingsServiceSettings mutateInstance(OpenShiftAiEmbeddi switch (between(0, 6)) { case 0 -> modelId = randomValueOtherThan(modelId, () -> randomAlphaOfLengthOrNull(8)); case 1 -> uri = randomValueOtherThan(uri, () -> ServiceUtils.createUri(randomAlphaOfLength(15))); - case 2 -> dimensions = randomValueOtherThan(dimensions, () -> randomIntBetween(32, 256)); - case 3 -> similarity = randomValueOtherThan(similarity, () -> randomFrom(SimilarityMeasure.values())); - case 4 -> maxInputTokens = randomValueOtherThan(maxInputTokens, () -> randomIntBetween(128, 256)); + case 2 -> dimensions = randomValueOtherThan(dimensions, () -> randomBoolean() ? randomIntBetween(32, 256) : null); + case 3 -> similarity = randomValueOtherThan(similarity, () -> randomBoolean() ? randomFrom(SimilarityMeasure.values()) : null); + case 4 -> maxInputTokens = randomValueOtherThan(maxInputTokens, () -> randomBoolean() ? randomIntBetween(128, 256) : null); case 5 -> rateLimitSettings = randomValueOtherThan(rateLimitSettings, RateLimitSettingsTests::createRandom); case 6 -> dimensionsSetByUser = randomValueOtherThan(dimensionsSetByUser, ESTestCase::randomBoolean); default -> throw new AssertionError("Illegal randomisation branch"); @@ -627,9 +627,10 @@ protected OpenShiftAiEmbeddingsServiceSettings mutateInstance(OpenShiftAiEmbeddi private static OpenShiftAiEmbeddingsServiceSettings createRandom() { var modelId = randomAlphaOfLength(8); var url = randomAlphaOfLength(15); - var similarityMeasure = randomFrom(SimilarityMeasure.values()); - var dimensions = randomIntBetween(32, 256); - var maxInputTokens = randomIntBetween(128, 256); + var similarityMeasure = randomBoolean() ? randomFrom(SimilarityMeasure.values()) : null; + var dimensions = randomBoolean() ? randomIntBetween(32, 256) : null; + var maxInputTokens = randomBoolean() ? randomIntBetween(128, 256) : null; + boolean dimensionsSetByUser = randomBoolean(); return new OpenShiftAiEmbeddingsServiceSettings( modelId, url, @@ -637,7 +638,7 @@ private static OpenShiftAiEmbeddingsServiceSettings createRandom() { similarityMeasure, maxInputTokens, RateLimitSettingsTests.createRandom(), - randomBoolean() + dimensionsSetByUser ); } From c46a1e160b0c0c456892acc1c2ab9928e133afcc Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 31 Oct 2025 12:20:07 +0200 Subject: [PATCH 27/70] Fix OpenShiftAiServiceTests, update transport version --- .../referable/ml_inference_openshift_ai_added.csv | 2 +- server/src/main/resources/transport/upper_bounds/9.3.csv | 2 +- .../services/openshiftai/OpenShiftAiServiceTests.java | 7 ++++++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv b/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv index 44b76d9df26ec..be5ad4bf02772 100644 --- a/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv +++ b/server/src/main/resources/transport/definitions/referable/ml_inference_openshift_ai_added.csv @@ -1 +1 @@ -9206000 +9210000 diff --git a/server/src/main/resources/transport/upper_bounds/9.3.csv b/server/src/main/resources/transport/upper_bounds/9.3.csv index 49f1bcb0a8eb0..160b644d5a99a 100644 --- a/server/src/main/resources/transport/upper_bounds/9.3.csv +++ b/server/src/main/resources/transport/upper_bounds/9.3.csv @@ -1 +1 @@ -ilm_searchable_snapshot_opt_out_clone,9209000 +ml_inference_openshift_ai_added,9210000 diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index 0adf54a8d38f4..e732dc2c329be 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -80,6 +80,7 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION; import static org.elasticsearch.inference.TaskType.COMPLETION; +import static org.elasticsearch.inference.TaskType.RERANK; import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; @@ -112,7 +113,11 @@ public OpenShiftAiServiceTests() { public static TestConfiguration createTestConfiguration() { return new TestConfiguration.Builder( - new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, EnumSet.of(TEXT_EMBEDDING, COMPLETION, CHAT_COMPLETION)) { + new CommonConfig( + TaskType.TEXT_EMBEDDING, + TaskType.SPARSE_EMBEDDING, + EnumSet.of(TEXT_EMBEDDING, COMPLETION, CHAT_COMPLETION, RERANK) + ) { @Override protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { From 53e8118b28f0a6d32621b6fb9922828604339b6d Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 31 Oct 2025 17:27:44 +0200 Subject: [PATCH 28/70] Enhance OpenShift AI model tests to support chunking settings and improve request parsing --- .../openshiftai/OpenShiftAiServiceTests.java | 52 +++++++++---------- .../OpenShiftAiEmbeddingsModelTests.java | 8 +-- .../OpenShiftAiEmbeddingsRequestTests.java | 3 +- 3 files changed, 32 insertions(+), 31 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index e732dc2c329be..625358f9cae2b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -42,7 +42,6 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; @@ -84,6 +83,7 @@ import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; @@ -243,7 +243,7 @@ private static OpenShiftAiEmbeddingsModel createInternalEmbeddingModel(@Nullable new RateLimitSettings(10_000), true ), - ChunkingSettingsTests.createRandomChunkingSettings(), + createRandomChunkingSettings(), new DefaultSecretSettings(new SecureString("secret".toCharArray())) ); } @@ -300,17 +300,13 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsN service.parseRequestConfig( "id", TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - getServiceSettingsMap("model", "url"), - createRandomChunkingSettingsMap(), - getSecretSettingsMap("secret") - ), + getRequestConfigMap(getServiceSettingsMap("model", "url"), getSecretSettingsMap("secret")), modelVerificationActionListener ); } } - public void testParseRequestConfig_Success_WithoutModelId() throws IOException { + public void testParseRequestConfig_WithoutModelId_Success() throws IOException { var url = "url"; var secret = "secret"; @@ -335,27 +331,21 @@ public void testParseRequestConfig_Success_WithoutModelId() throws IOException { } } - public void testParseRequestConfig_ThrowsException_WithoutUrl() throws IOException { + public void testParseRequestConfig_WithoutUrl_ThrowsException() throws IOException { var model = "model"; var secret = "secret"; try (var service = createService()) { - ActionListener modelVerificationListener = ActionListener.wrap(m -> { - assertThat(m, instanceOf(OpenShiftAiChatCompletionModel.class)); - - var chatCompletionModel = (OpenShiftAiChatCompletionModel) m; - - assertThat(chatCompletionModel.getServiceSettings().modelId(), is(model)); - assertNull(chatCompletionModel.getServiceSettings().modelId()); - assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is("secret")); - - }, exception -> { - assertThat(exception, instanceOf(ValidationException.class)); - assertThat( - exception.getMessage(), - is("Validation Failed: 1: [service_settings] does not contain the required setting [url];") - ); - }); + ActionListener modelVerificationListener = ActionListener.wrap( + m -> fail("Expected exception, but got model: " + m), + exception -> { + assertThat(exception, instanceOf(ValidationException.class)); + assertThat( + exception.getMessage(), + is("Validation Failed: 1: [service_settings] does not contain the required setting [url];") + ); + } + ); service.parseRequestConfig( "id", @@ -629,13 +619,21 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSe } public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { - var model = OpenShiftAiEmbeddingsModelTests.createModel(getUrl(webServer), "api_key", "model"); + var model = OpenShiftAiEmbeddingsModelTests.createModel(getUrl(webServer), "api_key", "model", 1234, false, 1536, null); testChunkedInfer(model); } public void testChunkedInfer_ChunkingSettingsSet() throws IOException { - var model = OpenShiftAiEmbeddingsModelTests.createModel(getUrl(webServer), "api_key", "model"); + var model = OpenShiftAiEmbeddingsModelTests.createModel( + getUrl(webServer), + "api_key", + "model", + 1234, + false, + 1536, + createRandomChunkingSettings() + ); testChunkedInfer(model); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModelTests.java index f815f31848ebf..5a78b35c86231 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModelTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -21,7 +22,7 @@ public static OpenShiftAiEmbeddingsModel createModel(String url, String apiKey, } public static OpenShiftAiEmbeddingsModel createModel(String url, String apiKey, @Nullable String modelId, int maxInputTokens) { - return createModel(url, apiKey, modelId, maxInputTokens, false, 1536); + return createModel(url, apiKey, modelId, maxInputTokens, false, 1536, null); } public static OpenShiftAiEmbeddingsModel createModel( @@ -30,7 +31,8 @@ public static OpenShiftAiEmbeddingsModel createModel( @Nullable String modelId, @Nullable Integer maxInputTokens, @Nullable Boolean dimensionsSetByUser, - @Nullable Integer dimensions + @Nullable Integer dimensions, + @Nullable ChunkingSettings chunkingSettings ) { return new OpenShiftAiEmbeddingsModel( "inferenceEntityId", @@ -45,7 +47,7 @@ public static OpenShiftAiEmbeddingsModel createModel( null, dimensionsSetByUser ), - null, + chunkingSettings, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java index 8ec14b2ef0fd5..d01885bde188e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java @@ -109,7 +109,8 @@ private static OpenShiftAiEmbeddingsRequest createRequest(Integer dimensions, Bo modelId, dimensions, dimensionsSetByUser, - dimensions + dimensions, + null ); return new OpenShiftAiEmbeddingsRequest( TruncatorTests.createTruncator(), From d68a636dc991345b4caf05430a4b924e590fb0a6 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Wed, 5 Nov 2025 11:38:12 +0200 Subject: [PATCH 29/70] Fix PR comments --- .../services/openshiftai/OpenShiftAiModel.java | 2 +- .../services/openshiftai/OpenShiftAiService.java | 12 +++++------- .../completion/OpenShiftAiChatCompletionModel.java | 2 +- .../OpenShiftAiChatCompletionServiceSettings.java | 2 +- .../OpenShiftAiCompletionResponseHandler.java | 2 +- .../embeddings/OpenShiftAiEmbeddingsModel.java | 4 ++-- .../rerank/OpenShiftAiRerankServiceSettings.java | 2 +- 7 files changed, 12 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java index 1a5fb0f803e81..be3cb00f2459d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java @@ -18,7 +18,7 @@ import java.util.Objects; /** - * Represents an OpenShift AI modelId that can be used for inference tasks. + * Represents an OpenShift AI model that can be used for inference tasks. * This class extends RateLimitGroupingModel to handle rate limiting based on modelId and API key. */ public abstract class OpenShiftAiModel extends RateLimitGroupingModel { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java index 7d4a82736a9ac..d0dcc61fae6b6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java @@ -57,7 +57,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; @@ -69,16 +68,15 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; /** - * OpenShiftAiService is an implementation of the SenderService that handles inference tasks + * OpenShiftAiService is an implementation of the {@link SenderService} and {@link RerankingInferenceService} that handles inference tasks * using models deployed to OpenShift AI environment. - * The service uses OpenShiftAiActionCreator to create actions for executing inference requests. + * The service uses {@link OpenShiftAiActionCreator} to create actions for executing inference requests. */ public class OpenShiftAiService extends SenderService implements RerankingInferenceService { public static final String NAME = "openshift_ai"; /** - * The optimal batch size depends on the hardware the model is deployed on. - * For OpenShift AI use a conservatively small max batch size as it is - * unknown how the model is deployed + * The optimal batch size depends on the model deployed in OpenShift AI. + * For OpenShift AI use a conservatively small max batch size as it is unknown what model is deployed. */ static final int EMBEDDING_MAX_BATCH_SIZE = 20; private static final String SERVICE_NAME = "OpenShift AI"; @@ -115,7 +113,7 @@ protected void doInfer( ) { var actionCreator = new OpenShiftAiActionCreator(getSender(), getServiceComponents()); - switch (Objects.requireNonNull(model)) { + switch (model) { case OpenShiftAiChatCompletionModel chatCompletionModel -> chatCompletionModel.accept(actionCreator) .execute(inputs, timeout, listener); case OpenShiftAiEmbeddingsModel embeddingsModel -> embeddingsModel.accept(actionCreator).execute(inputs, timeout, listener); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java index d029ab2e1bcad..24537b5c812e1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java @@ -28,7 +28,7 @@ public class OpenShiftAiChatCompletionModel extends OpenShiftAiModel { /** - * Constructor for creating a OpenShiftAiChatCompletionModel with specified parameters. + * Constructor for creating an OpenShiftAiChatCompletionModel with specified parameters. * @param inferenceEntityId the unique identifier for the inference entity * @param taskType the type of task this model is designed for * @param service the name of the inference service diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java index 85f8d775d8aba..2c49bc7bfc161 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java @@ -28,7 +28,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; /** - * Represents the settings for a OpenShift AI chat completion service. + * Represents the settings for an OpenShift AI chat completion service. * This class encapsulates the model ID, URI, and rate limit settings for the OpenShift AI chat completion service. */ public class OpenShiftAiChatCompletionServiceSettings extends OpenShiftAiServiceSettings { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiCompletionResponseHandler.java index 65522fc4495bf..2d32d146aec43 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiCompletionResponseHandler.java @@ -17,7 +17,7 @@ public class OpenShiftAiCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { /** - * Constructs a OpenShiftAiCompletionResponseHandler with the specified request type and response parser. + * Constructs an OpenShiftAiCompletionResponseHandler with the specified request type and response parser. * * @param requestType The type of request being handled (e.g., "Openshift AI completions"). * @param parseFunction The function to parse the response. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java index 6311bd82a068e..0bc8fdbda6b8d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java @@ -22,7 +22,7 @@ import java.util.Map; /** - * Represents a OpenShift AI embeddings model for inference. + * Represents an OpenShift AI embeddings model for inference. * This class extends the OpenShiftAiModel and provides specific configurations and settings for embeddings tasks. */ public class OpenShiftAiEmbeddingsModel extends OpenShiftAiModel { @@ -51,7 +51,7 @@ public OpenShiftAiEmbeddingsModel(OpenShiftAiEmbeddingsModel model, OpenShiftAiE } /** - * Constructor for creating a OpenShiftAiEmbeddingsModel with specified parameters. + * Constructor for creating an OpenShiftAiEmbeddingsModel with specified parameters. * * @param inferenceEntityId the unique identifier for the inference entity * @param taskType the type of task this model is designed for diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java index bd58a67ab299c..6e64e9f8babf6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java @@ -28,7 +28,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; /** - * Represents the settings for a OpenShift AI chat rerank service. + * Represents the settings for an OpenShift AI chat rerank service. * This class encapsulates the model ID, URI, and rate limit settings for the OpenShift AI chat rerank service. */ public class OpenShiftAiRerankServiceSettings extends OpenShiftAiServiceSettings { From 7f404590dbf26e3365292340e91e44f751ed903d Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Wed, 5 Nov 2025 14:23:50 +0200 Subject: [PATCH 30/70] Refactor accept methods in OpenShift AI models to include task settings parameter --- .../services/openshiftai/OpenShiftAiModel.java | 11 +++++++++++ .../services/openshiftai/OpenShiftAiService.java | 16 +++++++--------- .../OpenShiftAiChatCompletionModel.java | 10 +++------- .../embeddings/OpenShiftAiEmbeddingsModel.java | 10 +++------- .../rerank/OpenShiftAiRerankModel.java | 7 +------ 5 files changed, 25 insertions(+), 29 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java index be3cb00f2459d..80f5ea0bdcd6b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiModel.java @@ -11,10 +11,13 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.openshiftai.action.OpenShiftAiActionVisitor; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import java.util.Map; import java.util.Objects; /** @@ -55,4 +58,12 @@ public DefaultSecretSettings getSecretSettings() { return (DefaultSecretSettings) super.getSecretSettings(); } + /** + * Accepts a visitor to create an executable action for this OpenShift AI model. + * + * @param creator the visitor that creates the executable action + * @param taskSettings the task settings to be used for the executable action + * @return an {@link ExecutableAction} specific to this OpenShift AI model + */ + public abstract ExecutableAction accept(OpenShiftAiActionVisitor creator, Map taskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java index d0dcc61fae6b6..72574ecb180ff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java @@ -111,15 +111,13 @@ protected void doInfer( TimeValue timeout, ActionListener listener ) { - var actionCreator = new OpenShiftAiActionCreator(getSender(), getServiceComponents()); - - switch (model) { - case OpenShiftAiChatCompletionModel chatCompletionModel -> chatCompletionModel.accept(actionCreator) - .execute(inputs, timeout, listener); - case OpenShiftAiEmbeddingsModel embeddingsModel -> embeddingsModel.accept(actionCreator).execute(inputs, timeout, listener); - case OpenShiftAiRerankModel rerankModel -> rerankModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener); - default -> listener.onFailure(createInvalidModelException(model)); + if (model instanceof OpenShiftAiModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; } + var openShiftAiModel = (OpenShiftAiModel) model; + var actionCreator = new OpenShiftAiActionCreator(getSender(), getServiceComponents()); + openShiftAiModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener); } @Override @@ -176,7 +174,7 @@ protected void doChunkedInfer( ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { - var action = openShiftAiEmbeddingsModel.accept(actionCreator); + var action = openShiftAiEmbeddingsModel.accept(actionCreator, taskSettings); action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java index 24537b5c812e1..5af38dd653a87 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java @@ -114,13 +114,9 @@ public OpenShiftAiChatCompletionServiceSettings getServiceSettings() { return (OpenShiftAiChatCompletionServiceSettings) super.getServiceSettings(); } - /** - * Accepts a visitor that creates an executable action for this OpenShift AI chat completion. - * - * @param creator the visitor that creates the executable action - * @return an ExecutableAction representing this model - */ - public ExecutableAction accept(OpenShiftAiActionVisitor creator) { + @Override + public ExecutableAction accept(OpenShiftAiActionVisitor creator, Map taskSettings) { + // Chat completion models do not have task settings, so we ignore the taskSettings parameter. return creator.create(this); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java index 0bc8fdbda6b8d..74ba91c9424d0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java @@ -79,13 +79,9 @@ public OpenShiftAiEmbeddingsServiceSettings getServiceSettings() { return (OpenShiftAiEmbeddingsServiceSettings) super.getServiceSettings(); } - /** - * Accepts a visitor to create an executable action for this OpenShift AI embeddings model. - * - * @param creator the visitor that creates the executable action - * @return an ExecutableAction representing the OpenShift AI embeddings model - */ - public ExecutableAction accept(OpenShiftAiActionVisitor creator) { + @Override + public ExecutableAction accept(OpenShiftAiActionVisitor creator, Map taskSettings) { + // Embeddings models do not have task settings, so we ignore the taskSettings parameter. return creator.create(this); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModel.java index f5195eee72d1a..5fb9c05042f89 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModel.java @@ -73,12 +73,7 @@ public OpenShiftAiRerankTaskSettings getTaskSettings() { return (OpenShiftAiRerankTaskSettings) super.getTaskSettings(); } - /** - * Accepts a visitor to create an executable action. The returned action will not return documents in the response. - * @param visitor Interface for creating {@link ExecutableAction} instances for IBM watsonx models. - * @param taskSettings Settings in the request to override the model's defaults - * @return the rerank action - */ + @Override public ExecutableAction accept(OpenShiftAiActionVisitor visitor, Map taskSettings) { return visitor.create(this, taskSettings); } From ffd491d214c3b8b7ec45b7b4ac6f5c293877b12f Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Wed, 5 Nov 2025 17:47:38 +0200 Subject: [PATCH 31/70] Refactor OpenShift AI model to use model ID directly instead of UnifiedCompletionRequest --- .../openshiftai/OpenShiftAiService.java | 2 +- .../OpenShiftAiChatCompletionModel.java | 12 +-- .../OpenShiftAiChatCompletionModelTests.java | 80 +------------------ 3 files changed, 11 insertions(+), 83 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java index 72574ecb180ff..0e4e1ed37929f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java @@ -138,7 +138,7 @@ protected void doUnifiedCompletionInfer( } OpenShiftAiChatCompletionModel chatCompletionModel = (OpenShiftAiChatCompletionModel) model; - var overriddenModel = OpenShiftAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest()); + var overriddenModel = OpenShiftAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest().model()); var manager = new GenericRequestManager<>( getServiceComponents().threadPool(), overriddenModel, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java index 5af38dd653a87..a2e8277960dd3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java @@ -12,7 +12,6 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiModel; @@ -20,6 +19,7 @@ import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import java.util.Map; +import java.util.Objects; /** * Represents an OpenShift AI chat completion model. @@ -79,18 +79,18 @@ public OpenShiftAiChatCompletionModel( * If the request does not specify a model ID, the original model is returned. * * @param model the original OpenShiftAiChatCompletionModel - * @param request the UnifiedCompletionRequest containing potential overrides + * @param modelId the model ID specified in the request, which may override the original model's ID * @return a new OpenShiftAiChatCompletionModel with overridden settings or the original model ID if no overrides are specified */ - public static OpenShiftAiChatCompletionModel of(OpenShiftAiChatCompletionModel model, UnifiedCompletionRequest request) { - if (request.model() == null) { - // If no model ID is specified in the request, return the original model + public static OpenShiftAiChatCompletionModel of(OpenShiftAiChatCompletionModel model, String modelId) { + if (modelId == null || Objects.equals(model.getServiceSettings().modelId(), modelId)) { + // If no model ID is specified in the request, or if it matches the original model's ID, return the original model. return model; } var originalModelServiceSettings = model.getServiceSettings(); var overriddenServiceSettings = new OpenShiftAiChatCompletionServiceSettings( - request.model(), + modelId, originalModelServiceSettings.uri(), originalModelServiceSettings.rateLimitSettings() ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java index fb5c0189072c8..0e580f1371e65 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java @@ -9,12 +9,9 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; -import java.util.List; - import static org.hamcrest.Matchers.is; public class OpenShiftAiChatCompletionModelTests extends ESTestCase { @@ -38,90 +35,21 @@ public static OpenShiftAiChatCompletionModel createModelWithTaskType(String url, public void testOverrideWith_UnifiedCompletionRequest_KeepsSameModelId() { var model = createCompletionModel("url", "api_key", "model_name"); - var request = new UnifiedCompletionRequest( - List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), - "model_name", // same model - null, - null, - null, - null, - null, - null - ); - - var overriddenModel = OpenShiftAiChatCompletionModel.of(model, request); + var overriddenModel = OpenShiftAiChatCompletionModel.of(model, "model_name"); - assertThat(overriddenModel, is(model)); + assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name")); } public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() { var model = createCompletionModel("url", "api_key", "model_name"); - var request = new UnifiedCompletionRequest( - List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), - "different_model", // overriding model - null, - null, - null, - null, - null, - null - ); - - var overriddenModel = OpenShiftAiChatCompletionModel.of(model, request); + var overriddenModel = OpenShiftAiChatCompletionModel.of(model, "different_model"); assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); } - public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() { - var model = createCompletionModel("url", "api_key", null); - var request = new UnifiedCompletionRequest( - List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), - "different_model", // overriding model - null, - null, - null, - null, - null, - null - ); - - var overriddenModel = OpenShiftAiChatCompletionModel.of(model, request); - - assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); - } - - public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() { - var model = createCompletionModel("url", "api_key", null); - var request = new UnifiedCompletionRequest( - List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), - null, // not overriding model - null, - null, - null, - null, - null, - null - ); - - var overriddenModel = OpenShiftAiChatCompletionModel.of(model, request); - - assertNull(overriddenModel.getServiceSettings().modelId()); - } - public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { var model = createCompletionModel("url", "api_key", "model_name"); - var request = new UnifiedCompletionRequest( - List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), - null, // not overriding model - null, - null, - null, - null, - null, - null - ); - - var overriddenModel = OpenShiftAiChatCompletionModel.of(model, request); + var overriddenModel = OpenShiftAiChatCompletionModel.of(model, null); assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name")); } From 4fc255676f91bb6bcbd0103cecb8700288285226 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Wed, 5 Nov 2025 17:52:05 +0200 Subject: [PATCH 32/70] Update JinaAI rerank model token limit in rerankerWindowSize method --- .../inference/services/openshiftai/OpenShiftAiService.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java index 0e4e1ed37929f..d0b4f71017736 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java @@ -342,7 +342,7 @@ private OpenShiftAiModel createModelFromPersistent( @Override public int rerankerWindowSize(String modelId) { // OpenShift AI uses Cohere and JinaAI rerank protocols for reranking - // JinaAI rerank model has 8000 tokens limit length https://jina.ai/models/jina-reranker-v2-base-multilingual + // JinaAI rerank model has 131K tokens limit https://jina.ai/models/jina-reranker-v3/ // Cohere rerank model truncates at 4096 tokens https://docs.cohere.com/reference/rerank // We choose a conservative limit based on these two models // Using 1 token = 0.75 words as a rough estimate, we get 3072 words allowing for some headroom, we set the window size below 3072 From dbc1c56a140e49683811fb2e121ed9a74e2ddf50 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Thu, 6 Nov 2025 11:36:21 +0200 Subject: [PATCH 33/70] Enhance documentation for OpenShift AI models and add task settings handling in rerank model --- .../OpenShiftAiChatCompletionModel.java | 2 +- .../OpenShiftAiEmbeddingsModel.java | 2 +- .../rerank/OpenShiftAiRerankModel.java | 14 ++++ .../rerank/OpenShiftAiRerankModelTests.java | 69 +++++++++++++++++++ .../OpenShiftAiRerankTaskSettingsTests.java | 58 ++++++++++++++++ 5 files changed, 143 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java index a2e8277960dd3..181c25c5c04ec 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java @@ -23,7 +23,7 @@ /** * Represents an OpenShift AI chat completion model. - * This class extends the OpenShiftAiModel and provides specific configurations for chat completion tasks. + * This class extends the {@link OpenShiftAiModel} and provides specific configurations for chat completion tasks. */ public class OpenShiftAiChatCompletionModel extends OpenShiftAiModel { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java index 74ba91c9424d0..796cd60a932a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java @@ -23,7 +23,7 @@ /** * Represents an OpenShift AI embeddings model for inference. - * This class extends the OpenShiftAiModel and provides specific configurations and settings for embeddings tasks. + * This class extends the {@link OpenShiftAiModel} and provides specific configurations and settings for embeddings tasks. */ public class OpenShiftAiEmbeddingsModel extends OpenShiftAiModel { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModel.java index 5fb9c05042f89..94d9047e5d39a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModel.java @@ -19,9 +19,23 @@ import java.util.Map; +/** + * Represents an OpenShift AI rerank model. + * This class extends the {@link OpenShiftAiModel} and provides specific configurations for rerank tasks. + */ public class OpenShiftAiRerankModel extends OpenShiftAiModel { + + /** + * Creates a new {@link OpenShiftAiRerankModel} with updated task settings if they differ from the existing ones. + * @param model the existing OpenShift AI rerank model + * @param taskSettings the new task settings to apply + * @return a new {@link OpenShiftAiRerankModel} with updated task settings, or the original model if settings are unchanged + */ public static OpenShiftAiRerankModel of(OpenShiftAiRerankModel model, Map taskSettings) { var requestTaskSettings = OpenShiftAiRerankTaskSettings.fromMap(taskSettings); + if (requestTaskSettings.isEmpty() || requestTaskSettings.equals(model.getTaskSettings())) { + return model; + } return new OpenShiftAiRerankModel(model, OpenShiftAiRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java index c2f692b44e7c7..7e46a5124ad2f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java @@ -13,6 +13,13 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS; +import static org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankTaskSettings.TOP_N; +import static org.hamcrest.Matchers.is; + public class OpenShiftAiRerankModelTests extends ESTestCase { public static OpenShiftAiRerankModel createModel(String url, String apiKey, @Nullable String modelId) { @@ -35,4 +42,66 @@ public static OpenShiftAiRerankModel createModel( new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } + + public void testOverrideWith_SameParams_KeepsSameModel() { + testOverrideWith_KeepsSameModel(buildTaskSettingsMap(2, true)); + } + + public void testOverrideWith_EmptyParams_KeepsSameModel() { + testOverrideWith_KeepsSameModel(buildTaskSettingsMap(null, null)); + } + + private static void testOverrideWith_KeepsSameModel(Map taskSettings) { + var model = createModel("url", "api_key", "model_name", 2, true); + var overriddenModel = OpenShiftAiRerankModel.of(model, taskSettings); + + assertThat(overriddenModel.getTaskSettings().getTopN(), is(2)); + assertThat(overriddenModel.getTaskSettings().getReturnDocuments(), is(true)); + } + + public void testOverrideWith_DifferentParams_OverridesAllTaskSettings() { + testOverrideWith_DifferentParams(buildTaskSettingsMap(4, false), 4, false); + } + + public void testOverrideWith_DifferentParams_OverridesOnlyReturnDocuments() { + testOverrideWith_DifferentParams(buildTaskSettingsMap(null, false), 2, false); + } + + public void testOverrideWith_DifferentParams_OverridesOnlyTopN() { + testOverrideWith_DifferentParams(buildTaskSettingsMap(4, null), 4, true); + } + + public void testOverrideWith_DifferentParams_OverridesNullValues() { + var model = createModel("url", "api_key", "model_name", null, null); + var overriddenModel = OpenShiftAiRerankModel.of(model, buildTaskSettingsMap(4, false)); + + assertThat(overriddenModel.getTaskSettings().getTopN(), is(4)); + assertThat(overriddenModel.getTaskSettings().getReturnDocuments(), is(false)); + } + + private static void testOverrideWith_DifferentParams( + Map taskSettings, + int expectedTopN, + boolean expectedReturnDocuments + ) { + var model = createModel("url", "api_key", "model_name", 2, true); + var overriddenModel = OpenShiftAiRerankModel.of(model, taskSettings); + + assertThat(overriddenModel.getTaskSettings().getTopN(), is(expectedTopN)); + assertThat(overriddenModel.getTaskSettings().getReturnDocuments(), is(expectedReturnDocuments)); + } + + private static Map buildTaskSettingsMap(@Nullable Integer topN, @Nullable Boolean returnDocuments) { + final var map = new HashMap(); + + if (returnDocuments != null) { + map.put(RETURN_DOCUMENTS, returnDocuments); + } + + if (topN != null) { + map.put(TOP_N, topN); + } + + return map; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java index 2cff47df2588d..de12ea465587d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java @@ -8,15 +8,20 @@ package org.elasticsearch.xpack.inference.services.openshiftai.rerank; import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import java.io.IOException; import java.util.HashMap; import java.util.Map; +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; import static org.hamcrest.Matchers.containsString; public class OpenShiftAiRerankTaskSettingsTests extends AbstractBWCWireSerializationTestCase { @@ -97,6 +102,59 @@ public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings assertEquals(7, updatedSettings.getTopN().intValue()); } + public void testToXContent_WritesAllValues() throws IOException { + Integer topN = 2; + Boolean doReturnDocuments = true; + + testToXContent(topN, doReturnDocuments, """ + { + "top_n":2, + "return_documents":true + } + """); + } + + public void testToXContent_EmptyValues() throws IOException { + Integer topN = null; + Boolean doReturnDocuments = null; + + testToXContent(topN, doReturnDocuments, """ + {} + """); + } + + public void testToXContent_OnlyTopN() throws IOException { + Integer topN = 2; + Boolean doReturnDocuments = null; + + testToXContent(topN, doReturnDocuments, """ + { + "top_n":2 + } + """); + } + + public void testToXContent_OnlyReturnDocuments() throws IOException { + Integer topN = null; + Boolean doReturnDocuments = true; + + testToXContent(topN, doReturnDocuments, """ + { + "return_documents":true + } + """); + } + + private static void testToXContent(Integer topN, Boolean doReturnDocuments, String expectedString) throws IOException { + var taskSettings = new OpenShiftAiRerankTaskSettings(topN, doReturnDocuments); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + taskSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(expectedString)); + } + @Override protected Writeable.Reader instanceReader() { return OpenShiftAiRerankTaskSettings::new; From e9fbce778cb5308736a86cf0eb2b5ebdcd1f2e28 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Thu, 6 Nov 2025 14:23:43 +0200 Subject: [PATCH 34/70] Refactor OpenShift AI Rerank handler to use JinaAIResponseHandler, use overridden model --- .../action/OpenShiftAiActionCreator.java | 15 +++++++-------- .../action/OpenShiftAiActionCreatorTests.java | 3 +-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreator.java index 3913ced6f1a2b..23a6c56efacf9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreator.java @@ -19,8 +19,8 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.cohere.CohereResponseHandler; -import org.elasticsearch.xpack.inference.services.cohere.response.CohereRankedResponseEntity; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIResponseHandler; +import org.elasticsearch.xpack.inference.services.jinaai.response.JinaAIRerankResponseEntity; import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.openai.response.OpenAiEmbeddingsResponseEntity; import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModel; @@ -57,11 +57,10 @@ public class OpenShiftAiActionCreator implements OpenShiftAiActionVisitor { "OpenShift AI completion", OpenAiChatCompletionResponseEntity::fromResponse ); - // OpenShift AI Rerank task uses the same response format as Cohere, therefore we can reuse the CohereResponseHandler - private static final ResponseHandler RERANK_HANDLER = new CohereResponseHandler( + // OpenShift AI Rerank task uses the same response format as JinaAI, therefore we can reuse the JinaAIResponseHandler + private static final ResponseHandler RERANK_HANDLER = new JinaAIResponseHandler( "OpenShift AI rerank", - (request, response) -> CohereRankedResponseEntity.fromResponse(response), - false + (request, response) -> JinaAIRerankResponseEntity.fromResponse(response) ); private final Sender sender; @@ -122,11 +121,11 @@ public ExecutableAction create(OpenShiftAiRerankModel model, Map inputs.getChunks(), inputs.getReturnDocuments(), inputs.getTopN(), - model + overriddenModel ), QueryAndDocsInputs.class ); - var errorMessage = buildErrorMessage(TaskType.RERANK, model.getInferenceEntityId()); + var errorMessage = buildErrorMessage(TaskType.RERANK, overriddenModel.getInferenceEntityId()); return new SenderExecutableAction(sender, manager, errorMessage); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index fe04d818dcc02..051aef57814c1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -735,8 +735,7 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(thrownException.getMessage(), is(""" - Failed to send OpenShift AI rerank request from inference entity id [inferenceEntityId]. Cause: Failed to find required\ - field [results] in Cohere rerank response""")); + Failed to send OpenShift AI rerank request from inference entity id [inferenceEntityId]. Cause: Required [results]""")); } assertRerankActionCreator(documents); } From aeec397e50e14025e538ade9056083b032abec0d Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Thu, 6 Nov 2025 14:49:57 +0200 Subject: [PATCH 35/70] Fix parameter documentation for modelId in OpenShift AI service settings classes --- .../services/openshiftai/OpenShiftAiServiceSettings.java | 2 +- .../OpenShiftAiChatCompletionServiceSettings.java | 4 ++-- .../rerank/OpenShiftAiRerankServiceSettings.java | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java index 247af8f7bc3dd..3b5a9f5f62b63 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java @@ -51,7 +51,7 @@ protected OpenShiftAiServiceSettings(StreamInput in) throws IOException { /** * Constructs a new OpenShiftAiServiceSettings. * - * @param modelId the ID of the modelId + * @param modelId the ID of the model * @param uri the URI of the service * @param rateLimitSettings the rate limit settings for the service */ diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java index 2c49bc7bfc161..8e10daeca869a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java @@ -75,7 +75,7 @@ public OpenShiftAiChatCompletionServiceSettings(StreamInput in) throws IOExcepti /** * Constructs a new OpenShiftAiChatCompletionServiceSettings. * - * @param modelId the ID of the model ID + * @param modelId the ID of the model * @param uri the URI of the service * @param rateLimitSettings the rate limit settings for the service */ @@ -86,7 +86,7 @@ public OpenShiftAiChatCompletionServiceSettings(@Nullable String modelId, URI ur /** * Constructs a new OpenShiftAiChatCompletionServiceSettings. * - * @param modelId the ID of the model ID + * @param modelId the ID of the model * @param url the URL of the OpenShift AI service * @param rateLimitSettings the rate limit settings for the service */ diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java index 6e64e9f8babf6..c49c130da9f04 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java @@ -28,8 +28,8 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; /** - * Represents the settings for an OpenShift AI chat rerank service. - * This class encapsulates the model ID, URI, and rate limit settings for the OpenShift AI chat rerank service. + * Represents the settings for an OpenShift AI rerank service. + * This class encapsulates the model ID, URI, and rate limit settings for the OpenShift AI rerank service. */ public class OpenShiftAiRerankServiceSettings extends OpenShiftAiServiceSettings { public static final String NAME = "openshift_ai_rerank_service_settings"; @@ -87,7 +87,7 @@ public OpenShiftAiRerankServiceSettings(@Nullable String modelId, URI uri, @Null * Constructs a new OpenShiftAiRerankServiceSettings with the specified model ID and URL. * The rate limit settings will be set to the default value. * - * @param modelId the ID of the modelId + * @param modelId the ID of the model * @param url the URL of the service */ public OpenShiftAiRerankServiceSettings(@Nullable String modelId, String url, @Nullable RateLimitSettings rateLimitSettings) { From 88139369c79597fb8797c7fcd89edd918e34541b Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Thu, 6 Nov 2025 16:36:36 +0200 Subject: [PATCH 36/70] Refactor OpenShift AI service settings to streamline common settings extraction and validation --- .../OpenShiftAiServiceSettings.java | 56 +++++++++++++++++++ ...nShiftAiChatCompletionServiceSettings.java | 28 +++------- .../OpenShiftAiEmbeddingsServiceSettings.java | 26 ++------- .../OpenShiftAiRerankServiceSettings.java | 28 +++------- 4 files changed, 76 insertions(+), 62 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java index 3b5a9f5f62b63..1c668bb212904 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java @@ -8,21 +8,28 @@ package org.elasticsearch.xpack.inference.services.openshiftai; import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.io.IOException; import java.net.URI; +import java.util.Map; import java.util.Objects; +import java.util.function.Function; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; /** * Represents the settings for an OpenShift AI service. @@ -119,4 +126,53 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil rateLimitSettings.toXContent(builder, params); return builder; } + + /** + * Creates an instance of T from the provided map using the given factory function. + * @param map the map containing the service settings + * @param context the context for parsing configuration settings + * @param factory the factory function to create an instance of T + * @return an instance of T + * @param the type of {@link OpenShiftAiServiceSettings} to create + */ + protected static T fromMap( + Map map, + ConfigurationParseContext context, + Function factory + ) { + var validationException = new ValidationException(); + var commonServiceSettings = extractOpenShiftAiCommonServiceSettings(map, context, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return factory.apply(commonServiceSettings); + } + + /** + * Extracts common OpenShift AI service settings from the provided map. + * @param map the map containing the service settings + * @param context the context for parsing configuration settings + * @param validationException the validation exception to collect validation errors + * @return an instance of {@link OpenShiftAiCommonServiceSettings} + */ + protected static OpenShiftAiCommonServiceSettings extractOpenShiftAiCommonServiceSettings( + Map map, + ConfigurationParseContext context, + ValidationException validationException + ) { + var model = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = extractUri(map, URL, validationException); + var rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + OpenShiftAiService.NAME, + context + ); + return new OpenShiftAiCommonServiceSettings(model, uri, rateLimitSettings); + } + + protected record OpenShiftAiCommonServiceSettings(String model, URI uri, RateLimitSettings rateLimitSettings) {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java index 8e10daeca869a..9ab56ea9ec208 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettings.java @@ -10,9 +10,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; -import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiService; import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -21,11 +19,7 @@ import java.util.Map; import java.util.Objects; -import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; -import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; /** * Represents the settings for an OpenShift AI chat completion service. @@ -43,23 +37,15 @@ public class OpenShiftAiChatCompletionServiceSettings extends OpenShiftAiService * @throws ValidationException if required fields are missing or invalid */ public static OpenShiftAiChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { - ValidationException validationException = new ValidationException(); - - var model = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); - var uri = extractUri(map, URL, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of( + return fromMap( map, - DEFAULT_RATE_LIMIT_SETTINGS, - validationException, - OpenShiftAiService.NAME, - context + context, + commonServiceSettings -> new OpenShiftAiChatCompletionServiceSettings( + commonServiceSettings.model(), + commonServiceSettings.uri(), + commonServiceSettings.rateLimitSettings() + ) ); - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return new OpenShiftAiChatCompletionServiceSettings(model, uri, rateLimitSettings); } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java index 60fe0ce62f86d..384d11e32ebfb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java @@ -18,7 +18,6 @@ import org.elasticsearch.xpack.core.inference.InferenceUtils; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceUtils; -import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiService; import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -29,15 +28,11 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; -import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; -import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; /** * Settings for the OpenShift AI embeddings service. @@ -61,10 +56,8 @@ public class OpenShiftAiEmbeddingsServiceSettings extends OpenShiftAiServiceSett * @throws ValidationException if any required fields are missing or invalid */ public static OpenShiftAiEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { - ValidationException validationException = new ValidationException(); - - var model = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); - var uri = extractUri(map, URL, validationException); + var validationException = new ValidationException(); + var commonServiceSettings = extractOpenShiftAiCommonServiceSettings(map, context, validationException); var dimensions = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); var similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); var maxInputTokens = extractOptionalPositiveInteger( @@ -73,14 +66,7 @@ public static OpenShiftAiEmbeddingsServiceSettings fromMap(Map m ModelConfigurations.SERVICE_SETTINGS, validationException ); - var rateLimitSettings = RateLimitSettings.of( - map, - DEFAULT_RATE_LIMIT_SETTINGS, - validationException, - OpenShiftAiService.NAME, - context - ); - Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException); + var dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException); switch (context) { case REQUEST -> { if (dimensionsSetByUser != null) { @@ -103,12 +89,12 @@ public static OpenShiftAiEmbeddingsServiceSettings fromMap(Map m } return new OpenShiftAiEmbeddingsServiceSettings( - model, - uri, + commonServiceSettings.model(), + commonServiceSettings.uri(), dimensions, similarity, maxInputTokens, - rateLimitSettings, + commonServiceSettings.rateLimitSettings(), dimensionsSetByUser ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java index c49c130da9f04..f6cf673da33ce 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettings.java @@ -10,9 +10,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; -import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiService; import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -21,11 +19,7 @@ import java.util.Map; import java.util.Objects; -import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; -import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; /** * Represents the settings for an OpenShift AI rerank service. @@ -43,23 +37,15 @@ public class OpenShiftAiRerankServiceSettings extends OpenShiftAiServiceSettings * @throws ValidationException if required fields are missing or invalid */ public static OpenShiftAiRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { - ValidationException validationException = new ValidationException(); - - var model = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); - var uri = extractUri(map, URL, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of( + return fromMap( map, - DEFAULT_RATE_LIMIT_SETTINGS, - validationException, - OpenShiftAiService.NAME, - context + context, + commonServiceSettings -> new OpenShiftAiRerankServiceSettings( + commonServiceSettings.model(), + commonServiceSettings.uri(), + commonServiceSettings.rateLimitSettings() + ) ); - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return new OpenShiftAiRerankServiceSettings(model, uri, rateLimitSettings); } /** From 544ed20c2132c76457de0d8300e885e2a9bf754e Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Thu, 6 Nov 2025 17:05:55 +0200 Subject: [PATCH 37/70] Add check for empty or unchanged task settings in OpenShift AI rerank task settings --- .../openshiftai/rerank/OpenShiftAiRerankTaskSettings.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettings.java index b285a9c0a5072..d0c5fbbf3a0f1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettings.java @@ -70,6 +70,9 @@ public static OpenShiftAiRerankTaskSettings of( OpenShiftAiRerankTaskSettings originalSettings, OpenShiftAiRerankTaskSettings requestTaskSettings ) { + if (requestTaskSettings.isEmpty() || originalSettings.equals(requestTaskSettings)) { + return originalSettings; + } return new OpenShiftAiRerankTaskSettings( requestTaskSettings.getTopN() != null ? requestTaskSettings.getTopN() : originalSettings.getTopN(), requestTaskSettings.getReturnDocuments() != null @@ -86,6 +89,9 @@ public static OpenShiftAiRerankTaskSettings of( * @return a constructed {@link OpenShiftAiRerankTaskSettings} */ public static OpenShiftAiRerankTaskSettings of(@Nullable Integer topN, @Nullable Boolean returnDocuments) { + if (topN == null && returnDocuments == null) { + return EMPTY_SETTINGS; + } return new OpenShiftAiRerankTaskSettings(topN, returnDocuments); } From 8b5c4071cc7970c1feab8d5f6a71887f536e1da7 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Thu, 6 Nov 2025 21:44:29 +0200 Subject: [PATCH 38/70] Replace TIMEOUT constant with ESTestCase.TEST_REQUEST_TIMEOUT in OpenShift AI tests --- .../openshiftai/OpenShiftAiServiceTests.java | 11 ++++---- .../action/OpenShiftAiActionCreatorTests.java | 26 +++++++++++-------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index 625358f9cae2b..b8da86d8a6932 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -19,7 +19,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; @@ -35,6 +34,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -102,7 +102,6 @@ import static org.mockito.Mockito.mock; public class OpenShiftAiServiceTests extends AbstractInferenceServiceTests { - private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; @@ -398,7 +397,7 @@ public void testUnifiedCompletionInfer() throws Exception { listener ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(XContentHelper.stripWhitespace(""" { "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26", @@ -522,7 +521,7 @@ private void testStreamError(String expectedResponse) throws Exception { listener ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> { e = unwrapCause(e); @@ -688,7 +687,7 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio listener ); - var results = listener.actionGet(TIMEOUT); + var results = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat(results, hasSize(2)); { @@ -811,7 +810,7 @@ private InferenceEventsAssertion streamCompletion() throws Exception { listener ); - return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); + return InferenceEventsAssertion.assertThat(listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT)).hasFinishedStream(); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index 051aef57814c1..02b1e52ff7a30 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -38,7 +38,6 @@ import java.io.IOException; import java.util.List; import java.util.Map; -import java.util.concurrent.TimeUnit; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; @@ -61,7 +60,6 @@ public class OpenShiftAiActionCreatorTests extends ESTestCase { - private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; @@ -123,7 +121,7 @@ public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { listener ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); assertThat(webServer.requests(), hasSize(1)); @@ -187,7 +185,10 @@ public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat ); var failureCauseMessage = "Required [data]"; - var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT) + ); assertThat( thrownException.getMessage(), is( @@ -260,7 +261,7 @@ public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); assertThat(webServer.requests(), hasSize(1)); @@ -332,7 +333,10 @@ public void testCreate_OpenShiftAiChatCompletionModel_FailsFromInvalidResponseFo action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var failureCauseMessage = "Required [choices]"; - var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT) + ); assertThat( thrownException.getMessage(), is( @@ -420,7 +424,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC listener ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); assertThat(webServer.requests(), hasSize(2)); @@ -512,7 +516,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC listener ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); assertThat(webServer.requests(), hasSize(2)); @@ -589,7 +593,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { listener ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); assertThat(webServer.requests(), hasSize(1)); @@ -660,7 +664,7 @@ public void testCreate_OpenShiftAiRerankModel() throws IOException { listener ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat( result.asMap(), is( @@ -733,7 +737,7 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t listener ); - var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT)); assertThat(thrownException.getMessage(), is(""" Failed to send OpenShift AI rerank request from inference entity id [inferenceEntityId]. Cause: Required [results]""")); } From 9aaddfd51b4da4fca5d8df04b0bf610be081499e Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Thu, 6 Nov 2025 23:00:44 +0200 Subject: [PATCH 39/70] Refactor OpenShift AI service tests to use constants for URL, model ID, and API key --- .../openshiftai/OpenShiftAiServiceTests.java | 126 ++++++++---------- 1 file changed, 57 insertions(+), 69 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index b8da86d8a6932..754ffdffed051 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -60,7 +60,6 @@ import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.hamcrest.CoreMatchers; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -94,14 +93,19 @@ import static org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModelTests.createChatCompletionModel; import static org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionServiceSettingsTests.getServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; -import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.isA; import static org.mockito.Mockito.mock; public class OpenShiftAiServiceTests extends AbstractInferenceServiceTests { + private static final String URL = "http://www.abc.com"; + private static final String MODEL_ID = "model_id"; + private static final String USER_ROLE = "user"; + private static final String API_KEY = "secret"; + private static final String INFERENCE_ID = "id"; private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; @@ -168,19 +172,19 @@ private static void assertModel(Model model, TaskType taskType, boolean modelInc private static void assertTextEmbeddingModel(Model model, boolean modelIncludesSecrets) { var openShiftAiModel = assertCommonModelFields(model, modelIncludesSecrets); - assertThat(openShiftAiModel.getTaskType(), Matchers.is(TaskType.TEXT_EMBEDDING)); + assertThat(openShiftAiModel.getTaskType(), is(TaskType.TEXT_EMBEDDING)); } private static OpenShiftAiModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) { assertThat(model, instanceOf(OpenShiftAiModel.class)); var openShiftAiModel = (OpenShiftAiModel) model; - assertThat(openShiftAiModel.getServiceSettings().modelId(), is("model_id")); - assertThat(openShiftAiModel.getServiceSettings().uri.toString(), Matchers.is("http://www.abc.com")); - assertThat(openShiftAiModel.getTaskSettings(), Matchers.is(EmptyTaskSettings.INSTANCE)); + assertThat(openShiftAiModel.getServiceSettings().modelId(), is(MODEL_ID)); + assertThat(openShiftAiModel.getServiceSettings().uri.toString(), is(URL)); + assertThat(openShiftAiModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); if (modelIncludesSecrets) { - assertThat(openShiftAiModel.getSecretSettings().apiKey(), Matchers.is(new SecureString("secret".toCharArray()))); + assertThat(openShiftAiModel.getSecretSettings().apiKey(), is(new SecureString(API_KEY.toCharArray()))); } return openShiftAiModel; @@ -188,12 +192,12 @@ private static OpenShiftAiModel assertCommonModelFields(Model model, boolean mod private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) { var openShiftAiModel = assertCommonModelFields(model, modelIncludesSecrets); - assertThat(openShiftAiModel.getTaskType(), Matchers.is(TaskType.COMPLETION)); + assertThat(openShiftAiModel.getTaskType(), is(TaskType.COMPLETION)); } private static void assertChatCompletionModel(Model model, boolean modelIncludesSecrets) { var openShiftAiModel = assertCommonModelFields(model, modelIncludesSecrets); - assertThat(openShiftAiModel.getTaskType(), Matchers.is(TaskType.CHAT_COMPLETION)); + assertThat(openShiftAiModel.getTaskType(), is(TaskType.CHAT_COMPLETION)); } public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { @@ -202,9 +206,7 @@ public static SenderService createService(ThreadPool threadPool, HttpClientManag } private static Map createServiceSettingsMap(TaskType taskType) { - Map settingsMap = new HashMap<>( - Map.of(ServiceFields.URL, "http://www.abc.com", ServiceFields.MODEL_ID, "model_id") - ); + Map settingsMap = new HashMap<>(Map.of(ServiceFields.URL, URL, ServiceFields.MODEL_ID, MODEL_ID)); if (taskType == TaskType.TEXT_EMBEDDING) { settingsMap.putAll( @@ -223,27 +225,17 @@ private static Map createServiceSettingsMap(TaskType taskType) { } private static Map createSecretSettingsMap() { - return new HashMap<>(Map.of("api_key", "secret")); + return new HashMap<>(Map.of("api_key", API_KEY)); } private static OpenShiftAiEmbeddingsModel createInternalEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure) { - var inferenceId = "inference_id"; - return new OpenShiftAiEmbeddingsModel( - inferenceId, + INFERENCE_ID, TaskType.TEXT_EMBEDDING, OpenShiftAiService.NAME, - new OpenShiftAiEmbeddingsServiceSettings( - "model_id", - "http://www.abc.com", - 1536, - similarityMeasure, - 512, - new RateLimitSettings(10_000), - true - ), + new OpenShiftAiEmbeddingsServiceSettings(MODEL_ID, URL, 1536, similarityMeasure, 512, new RateLimitSettings(10_000), true), createRandomChunkingSettings(), - new DefaultSecretSettings(new SecureString("secret".toCharArray())) + new DefaultSecretSettings(new SecureString(API_KEY.toCharArray())) ); } @@ -267,19 +259,15 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP assertThat(model, instanceOf(OpenShiftAiEmbeddingsModel.class)); var embeddingsModel = (OpenShiftAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is(URL)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY)); }, e -> fail("parse request should not fail " + e.getMessage())); service.parseRequestConfig( - "id", + INFERENCE_ID, TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - getServiceSettingsMap("model", "url"), - createRandomChunkingSettingsMap(), - getSecretSettingsMap("secret") - ), + getRequestConfigMap(getServiceSettingsMap(MODEL_ID, URL), createRandomChunkingSettingsMap(), getSecretSettingsMap(API_KEY)), modelVerificationActionListener ); } @@ -291,49 +279,43 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsN assertThat(model, instanceOf(OpenShiftAiEmbeddingsModel.class)); var embeddingsModel = (OpenShiftAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is(URL)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY)); }, e -> fail("parse request should not fail " + e.getMessage())); service.parseRequestConfig( - "id", + INFERENCE_ID, TaskType.TEXT_EMBEDDING, - getRequestConfigMap(getServiceSettingsMap("model", "url"), getSecretSettingsMap("secret")), + getRequestConfigMap(getServiceSettingsMap(MODEL_ID, URL), getSecretSettingsMap(API_KEY)), modelVerificationActionListener ); } } public void testParseRequestConfig_WithoutModelId_Success() throws IOException { - var url = "url"; - var secret = "secret"; - try (var service = createService()) { ActionListener modelVerificationListener = ActionListener.wrap(m -> { assertThat(m, instanceOf(OpenShiftAiChatCompletionModel.class)); var chatCompletionModel = (OpenShiftAiChatCompletionModel) m; - assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is(url)); + assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is(URL)); assertNull(chatCompletionModel.getServiceSettings().modelId()); - assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is(API_KEY)); }, e -> fail("parse request should not fail " + e.getMessage())); service.parseRequestConfig( - "id", + INFERENCE_ID, TaskType.CHAT_COMPLETION, - getRequestConfigMap(getServiceSettingsMap(null, url), getSecretSettingsMap(secret)), + getRequestConfigMap(getServiceSettingsMap(null, URL), getSecretSettingsMap(API_KEY)), modelVerificationListener ); } } public void testParseRequestConfig_WithoutUrl_ThrowsException() throws IOException { - var model = "model"; - var secret = "secret"; - try (var service = createService()) { ActionListener modelVerificationListener = ActionListener.wrap( m -> fail("Expected exception, but got model: " + m), @@ -347,9 +329,9 @@ public void testParseRequestConfig_WithoutUrl_ThrowsException() throws IOExcepti ); service.parseRequestConfig( - "id", + INFERENCE_ID, TaskType.CHAT_COMPLETION, - getRequestConfigMap(getServiceSettingsMap(model, null), getSecretSettingsMap(secret)), + getRequestConfigMap(getServiceSettingsMap(MODEL_ID, null), getSecretSettingsMap(API_KEY)), modelVerificationListener ); } @@ -386,12 +368,14 @@ public void testUnifiedCompletionInfer() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - var model = createChatCompletionModel(getUrl(webServer), "secret", "model"); + var model = createChatCompletionModel(getUrl(webServer), API_KEY, MODEL_ID); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( model, UnifiedCompletionRequest.of( - List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + List.of( + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), USER_ROLE, null, null) + ) ), InferenceAction.Request.DEFAULT_TIMEOUT, listener @@ -426,12 +410,14 @@ public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - var model = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); + var model = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), API_KEY, MODEL_ID); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( model, UnifiedCompletionRequest.of( - List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + List.of( + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), USER_ROLE, null, null) + ) ), InferenceAction.Request.DEFAULT_TIMEOUT, ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> { @@ -510,12 +496,14 @@ public void testInfer_StreamRequest() throws Exception { private void testStreamError(String expectedResponse) throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - var model = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); + var model = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), API_KEY, MODEL_ID); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( model, UnifiedCompletionRequest.of( - List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + List.of( + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), USER_ROLE, null, null) + ) ), InferenceAction.Request.DEFAULT_TIMEOUT, listener @@ -597,7 +585,7 @@ public void testSupportsStreaming() throws IOException { public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException { try (var service = createService()) { - var secretSettings = getSecretSettingsMap("secret"); + var secretSettings = getSecretSettingsMap(API_KEY); secretSettings.put("extra_key", "value"); var config = getRequestConfigMap(getEmbeddingsServiceSettingsMap(), secretSettings); @@ -613,12 +601,12 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSe } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); + service.parseRequestConfig(INFERENCE_ID, TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { - var model = OpenShiftAiEmbeddingsModelTests.createModel(getUrl(webServer), "api_key", "model", 1234, false, 1536, null); + var model = OpenShiftAiEmbeddingsModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID, 1234, false, 1536, null); testChunkedInfer(model); } @@ -626,8 +614,8 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { public void testChunkedInfer_ChunkingSettingsSet() throws IOException { var model = OpenShiftAiEmbeddingsModelTests.createModel( getUrl(webServer), - "api_key", - "model", + API_KEY, + MODEL_ID, 1234, false, 1536, @@ -691,7 +679,7 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); + assertThat(results.get(0), Matchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); @@ -703,7 +691,7 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio ); } { - assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); + assertThat(results.get(1), Matchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); @@ -721,12 +709,12 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer api_key")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), Matchers.is(2)); - assertThat(requestMap.get("input"), Matchers.is(List.of("abc", "def"))); - assertThat(requestMap.get("model"), Matchers.is("model")); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("input"), is(List.of("abc", "def"))); + assertThat(requestMap.get("model"), is(MODEL_ID)); } } @@ -795,7 +783,7 @@ public void testGetConfiguration() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - var model = OpenShiftAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); + var model = OpenShiftAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), API_KEY, MODEL_ID); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( model, @@ -842,7 +830,7 @@ private Map getRequestConfigMap(Map serviceSetti } private static Map getEmbeddingsServiceSettingsMap() { - return buildServiceSettingsMap("id", "url", SimilarityMeasure.COSINE.toString(), null, null, null); + return buildServiceSettingsMap(INFERENCE_ID, URL, SimilarityMeasure.COSINE.toString(), null, null, null); } @Override From bd3cf05836ad5974dbea544566eeb84eb9c6a91f Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 7 Nov 2025 00:06:30 +0200 Subject: [PATCH 40/70] Add assertions for OpenShift AI embeddings model service settings --- .../services/openshiftai/OpenShiftAiServiceTests.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index 754ffdffed051..ebd56afd4e283 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -173,6 +173,11 @@ private static void assertTextEmbeddingModel(Model model, boolean modelIncludesS var openShiftAiModel = assertCommonModelFields(model, modelIncludesSecrets); assertThat(openShiftAiModel.getTaskType(), is(TaskType.TEXT_EMBEDDING)); + assertThat(model, instanceOf(OpenShiftAiEmbeddingsModel.class)); + var embeddingsModel = (OpenShiftAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1536)); + assertThat(embeddingsModel.getServiceSettings().similarity(), is(SimilarityMeasure.COSINE)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); } private static OpenShiftAiModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) { From 0902e9af61eb97debd0cd0592682b0685be23116 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 7 Nov 2025 13:30:44 +0200 Subject: [PATCH 41/70] Refactor assertions in OpenShift AI tests to use Hamcrest matchers --- .../openshiftai/OpenShiftAiServiceTests.java | 20 +++++++++++-------- .../OpenShiftAiEmbeddingsRequestTests.java | 2 +- .../OpenShiftAiRerankTaskSettingsTests.java | 5 +++-- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index ebd56afd4e283..b8535a1f209a1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -42,6 +42,7 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; @@ -83,7 +84,6 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; -import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; @@ -259,6 +259,7 @@ public void shutdown() throws IOException { } public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + var chunkingSettingsMap = createRandomChunkingSettings(); try (var service = createService()) { ActionListener modelVerificationActionListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(OpenShiftAiEmbeddingsModel.class)); @@ -266,13 +267,14 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP var embeddingsModel = (OpenShiftAiEmbeddingsModel) model; assertThat(embeddingsModel.getServiceSettings().uri().toString(), is(URL)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings().asMap(), is(chunkingSettingsMap.asMap())); assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY)); }, e -> fail("parse request should not fail " + e.getMessage())); service.parseRequestConfig( INFERENCE_ID, TaskType.TEXT_EMBEDDING, - getRequestConfigMap(getServiceSettingsMap(MODEL_ID, URL), createRandomChunkingSettingsMap(), getSecretSettingsMap(API_KEY)), + getRequestConfigMap(getServiceSettingsMap(MODEL_ID, URL), chunkingSettingsMap.asMap(), getSecretSettingsMap(API_KEY)), modelVerificationActionListener ); } @@ -285,7 +287,7 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsN var embeddingsModel = (OpenShiftAiEmbeddingsModel) model; assertThat(embeddingsModel.getServiceSettings().uri().toString(), is(URL)); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), is(ChunkingSettingsBuilder.DEFAULT_SETTINGS)); assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY)); }, e -> fail("parse request should not fail " + e.getMessage())); @@ -451,7 +453,7 @@ public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception { } }), latch::countDown) ); - assertTrue(latch.await(30, TimeUnit.SECONDS)); + assertThat(latch.await(30, TimeUnit.SECONDS), is(true)); } } @@ -688,11 +690,12 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); - assertTrue( + assertThat( Arrays.equals( new float[] { 0.0089111328125f, -0.007049560546875f }, ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() - ) + ), + is(true) ); } { @@ -700,11 +703,12 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); - assertTrue( + assertThat( Arrays.equals( new float[] { -0.008544921875f, -0.0230712890625f }, ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() - ) + ), + is(true) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java index d01885bde188e..98113774654b8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java @@ -87,7 +87,7 @@ public void testIsTruncated_ReturnsTrue() { assertFalse(request.getTruncationInfo()[0]); var truncatedRequest = request.truncate(); - assertTrue(truncatedRequest.getTruncationInfo()[0]); + assertThat(truncatedRequest.getTruncationInfo()[0], is(true)); } private HttpPost validateRequestUrlAndContentType(HttpRequest request) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java index de12ea465587d..11bf2660b5f28 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java @@ -23,6 +23,7 @@ import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; public class OpenShiftAiRerankTaskSettingsTests extends AbstractBWCWireSerializationTestCase { public static OpenShiftAiRerankTaskSettings createRandom() { @@ -35,8 +36,8 @@ public static OpenShiftAiRerankTaskSettings createRandom() { public void testFromMap_WithValidValues_ReturnsSettings() { Map taskMap = Map.of(OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, true, OpenShiftAiRerankTaskSettings.TOP_N, 5); var settings = OpenShiftAiRerankTaskSettings.fromMap(new HashMap<>(taskMap)); - assertTrue(settings.getReturnDocuments()); - assertEquals(5, settings.getTopN().intValue()); + assertThat(settings.getReturnDocuments(), is(true)); + assertThat(settings.getTopN().intValue(), is(5)); } public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() { From a1a56b6f2e723099924f0d28f343895892bbf164 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 7 Nov 2025 14:29:16 +0200 Subject: [PATCH 42/70] Update assertions in OpenShift AI chat completion model tests to use sameInstance matcher --- .../OpenShiftAiChatCompletionModelTests.java | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java index 0e580f1371e65..6604d473e945a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java @@ -13,8 +13,14 @@ import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; public class OpenShiftAiChatCompletionModelTests extends ESTestCase { + + private static final String MODEL_ID = "model_name"; + private static final String API_KEY = "api_key"; + private static final String URL = "url"; + public static OpenShiftAiChatCompletionModel createCompletionModel(String url, String apiKey, String modelName) { return createModelWithTaskType(url, apiKey, modelName, TaskType.COMPLETION); } @@ -34,23 +40,30 @@ public static OpenShiftAiChatCompletionModel createModelWithTaskType(String url, } public void testOverrideWith_UnifiedCompletionRequest_KeepsSameModelId() { - var model = createCompletionModel("url", "api_key", "model_name"); - var overriddenModel = OpenShiftAiChatCompletionModel.of(model, "model_name"); + var model = createCompletionModel(URL, API_KEY, MODEL_ID); + var overriddenModel = OpenShiftAiChatCompletionModel.of(model, MODEL_ID); - assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name")); + assertThat(overriddenModel, is(sameInstance(model))); } public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() { - var model = createCompletionModel("url", "api_key", "model_name"); + var model = createCompletionModel(URL, API_KEY, MODEL_ID); var overriddenModel = OpenShiftAiChatCompletionModel.of(model, "different_model"); assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); } public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { - var model = createCompletionModel("url", "api_key", "model_name"); + var model = createCompletionModel(URL, API_KEY, MODEL_ID); + var overriddenModel = OpenShiftAiChatCompletionModel.of(model, null); + + assertThat(overriddenModel, is(sameInstance(model))); + } + + public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() { + var model = createCompletionModel(URL, API_KEY, null); var overriddenModel = OpenShiftAiChatCompletionModel.of(model, null); - assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name")); + assertThat(overriddenModel, is(sameInstance(model))); } } From 6e7fbb449811f447453ebfb1983f69c37fad1506 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 7 Nov 2025 14:51:48 +0200 Subject: [PATCH 43/70] Fix formatting issues in error messages for OpenShift AI chat completion response tests --- ...iftAiChatCompletionResponseHandlerTests.java | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java index 114b6563382fe..7e9cc24781c6f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java @@ -60,8 +60,9 @@ public void testFailBadRequest() throws IOException { var responseJson = XContentHelper.stripWhitespace(""" { "object": "error", - "message": "[{'type': 'missing', 'loc': ('body', 'messages'), 'msg': 'Field required', 'input': {'model': 'llama-31-8b-ins\ - truct', '1messages': [{'role': 'user', 'content': 'What is deep learning?'}], 'max_tokens': 2, 'stream': True}}]", + "message": "[{'type': 'missing', 'loc': ('body', 'messages'), 'msg': 'Field required', \ + 'input': {'model': 'llama-31-8b-instruct', 'messages': [{'role': 'user', 'content': 'What is deep learning?'}], \ + 'max_tokens': 2, 'stream': True}}]", "type": "Bad Request", "param": null, "code": 400 @@ -74,10 +75,10 @@ public void testFailBadRequest() throws IOException { { "error": { "code": "bad_request", - "message": "Received a bad request status code for request from inference entity id [id] status [400].\ - Error message: [{\\"object\\":\\"error\\",\\"message\\":\\"[{'type': 'missing', 'loc': ('body', 'messages'), 'msg': 'Field r\ - equired', 'input': {'model': 'llama-31-8b-ins truct', '1messages': [{'role': 'user', 'content': 'What is deep learning?'}]\ - , 'max_tokens': 2, 'stream': True}}]\\",\\"type\\":\\"Bad Request\\",\\"param\\":null,\\"code\\":400}]", + "message": "Received a bad request status code for request from inference entity id [id] status [400]. Error message: \ + [{\\"object\\":\\"error\\",\\"message\\":\\"[{'type': 'missing', 'loc': ('body', 'messages'), 'msg': 'Field required', \ + 'input': {'model': 'llama-31-8b-instruct', 'messages': [{'role': 'user', 'content': 'What is deep learning?'}], \ + 'max_tokens': 2, 'stream': True}}]\\",\\"type\\":\\"Bad Request\\",\\"param\\":null,\\"code\\":400}]", "type": "openshift_ai_error" } } @@ -95,8 +96,8 @@ public void testFailValidationWithInvalidJson() throws IOException { { "error": { "code": "bad_request", - "message": "Received a server error status code for request from inference entity id [id] status [500]. Error message: \ - [what? this isn't a json\\n]", + "message": "Received a server error status code for request from inference entity id [id] status [500]. \ + Error message: [what? this isn't a json\\n]", "type": "openshift_ai_error" } } From 8c5524cfdaee592559cc9c10429b8248aff371de Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 7 Nov 2025 19:55:04 +0200 Subject: [PATCH 44/70] Refactor OpenShift AI tests to use constants for model ID, API key, and input values --- .../action/OpenShiftAiActionCreatorTests.java | 93 +++++++++---------- ...tAiChatCompletionResponseHandlerTests.java | 18 ++-- ...tAiChatCompletionServiceSettingsTests.java | 10 +- ...ShiftAiEmbeddingsServiceSettingsTests.java | 19 ++-- 4 files changed, 68 insertions(+), 72 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index 02b1e52ff7a30..4379ad99fd216 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -60,6 +60,11 @@ public class OpenShiftAiActionCreatorTests extends ESTestCase { + private static final String MODEL_ID = "model"; + private static final String API_KEY = "secret"; + private static final String QUERY = "popular name"; + private static final String INPUT = "abc"; + public static final String USER_ROLE = "user"; private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; @@ -110,13 +115,13 @@ public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createModel(getUrl(webServer), "secret", "model"); + var model = createModel(getUrl(webServer), API_KEY, MODEL_ID); var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(INPUT), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -130,12 +135,12 @@ public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(2)); - assertThat(requestMap.get("input"), is(List.of("abc"))); - assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("input"), is(List.of(INPUT))); + assertThat(requestMap.get("model"), is(MODEL_ID)); } } @@ -173,13 +178,13 @@ public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createModel(getUrl(webServer), "secret", "model"); + var model = createModel(getUrl(webServer), API_KEY, MODEL_ID); var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(INPUT), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -206,12 +211,12 @@ public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(2)); - assertThat(requestMap.get("input"), is(List.of("abc"))); - assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("input"), is(List.of(INPUT))); + assertThat(requestMap.get("model"), is(MODEL_ID)); } } @@ -254,12 +259,12 @@ public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createCompletionModel(getUrl(webServer), "secret", "model"); + var model = createCompletionModel(getUrl(webServer), API_KEY, MODEL_ID); var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(INPUT)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); @@ -270,12 +275,12 @@ public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { assertNull(request.getUri().getQuery()); assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters())); - assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(4)); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); - assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", USER_ROLE, "content", INPUT)))); + assertThat(requestMap.get("model"), is(MODEL_ID)); assertThat(requestMap.get("n"), is(1)); assertThat(requestMap.get("stream"), is(false)); } @@ -325,12 +330,12 @@ public void testCreate_OpenShiftAiChatCompletionModel_FailsFromInvalidResponseFo """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createCompletionModel(getUrl(webServer), "secret", "model"); + var model = createCompletionModel(getUrl(webServer), API_KEY, MODEL_ID); var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(INPUT)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var failureCauseMessage = "Required [choices]"; var thrownException = expectThrows( @@ -354,12 +359,12 @@ public void testCreate_OpenShiftAiChatCompletionModel_FailsFromInvalidResponseFo webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(4)); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); - assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", USER_ROLE, "content", INPUT)))); + assertThat(requestMap.get("model"), is(MODEL_ID)); assertThat(requestMap.get("n"), is(1)); assertThat(requestMap.get("stream"), is(false)); } @@ -413,7 +418,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC webServer.enqueue(new MockResponse().setResponseCode(413).setBody(responseJsonContentTooLarge)); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createModel(getUrl(webServer), "secret", "model"); + var model = createModel(getUrl(webServer), API_KEY, MODEL_ID); var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); var action = actionCreator.create(model); @@ -434,12 +439,12 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(2)); assertThat(requestMap.get("input"), is(List.of("abcd"))); - assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("model"), is(MODEL_ID)); } { assertNull(webServer.requests().get(1).getUri().getQuery()); @@ -447,12 +452,12 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC webServer.requests().get(1).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(1).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(1).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().get(1).getBody()); assertThat(requestMap.size(), is(2)); assertThat(requestMap.get("input"), is(List.of("ab"))); - assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("model"), is(MODEL_ID)); } } } @@ -505,7 +510,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC webServer.enqueue(new MockResponse().setResponseCode(400).setBody(responseJsonContentTooLarge)); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createModel(getUrl(webServer), "secret", "model"); + var model = createModel(getUrl(webServer), API_KEY, MODEL_ID); var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); var action = actionCreator.create(model); @@ -526,12 +531,12 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(2)); assertThat(requestMap.get("input"), is(List.of("abcd"))); - assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("model"), is(MODEL_ID)); } { assertNull(webServer.requests().get(1).getUri().getQuery()); @@ -539,12 +544,12 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC webServer.requests().get(1).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(1).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(1).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().get(1).getBody()); assertThat(requestMap.size(), is(2)); assertThat(requestMap.get("input"), is(List.of("ab"))); - assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("model"), is(MODEL_ID)); } } } @@ -582,7 +587,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); // truncated to 1 token = 3 characters - var model = createModel(getUrl(webServer), "secret", "model", 1); + var model = createModel(getUrl(webServer), API_KEY, MODEL_ID, 1); var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); var action = actionCreator.create(model); @@ -602,12 +607,12 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(2)); assertThat(requestMap.get("input"), is(List.of("sup"))); - assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("model"), is(MODEL_ID)); } } @@ -650,7 +655,7 @@ public void testCreate_OpenShiftAiRerankModel() throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), "secret", "model"); + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator()) @@ -658,11 +663,7 @@ public void testCreate_OpenShiftAiRerankModel() throws IOException { var action = actionCreator.create(model, null); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute( - new QueryAndDocsInputs("popular name", documents, null, null, false), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new QueryAndDocsInputs(QUERY, documents, null, null, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat( @@ -723,7 +724,7 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), "secret", "model"); + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator()) @@ -731,11 +732,7 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t var action = actionCreator.create(model, null); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute( - new QueryAndDocsInputs("popular name", documents, null, null, false), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new QueryAndDocsInputs(QUERY, documents, null, null, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT)); assertThat(thrownException.getMessage(), is(""" @@ -751,13 +748,13 @@ private void assertRerankActionCreator(List documents) throws IOExceptio webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("documents"), is(documents)); - assertThat(requestMap.get("model"), is("model")); - assertThat(requestMap.get("query"), is("popular name")); + assertThat(requestMap.get("model"), is(MODEL_ID)); + assertThat(requestMap.get("query"), is(QUERY)); assertThat(requestMap.get("top_n"), is(2)); assertThat(requestMap.get("return_documents"), is(true)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java index 7e9cc24781c6f..972d7297a7a1f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java @@ -31,6 +31,8 @@ import static org.mockito.Mockito.when; public class OpenShiftAiChatCompletionResponseHandlerTests extends ESTestCase { + private static final String URL = "https://api.openshift.ai/v1/chat/completions"; + private static final String INFERENCE_ID = "id"; private final OpenShiftAiChatCompletionResponseHandler responseHandler = new OpenShiftAiChatCompletionResponseHandler( "chat completions", (a, b) -> mock() @@ -49,11 +51,11 @@ public void testFailNotFound() throws IOException { { "error" : { "code" : "not_found", - "message" : "Resource not found at [https://api.llama.ai/v1/chat/completions] for request from inference entity id [id] \ + "message" : "Resource not found at [%s] for request from inference entity id [%s] \ status [404]. Error message: [{\\"detail\\":\\"Not Found\\"}]", "type" : "openshift_ai_error" } - }"""))); + }""".formatted(URL, INFERENCE_ID)))); } public void testFailBadRequest() throws IOException { @@ -75,14 +77,14 @@ public void testFailBadRequest() throws IOException { { "error": { "code": "bad_request", - "message": "Received a bad request status code for request from inference entity id [id] status [400]. Error message: \ + "message": "Received a bad request status code for request from inference entity id [%s] status [400]. Error message: \ [{\\"object\\":\\"error\\",\\"message\\":\\"[{'type': 'missing', 'loc': ('body', 'messages'), 'msg': 'Field required', \ 'input': {'model': 'llama-31-8b-instruct', 'messages': [{'role': 'user', 'content': 'What is deep learning?'}], \ 'max_tokens': 2, 'stream': True}}]\\",\\"type\\":\\"Bad Request\\",\\"param\\":null,\\"code\\":400}]", "type": "openshift_ai_error" } } - """))); + """.formatted(INFERENCE_ID)))); } public void testFailValidationWithInvalidJson() throws IOException { @@ -96,12 +98,12 @@ public void testFailValidationWithInvalidJson() throws IOException { { "error": { "code": "bad_request", - "message": "Received a server error status code for request from inference entity id [id] status [500]. \ + "message": "Received a server error status code for request from inference entity id [%s] status [500]. \ Error message: [what? this isn't a json\\n]", "type": "openshift_ai_error" } } - """))); + """.formatted(INFERENCE_ID)))); } private String invalidResponseJson(String responseJson, int statusCode) throws IOException { @@ -125,9 +127,9 @@ private Exception invalidResponse(String responseJson, int statusCode) { private static Request mockRequest() throws URISyntaxException { var request = mock(Request.class); - when(request.getInferenceEntityId()).thenReturn("id"); + when(request.getInferenceEntityId()).thenReturn(INFERENCE_ID); when(request.isStreaming()).thenReturn(true); - when(request.getURI()).thenReturn(new URI("https://api.llama.ai/v1/chat/completions")); + when(request.getURI()).thenReturn(new URI(URL)); return request; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java index 5df17de6d190e..c3ebf6cfd67b5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java @@ -125,13 +125,13 @@ public void testToXContent_WritesAllValues() throws IOException { String xContentResult = Strings.toString(builder); var expected = XContentHelper.stripWhitespace(""" { - "model_id": "some model", - "url": "https://www.elastic.co", + "model_id": "%s", + "url": "%s", "rate_limit": { "requests_per_minute": 2 } } - """); + """.formatted(MODEL_ID, CORRECT_URL)); assertThat(xContentResult, is(expected)); } @@ -147,12 +147,12 @@ public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws String xContentResult = Strings.toString(builder); var expected = XContentHelper.stripWhitespace(""" { - "url": "https://www.elastic.co", + "url": "%s", "rate_limit": { "requests_per_minute": 3000 } } - """); + """.formatted(CORRECT_URL)); assertThat(xContentResult, is(expected)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java index 67dc7d95a5a0e..59ae564773ccd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java @@ -38,6 +38,7 @@ public class OpenShiftAiEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase { private static final String MODEL_ID = "some model"; private static final String CORRECT_URL = "https://www.elastic.co"; + private static final String INVALID_URL = "^^^"; private static final int DIMENSIONS = 384; private static final SimilarityMeasure SIMILARITY_MEASURE = SimilarityMeasure.DOT_PRODUCT; private static final int MAX_INPUT_TOKENS = 128; @@ -152,7 +153,7 @@ public void testFromMap_InvalidUrl_ThrowsException() { () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( MODEL_ID, - "^^^", + INVALID_URL, SIMILARITY_MEASURE.toString(), DIMENSIONS, MAX_INPUT_TOKENS, @@ -162,13 +163,9 @@ public void testFromMap_InvalidUrl_ThrowsException() { ConfigurationParseContext.PERSISTENT ) ); - assertThat( - thrownException.getMessage(), - containsString( - "Validation Failed: 1: [service_settings] Invalid url [^^^] received for field [url]. " - + "Error: unable to parse url [^^^]. Reason: Illegal character in path;" - ) - ); + assertThat(thrownException.getMessage(), containsString(""" + Validation Failed: 1: [service_settings] Invalid url [%s] received for field [url]. \ + Error: unable to parse url [%s]. Reason: Illegal character in path;""".formatted(INVALID_URL, INVALID_URL))); } public void testFromMap_NoSimilarity_Success() { @@ -549,8 +546,8 @@ public void testToXContent_WritesAllValues() throws IOException { assertThat(xContentResult, CoreMatchers.is(XContentHelper.stripWhitespace(""" { - "model_id": "some model", - "url": "https://www.elastic.co", + "model_id": "%s", + "url": "%s", "rate_limit": { "requests_per_minute": 3 }, @@ -559,7 +556,7 @@ public void testToXContent_WritesAllValues() throws IOException { "max_input_tokens": 128, "dimensions_set_by_user": false } - """))); + """.formatted(MODEL_ID, CORRECT_URL)))); } public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException { From e9b5d97acc458c72ba1f6b8de8c9e136ed06d436 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 7 Nov 2025 20:17:29 +0200 Subject: [PATCH 45/70] Refactor OpenShift AI chat completion tests to use constants for URL, model ID, user role, and API key --- ...OpenShiftAiChatCompletionRequestTests.java | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java index ce2c8a946e3a6..d5f580720cc00 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java @@ -20,40 +20,47 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; public class OpenShiftAiChatCompletionRequestTests extends ESTestCase { + + private static final String URL = "url"; + private static final String MODEL_ID = "model"; + private static final String USER_ROLE = "user"; + private static final String API_KEY = "secret"; + public void testCreateRequest_WithStreaming() throws IOException { String input = randomAlphaOfLength(15); - var request = createRequest("model", "url", "secret", input, true); + var request = createRequest(MODEL_ID, URL, API_KEY, input, true); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(request.getURI().toString(), is("url")); + assertThat(request.getURI().toString(), is(URL)); assertThat(requestMap.get("stream"), is(true)); - assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("model"), is(MODEL_ID)); assertThat(requestMap.get("n"), is(1)); assertNull(requestMap.get("stream_options")); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input)))); - assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", USER_ROLE, "content", input)))); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY))); } public void testTruncate_DoesNotReduceInputTextSize() { String input = randomAlphaOfLength(5); - var request = createRequest("model", "url", "secret", input, true); - assertThat(request.truncate(), is(request)); + var request = createRequest(MODEL_ID, URL, API_KEY, input, true); + assertThat(request.truncate(), is(sameInstance(request))); } public void testTruncationInfo_ReturnsNull() { - var request = createRequest("model", "url", "secret", randomAlphaOfLength(5), true); + var request = createRequest(MODEL_ID, URL, API_KEY, randomAlphaOfLength(5), true); assertNull(request.getTruncationInfo()); } public static OpenShiftAiChatCompletionRequest createRequest(String modelId, String url, String apiKey, String input, boolean stream) { var chatCompletionModel = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(url, apiKey, modelId); - return new OpenShiftAiChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); + return new OpenShiftAiChatCompletionRequest(new UnifiedChatInput(List.of(input), USER_ROLE, stream), chatCompletionModel); } } From 356bd84805774616b09107b7325d9bb395987739 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 7 Nov 2025 20:36:09 +0200 Subject: [PATCH 46/70] Refactor OpenShift AI tests to use getFirst() for request retrieval --- .../openshiftai/OpenShiftAiServiceTests.java | 20 +++---- .../action/OpenShiftAiActionCreatorTests.java | 60 +++++++++---------- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index b8535a1f209a1..44d49d9123a7a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -686,14 +686,14 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio assertThat(results, hasSize(2)); { - assertThat(results.get(0), Matchers.instanceOf(ChunkedInferenceEmbedding.class)); - var floatResult = (ChunkedInferenceEmbedding) results.get(0); + assertThat(results.getFirst(), Matchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.getFirst(); assertThat(floatResult.chunks(), hasSize(1)); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().getFirst().embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertThat( Arrays.equals( new float[] { 0.0089111328125f, -0.007049560546875f }, - ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().getFirst().embedding()).values() ), is(true) ); @@ -702,25 +702,25 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio assertThat(results.get(1), Matchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().getFirst().embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertThat( Arrays.equals( new float[] { -0.008544921875f, -0.0230712890625f }, - ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().getFirst().embedding()).values() ), is(true) ); } assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(2)); assertThat(requestMap.get("input"), is(List.of("abc", "def"))); assertThat(requestMap.get("model"), is(MODEL_ID)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index 4379ad99fd216..97960cfb49d6a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -130,14 +130,14 @@ public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(2)); assertThat(requestMap.get("input"), is(List.of(INPUT))); assertThat(requestMap.get("model"), is(MODEL_ID)); @@ -206,14 +206,14 @@ public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat assertThat(thrownException.getCause().getMessage(), is(failureCauseMessage)); assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(2)); assertThat(requestMap.get("input"), is(List.of(INPUT))); assertThat(requestMap.get("model"), is(MODEL_ID)); @@ -271,13 +271,13 @@ public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); assertThat(webServer.requests(), hasSize(1)); - var request = webServer.requests().get(0); + var request = webServer.requests().getFirst(); assertNull(request.getUri().getQuery()); assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters())); assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(4)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", USER_ROLE, "content", INPUT)))); assertThat(requestMap.get("model"), is(MODEL_ID)); @@ -354,14 +354,14 @@ public void testCreate_OpenShiftAiChatCompletionModel_FailsFromInvalidResponseFo assertThat(thrownException.getCause().getMessage(), is(failureCauseMessage)); assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(4)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", USER_ROLE, "content", INPUT)))); assertThat(requestMap.get("model"), is(MODEL_ID)); @@ -434,14 +434,14 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); assertThat(webServer.requests(), hasSize(2)); { - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(2)); assertThat(requestMap.get("input"), is(List.of("abcd"))); assertThat(requestMap.get("model"), is(MODEL_ID)); @@ -526,14 +526,14 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); assertThat(webServer.requests(), hasSize(2)); { - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(2)); assertThat(requestMap.get("input"), is(List.of("abcd"))); assertThat(requestMap.get("model"), is(MODEL_ID)); @@ -602,14 +602,14 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(2)); assertThat(requestMap.get("input"), is(List.of("sup"))); assertThat(requestMap.get("model"), is(MODEL_ID)); @@ -743,14 +743,14 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t private void assertRerankActionCreator(List documents) throws IOException { assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("documents"), is(documents)); assertThat(requestMap.get("model"), is(MODEL_ID)); From 401ce6b54460d194e29a05b29b53ec399028db4f Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 7 Nov 2025 20:53:53 +0200 Subject: [PATCH 47/70] Remove redundant request body assertions in OpenShift AI action creator tests --- .../openshiftai/action/OpenShiftAiActionCreatorTests.java | 5 ----- 1 file changed, 5 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index 97960cfb49d6a..a1f83afbaabba 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -212,11 +212,6 @@ public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); - - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat(requestMap.size(), is(2)); - assertThat(requestMap.get("input"), is(List.of(INPUT))); - assertThat(requestMap.get("model"), is(MODEL_ID)); } } From 40d07ae205564426b4ff1f055e30123bcc773fdf Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 7 Nov 2025 21:03:17 +0200 Subject: [PATCH 48/70] Remove redundant assertions in OpenShift AI action creator tests --- .../action/OpenShiftAiActionCreatorTests.java | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index a1f83afbaabba..2ecf0d3252c61 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -204,14 +204,6 @@ public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat ) ); assertThat(thrownException.getCause().getMessage(), is(failureCauseMessage)); - - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); - assertThat( - webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaTypeWithoutParameters()) - ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); } } @@ -347,21 +339,6 @@ public void testCreate_OpenShiftAiChatCompletionModel_FailsFromInvalidResponseFo ) ); assertThat(thrownException.getCause().getMessage(), is(failureCauseMessage)); - - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); - assertThat( - webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaTypeWithoutParameters()) - ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); - - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat(requestMap.size(), is(4)); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", USER_ROLE, "content", INPUT)))); - assertThat(requestMap.get("model"), is(MODEL_ID)); - assertThat(requestMap.get("n"), is(1)); - assertThat(requestMap.get("stream"), is(false)); } } From a10dc51de54a031929dbf61e245c8008165725f8 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 7 Nov 2025 21:18:15 +0200 Subject: [PATCH 49/70] Rename input variables in OpenShift AI action creator tests for clarity --- .../action/OpenShiftAiActionCreatorTests.java | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index 2ecf0d3252c61..8d8d8b8e1d6f5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -63,8 +63,9 @@ public class OpenShiftAiActionCreatorTests extends ESTestCase { private static final String MODEL_ID = "model"; private static final String API_KEY = "secret"; private static final String QUERY = "popular name"; - private static final String INPUT = "abc"; - public static final String USER_ROLE = "user"; + private static final String USER_ROLE = "user"; + private static final String FULL_INPUT = "abcd"; + private static final String HALF_OF_INPUT = "ab"; private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; @@ -121,7 +122,7 @@ public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(INPUT), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(FULL_INPUT), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -139,7 +140,7 @@ public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(2)); - assertThat(requestMap.get("input"), is(List.of(INPUT))); + assertThat(requestMap.get("input"), is(List.of(FULL_INPUT))); assertThat(requestMap.get("model"), is(MODEL_ID)); } } @@ -184,7 +185,7 @@ public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(INPUT), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(FULL_INPUT), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -251,7 +252,7 @@ public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of(INPUT)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(FULL_INPUT)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); @@ -266,7 +267,7 @@ public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(4)); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", USER_ROLE, "content", INPUT)))); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", USER_ROLE, "content", FULL_INPUT)))); assertThat(requestMap.get("model"), is(MODEL_ID)); assertThat(requestMap.get("n"), is(1)); assertThat(requestMap.get("stream"), is(false)); @@ -322,7 +323,7 @@ public void testCreate_OpenShiftAiChatCompletionModel_FailsFromInvalidResponseFo var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of(INPUT)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(FULL_INPUT)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var failureCauseMessage = "Required [choices]"; var thrownException = expectThrows( @@ -396,7 +397,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(FULL_INPUT), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -415,7 +416,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(2)); - assertThat(requestMap.get("input"), is(List.of("abcd"))); + assertThat(requestMap.get("input"), is(List.of(FULL_INPUT))); assertThat(requestMap.get("model"), is(MODEL_ID)); } { @@ -428,7 +429,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC var requestMap = entityAsMap(webServer.requests().get(1).getBody()); assertThat(requestMap.size(), is(2)); - assertThat(requestMap.get("input"), is(List.of("ab"))); + assertThat(requestMap.get("input"), is(List.of(HALF_OF_INPUT))); assertThat(requestMap.get("model"), is(MODEL_ID)); } } @@ -488,7 +489,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(FULL_INPUT), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -507,7 +508,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(2)); - assertThat(requestMap.get("input"), is(List.of("abcd"))); + assertThat(requestMap.get("input"), is(List.of(FULL_INPUT))); assertThat(requestMap.get("model"), is(MODEL_ID)); } { @@ -520,7 +521,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC var requestMap = entityAsMap(webServer.requests().get(1).getBody()); assertThat(requestMap.size(), is(2)); - assertThat(requestMap.get("input"), is(List.of("ab"))); + assertThat(requestMap.get("input"), is(List.of(HALF_OF_INPUT))); assertThat(requestMap.get("model"), is(MODEL_ID)); } } From 662ccc614752f960d2f8a81f806c8646e8947f26 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 7 Nov 2025 21:33:01 +0200 Subject: [PATCH 50/70] Refactor error message assertions in OpenShift AI tests for improved readability --- .../openshiftai/OpenShiftAiServiceTests.java | 2 +- .../action/OpenShiftAiActionCreatorTests.java | 14 +++++++------- .../OpenShiftAiEmbeddingsServiceSettingsTests.java | 10 +++------- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index 44d49d9123a7a..feb994d83189a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -165,7 +165,7 @@ private static void assertModel(Model model, TaskType taskType, boolean modelInc case TEXT_EMBEDDING -> assertTextEmbeddingModel(model, modelIncludesSecrets); case COMPLETION -> assertCompletionModel(model, modelIncludesSecrets); case CHAT_COMPLETION -> assertChatCompletionModel(model, modelIncludesSecrets); - default -> fail("unexpected task type [" + taskType + "]"); + default -> fail("unexpected task type [%s]".formatted(taskType)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index 8d8d8b8e1d6f5..d0c54795b4817 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -349,9 +349,9 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC try (var sender = createSender(senderFactory)) { sender.startSynchronously(); - var contentTooLargeErrorMessage = - "This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;" - + "0 for the completion). Please reduce your prompt; or completion length."; + var contentTooLargeErrorMessage = """ + This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;\ + 0 for the completion). Please reduce your prompt; or completion length."""; String responseJsonContentTooLarge = Strings.format(""" { @@ -435,15 +435,15 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC } } - public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException { + public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCodeWithContentTooLargeMessage() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { sender.startSynchronously(); - var contentTooLargeErrorMessage = - "This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;" - + "0 for the completion). Please reduce your prompt; or completion length."; + var contentTooLargeErrorMessage = """ + This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;\ + 0 for the completion). Please reduce your prompt; or completion length."""; String responseJsonContentTooLarge = Strings.format(""" { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java index 59ae564773ccd..98b34a922b445 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java @@ -214,13 +214,9 @@ public void testFromMap_InvalidSimilarity_ThrowsException() { ConfigurationParseContext.PERSISTENT ) ); - assertThat( - thrownException.getMessage(), - containsString( - "Validation Failed: 1: [service_settings] Invalid value [by_size] received. " - + "[similarity] must be one of [cosine, dot_product, l2_norm];" - ) - ); + assertThat(thrownException.getMessage(), containsString(""" + Validation Failed: 1: [service_settings] Invalid value [by_size] received. \ + [similarity] must be one of [cosine, dot_product, l2_norm];""")); } public void testFromMap_NoDimensions_SetByUserFalse_Persistent_Success() { From 00d803f7b5fdb6487063d05a3ed0207b2e759e9f Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 7 Nov 2025 21:56:34 +0200 Subject: [PATCH 51/70] Refactor OpenShift AI action creator tests to use NO_RETRY_SETTINGS for consistency --- .../action/OpenShiftAiActionCreatorTests.java | 39 +++++++------------ 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index d0c54795b4817..01334b9f73013 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -66,6 +66,11 @@ public class OpenShiftAiActionCreatorTests extends ESTestCase { private static final String USER_ROLE = "user"; private static final String FULL_INPUT = "abcd"; private static final String HALF_OF_INPUT = "ab"; + private static final Settings NO_RETRY_SETTINGS = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; @@ -147,12 +152,7 @@ public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat() throws IOException { // timeout as zero for no retries - var settings = buildSettingsWithRetryFields( - TimeValue.timeValueMillis(1), - TimeValue.timeValueMinutes(1), - TimeValue.timeValueSeconds(0) - ); - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); try (var sender = createSender(senderFactory)) { sender.startSynchronously(); @@ -276,12 +276,7 @@ public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { public void testCreate_OpenShiftAiChatCompletionModel_FailsFromInvalidResponseFormat() throws IOException { // timeout as zero for no retries - var settings = buildSettingsWithRetryFields( - TimeValue.timeValueMillis(1), - TimeValue.timeValueMinutes(1), - TimeValue.timeValueSeconds(0) - ); - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); try (var sender = createSender(senderFactory)) { sender.startSynchronously(); @@ -590,12 +585,8 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { } public void testCreate_OpenShiftAiRerankModel() throws IOException { - var settings = buildSettingsWithRetryFields( - TimeValue.timeValueMillis(1), - TimeValue.timeValueMinutes(1), - TimeValue.timeValueSeconds(0) - ); - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + // timeout as zero for no retries + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); List documents = List.of("Luke"); try (var sender = createSender(senderFactory)) { @@ -631,7 +622,7 @@ public void testCreate_OpenShiftAiRerankModel() throws IOException { var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID); var actionCreator = new OpenShiftAiActionCreator( sender, - new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator()) + new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) ); var action = actionCreator.create(model, null); @@ -659,12 +650,8 @@ public void testCreate_OpenShiftAiRerankModel() throws IOException { } public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() throws IOException { - var settings = buildSettingsWithRetryFields( - TimeValue.timeValueMillis(1), - TimeValue.timeValueMinutes(1), - TimeValue.timeValueSeconds(0) - ); - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + // timeout as zero for no retries + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); List documents = List.of("Luke"); try (var sender = createSender(senderFactory)) { @@ -700,7 +687,7 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID); var actionCreator = new OpenShiftAiActionCreator( sender, - new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator()) + new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) ); var action = actionCreator.create(model, null); From 9b4d560b3c968744b2f146a266b8e16853d35f6d Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 7 Nov 2025 22:22:14 +0200 Subject: [PATCH 52/70] Refactor OpenShift AI action creator tests to improve variable naming for clarity --- .../action/OpenShiftAiActionCreatorTests.java | 32 +++++++++---------- ...tAiChatCompletionServiceSettingsTests.java | 6 ++-- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index 01334b9f73013..be078df35cfb4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -64,13 +64,13 @@ public class OpenShiftAiActionCreatorTests extends ESTestCase { private static final String API_KEY = "secret"; private static final String QUERY = "popular name"; private static final String USER_ROLE = "user"; - private static final String FULL_INPUT = "abcd"; - private static final String HALF_OF_INPUT = "ab"; + private static final String INPUT = "abcd"; private static final Settings NO_RETRY_SETTINGS = buildSettingsWithRetryFields( TimeValue.timeValueMillis(1), TimeValue.timeValueMinutes(1), TimeValue.timeValueSeconds(0) ); + private static final String INPUT_TO_TRUNCATE = "super long input"; private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; @@ -127,7 +127,7 @@ public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(FULL_INPUT), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(INPUT), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -145,7 +145,7 @@ public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(2)); - assertThat(requestMap.get("input"), is(List.of(FULL_INPUT))); + assertThat(requestMap.get("input"), is(List.of(INPUT))); assertThat(requestMap.get("model"), is(MODEL_ID)); } } @@ -185,7 +185,7 @@ public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(FULL_INPUT), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(INPUT), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -252,7 +252,7 @@ public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of(FULL_INPUT)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(INPUT)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); @@ -267,7 +267,7 @@ public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(4)); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", USER_ROLE, "content", FULL_INPUT)))); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", USER_ROLE, "content", INPUT)))); assertThat(requestMap.get("model"), is(MODEL_ID)); assertThat(requestMap.get("n"), is(1)); assertThat(requestMap.get("stream"), is(false)); @@ -318,7 +318,7 @@ public void testCreate_OpenShiftAiChatCompletionModel_FailsFromInvalidResponseFo var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of(FULL_INPUT)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(INPUT)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var failureCauseMessage = "Required [choices]"; var thrownException = expectThrows( @@ -392,7 +392,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(FULL_INPUT), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(INPUT), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -411,7 +411,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(2)); - assertThat(requestMap.get("input"), is(List.of(FULL_INPUT))); + assertThat(requestMap.get("input"), is(List.of(INPUT))); assertThat(requestMap.get("model"), is(MODEL_ID)); } { @@ -424,7 +424,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC var requestMap = entityAsMap(webServer.requests().get(1).getBody()); assertThat(requestMap.size(), is(2)); - assertThat(requestMap.get("input"), is(List.of(HALF_OF_INPUT))); + assertThat(requestMap.get("input"), is(List.of(INPUT.substring(0, 2)))); assertThat(requestMap.get("model"), is(MODEL_ID)); } } @@ -484,7 +484,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(FULL_INPUT), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(INPUT), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -503,7 +503,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(2)); - assertThat(requestMap.get("input"), is(List.of(FULL_INPUT))); + assertThat(requestMap.get("input"), is(List.of(INPUT))); assertThat(requestMap.get("model"), is(MODEL_ID)); } { @@ -516,7 +516,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC var requestMap = entityAsMap(webServer.requests().get(1).getBody()); assertThat(requestMap.size(), is(2)); - assertThat(requestMap.get("input"), is(List.of(HALF_OF_INPUT))); + assertThat(requestMap.get("input"), is(List.of(INPUT.substring(0, 2)))); assertThat(requestMap.get("model"), is(MODEL_ID)); } } @@ -561,7 +561,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("super long input"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(INPUT_TO_TRUNCATE), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -579,7 +579,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(2)); - assertThat(requestMap.get("input"), is(List.of("sup"))); + assertThat(requestMap.get("input"), is(List.of(INPUT_TO_TRUNCATE.substring(0, 3)))); assertThat(requestMap.get("model"), is(MODEL_ID)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java index c3ebf6cfd67b5..efd78b1bbe047 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java @@ -33,9 +33,9 @@ public class OpenShiftAiChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase< OpenShiftAiChatCompletionServiceSettings> { - public static final String MODEL_ID = "some model"; - public static final String CORRECT_URL = "https://www.elastic.co"; - public static final int RATE_LIMIT = 2; + private static final String MODEL_ID = "some model"; + private static final String CORRECT_URL = "https://www.elastic.co"; + private static final int RATE_LIMIT = 2; public void testFromMap_AllFields_Success() { var serviceSettings = OpenShiftAiChatCompletionServiceSettings.fromMap( From dc3a27fbcc8573d5a9f5cf1ede0bba74b33a88bb Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Sat, 8 Nov 2025 00:17:24 +0200 Subject: [PATCH 53/70] Enhance OpenShift AI action creator tests with additional task settings scenarios and improved assertions --- .../action/OpenShiftAiActionCreatorTests.java | 257 +++++++++++++++++- 1 file changed, 250 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index be078df35cfb4..a86dd42d60034 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; @@ -32,10 +33,12 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankModelTests; +import org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankTaskSettings; import org.junit.After; import org.junit.Before; import java.io.IOException; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -584,7 +587,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { } } - public void testCreate_OpenShiftAiRerankModel() throws IOException { + public void testCreate_OpenShiftAiRerankModel_WithTaskSettings() throws IOException { // timeout as zero for no retries var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); @@ -646,7 +649,236 @@ public void testCreate_OpenShiftAiRerankModel() throws IOException { ) ); } - assertRerankActionCreator(documents); + assertRerankActionCreator(documents, 2, true); + } + + public void testCreate_OpenShiftAiRerankModel_WithOverriddenTaskSettings() throws IOException { + // timeout as zero for no retries + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); + + List documents = List.of("Luke"); + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "rerank-d300256dd02b4c63b8a2bc34dcdad845", + "model": "bge-reranker-v2-m3", + "usage": { + "total_tokens": 10 + }, + "results": [ + { + "index": 0, + "relevance_score": 0.4921875 + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID); + var actionCreator = new OpenShiftAiActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create( + model, + new HashMap<>(Map.of(OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, false, OpenShiftAiRerankTaskSettings.TOP_N, 1)) + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new QueryAndDocsInputs(QUERY, documents, null, null, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + assertThat( + result.asMap(), + is( + buildExpectationRerank( + List.of(new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", 0.4921875f))) + ) + ) + ); + } + assertRerankActionCreator(documents, 1, false); + } + + public void testCreate_OpenShiftAiRerankModel_NoTaskSettings() throws IOException { + // timeout as zero for no retries + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); + + List documents = List.of("Luke"); + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "rerank-d300256dd02b4c63b8a2bc34dcdad845", + "model": "bge-reranker-v2-m3", + "usage": { + "total_tokens": 30 + }, + "results": [ + { + "index": 1, + "document": { + "text": "awgawgawgawg" + }, + "relevance_score": 0.9921875 + }, + { + "index": 0, + "document": { + "text": "awdawdawda" + }, + "relevance_score": 0.4921875 + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID, null, null); + var actionCreator = new OpenShiftAiActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create(model, null); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new QueryAndDocsInputs(QUERY, documents, null, null, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + assertThat( + result.asMap(), + is( + buildExpectationRerank( + List.of( + new RankedDocsResultsTests.RerankExpectation( + Map.of("text", "awgawgawgawg", "index", 1, "relevance_score", 0.9921875f) + ), + new RankedDocsResultsTests.RerankExpectation( + Map.of("text", "awdawdawda", "index", 0, "relevance_score", 0.4921875f) + ) + ) + ) + ) + ); + } + assertRerankActionCreator(documents, null, null); + } + + public void testCreate_OpenShiftAiRerankModel_NoTaskSettings_WithRequestParameters() throws IOException { + // timeout as zero for no retries + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); + + List documents = List.of("Luke"); + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "rerank-d300256dd02b4c63b8a2bc34dcdad845", + "model": "bge-reranker-v2-m3", + "usage": { + "total_tokens": 30 + }, + "results": [ + { + "index": 1, + "document": { + "text": "awgawgawgawg" + }, + "relevance_score": 0.9921875 + }, + { + "index": 0, + "document": { + "text": "awdawdawda" + }, + "relevance_score": 0.4921875 + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID, null, null); + var actionCreator = new OpenShiftAiActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create(model, null); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new QueryAndDocsInputs(QUERY, documents, true, 2, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + assertThat( + result.asMap(), + is( + buildExpectationRerank( + List.of( + new RankedDocsResultsTests.RerankExpectation( + Map.of("text", "awgawgawgawg", "index", 1, "relevance_score", 0.9921875f) + ), + new RankedDocsResultsTests.RerankExpectation( + Map.of("text", "awdawdawda", "index", 0, "relevance_score", 0.4921875f) + ) + ) + ) + ) + ); + } + assertRerankActionCreator(documents, 2, true); + } + + public void testCreate_OpenShiftAiRerankModel_WithTaskSettings_WithRequestParametersPrioritized() throws IOException { + // timeout as zero for no retries + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); + + List documents = List.of("Luke"); + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "rerank-d300256dd02b4c63b8a2bc34dcdad845", + "model": "bge-reranker-v2-m3", + "usage": { + "total_tokens": 10 + }, + "results": [ + { + "index": 0, + "relevance_score": 0.4921875 + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID); + var actionCreator = new OpenShiftAiActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create(model, null); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new QueryAndDocsInputs(QUERY, documents, false, 1, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + assertThat( + result.asMap(), + is( + buildExpectationRerank( + List.of(new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", 0.4921875f))) + ) + ) + ); + } + assertRerankActionCreator(documents, 1, false); } public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() throws IOException { @@ -698,10 +930,14 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t assertThat(thrownException.getMessage(), is(""" Failed to send OpenShift AI rerank request from inference entity id [inferenceEntityId]. Cause: Required [results]""")); } - assertRerankActionCreator(documents); + assertRerankActionCreator(documents, 2, true); } - private void assertRerankActionCreator(List documents) throws IOException { + private void assertRerankActionCreator( + List documents, + @Nullable Integer expectedTopN, + @Nullable Boolean expectedReturnDocuments + ) throws IOException { assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().getFirst().getUri().getQuery()); assertThat( @@ -711,11 +947,18 @@ private void assertRerankActionCreator(List documents) throws IOExceptio assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat(requestMap.size(), is(5)); + int fieldCount = 3; assertThat(requestMap.get("documents"), is(documents)); assertThat(requestMap.get("model"), is(MODEL_ID)); assertThat(requestMap.get("query"), is(QUERY)); - assertThat(requestMap.get("top_n"), is(2)); - assertThat(requestMap.get("return_documents"), is(true)); + if (expectedTopN != null) { + assertThat(requestMap.get("top_n"), is(expectedTopN)); + fieldCount++; + } + if (expectedReturnDocuments != null) { + assertThat(requestMap.get("return_documents"), is(expectedReturnDocuments)); + fieldCount++; + } + assertThat(requestMap.size(), is(fieldCount)); } } From c752a84fb364ebe0e99b0223e6d2e5b3477c4dbf Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Sat, 8 Nov 2025 00:25:32 +0200 Subject: [PATCH 54/70] Refactor error message assertion in OpenShift AI tests for improved formatting --- .../openshiftai/action/OpenShiftAiActionCreatorTests.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index a86dd42d60034..09963c35ba685 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -927,8 +927,10 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t action.execute(new QueryAndDocsInputs(QUERY, documents, null, null, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT)); - assertThat(thrownException.getMessage(), is(""" - Failed to send OpenShift AI rerank request from inference entity id [inferenceEntityId]. Cause: Required [results]""")); + assertThat( + thrownException.getMessage(), + is("Failed to send OpenShift AI rerank request from inference entity id [inferenceEntityId]. Cause: Required [results]") + ); } assertRerankActionCreator(documents, 2, true); } From b1c243be6a3cff205fe78b73601d11938a55cc60 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Sat, 8 Nov 2025 00:31:53 +0200 Subject: [PATCH 55/70] Remove redundant assertion in OpenShift AI action creator tests for clarity --- .../openshiftai/action/OpenShiftAiActionCreatorTests.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index 09963c35ba685..c54a6456a5a51 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -932,7 +932,6 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t is("Failed to send OpenShift AI rerank request from inference entity id [inferenceEntityId]. Cause: Required [results]") ); } - assertRerankActionCreator(documents, 2, true); } private void assertRerankActionCreator( From 9de3fd68b2b84790a3dd8f8bd531b05e3c780796 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Sat, 8 Nov 2025 00:45:01 +0200 Subject: [PATCH 56/70] Ensure non-null values for user-defined dimensions and URI in OpenShift AI service settings --- .../services/openshiftai/OpenShiftAiServiceSettings.java | 2 +- .../embeddings/OpenShiftAiEmbeddingsServiceSettings.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java index 1c668bb212904..03184031eff27 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceSettings.java @@ -64,7 +64,7 @@ protected OpenShiftAiServiceSettings(StreamInput in) throws IOException { */ protected OpenShiftAiServiceSettings(@Nullable String modelId, URI uri, @Nullable RateLimitSettings rateLimitSettings) { this.modelId = modelId; - this.uri = uri; + this.uri = Objects.requireNonNull(uri); this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java index 384d11e32ebfb..1000219633a64 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java @@ -137,7 +137,7 @@ public OpenShiftAiEmbeddingsServiceSettings( this.dimensions = dimensions; this.similarity = similarity; this.maxInputTokens = maxInputTokens; - this.dimensionsSetByUser = dimensionsSetByUser; + this.dimensionsSetByUser = Objects.requireNonNull(dimensionsSetByUser); } /** From 5f0cc265c37614afdc9693d3ccf5495c38f65fb4 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Sat, 8 Nov 2025 00:48:31 +0200 Subject: [PATCH 57/70] Remove redundant test for OpenShift AI embeddings service settings serialization --- ...ShiftAiEmbeddingsServiceSettingsTests.java | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java index 98b34a922b445..26299fbbcbd39 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java @@ -9,8 +9,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; -import org.elasticsearch.common.io.stream.ByteArrayStreamInput; -import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; @@ -555,27 +553,6 @@ public void testToXContent_WritesAllValues() throws IOException { """.formatted(MODEL_ID, CORRECT_URL)))); } - public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException { - var outputBuffer = new BytesStreamOutput(); - var settings = new OpenShiftAiEmbeddingsServiceSettings( - MODEL_ID, - CORRECT_URL, - DIMENSIONS, - SIMILARITY_MEASURE, - MAX_INPUT_TOKENS, - new RateLimitSettings(3), - false - ); - settings.writeTo(outputBuffer); - - var outputBufferRef = outputBuffer.bytes(); - var inputBuffer = new ByteArrayStreamInput(outputBufferRef.array()); - - var settingsFromBuffer = new OpenShiftAiEmbeddingsServiceSettings(inputBuffer); - - assertEquals(settings, settingsFromBuffer); - } - @Override protected Writeable.Reader instanceReader() { return OpenShiftAiEmbeddingsServiceSettings::new; From 01eb0acd8d7d78cf0f5d9de2272e7b229ddc0eb0 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Sat, 8 Nov 2025 17:51:24 +0200 Subject: [PATCH 58/70] Refactor assertions in OpenShift AI tests for consistency and clarity --- .../openshiftai/OpenShiftAiServiceTests.java | 7 +- .../action/OpenShiftAiActionCreatorTests.java | 17 ++--- ...iftAiChatCompletionRequestEntityTests.java | 4 +- ...OpenShiftAiChatCompletionRequestTests.java | 5 +- .../OpenShiftAiEmbeddingsRequestTests.java | 7 +- .../rerank/OpenShiftAiRerankModelTests.java | 7 +- .../OpenShiftAiRerankTaskSettingsTests.java | 64 ++++++++----------- 7 files changed, 52 insertions(+), 59 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index feb994d83189a..7cb9151655a28 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -98,6 +98,7 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.nullValue; import static org.mockito.Mockito.mock; public class OpenShiftAiServiceTests extends AbstractInferenceServiceTests { @@ -308,7 +309,7 @@ public void testParseRequestConfig_WithoutModelId_Success() throws IOException { var chatCompletionModel = (OpenShiftAiChatCompletionModel) m; assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is(URL)); - assertNull(chatCompletionModel.getServiceSettings().modelId()); + assertThat(chatCompletionModel.getServiceSettings().modelId(), is(nullValue())); assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is(API_KEY)); }, e -> fail("parse request should not fail " + e.getMessage())); @@ -586,7 +587,7 @@ public void testInfer_StreamRequestRetry() throws Exception { public void testSupportsStreaming() throws IOException { try (var service = new OpenShiftAiService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); - assertFalse(service.canStream(TaskType.ANY)); + assertThat(service.canStream(TaskType.ANY), is(false)); } } @@ -713,7 +714,7 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio } assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertThat(webServer.requests().getFirst().getUri().getQuery(), is(nullValue())); assertThat( webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index c54a6456a5a51..2da1d76d05439 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -59,6 +59,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; import static org.mockito.Mockito.mock; public class OpenShiftAiActionCreatorTests extends ESTestCase { @@ -139,7 +140,7 @@ public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertThat(webServer.requests().getFirst().getUri().getQuery(), is(nullValue())); assertThat( webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) @@ -264,7 +265,7 @@ public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { var request = webServer.requests().getFirst(); - assertNull(request.getUri().getQuery()); + assertThat(request.getUri().getQuery(), is(nullValue())); assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters())); assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); @@ -405,7 +406,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); assertThat(webServer.requests(), hasSize(2)); { - assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertThat(webServer.requests().getFirst().getUri().getQuery(), is(nullValue())); assertThat( webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) @@ -418,7 +419,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC assertThat(requestMap.get("model"), is(MODEL_ID)); } { - assertNull(webServer.requests().get(1).getUri().getQuery()); + assertThat(webServer.requests().get(1).getUri().getQuery(), is(nullValue())); assertThat( webServer.requests().get(1).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) @@ -497,7 +498,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); assertThat(webServer.requests(), hasSize(2)); { - assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertThat(webServer.requests().getFirst().getUri().getQuery(), is(nullValue())); assertThat( webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) @@ -510,7 +511,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC assertThat(requestMap.get("model"), is(MODEL_ID)); } { - assertNull(webServer.requests().get(1).getUri().getQuery()); + assertThat(webServer.requests().get(1).getUri().getQuery(), is(nullValue())); assertThat( webServer.requests().get(1).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) @@ -573,7 +574,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertThat(webServer.requests().getFirst().getUri().getQuery(), is(nullValue())); assertThat( webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) @@ -940,7 +941,7 @@ private void assertRerankActionCreator( @Nullable Boolean expectedReturnDocuments ) throws IOException { assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertThat(webServer.requests().getFirst().getUri().getQuery(), is(nullValue())); assertThat( webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java index 084d6e986c527..50c1135644bc3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java @@ -19,6 +19,8 @@ import java.io.IOException; import java.util.ArrayList; +import static org.hamcrest.Matchers.is; + public class OpenShiftAiChatCompletionRequestEntityTests extends ESTestCase { private static final String ROLE = "user"; @@ -93,7 +95,7 @@ private static void testSerialization(String modelId, boolean isStreaming, Strin XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder)); + assertThat(Strings.toString(builder), is(XContentHelper.stripWhitespace(expectedJson))); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java index d5f580720cc00..3ec6bfd3b1974 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java @@ -20,6 +20,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.sameInstance; public class OpenShiftAiChatCompletionRequestTests extends ESTestCase { @@ -42,7 +43,7 @@ public void testCreateRequest_WithStreaming() throws IOException { assertThat(requestMap.get("stream"), is(true)); assertThat(requestMap.get("model"), is(MODEL_ID)); assertThat(requestMap.get("n"), is(1)); - assertNull(requestMap.get("stream_options")); + assertThat(requestMap.get("stream_options"), is(nullValue())); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", USER_ROLE, "content", input)))); assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY))); } @@ -55,7 +56,7 @@ public void testTruncate_DoesNotReduceInputTextSize() { public void testTruncationInfo_ReturnsNull() { var request = createRequest(MODEL_ID, URL, API_KEY, randomAlphaOfLength(5), true); - assertNull(request.getTruncationInfo()); + assertThat(request.getTruncationInfo(), is(nullValue())); } public static OpenShiftAiChatCompletionRequest createRequest(String modelId, String url, String apiKey, String input, boolean stream) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java index 98113774654b8..f8acf0b2a4afe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java @@ -23,6 +23,7 @@ import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; public class OpenShiftAiEmbeddingsRequestTests extends ESTestCase { @@ -61,8 +62,8 @@ public void testCreateRequest_NoModel_Success() throws IOException { var requestMap = entityAsMap(httpPost.getEntity().getContent()); assertThat(requestMap.get("input"), is(List.of("ABCD"))); - assertNull(requestMap.get("model")); - assertNull(requestMap.get("dimensions")); + assertThat(requestMap.get("model"), is(nullValue())); + assertThat(requestMap.get("dimensions"), is(nullValue())); assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer apikey")); } @@ -84,7 +85,7 @@ public void testTruncate_ReducesInputTextSizeByHalf() throws IOException { public void testIsTruncated_ReturnsTrue() { var request = createRequest(null, false); - assertFalse(request.getTruncationInfo()[0]); + assertThat(request.getTruncationInfo()[0], is(false)); var truncatedRequest = request.truncate(); assertThat(truncatedRequest.getTruncationInfo()[0], is(true)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java index 7e46a5124ad2f..c96361d336557 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java @@ -19,6 +19,7 @@ import static org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS; import static org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankTaskSettings.TOP_N; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; public class OpenShiftAiRerankModelTests extends ESTestCase { @@ -54,9 +55,7 @@ public void testOverrideWith_EmptyParams_KeepsSameModel() { private static void testOverrideWith_KeepsSameModel(Map taskSettings) { var model = createModel("url", "api_key", "model_name", 2, true); var overriddenModel = OpenShiftAiRerankModel.of(model, taskSettings); - - assertThat(overriddenModel.getTaskSettings().getTopN(), is(2)); - assertThat(overriddenModel.getTaskSettings().getReturnDocuments(), is(true)); + assertThat(overriddenModel, is(sameInstance(model))); } public void testOverrideWith_DifferentParams_OverridesAllTaskSettings() { @@ -91,7 +90,7 @@ private static void testOverrideWith_DifferentParams( assertThat(overriddenModel.getTaskSettings().getReturnDocuments(), is(expectedReturnDocuments)); } - private static Map buildTaskSettingsMap(@Nullable Integer topN, @Nullable Boolean returnDocuments) { + public static Map buildTaskSettingsMap(@Nullable Integer topN, @Nullable Boolean returnDocuments) { final var map = new HashMap(); if (returnDocuments != null) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java index 11bf2660b5f28..f2f2bc0f66e88 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java @@ -22,8 +22,11 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; +import static org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankModelTests.buildTaskSettingsMap; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.sameInstance; public class OpenShiftAiRerankTaskSettingsTests extends AbstractBWCWireSerializationTestCase { public static OpenShiftAiRerankTaskSettings createRandom() { @@ -34,16 +37,15 @@ public static OpenShiftAiRerankTaskSettings createRandom() { } public void testFromMap_WithValidValues_ReturnsSettings() { - Map taskMap = Map.of(OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, true, OpenShiftAiRerankTaskSettings.TOP_N, 5); - var settings = OpenShiftAiRerankTaskSettings.fromMap(new HashMap<>(taskMap)); + var settings = OpenShiftAiRerankTaskSettings.fromMap(buildTaskSettingsMap(5, true)); assertThat(settings.getReturnDocuments(), is(true)); - assertThat(settings.getTopN().intValue(), is(5)); + assertThat(settings.getTopN(), is(5)); } public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() { var settings = OpenShiftAiRerankTaskSettings.fromMap(Map.of()); - assertNull(settings.getReturnDocuments()); - assertNull(settings.getTopN()); + assertThat(settings.getReturnDocuments(), is(nullValue())); + assertThat(settings.getTopN(), is(nullValue())); } public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() { @@ -71,43 +73,38 @@ public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() { public void testUpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { var initialSettings = new OpenShiftAiRerankTaskSettings(5, true); OpenShiftAiRerankTaskSettings updatedSettings = (OpenShiftAiRerankTaskSettings) initialSettings.updatedTaskSettings(Map.of()); - assertEquals(initialSettings, updatedSettings); + assertThat(initialSettings, is(sameInstance(updatedSettings))); } public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() { var initialSettings = new OpenShiftAiRerankTaskSettings(5, true); - Map newSettings = Map.of(OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, false); - OpenShiftAiRerankTaskSettings updatedSettings = (OpenShiftAiRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); - assertFalse(updatedSettings.getReturnDocuments()); - assertEquals(initialSettings.getTopN(), updatedSettings.getTopN()); + OpenShiftAiRerankTaskSettings updatedSettings = (OpenShiftAiRerankTaskSettings) initialSettings.updatedTaskSettings( + buildTaskSettingsMap(null, false) + ); + assertThat(updatedSettings.getReturnDocuments(), is(false)); + assertThat(initialSettings.getTopN(), is(updatedSettings.getTopN())); } public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() { var initialSettings = new OpenShiftAiRerankTaskSettings(5, true); - Map newSettings = Map.of(OpenShiftAiRerankTaskSettings.TOP_N, 7); - OpenShiftAiRerankTaskSettings updatedSettings = (OpenShiftAiRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); - assertEquals(7, updatedSettings.getTopN().intValue()); - assertEquals(initialSettings.getReturnDocuments(), updatedSettings.getReturnDocuments()); + OpenShiftAiRerankTaskSettings updatedSettings = (OpenShiftAiRerankTaskSettings) initialSettings.updatedTaskSettings( + buildTaskSettingsMap(7, null) + ); + assertThat(updatedSettings.getTopN(), is(7)); + assertThat(updatedSettings.getReturnDocuments(), is(initialSettings.getReturnDocuments())); } public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() { var initialSettings = new OpenShiftAiRerankTaskSettings(5, true); - Map newSettings = Map.of( - OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, - false, - OpenShiftAiRerankTaskSettings.TOP_N, - 7 + OpenShiftAiRerankTaskSettings updatedSettings = (OpenShiftAiRerankTaskSettings) initialSettings.updatedTaskSettings( + buildTaskSettingsMap(7, false) ); - OpenShiftAiRerankTaskSettings updatedSettings = (OpenShiftAiRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); - assertFalse(updatedSettings.getReturnDocuments()); - assertEquals(7, updatedSettings.getTopN().intValue()); + assertThat(updatedSettings.getReturnDocuments(), is(false)); + assertThat(updatedSettings.getTopN(), is(7)); } public void testToXContent_WritesAllValues() throws IOException { - Integer topN = 2; - Boolean doReturnDocuments = true; - - testToXContent(topN, doReturnDocuments, """ + testToXContent(2, true, """ { "top_n":2, "return_documents":true @@ -116,19 +113,13 @@ public void testToXContent_WritesAllValues() throws IOException { } public void testToXContent_EmptyValues() throws IOException { - Integer topN = null; - Boolean doReturnDocuments = null; - - testToXContent(topN, doReturnDocuments, """ + testToXContent(null, null, """ {} """); } public void testToXContent_OnlyTopN() throws IOException { - Integer topN = 2; - Boolean doReturnDocuments = null; - - testToXContent(topN, doReturnDocuments, """ + testToXContent(2, null, """ { "top_n":2 } @@ -136,10 +127,7 @@ public void testToXContent_OnlyTopN() throws IOException { } public void testToXContent_OnlyReturnDocuments() throws IOException { - Integer topN = null; - Boolean doReturnDocuments = true; - - testToXContent(topN, doReturnDocuments, """ + testToXContent(null, true, """ { "return_documents":true } From c75b710812ef8e82925fe4d9228e0b4217068d1b Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Sat, 8 Nov 2025 18:13:43 +0200 Subject: [PATCH 59/70] Enhance OpenShift AI service settings tests for clarity and completeness --- ...OpenShiftAiRerankServiceSettingsTests.java | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettingsTests.java index c45940c084322..4e7c3ea3605b1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankServiceSettingsTests.java @@ -35,25 +35,41 @@ private static OpenShiftAiRerankServiceSettings createRandom() { return new OpenShiftAiRerankServiceSettings(modelId, ServiceUtils.createUri(url), RateLimitSettingsTests.createRandom()); } - public void testToXContent_WritesAllValues() throws IOException { + public void testToXContent_WritesAllFields() throws IOException { var url = "http://www.abc.com"; var model = "model"; + var rateLimitSettings = new RateLimitSettings(100); - var serviceSettings = new OpenShiftAiRerankServiceSettings(model, url, null); + assertXContentEquals(new OpenShiftAiRerankServiceSettings(model, url, rateLimitSettings), """ + { + "model_id":"model", + "url":"http://www.abc.com", + "rate_limit": { + "requests_per_minute": 100 + } + } + """); + } - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - serviceSettings.toXContent(builder, null); - String xContentResult = Strings.toString(builder); + public void testToXContent_WritesDefaultRateLimitAndOmitsModelIdIfNotSet() throws IOException { + var url = "http://www.abc.com"; - assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + assertXContentEquals(new OpenShiftAiRerankServiceSettings(null, url, null), """ { - "model_id":"model", "url":"http://www.abc.com", "rate_limit": { "requests_per_minute": 3000 } } - """)); + """); + } + + private static void assertXContentEquals(OpenShiftAiRerankServiceSettings serviceSettings, String expectedString) throws IOException { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(expectedString)); } @Override From 27654f804d9e1fa590d7e063bfc797008ee1c097 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Sat, 8 Nov 2025 18:44:21 +0200 Subject: [PATCH 60/70] Refactor OpenShift AI request tests to improve variable naming and assertion clarity --- .../openshiftai/OpenShiftAiServiceTests.java | 3 +- .../action/OpenShiftAiActionCreatorTests.java | 17 +++++---- ...OpenShiftAiChatCompletionRequestTests.java | 2 + .../rarank/OpenShiftAiRerankRequestTests.java | 37 ++++++++++--------- 4 files changed, 33 insertions(+), 26 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index 7cb9151655a28..25b0ad9e9094d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -93,6 +93,7 @@ import static org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModelTests.createChatCompletionModel; import static org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionServiceSettingsTests.getServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -722,7 +723,7 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat(requestMap.size(), is(2)); + assertThat(requestMap, aMapWithSize(2)); assertThat(requestMap.get("input"), is(List.of("abc", "def"))); assertThat(requestMap.get("model"), is(MODEL_ID)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index 2da1d76d05439..5765d6505f033 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -56,6 +56,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModelTests.createCompletionModel; import static org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModelTests.createModel; +import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -148,7 +149,7 @@ public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat(requestMap.size(), is(2)); + assertThat(requestMap, aMapWithSize(2)); assertThat(requestMap.get("input"), is(List.of(INPUT))); assertThat(requestMap.get("model"), is(MODEL_ID)); } @@ -270,7 +271,7 @@ public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap, aMapWithSize(4)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", USER_ROLE, "content", INPUT)))); assertThat(requestMap.get("model"), is(MODEL_ID)); assertThat(requestMap.get("n"), is(1)); @@ -414,7 +415,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat(requestMap.size(), is(2)); + assertThat(requestMap, aMapWithSize(2)); assertThat(requestMap.get("input"), is(List.of(INPUT))); assertThat(requestMap.get("model"), is(MODEL_ID)); } @@ -427,7 +428,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC assertThat(webServer.requests().get(1).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().get(1).getBody()); - assertThat(requestMap.size(), is(2)); + assertThat(requestMap, aMapWithSize(2)); assertThat(requestMap.get("input"), is(List.of(INPUT.substring(0, 2)))); assertThat(requestMap.get("model"), is(MODEL_ID)); } @@ -506,7 +507,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat(requestMap.size(), is(2)); + assertThat(requestMap, aMapWithSize(2)); assertThat(requestMap.get("input"), is(List.of(INPUT))); assertThat(requestMap.get("model"), is(MODEL_ID)); } @@ -519,7 +520,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC assertThat(webServer.requests().get(1).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().get(1).getBody()); - assertThat(requestMap.size(), is(2)); + assertThat(requestMap, aMapWithSize(2)); assertThat(requestMap.get("input"), is(List.of(INPUT.substring(0, 2)))); assertThat(requestMap.get("model"), is(MODEL_ID)); } @@ -582,7 +583,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat(requestMap.size(), is(2)); + assertThat(requestMap, aMapWithSize(2)); assertThat(requestMap.get("input"), is(List.of(INPUT_TO_TRUNCATE.substring(0, 3)))); assertThat(requestMap.get("model"), is(MODEL_ID)); } @@ -961,6 +962,6 @@ private void assertRerankActionCreator( assertThat(requestMap.get("return_documents"), is(expectedReturnDocuments)); fieldCount++; } - assertThat(requestMap.size(), is(fieldCount)); + assertThat(requestMap, aMapWithSize(fieldCount)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java index 3ec6bfd3b1974..9bcebce08a5c7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java @@ -18,6 +18,7 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; @@ -45,6 +46,7 @@ public void testCreateRequest_WithStreaming() throws IOException { assertThat(requestMap.get("n"), is(1)); assertThat(requestMap.get("stream_options"), is(nullValue())); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", USER_ROLE, "content", input)))); + assertThat(requestMap, aMapWithSize(4)); assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY))); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java index 53246118ac1f8..b5838563cb769 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java @@ -27,9 +27,8 @@ public class OpenShiftAiRerankRequestTests extends ESTestCase { private static final String QUERY = "query"; private static final String MODEL_ID = "modelId"; private static final Integer TOP_N = 8; - private static final Boolean RETURN_TEXT = false; - - private static final String AUTH_HEADER_VALUE = "Bearer secret"; + private static final Boolean RETURN_DOCUMENTS = false; + private static final String API_KEY = "secret"; public void testCreateRequest_WithMinimalFieldsSet() throws IOException { testCreateRequest(null, null, null, createRequest(null, null, null)); @@ -40,7 +39,7 @@ public void testCreateRequest_WithTopN() throws IOException { } public void testCreateRequest_WithReturnDocuments() throws IOException { - testCreateRequest(null, RETURN_TEXT, null, createRequest(null, RETURN_TEXT, null)); + testCreateRequest(null, RETURN_DOCUMENTS, null, createRequest(null, RETURN_DOCUMENTS, null)); } public void testCreateRequest_WithModelId() throws IOException { @@ -48,42 +47,46 @@ public void testCreateRequest_WithModelId() throws IOException { } public void testCreateRequest_AllFields() throws IOException { - testCreateRequest(TOP_N, RETURN_TEXT, MODEL_ID, createRequest(TOP_N, RETURN_TEXT, MODEL_ID)); + testCreateRequest(TOP_N, RETURN_DOCUMENTS, MODEL_ID, createRequest(TOP_N, RETURN_DOCUMENTS, MODEL_ID)); } public void testCreateRequest_AllFields_OverridesTaskSettings() throws IOException { - testCreateRequest(TOP_N, RETURN_TEXT, MODEL_ID, createRequestWithDifferentTaskSettings(TOP_N, RETURN_TEXT)); + testCreateRequest(TOP_N, RETURN_DOCUMENTS, MODEL_ID, createRequestWithDifferentTaskSettings(TOP_N, RETURN_DOCUMENTS)); } public void testCreateRequest_AllFields_KeepsTaskSettings() throws IOException { testCreateRequest(1, true, MODEL_ID, createRequestWithDifferentTaskSettings(null, null)); } - private void testCreateRequest(Integer topN, Boolean returnDocuments, String modelId, OpenShiftAiRerankRequest request) - throws IOException { + private void testCreateRequest( + Integer expectedTopN, + Boolean expectedReturnDocuments, + String expectedModelId, + OpenShiftAiRerankRequest request + ) throws IOException { var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters())); - assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY))); var requestMap = entityAsMap(httpPost.getEntity().getContent()); assertThat(requestMap.get(INPUT), is(List.of(INPUT))); assertThat(requestMap.get(QUERY), is(QUERY)); int itemsCount = 2; - if (topN != null) { - assertThat(requestMap.get("top_n"), is(topN)); + if (expectedTopN != null) { + assertThat(requestMap.get("top_n"), is(expectedTopN)); itemsCount++; } - if (returnDocuments != null) { - assertThat(requestMap.get("return_documents"), is(returnDocuments)); + if (expectedReturnDocuments != null) { + assertThat(requestMap.get("return_documents"), is(expectedReturnDocuments)); itemsCount++; } - if (modelId != null) { - assertThat(requestMap.get("model"), is(modelId)); + if (expectedModelId != null) { + assertThat(requestMap.get("model"), is(expectedModelId)); itemsCount++; } assertThat(requestMap, aMapWithSize(itemsCount)); @@ -94,7 +97,7 @@ private static OpenShiftAiRerankRequest createRequest( @Nullable Boolean returnDocuments, @Nullable String modelId ) { - var rerankModel = OpenShiftAiRerankModelTests.createModel(randomAlphaOfLength(10), "secret", modelId, topN, returnDocuments); + var rerankModel = OpenShiftAiRerankModelTests.createModel(randomAlphaOfLength(10), API_KEY, modelId, topN, returnDocuments); return new OpenShiftAiRerankRequest(QUERY, List.of(INPUT), returnDocuments, topN, rerankModel); } @@ -102,7 +105,7 @@ private static OpenShiftAiRerankRequest createRequestWithDifferentTaskSettings( @Nullable Integer topN, @Nullable Boolean returnDocuments ) { - var rerankModel = OpenShiftAiRerankModelTests.createModel(randomAlphaOfLength(10), "secret", MODEL_ID, 1, true); + var rerankModel = OpenShiftAiRerankModelTests.createModel(randomAlphaOfLength(10), API_KEY, MODEL_ID, 1, true); return new OpenShiftAiRerankRequest(QUERY, List.of(INPUT), returnDocuments, topN, rerankModel); } } From 73e75dc8fa0beab9296b93caa6ac99802d4b6ed7 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Sat, 8 Nov 2025 19:38:41 +0200 Subject: [PATCH 61/70] Refactor OpenShift AI test constants for improved clarity and consistency --- .../OpenShiftAiChatCompletionModelTests.java | 2 +- ...iftAiChatCompletionRequestEntityTests.java | 9 ++- ...OpenShiftAiChatCompletionRequestTests.java | 14 ++--- ...enShiftAiEmbeddingsRequestEntityTests.java | 23 ++++--- .../OpenShiftAiEmbeddingsRequestTests.java | 39 +++++++----- .../OpenShiftAIRerankRequestEntityTests.java | 20 +++--- .../rarank/OpenShiftAiRerankRequestTests.java | 61 ++++++++++++------- 7 files changed, 101 insertions(+), 67 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java index 6604d473e945a..8eccbb0b2b080 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java @@ -19,7 +19,7 @@ public class OpenShiftAiChatCompletionModelTests extends ESTestCase { private static final String MODEL_ID = "model_name"; private static final String API_KEY = "api_key"; - private static final String URL = "url"; + private static final String URL = "some_url"; public static OpenShiftAiChatCompletionModel createCompletionModel(String url, String apiKey, String modelName) { return createModelWithTaskType(url, apiKey, modelName, TaskType.COMPLETION); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java index 50c1135644bc3..000d9876b05c7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java @@ -22,7 +22,7 @@ import static org.hamcrest.Matchers.is; public class OpenShiftAiChatCompletionRequestEntityTests extends ESTestCase { - private static final String ROLE = "user"; + private static final String USER_ROLE_VALUE = "user"; public void testSerializationWithModelIdStreaming() throws IOException { testSerialization("modelId", true, """ @@ -83,7 +83,12 @@ public void testSerializationWithoutModelIdNonStreaming() throws IOException { } private static void testSerialization(String modelId, boolean isStreaming, String expectedJson) throws IOException { - var message = new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("Hello, world!"), ROLE, null, null); + var message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + USER_ROLE_VALUE, + null, + null + ); var messageList = new ArrayList(); messageList.add(message); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java index 9bcebce08a5c7..76e826f873d8d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java @@ -26,23 +26,23 @@ public class OpenShiftAiChatCompletionRequestTests extends ESTestCase { - private static final String URL = "url"; - private static final String MODEL_ID = "model"; + private static final String URL_VALUE = "some_url"; + private static final String MODEL_VALUE = "some model"; private static final String USER_ROLE = "user"; private static final String API_KEY = "secret"; public void testCreateRequest_WithStreaming() throws IOException { String input = randomAlphaOfLength(15); - var request = createRequest(MODEL_ID, URL, API_KEY, input, true); + var request = createRequest(MODEL_VALUE, URL_VALUE, API_KEY, input, true); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(request.getURI().toString(), is(URL)); + assertThat(request.getURI().toString(), is(URL_VALUE)); assertThat(requestMap.get("stream"), is(true)); - assertThat(requestMap.get("model"), is(MODEL_ID)); + assertThat(requestMap.get("model"), is(MODEL_VALUE)); assertThat(requestMap.get("n"), is(1)); assertThat(requestMap.get("stream_options"), is(nullValue())); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", USER_ROLE, "content", input)))); @@ -52,12 +52,12 @@ public void testCreateRequest_WithStreaming() throws IOException { public void testTruncate_DoesNotReduceInputTextSize() { String input = randomAlphaOfLength(5); - var request = createRequest(MODEL_ID, URL, API_KEY, input, true); + var request = createRequest(MODEL_VALUE, URL_VALUE, API_KEY, input, true); assertThat(request.truncate(), is(sameInstance(request))); } public void testTruncationInfo_ReturnsNull() { - var request = createRequest(MODEL_ID, URL, API_KEY, randomAlphaOfLength(5), true); + var request = createRequest(MODEL_VALUE, URL_VALUE, API_KEY, randomAlphaOfLength(5), true); assertThat(request.getTruncationInfo(), is(nullValue())); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntityTests.java index c46de709f7eec..afa06a834fa6b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntityTests.java @@ -21,18 +21,21 @@ public class OpenShiftAiEmbeddingsRequestEntityTests extends ESTestCase { + private static final String MODEL = "some model"; + private static final String INPUT = "some input"; + public void testXContent_DoesNotWriteDimensionsWhenNullAndSetByUserIsFalse() throws IOException { - var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of("abc"), "model", null, false); + var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of(INPUT), MODEL, null, false); testXContent_DoesNotWriteDimensions(entity); } public void testXContent_DoesNotWriteDimensionsWhenNotSetByUser() throws IOException { - var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of("abc"), "model", 100, false); + var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of(INPUT), MODEL, 100, false); testXContent_DoesNotWriteDimensions(entity); } public void testXContent_DoesNotWriteDimensionsWhenNull_EvenIfSetByUserIsTrue() throws IOException { - var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of("abc"), "model", null, true); + var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of(INPUT), MODEL, null, true); testXContent_DoesNotWriteDimensions(entity); } @@ -43,14 +46,14 @@ private static void testXContent_DoesNotWriteDimensions(OpenShiftAiEmbeddingsReq assertThat(xContentResult, is(XContentHelper.stripWhitespace(""" { - "input": ["abc"], - "model": "model" + "input": ["some input"], + "model": "some model" } """))); } public void testXContent_DoesNotWriteModelWhenItIsNull() throws IOException { - var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of("abc"), null, null, false); + var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of(INPUT), null, null, false); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -58,13 +61,13 @@ public void testXContent_DoesNotWriteModelWhenItIsNull() throws IOException { assertThat(xContentResult, is(XContentHelper.stripWhitespace(""" { - "input": ["abc"] + "input": ["some input"] } """))); } public void testXContent_WritesDimensionsWhenNonNull_AndSetByUserIsTrue() throws IOException { - var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of("abc"), "model", 100, true); + var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of(INPUT), MODEL, 100, true); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -72,8 +75,8 @@ public void testXContent_WritesDimensionsWhenNonNull_AndSetByUserIsTrue() throws assertThat(xContentResult, is(XContentHelper.stripWhitespace(""" { - "input": ["abc"], - "model": "model", + "input": ["some input"], + "model": "some model", "dimensions": 100 } """))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java index f8acf0b2a4afe..704f8222ae014 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java @@ -27,6 +27,15 @@ public class OpenShiftAiEmbeddingsRequestTests extends ESTestCase { + private static final String INPUT_FIELD_NAME = "input"; + private static final String MODEL_FIELD_NAME = "model"; + private static final String DIMENSIONS_FIELD_NAME = "dimensions"; + + private static final String MODEL_VALUE = "some model"; + private static final String INPUT_VALUE = "ABCD"; + private static final String URL_VALUE = "some_url"; + private static final String API_KEY = "some api key"; + public void testCreateRequest_NoDimensions_DimensionsSetByUserFalse_Success() throws IOException { testCreateRequest_Success(null, false, null); } @@ -49,10 +58,10 @@ private void testCreateRequest_Success(Integer dimensions, boolean dimensionsSet var httpPost = validateRequestUrlAndContentType(httpRequest); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap.get("input"), is(List.of("ABCD"))); - assertThat(requestMap.get("model"), is("llama-embed")); - assertThat(requestMap.get("dimensions"), is(expectedDimensions)); - assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer apikey")); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_VALUE))); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); + assertThat(requestMap.get(DIMENSIONS_FIELD_NAME), is(expectedDimensions)); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY))); } public void testCreateRequest_NoModel_Success() throws IOException { @@ -61,10 +70,10 @@ public void testCreateRequest_NoModel_Success() throws IOException { var httpPost = validateRequestUrlAndContentType(httpRequest); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap.get("input"), is(List.of("ABCD"))); - assertThat(requestMap.get("model"), is(nullValue())); - assertThat(requestMap.get("dimensions"), is(nullValue())); - assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer apikey")); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_VALUE))); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(nullValue())); + assertThat(requestMap.get(DIMENSIONS_FIELD_NAME), is(nullValue())); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY))); } @@ -78,8 +87,8 @@ public void testTruncate_ReducesInputTextSizeByHalf() throws IOException { var httpPost = (HttpPost) httpRequest.httpRequestBase(); var requestMap = entityAsMap(httpPost.getEntity().getContent()); assertThat(requestMap, aMapWithSize(2)); - assertThat(requestMap.get("input"), is(List.of("AB"))); - assertThat(requestMap.get("model"), is("llama-embed")); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_VALUE.substring(0, 2)))); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); } @@ -94,19 +103,19 @@ public void testIsTruncated_ReturnsTrue() { private HttpPost validateRequestUrlAndContentType(HttpRequest request) { assertThat(request.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) request.httpRequestBase(); - assertThat(httpPost.getURI().toString(), is("url")); + assertThat(httpPost.getURI().toString(), is(URL_VALUE)); assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters())); return httpPost; } private static OpenShiftAiEmbeddingsRequest createRequest(Integer dimensions, Boolean dimensionsSetByUser) { - return createRequest(dimensions, dimensionsSetByUser, "llama-embed"); + return createRequest(dimensions, dimensionsSetByUser, MODEL_VALUE); } private static OpenShiftAiEmbeddingsRequest createRequest(Integer dimensions, Boolean dimensionsSetByUser, String modelId) { var embeddingsModel = OpenShiftAiEmbeddingsModelTests.createModel( - "url", - "apikey", + URL_VALUE, + API_KEY, modelId, dimensions, dimensionsSetByUser, @@ -115,7 +124,7 @@ private static OpenShiftAiEmbeddingsRequest createRequest(Integer dimensions, Bo ); return new OpenShiftAiEmbeddingsRequest( TruncatorTests.createTruncator(), - new Truncator.TruncationResult(List.of("ABCD"), new boolean[] { false }), + new Truncator.TruncationResult(List.of(INPUT_VALUE), new boolean[] { false }), embeddingsModel ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java index 35637228a2776..f3deea999e5b7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java @@ -21,23 +21,23 @@ import static org.hamcrest.Matchers.is; public class OpenShiftAIRerankRequestEntityTests extends ESTestCase { - private static final String INPUT = "documents"; - private static final String QUERY = "query"; - private static final String MODEL = "model"; + private static final String DOCUMENT = "some document"; + private static final String QUERY = "some query"; + private static final String MODEL = "some model"; private static final Integer TOP_N = 8; private static final Boolean RETURN_DOCUMENTS = true; public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { - var entity = new OpenShiftAIRerankRequestEntity(MODEL, QUERY, List.of(INPUT), RETURN_DOCUMENTS, TOP_N); + var entity = new OpenShiftAIRerankRequestEntity(MODEL, QUERY, List.of(DOCUMENT), RETURN_DOCUMENTS, TOP_N); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); String result = Strings.toString(builder); String expected = """ { - "model": "model", - "query": "query", - "documents": ["documents"], + "model": "some model", + "query": "some query", + "documents": ["some document"], "top_n": 8, "return_documents": true } @@ -46,15 +46,15 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException } public void testXContent_WritesMinimalFields() throws IOException { - var entity = new OpenShiftAIRerankRequestEntity(null, QUERY, List.of(INPUT), null, null); + var entity = new OpenShiftAIRerankRequestEntity(null, QUERY, List.of(DOCUMENT), null, null); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); String result = Strings.toString(builder); String expected = """ { - "query": "query", - "documents": ["documents"] + "query": "some query", + "documents": ["some document"] } """; assertThat(stripWhitespace(expected), is(result)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java index b5838563cb769..2922f8e6bc2cd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java @@ -23,39 +23,56 @@ import static org.hamcrest.Matchers.is; public class OpenShiftAiRerankRequestTests extends ESTestCase { - private static final String INPUT = "documents"; - private static final String QUERY = "query"; - private static final String MODEL_ID = "modelId"; - private static final Integer TOP_N = 8; - private static final Boolean RETURN_DOCUMENTS = false; - private static final String API_KEY = "secret"; + + private static final String TOP_N_FIELD_NAME = "top_n"; + private static final String RETURN_DOCUMENTS_FIELD_NAME = "return_documents"; + private static final String MODEL_FIELD_NAME = "model"; + private static final String DOCUMENTS_FIELD_NAME = "documents"; + private static final String QUERY_FIELD_NAME = "query"; + + private static final String DOCUMENT_VALUE = "some document"; + private static final String QUERY_VALUE = "some query"; + private static final String MODEL_VALUE = "some model"; + private static final Integer TOP_N_VALUE = 8; + private static final Boolean RETURN_DOCUMENTS_VALUE = false; + private static final String API_KEY_VALUE = "some api key"; public void testCreateRequest_WithMinimalFieldsSet() throws IOException { testCreateRequest(null, null, null, createRequest(null, null, null)); } public void testCreateRequest_WithTopN() throws IOException { - testCreateRequest(TOP_N, null, null, createRequest(TOP_N, null, null)); + testCreateRequest(TOP_N_VALUE, null, null, createRequest(TOP_N_VALUE, null, null)); } public void testCreateRequest_WithReturnDocuments() throws IOException { - testCreateRequest(null, RETURN_DOCUMENTS, null, createRequest(null, RETURN_DOCUMENTS, null)); + testCreateRequest(null, RETURN_DOCUMENTS_VALUE, null, createRequest(null, RETURN_DOCUMENTS_VALUE, null)); } public void testCreateRequest_WithModelId() throws IOException { - testCreateRequest(null, null, MODEL_ID, createRequest(null, null, MODEL_ID)); + testCreateRequest(null, null, MODEL_VALUE, createRequest(null, null, MODEL_VALUE)); } public void testCreateRequest_AllFields() throws IOException { - testCreateRequest(TOP_N, RETURN_DOCUMENTS, MODEL_ID, createRequest(TOP_N, RETURN_DOCUMENTS, MODEL_ID)); + testCreateRequest( + TOP_N_VALUE, + RETURN_DOCUMENTS_VALUE, + MODEL_VALUE, + createRequest(TOP_N_VALUE, RETURN_DOCUMENTS_VALUE, MODEL_VALUE) + ); } public void testCreateRequest_AllFields_OverridesTaskSettings() throws IOException { - testCreateRequest(TOP_N, RETURN_DOCUMENTS, MODEL_ID, createRequestWithDifferentTaskSettings(TOP_N, RETURN_DOCUMENTS)); + testCreateRequest( + TOP_N_VALUE, + RETURN_DOCUMENTS_VALUE, + MODEL_VALUE, + createRequestWithDifferentTaskSettings(TOP_N_VALUE, RETURN_DOCUMENTS_VALUE) + ); } public void testCreateRequest_AllFields_KeepsTaskSettings() throws IOException { - testCreateRequest(1, true, MODEL_ID, createRequestWithDifferentTaskSettings(null, null)); + testCreateRequest(1, true, MODEL_VALUE, createRequestWithDifferentTaskSettings(null, null)); } private void testCreateRequest( @@ -70,23 +87,23 @@ private void testCreateRequest( var httpPost = (HttpPost) httpRequest.httpRequestBase(); assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters())); - assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY))); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY_VALUE))); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap.get(INPUT), is(List.of(INPUT))); - assertThat(requestMap.get(QUERY), is(QUERY)); + assertThat(requestMap.get(DOCUMENTS_FIELD_NAME), is(List.of(DOCUMENT_VALUE))); + assertThat(requestMap.get(QUERY_FIELD_NAME), is(QUERY_VALUE)); int itemsCount = 2; if (expectedTopN != null) { - assertThat(requestMap.get("top_n"), is(expectedTopN)); + assertThat(requestMap.get(TOP_N_FIELD_NAME), is(expectedTopN)); itemsCount++; } if (expectedReturnDocuments != null) { - assertThat(requestMap.get("return_documents"), is(expectedReturnDocuments)); + assertThat(requestMap.get(RETURN_DOCUMENTS_FIELD_NAME), is(expectedReturnDocuments)); itemsCount++; } if (expectedModelId != null) { - assertThat(requestMap.get("model"), is(expectedModelId)); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(expectedModelId)); itemsCount++; } assertThat(requestMap, aMapWithSize(itemsCount)); @@ -97,15 +114,15 @@ private static OpenShiftAiRerankRequest createRequest( @Nullable Boolean returnDocuments, @Nullable String modelId ) { - var rerankModel = OpenShiftAiRerankModelTests.createModel(randomAlphaOfLength(10), API_KEY, modelId, topN, returnDocuments); - return new OpenShiftAiRerankRequest(QUERY, List.of(INPUT), returnDocuments, topN, rerankModel); + var rerankModel = OpenShiftAiRerankModelTests.createModel(randomAlphaOfLength(10), API_KEY_VALUE, modelId, topN, returnDocuments); + return new OpenShiftAiRerankRequest(QUERY_VALUE, List.of(DOCUMENT_VALUE), returnDocuments, topN, rerankModel); } private static OpenShiftAiRerankRequest createRequestWithDifferentTaskSettings( @Nullable Integer topN, @Nullable Boolean returnDocuments ) { - var rerankModel = OpenShiftAiRerankModelTests.createModel(randomAlphaOfLength(10), API_KEY, MODEL_ID, 1, true); - return new OpenShiftAiRerankRequest(QUERY, List.of(INPUT), returnDocuments, topN, rerankModel); + var rerankModel = OpenShiftAiRerankModelTests.createModel(randomAlphaOfLength(10), API_KEY_VALUE, MODEL_VALUE, 1, true); + return new OpenShiftAiRerankRequest(QUERY_VALUE, List.of(DOCUMENT_VALUE), returnDocuments, topN, rerankModel); } } From 1de3b0646a940cb8db4a2401461fe82faad62037 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Mon, 10 Nov 2025 20:26:40 +0200 Subject: [PATCH 62/70] Refactor OpenShift AI action creator tests for improved readability and consistency --- .../action/OpenShiftAiActionCreatorTests.java | 408 ++++++++---------- 1 file changed, 187 insertions(+), 221 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index 5765d6505f033..43bbdc5a68644 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockRequest; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -42,7 +43,6 @@ import java.util.List; import java.util.Map; -import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests.buildExpectationRerank; @@ -65,17 +65,49 @@ public class OpenShiftAiActionCreatorTests extends ESTestCase { - private static final String MODEL_ID = "model"; - private static final String API_KEY = "secret"; - private static final String QUERY = "popular name"; - private static final String USER_ROLE = "user"; + // Completion field names + private static final String N_FIELD_NAME = "n"; + private static final String STREAM_FIELD_NAME = "stream"; + private static final String MESSAGES_FIELD_NAME = "messages"; + private static final String ROLE_FIELD_NAME = "role"; + private static final String CONTENT_FIELD_NAME = "content"; + + // Rerank field names + private static final String DOCUMENTS_FIELD_NAME = "documents"; + private static final String MODEL_FIELD_NAME = "model"; + private static final String QUERY_FIELD_NAME = "query"; + private static final String TOP_N_FIELD_NAME = "top_n"; + private static final String RETURN_DOCUMENTS_FIELD_NAME = "return_documents"; + + // Embeddings field names + private static final String INPUT_FIELD_NAME = "input"; + + // Test values + private static final String API_KEY = "test-api-key"; + private static final String MODEL_TEST_VALUE = "some model"; + private static final String QUERY_TEST_VALUE = "popular name"; + private static final String ROLE_TEST_VALUE = "user"; private static final String INPUT = "abcd"; + private static final List INPUT_TEST_VALUE = List.of(INPUT); + private static final String INPUT_TO_TRUNCATE = "super long input"; + private static final List EMBEDDINGS_TEST_VALUE = List.of(new float[] { 0.0123F, -0.0123F }); + private static final List DOCUMENTS_TEST_VALUE = List.of("Luke"); + private static final List RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS = List.of( + new RankedDocsResultsTests.RerankExpectation(Map.of("text", "awgawgawgawg", "index", 1, "relevance_score", 0.9921875f)), + new RankedDocsResultsTests.RerankExpectation(Map.of("text", "awdawdawda", "index", 0, "relevance_score", 0.4921875f)) + ); + private static final List RERANK_EXPECTATIONS_NO_TEXT_SINGLE_RESULT = List.of( + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", 0.4921875f)) + ); + + // Settings with no retries private static final Settings NO_RETRY_SETTINGS = buildSettingsWithRetryFields( TimeValue.timeValueMillis(1), TimeValue.timeValueMinutes(1), TimeValue.timeValueSeconds(0) ); - private static final String INPUT_TO_TRUNCATE = "super long input"; + + // Mock server and client manager private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; @@ -126,37 +158,39 @@ public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createModel(getUrl(webServer), API_KEY, MODEL_ID); + var model = createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(INPUT), InputTypeTests.randomWithNull()), + new EmbeddingsInput(INPUT_TEST_VALUE, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); - assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); + assertThat(result.asMap(), is(buildExpectationFloat(EMBEDDINGS_TEST_VALUE))); assertThat(webServer.requests(), hasSize(1)); - assertThat(webServer.requests().getFirst().getUri().getQuery(), is(nullValue())); - assertThat( - webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaTypeWithoutParameters()) - ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + var request = webServer.requests().getFirst(); + assertThat(request.getUri().getQuery(), is(nullValue())); + assertContentTypeAndAuthorization(request); + + var requestMap = entityAsMap(request.getBody()); assertThat(requestMap, aMapWithSize(2)); - assertThat(requestMap.get("input"), is(List.of(INPUT))); - assertThat(requestMap.get("model"), is(MODEL_ID)); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(INPUT_TEST_VALUE)); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_TEST_VALUE)); } } + private static void assertContentTypeAndAuthorization(MockRequest request) { + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters())); + assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); + } + public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat() throws IOException { - // timeout as zero for no retries var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); try (var sender = createSender(senderFactory)) { @@ -164,47 +198,50 @@ public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat String responseJson = """ { - "object": "list", - "data_does_not_exist": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } + "id": "embd-45e6d99b97a645c0af96653598069cd9", + "object": "list", + "created": 1760085467, + "model": "gritlm-7b", + "data_does_not_exist": [ + { + "index": 0, + "object": "embedding", + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "usage": { + "prompt_tokens": 7, + "total_tokens": 7, + "completion_tokens": 0, + "prompt_tokens_details": null + } } """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createModel(getUrl(webServer), API_KEY, MODEL_ID); - var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); - var action = actionCreator.create(model); + var model = createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); + var action = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)).create(model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(INPUT), InputTypeTests.randomWithNull()), + new EmbeddingsInput(INPUT_TEST_VALUE, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); - var failureCauseMessage = "Required [data]"; var thrownException = expectThrows( ElasticsearchStatusException.class, () -> listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT) ); + + var failureCauseMessage = "Required [data]"; assertThat( thrownException.getMessage(), is( - format( - "Failed to send OpenShift AI text_embedding request from inference entity id [inferenceEntityId]. Cause: %s", + "Failed to send OpenShift AI text_embedding request from inference entity id [inferenceEntityId]. Cause: %s".formatted( failureCauseMessage ) ) @@ -252,12 +289,11 @@ public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createCompletionModel(getUrl(webServer), API_KEY, MODEL_ID); - var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); - var action = actionCreator.create(model); + var model = createCompletionModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); + var action = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)).create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of(INPUT)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(INPUT_TEST_VALUE), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); @@ -265,22 +301,20 @@ public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { assertThat(webServer.requests(), hasSize(1)); var request = webServer.requests().getFirst(); + assertContentTypeAndAuthorization(request); - assertThat(request.getUri().getQuery(), is(nullValue())); - assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters())); - assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); - - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat(requestMap, aMapWithSize(4)); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", USER_ROLE, "content", INPUT)))); - assertThat(requestMap.get("model"), is(MODEL_ID)); - assertThat(requestMap.get("n"), is(1)); - assertThat(requestMap.get("stream"), is(false)); + var requestMap = entityAsMap(request.getBody()); + assertThat( + requestMap.get(MESSAGES_FIELD_NAME), + is(List.of(Map.of(ROLE_FIELD_NAME, ROLE_TEST_VALUE, CONTENT_FIELD_NAME, INPUT))) + ); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_TEST_VALUE)); + assertThat(requestMap.get(N_FIELD_NAME), is(1)); + assertThat(requestMap.get(STREAM_FIELD_NAME), is(false)); } } public void testCreate_OpenShiftAiChatCompletionModel_FailsFromInvalidResponseFormat() throws IOException { - // timeout as zero for no retries var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); try (var sender = createSender(senderFactory)) { @@ -318,23 +352,21 @@ public void testCreate_OpenShiftAiChatCompletionModel_FailsFromInvalidResponseFo """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createCompletionModel(getUrl(webServer), API_KEY, MODEL_ID); - var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); - var action = actionCreator.create(model); + var model = createCompletionModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); + var action = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)).create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of(INPUT)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(INPUT_TEST_VALUE), InferenceAction.Request.DEFAULT_TIMEOUT, listener); - var failureCauseMessage = "Required [choices]"; var thrownException = expectThrows( ElasticsearchStatusException.class, () -> listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT) ); + var failureCauseMessage = "Required [choices]"; assertThat( thrownException.getMessage(), is( - format( - "Failed to send OpenShift AI completion request from inference entity id [inferenceEntityId]. Cause: %s", + "Failed to send OpenShift AI completion request from inference entity id [inferenceEntityId]. Cause: %s".formatted( failureCauseMessage ) ) @@ -391,46 +423,38 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC webServer.enqueue(new MockResponse().setResponseCode(413).setBody(responseJsonContentTooLarge)); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createModel(getUrl(webServer), API_KEY, MODEL_ID); + var model = createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(INPUT), InputTypeTests.randomWithNull()), + new EmbeddingsInput(INPUT_TEST_VALUE, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); - assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); + assertThat(result.asMap(), is(buildExpectationFloat(EMBEDDINGS_TEST_VALUE))); assertThat(webServer.requests(), hasSize(2)); { assertThat(webServer.requests().getFirst().getUri().getQuery(), is(nullValue())); - assertThat( - webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaTypeWithoutParameters()) - ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); + assertContentTypeAndAuthorization(webServer.requests().getFirst()); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap, aMapWithSize(2)); - assertThat(requestMap.get("input"), is(List.of(INPUT))); - assertThat(requestMap.get("model"), is(MODEL_ID)); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(INPUT_TEST_VALUE)); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_TEST_VALUE)); } { assertThat(webServer.requests().get(1).getUri().getQuery(), is(nullValue())); - assertThat( - webServer.requests().get(1).getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaTypeWithoutParameters()) - ); - assertThat(webServer.requests().get(1).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); + assertContentTypeAndAuthorization(webServer.requests().get(1)); var requestMap = entityAsMap(webServer.requests().get(1).getBody()); assertThat(requestMap, aMapWithSize(2)); - assertThat(requestMap.get("input"), is(List.of(INPUT.substring(0, 2)))); - assertThat(requestMap.get("model"), is(MODEL_ID)); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT.substring(0, 2)))); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_TEST_VALUE)); } } } @@ -456,7 +480,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC } """, contentTooLargeErrorMessage); - String responseJson = """ + var responseJson = """ { "id": "embd-45e6d99b97a645c0af96653598069cd9", "object": "list", @@ -483,46 +507,36 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC webServer.enqueue(new MockResponse().setResponseCode(400).setBody(responseJsonContentTooLarge)); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createModel(getUrl(webServer), API_KEY, MODEL_ID); - var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); - var action = actionCreator.create(model); + var model = createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); + var action = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)).create(model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(INPUT), InputTypeTests.randomWithNull()), + new EmbeddingsInput(INPUT_TEST_VALUE, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); - assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); + assertThat(result.asMap(), is(buildExpectationFloat(EMBEDDINGS_TEST_VALUE))); assertThat(webServer.requests(), hasSize(2)); - { - assertThat(webServer.requests().getFirst().getUri().getQuery(), is(nullValue())); - assertThat( - webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaTypeWithoutParameters()) - ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat(requestMap, aMapWithSize(2)); - assertThat(requestMap.get("input"), is(List.of(INPUT))); - assertThat(requestMap.get("model"), is(MODEL_ID)); + { + var firstRequest = webServer.requests().getFirst(); + assertContentTypeAndAuthorization(firstRequest); + var firstRequestMap = entityAsMap(firstRequest.getBody()); + assertThat(firstRequestMap, aMapWithSize(2)); + assertThat(firstRequestMap.get(INPUT_FIELD_NAME), is(INPUT_TEST_VALUE)); + assertThat(firstRequestMap.get(MODEL_FIELD_NAME), is(MODEL_TEST_VALUE)); } { - assertThat(webServer.requests().get(1).getUri().getQuery(), is(nullValue())); - assertThat( - webServer.requests().get(1).getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaTypeWithoutParameters()) - ); - assertThat(webServer.requests().get(1).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); - - var requestMap = entityAsMap(webServer.requests().get(1).getBody()); - assertThat(requestMap, aMapWithSize(2)); - assertThat(requestMap.get("input"), is(List.of(INPUT.substring(0, 2)))); - assertThat(requestMap.get("model"), is(MODEL_ID)); + var secondRequest = webServer.requests().get(1); + assertContentTypeAndAuthorization(secondRequest); + var secondRequestMap = entityAsMap(secondRequest.getBody()); + assertThat(secondRequestMap, aMapWithSize(2)); + assertThat(secondRequestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT.substring(0, 2)))); + assertThat(secondRequestMap.get(MODEL_FIELD_NAME), is(MODEL_TEST_VALUE)); } } } @@ -533,7 +547,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { try (var sender = createSender(senderFactory)) { sender.startSynchronously(); - String responseJson = """ + var responseJson = """ { "id": "embd-45e6d99b97a645c0af96653598069cd9", "object": "list", @@ -560,11 +574,10 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); // truncated to 1 token = 3 characters - var model = createModel(getUrl(webServer), API_KEY, MODEL_ID, 1); - var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); - var action = actionCreator.create(model); + var model = createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE, 1); + var action = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)).create(model); - PlainActionFuture listener = new PlainActionFuture<>(); + var listener = new PlainActionFuture(); action.execute( new EmbeddingsInput(List.of(INPUT_TO_TRUNCATE), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, @@ -573,27 +586,23 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); - assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); + assertThat(result.asMap(), is(buildExpectationFloat(EMBEDDINGS_TEST_VALUE))); assertThat(webServer.requests(), hasSize(1)); - assertThat(webServer.requests().getFirst().getUri().getQuery(), is(nullValue())); - assertThat( - webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaTypeWithoutParameters()) - ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + var request = webServer.requests().getFirst(); + assertThat(request.getUri().getQuery(), is(nullValue())); + assertContentTypeAndAuthorization(request); + + var requestMap = entityAsMap(request.getBody()); assertThat(requestMap, aMapWithSize(2)); - assertThat(requestMap.get("input"), is(List.of(INPUT_TO_TRUNCATE.substring(0, 3)))); - assertThat(requestMap.get("model"), is(MODEL_ID)); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_TO_TRUNCATE.substring(0, 3)))); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_TEST_VALUE)); } } public void testCreate_OpenShiftAiRerankModel_WithTaskSettings() throws IOException { - // timeout as zero for no retries var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); - List documents = List.of("Luke"); try (var sender = createSender(senderFactory)) { sender.startSynchronously(); @@ -624,7 +633,7 @@ public void testCreate_OpenShiftAiRerankModel_WithTaskSettings() throws IOExcept """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID); + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) @@ -632,33 +641,22 @@ public void testCreate_OpenShiftAiRerankModel_WithTaskSettings() throws IOExcept var action = actionCreator.create(model, null); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new QueryAndDocsInputs(QUERY, documents, null, null, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new QueryAndDocsInputs(QUERY_TEST_VALUE, DOCUMENTS_TEST_VALUE, null, null, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); - assertThat( - result.asMap(), - is( - buildExpectationRerank( - List.of( - new RankedDocsResultsTests.RerankExpectation( - Map.of("text", "awgawgawgawg", "index", 1, "relevance_score", 0.9921875f) - ), - new RankedDocsResultsTests.RerankExpectation( - Map.of("text", "awdawdawda", "index", 0, "relevance_score", 0.4921875f) - ) - ) - ) - ) - ); + assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS))); } - assertRerankActionCreator(documents, 2, true); + assertRerankActionCreator(DOCUMENTS_TEST_VALUE, 2, true); } public void testCreate_OpenShiftAiRerankModel_WithOverriddenTaskSettings() throws IOException { - // timeout as zero for no retries var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); - List documents = List.of("Luke"); + List documents = DOCUMENTS_TEST_VALUE; try (var sender = createSender(senderFactory)) { sender.startSynchronously(); @@ -679,7 +677,7 @@ public void testCreate_OpenShiftAiRerankModel_WithOverriddenTaskSettings() throw """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID); + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) @@ -690,26 +688,22 @@ public void testCreate_OpenShiftAiRerankModel_WithOverriddenTaskSettings() throw ); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new QueryAndDocsInputs(QUERY, documents, null, null, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new QueryAndDocsInputs(QUERY_TEST_VALUE, documents, null, null, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); - assertThat( - result.asMap(), - is( - buildExpectationRerank( - List.of(new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", 0.4921875f))) - ) - ) - ); + assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_NO_TEXT_SINGLE_RESULT))); } assertRerankActionCreator(documents, 1, false); } public void testCreate_OpenShiftAiRerankModel_NoTaskSettings() throws IOException { - // timeout as zero for no retries var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); - List documents = List.of("Luke"); + List documents = DOCUMENTS_TEST_VALUE; try (var sender = createSender(senderFactory)) { sender.startSynchronously(); @@ -740,7 +734,7 @@ public void testCreate_OpenShiftAiRerankModel_NoTaskSettings() throws IOExceptio """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID, null, null); + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE, null, null); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) @@ -748,33 +742,21 @@ public void testCreate_OpenShiftAiRerankModel_NoTaskSettings() throws IOExceptio var action = actionCreator.create(model, null); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new QueryAndDocsInputs(QUERY, documents, null, null, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new QueryAndDocsInputs(QUERY_TEST_VALUE, documents, null, null, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); - assertThat( - result.asMap(), - is( - buildExpectationRerank( - List.of( - new RankedDocsResultsTests.RerankExpectation( - Map.of("text", "awgawgawgawg", "index", 1, "relevance_score", 0.9921875f) - ), - new RankedDocsResultsTests.RerankExpectation( - Map.of("text", "awdawdawda", "index", 0, "relevance_score", 0.4921875f) - ) - ) - ) - ) - ); + assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS))); } assertRerankActionCreator(documents, null, null); } public void testCreate_OpenShiftAiRerankModel_NoTaskSettings_WithRequestParameters() throws IOException { - // timeout as zero for no retries var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); - List documents = List.of("Luke"); try (var sender = createSender(senderFactory)) { sender.startSynchronously(); @@ -805,7 +787,7 @@ public void testCreate_OpenShiftAiRerankModel_NoTaskSettings_WithRequestParamete """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID, null, null); + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE, null, null); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) @@ -813,33 +795,22 @@ public void testCreate_OpenShiftAiRerankModel_NoTaskSettings_WithRequestParamete var action = actionCreator.create(model, null); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new QueryAndDocsInputs(QUERY, documents, true, 2, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new QueryAndDocsInputs(QUERY_TEST_VALUE, DOCUMENTS_TEST_VALUE, true, 2, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); - assertThat( - result.asMap(), - is( - buildExpectationRerank( - List.of( - new RankedDocsResultsTests.RerankExpectation( - Map.of("text", "awgawgawgawg", "index", 1, "relevance_score", 0.9921875f) - ), - new RankedDocsResultsTests.RerankExpectation( - Map.of("text", "awdawdawda", "index", 0, "relevance_score", 0.4921875f) - ) - ) - ) - ) - ); + assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS))); } - assertRerankActionCreator(documents, 2, true); + assertRerankActionCreator(DOCUMENTS_TEST_VALUE, 2, true); } public void testCreate_OpenShiftAiRerankModel_WithTaskSettings_WithRequestParametersPrioritized() throws IOException { - // timeout as zero for no retries var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); - List documents = List.of("Luke"); + List documents = DOCUMENTS_TEST_VALUE; try (var sender = createSender(senderFactory)) { sender.startSynchronously(); @@ -860,7 +831,7 @@ public void testCreate_OpenShiftAiRerankModel_WithTaskSettings_WithRequestParame """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID); + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) @@ -868,26 +839,21 @@ public void testCreate_OpenShiftAiRerankModel_WithTaskSettings_WithRequestParame var action = actionCreator.create(model, null); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new QueryAndDocsInputs(QUERY, documents, false, 1, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new QueryAndDocsInputs(QUERY_TEST_VALUE, documents, false, 1, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); - assertThat( - result.asMap(), - is( - buildExpectationRerank( - List.of(new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", 0.4921875f))) - ) - ) - ); + assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_NO_TEXT_SINGLE_RESULT))); } assertRerankActionCreator(documents, 1, false); } public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() throws IOException { - // timeout as zero for no retries var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); - List documents = List.of("Luke"); try (var sender = createSender(senderFactory)) { sender.startSynchronously(); @@ -918,7 +884,7 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID); + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) @@ -926,7 +892,11 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t var action = actionCreator.create(model, null); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new QueryAndDocsInputs(QUERY, documents, null, null, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new QueryAndDocsInputs(QUERY_TEST_VALUE, DOCUMENTS_TEST_VALUE, null, null, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT)); assertThat( @@ -943,23 +913,19 @@ private void assertRerankActionCreator( ) throws IOException { assertThat(webServer.requests(), hasSize(1)); assertThat(webServer.requests().getFirst().getUri().getQuery(), is(nullValue())); - assertThat( - webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaTypeWithoutParameters()) - ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); + assertContentTypeAndAuthorization(webServer.requests().getFirst()); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); int fieldCount = 3; - assertThat(requestMap.get("documents"), is(documents)); - assertThat(requestMap.get("model"), is(MODEL_ID)); - assertThat(requestMap.get("query"), is(QUERY)); + assertThat(requestMap.get(DOCUMENTS_FIELD_NAME), is(documents)); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_TEST_VALUE)); + assertThat(requestMap.get(QUERY_FIELD_NAME), is(QUERY_TEST_VALUE)); if (expectedTopN != null) { - assertThat(requestMap.get("top_n"), is(expectedTopN)); + assertThat(requestMap.get(TOP_N_FIELD_NAME), is(expectedTopN)); fieldCount++; } if (expectedReturnDocuments != null) { - assertThat(requestMap.get("return_documents"), is(expectedReturnDocuments)); + assertThat(requestMap.get(RETURN_DOCUMENTS_FIELD_NAME), is(expectedReturnDocuments)); fieldCount++; } assertThat(requestMap, aMapWithSize(fieldCount)); From 9a7586f10fef0e121e2b938114de8e2eed8fa0d7 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Mon, 10 Nov 2025 23:24:48 +0200 Subject: [PATCH 63/70] Add DIMENSIONS_SET_BY_USER constant and refactor variable names for clarity in OpenShift AI integration --- .../inference/services/ServiceFields.java | 1 + .../OpenShiftAiEmbeddingsServiceSettings.java | 2 +- .../openshiftai/OpenShiftAiServiceTests.java | 124 ++++--- .../action/OpenShiftAiActionCreatorTests.java | 120 ++++--- .../OpenShiftAiChatCompletionModelTests.java | 21 +- ...tAiChatCompletionResponseHandlerTests.java | 6 +- ...tAiChatCompletionServiceSettingsTests.java | 30 +- ...ShiftAiEmbeddingsServiceSettingsTests.java | 328 +++++++++--------- ...iftAiChatCompletionRequestEntityTests.java | 4 +- ...OpenShiftAiChatCompletionRequestTests.java | 36 +- ...enShiftAiEmbeddingsRequestEntityTests.java | 19 +- .../OpenShiftAiEmbeddingsRequestTests.java | 23 +- .../OpenShiftAIRerankRequestEntityTests.java | 16 +- .../rarank/OpenShiftAiRerankRequestTests.java | 88 +++-- .../rerank/OpenShiftAiRerankModelTests.java | 24 +- 15 files changed, 458 insertions(+), 384 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java index 1af79a69839ac..062ef0f67059c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java @@ -14,6 +14,7 @@ public final class ServiceFields { public static final String SIMILARITY = "similarity"; public static final String DIMENSIONS = "dimensions"; + public static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; // Typically we use this to define the maximum tokens for the input text (text being sent to an integration) public static final String MAX_INPUT_TOKENS = "max_input_tokens"; public static final String URL = "url"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java index 1000219633a64..b9343179db185 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettings.java @@ -27,6 +27,7 @@ import java.util.Objects; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS_SET_BY_USER; import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; @@ -40,7 +41,6 @@ */ public class OpenShiftAiEmbeddingsServiceSettings extends OpenShiftAiServiceSettings { public static final String NAME = "openshift_ai_embeddings_service_settings"; - static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; private final Integer dimensions; private final SimilarityMeasure similarity; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index 25b0ad9e9094d..ad0db05493aa5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -103,11 +103,17 @@ import static org.mockito.Mockito.mock; public class OpenShiftAiServiceTests extends AbstractInferenceServiceTests { - private static final String URL = "http://www.abc.com"; - private static final String MODEL_ID = "model_id"; - private static final String USER_ROLE = "user"; - private static final String API_KEY = "secret"; - private static final String INFERENCE_ID = "id"; + private static final String URL_VALUE = "http://www.abc.com"; + private static final String MODEL_VALUE = "some_model"; + private static final String ROLE_VALUE = "user"; + private static final String API_KEY_VALUE = "test_api_key"; + private static final String INFERENCE_ID_VALUE = "id"; + private static final String API_KEY_FIELD_NAME = "api_key"; + private static final int DIMENSIONS_VALUE = 1536; + private static final int MAX_INPUT_TOKENS_VALUE = 512; + private static final String INPUT_FIELD_NAME = "input"; + private static final String MODEL_FIELD_NAME = "model"; + private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; @@ -177,21 +183,21 @@ private static void assertTextEmbeddingModel(Model model, boolean modelIncludesS assertThat(openShiftAiModel.getTaskType(), is(TaskType.TEXT_EMBEDDING)); assertThat(model, instanceOf(OpenShiftAiEmbeddingsModel.class)); var embeddingsModel = (OpenShiftAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1536)); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(DIMENSIONS_VALUE)); assertThat(embeddingsModel.getServiceSettings().similarity(), is(SimilarityMeasure.COSINE)); - assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(MAX_INPUT_TOKENS_VALUE)); } private static OpenShiftAiModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) { assertThat(model, instanceOf(OpenShiftAiModel.class)); var openShiftAiModel = (OpenShiftAiModel) model; - assertThat(openShiftAiModel.getServiceSettings().modelId(), is(MODEL_ID)); - assertThat(openShiftAiModel.getServiceSettings().uri.toString(), is(URL)); + assertThat(openShiftAiModel.getServiceSettings().modelId(), is(MODEL_VALUE)); + assertThat(openShiftAiModel.getServiceSettings().uri.toString(), is(URL_VALUE)); assertThat(openShiftAiModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); if (modelIncludesSecrets) { - assertThat(openShiftAiModel.getSecretSettings().apiKey(), is(new SecureString(API_KEY.toCharArray()))); + assertThat(openShiftAiModel.getSecretSettings().apiKey(), is(new SecureString(API_KEY_VALUE.toCharArray()))); } return openShiftAiModel; @@ -213,7 +219,7 @@ public static SenderService createService(ThreadPool threadPool, HttpClientManag } private static Map createServiceSettingsMap(TaskType taskType) { - Map settingsMap = new HashMap<>(Map.of(ServiceFields.URL, URL, ServiceFields.MODEL_ID, MODEL_ID)); + Map settingsMap = new HashMap<>(Map.of(ServiceFields.URL, URL_VALUE, ServiceFields.MODEL_ID, MODEL_VALUE)); if (taskType == TaskType.TEXT_EMBEDDING) { settingsMap.putAll( @@ -221,9 +227,9 @@ private static Map createServiceSettingsMap(TaskType taskType) { ServiceFields.SIMILARITY, SimilarityMeasure.COSINE.toString(), ServiceFields.DIMENSIONS, - 1536, + DIMENSIONS_VALUE, ServiceFields.MAX_INPUT_TOKENS, - 512 + MAX_INPUT_TOKENS_VALUE ) ); } @@ -232,17 +238,25 @@ private static Map createServiceSettingsMap(TaskType taskType) { } private static Map createSecretSettingsMap() { - return new HashMap<>(Map.of("api_key", API_KEY)); + return new HashMap<>(Map.of(API_KEY_FIELD_NAME, API_KEY_VALUE)); } private static OpenShiftAiEmbeddingsModel createInternalEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure) { return new OpenShiftAiEmbeddingsModel( - INFERENCE_ID, + INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, OpenShiftAiService.NAME, - new OpenShiftAiEmbeddingsServiceSettings(MODEL_ID, URL, 1536, similarityMeasure, 512, new RateLimitSettings(10_000), true), + new OpenShiftAiEmbeddingsServiceSettings( + MODEL_VALUE, + URL_VALUE, + DIMENSIONS_VALUE, + similarityMeasure, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(10_000), + true + ), createRandomChunkingSettings(), - new DefaultSecretSettings(new SecureString(API_KEY.toCharArray())) + new DefaultSecretSettings(new SecureString(API_KEY_VALUE.toCharArray())) ); } @@ -267,16 +281,20 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP assertThat(model, instanceOf(OpenShiftAiEmbeddingsModel.class)); var embeddingsModel = (OpenShiftAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is(URL)); + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is(URL_VALUE)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings().asMap(), is(chunkingSettingsMap.asMap())); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); }, e -> fail("parse request should not fail " + e.getMessage())); service.parseRequestConfig( - INFERENCE_ID, + INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, - getRequestConfigMap(getServiceSettingsMap(MODEL_ID, URL), chunkingSettingsMap.asMap(), getSecretSettingsMap(API_KEY)), + getRequestConfigMap( + getServiceSettingsMap(MODEL_VALUE, URL_VALUE), + chunkingSettingsMap.asMap(), + getSecretSettingsMap(API_KEY_VALUE) + ), modelVerificationActionListener ); } @@ -288,15 +306,15 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsN assertThat(model, instanceOf(OpenShiftAiEmbeddingsModel.class)); var embeddingsModel = (OpenShiftAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is(URL)); + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is(URL_VALUE)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), is(ChunkingSettingsBuilder.DEFAULT_SETTINGS)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); }, e -> fail("parse request should not fail " + e.getMessage())); service.parseRequestConfig( - INFERENCE_ID, + INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, - getRequestConfigMap(getServiceSettingsMap(MODEL_ID, URL), getSecretSettingsMap(API_KEY)), + getRequestConfigMap(getServiceSettingsMap(MODEL_VALUE, URL_VALUE), getSecretSettingsMap(API_KEY_VALUE)), modelVerificationActionListener ); } @@ -309,16 +327,16 @@ public void testParseRequestConfig_WithoutModelId_Success() throws IOException { var chatCompletionModel = (OpenShiftAiChatCompletionModel) m; - assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is(URL)); + assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is(URL_VALUE)); assertThat(chatCompletionModel.getServiceSettings().modelId(), is(nullValue())); - assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is(API_KEY)); + assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); }, e -> fail("parse request should not fail " + e.getMessage())); service.parseRequestConfig( - INFERENCE_ID, + INFERENCE_ID_VALUE, TaskType.CHAT_COMPLETION, - getRequestConfigMap(getServiceSettingsMap(null, URL), getSecretSettingsMap(API_KEY)), + getRequestConfigMap(getServiceSettingsMap(null, URL_VALUE), getSecretSettingsMap(API_KEY_VALUE)), modelVerificationListener ); } @@ -338,9 +356,9 @@ public void testParseRequestConfig_WithoutUrl_ThrowsException() throws IOExcepti ); service.parseRequestConfig( - INFERENCE_ID, + INFERENCE_ID_VALUE, TaskType.CHAT_COMPLETION, - getRequestConfigMap(getServiceSettingsMap(MODEL_ID, null), getSecretSettingsMap(API_KEY)), + getRequestConfigMap(getServiceSettingsMap(MODEL_VALUE, null), getSecretSettingsMap(API_KEY_VALUE)), modelVerificationListener ); } @@ -377,13 +395,13 @@ public void testUnifiedCompletionInfer() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - var model = createChatCompletionModel(getUrl(webServer), API_KEY, MODEL_ID); + var model = createChatCompletionModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( model, UnifiedCompletionRequest.of( List.of( - new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), USER_ROLE, null, null) + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), ROLE_VALUE, null, null) ) ), InferenceAction.Request.DEFAULT_TIMEOUT, @@ -419,13 +437,13 @@ public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - var model = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), API_KEY, MODEL_ID); + var model = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( model, UnifiedCompletionRequest.of( List.of( - new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), USER_ROLE, null, null) + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), ROLE_VALUE, null, null) ) ), InferenceAction.Request.DEFAULT_TIMEOUT, @@ -505,13 +523,13 @@ public void testInfer_StreamRequest() throws Exception { private void testStreamError(String expectedResponse) throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - var model = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), API_KEY, MODEL_ID); + var model = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( model, UnifiedCompletionRequest.of( List.of( - new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), USER_ROLE, null, null) + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), ROLE_VALUE, null, null) ) ), InferenceAction.Request.DEFAULT_TIMEOUT, @@ -594,7 +612,7 @@ public void testSupportsStreaming() throws IOException { public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException { try (var service = createService()) { - var secretSettings = getSecretSettingsMap(API_KEY); + var secretSettings = getSecretSettingsMap(API_KEY_VALUE); secretSettings.put("extra_key", "value"); var config = getRequestConfigMap(getEmbeddingsServiceSettingsMap(), secretSettings); @@ -610,12 +628,20 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSe } ); - service.parseRequestConfig(INFERENCE_ID, TaskType.TEXT_EMBEDDING, config, modelVerificationListener); + service.parseRequestConfig(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { - var model = OpenShiftAiEmbeddingsModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID, 1234, false, 1536, null); + var model = OpenShiftAiEmbeddingsModelTests.createModel( + getUrl(webServer), + API_KEY_VALUE, + MODEL_VALUE, + 1234, + false, + DIMENSIONS_VALUE, + null + ); testChunkedInfer(model); } @@ -623,11 +649,11 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { public void testChunkedInfer_ChunkingSettingsSet() throws IOException { var model = OpenShiftAiEmbeddingsModelTests.createModel( getUrl(webServer), - API_KEY, - MODEL_ID, + API_KEY_VALUE, + MODEL_VALUE, 1234, false, - 1536, + DIMENSIONS_VALUE, createRandomChunkingSettings() ); @@ -720,12 +746,12 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), is("Bearer %s".formatted(API_KEY_VALUE))); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap, aMapWithSize(2)); - assertThat(requestMap.get("input"), is(List.of("abc", "def"))); - assertThat(requestMap.get("model"), is(MODEL_ID)); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of("abc", "def"))); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); } } @@ -794,7 +820,7 @@ public void testGetConfiguration() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - var model = OpenShiftAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), API_KEY, MODEL_ID); + var model = OpenShiftAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( model, @@ -841,7 +867,7 @@ private Map getRequestConfigMap(Map serviceSetti } private static Map getEmbeddingsServiceSettingsMap() { - return buildServiceSettingsMap(INFERENCE_ID, URL, SimilarityMeasure.COSINE.toString(), null, null, null); + return buildServiceSettingsMap(INFERENCE_ID_VALUE, URL_VALUE, SimilarityMeasure.COSINE.toString(), null, null, null); } @Override @@ -851,7 +877,7 @@ public InferenceService createInferenceService() { @Override protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { - assertThat(rerankingInferenceService.rerankerWindowSize("any model"), is(2800)); + assertThat(rerankingInferenceService.rerankerWindowSize(MODEL_VALUE), is(2800)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index 43bbdc5a68644..9a8cea882b9de 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -65,33 +65,31 @@ public class OpenShiftAiActionCreatorTests extends ESTestCase { - // Completion field names + // Field names private static final String N_FIELD_NAME = "n"; private static final String STREAM_FIELD_NAME = "stream"; private static final String MESSAGES_FIELD_NAME = "messages"; private static final String ROLE_FIELD_NAME = "role"; private static final String CONTENT_FIELD_NAME = "content"; - // Rerank field names private static final String DOCUMENTS_FIELD_NAME = "documents"; private static final String MODEL_FIELD_NAME = "model"; private static final String QUERY_FIELD_NAME = "query"; private static final String TOP_N_FIELD_NAME = "top_n"; private static final String RETURN_DOCUMENTS_FIELD_NAME = "return_documents"; - // Embeddings field names private static final String INPUT_FIELD_NAME = "input"; // Test values - private static final String API_KEY = "test-api-key"; - private static final String MODEL_TEST_VALUE = "some model"; - private static final String QUERY_TEST_VALUE = "popular name"; - private static final String ROLE_TEST_VALUE = "user"; - private static final String INPUT = "abcd"; - private static final List INPUT_TEST_VALUE = List.of(INPUT); + private static final String API_KEY_VALUE = "test_api_key"; + private static final String MODEL_VALUE = "some_model"; + private static final String QUERY_VALUE = "popular name"; + private static final String ROLE_VALUE = "user"; + private static final String INPUT_ENTRY_VALUE = "abcd"; + private static final List INPUT_VALUE = List.of(INPUT_ENTRY_VALUE); private static final String INPUT_TO_TRUNCATE = "super long input"; - private static final List EMBEDDINGS_TEST_VALUE = List.of(new float[] { 0.0123F, -0.0123F }); - private static final List DOCUMENTS_TEST_VALUE = List.of("Luke"); + private static final List EMBEDDINGS_VALUE = List.of(new float[] { 0.0123F, -0.0123F }); + private static final List DOCUMENTS_VALUE = List.of("Luke"); private static final List RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS = List.of( new RankedDocsResultsTests.RerankExpectation(Map.of("text", "awgawgawgawg", "index", 1, "relevance_score", 0.9921875f)), new RankedDocsResultsTests.RerankExpectation(Map.of("text", "awdawdawda", "index", 0, "relevance_score", 0.4921875f)) @@ -158,20 +156,20 @@ public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); + var model = createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(INPUT_TEST_VALUE, InputTypeTests.randomWithNull()), + new EmbeddingsInput(INPUT_VALUE, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); - assertThat(result.asMap(), is(buildExpectationFloat(EMBEDDINGS_TEST_VALUE))); + assertThat(result.asMap(), is(buildExpectationFloat(EMBEDDINGS_VALUE))); assertThat(webServer.requests(), hasSize(1)); var request = webServer.requests().getFirst(); @@ -180,14 +178,14 @@ public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { var requestMap = entityAsMap(request.getBody()); assertThat(requestMap, aMapWithSize(2)); - assertThat(requestMap.get(INPUT_FIELD_NAME), is(INPUT_TEST_VALUE)); - assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_TEST_VALUE)); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(INPUT_VALUE)); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); } } private static void assertContentTypeAndAuthorization(MockRequest request) { assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters())); - assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY))); + assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY_VALUE))); } public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat() throws IOException { @@ -222,12 +220,12 @@ public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); + var model = createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); var action = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)).create(model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(INPUT_TEST_VALUE, InputTypeTests.randomWithNull()), + new EmbeddingsInput(INPUT_VALUE, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -289,11 +287,11 @@ public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createCompletionModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); + var model = createCompletionModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); var action = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)).create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(INPUT_TEST_VALUE), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(INPUT_VALUE), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); @@ -306,9 +304,9 @@ public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { var requestMap = entityAsMap(request.getBody()); assertThat( requestMap.get(MESSAGES_FIELD_NAME), - is(List.of(Map.of(ROLE_FIELD_NAME, ROLE_TEST_VALUE, CONTENT_FIELD_NAME, INPUT))) + is(List.of(Map.of(ROLE_FIELD_NAME, ROLE_VALUE, CONTENT_FIELD_NAME, INPUT_ENTRY_VALUE))) ); - assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_TEST_VALUE)); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); assertThat(requestMap.get(N_FIELD_NAME), is(1)); assertThat(requestMap.get(STREAM_FIELD_NAME), is(false)); } @@ -352,11 +350,11 @@ public void testCreate_OpenShiftAiChatCompletionModel_FailsFromInvalidResponseFo """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createCompletionModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); + var model = createCompletionModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); var action = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)).create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(INPUT_TEST_VALUE), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(INPUT_VALUE), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows( ElasticsearchStatusException.class, @@ -423,20 +421,20 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC webServer.enqueue(new MockResponse().setResponseCode(413).setBody(responseJsonContentTooLarge)); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); + var model = createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); var actionCreator = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)); var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(INPUT_TEST_VALUE, InputTypeTests.randomWithNull()), + new EmbeddingsInput(INPUT_VALUE, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); - assertThat(result.asMap(), is(buildExpectationFloat(EMBEDDINGS_TEST_VALUE))); + assertThat(result.asMap(), is(buildExpectationFloat(EMBEDDINGS_VALUE))); assertThat(webServer.requests(), hasSize(2)); { assertThat(webServer.requests().getFirst().getUri().getQuery(), is(nullValue())); @@ -444,8 +442,8 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap, aMapWithSize(2)); - assertThat(requestMap.get(INPUT_FIELD_NAME), is(INPUT_TEST_VALUE)); - assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_TEST_VALUE)); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(INPUT_VALUE)); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); } { assertThat(webServer.requests().get(1).getUri().getQuery(), is(nullValue())); @@ -453,8 +451,8 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC var requestMap = entityAsMap(webServer.requests().get(1).getBody()); assertThat(requestMap, aMapWithSize(2)); - assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT.substring(0, 2)))); - assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_TEST_VALUE)); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_ENTRY_VALUE.substring(0, 2)))); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); } } } @@ -507,19 +505,19 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC webServer.enqueue(new MockResponse().setResponseCode(400).setBody(responseJsonContentTooLarge)); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); + var model = createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); var action = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)).create(model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(INPUT_TEST_VALUE, InputTypeTests.randomWithNull()), + new EmbeddingsInput(INPUT_VALUE, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); - assertThat(result.asMap(), is(buildExpectationFloat(EMBEDDINGS_TEST_VALUE))); + assertThat(result.asMap(), is(buildExpectationFloat(EMBEDDINGS_VALUE))); assertThat(webServer.requests(), hasSize(2)); { @@ -527,16 +525,16 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC assertContentTypeAndAuthorization(firstRequest); var firstRequestMap = entityAsMap(firstRequest.getBody()); assertThat(firstRequestMap, aMapWithSize(2)); - assertThat(firstRequestMap.get(INPUT_FIELD_NAME), is(INPUT_TEST_VALUE)); - assertThat(firstRequestMap.get(MODEL_FIELD_NAME), is(MODEL_TEST_VALUE)); + assertThat(firstRequestMap.get(INPUT_FIELD_NAME), is(INPUT_VALUE)); + assertThat(firstRequestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); } { var secondRequest = webServer.requests().get(1); assertContentTypeAndAuthorization(secondRequest); var secondRequestMap = entityAsMap(secondRequest.getBody()); assertThat(secondRequestMap, aMapWithSize(2)); - assertThat(secondRequestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT.substring(0, 2)))); - assertThat(secondRequestMap.get(MODEL_FIELD_NAME), is(MODEL_TEST_VALUE)); + assertThat(secondRequestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_ENTRY_VALUE.substring(0, 2)))); + assertThat(secondRequestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); } } } @@ -574,7 +572,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); // truncated to 1 token = 3 characters - var model = createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE, 1); + var model = createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE, 1); var action = new OpenShiftAiActionCreator(sender, createWithEmptySettings(threadPool)).create(model); var listener = new PlainActionFuture(); @@ -586,7 +584,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); - assertThat(result.asMap(), is(buildExpectationFloat(EMBEDDINGS_TEST_VALUE))); + assertThat(result.asMap(), is(buildExpectationFloat(EMBEDDINGS_VALUE))); assertThat(webServer.requests(), hasSize(1)); var request = webServer.requests().getFirst(); @@ -596,7 +594,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var requestMap = entityAsMap(request.getBody()); assertThat(requestMap, aMapWithSize(2)); assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_TO_TRUNCATE.substring(0, 3)))); - assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_TEST_VALUE)); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); } } @@ -633,7 +631,7 @@ public void testCreate_OpenShiftAiRerankModel_WithTaskSettings() throws IOExcept """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) @@ -642,7 +640,7 @@ public void testCreate_OpenShiftAiRerankModel_WithTaskSettings() throws IOExcept PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new QueryAndDocsInputs(QUERY_TEST_VALUE, DOCUMENTS_TEST_VALUE, null, null, false), + new QueryAndDocsInputs(QUERY_VALUE, DOCUMENTS_VALUE, null, null, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -650,13 +648,13 @@ public void testCreate_OpenShiftAiRerankModel_WithTaskSettings() throws IOExcept var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS))); } - assertRerankActionCreator(DOCUMENTS_TEST_VALUE, 2, true); + assertRerankActionCreator(DOCUMENTS_VALUE, 2, true); } public void testCreate_OpenShiftAiRerankModel_WithOverriddenTaskSettings() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); - List documents = DOCUMENTS_TEST_VALUE; + List documents = DOCUMENTS_VALUE; try (var sender = createSender(senderFactory)) { sender.startSynchronously(); @@ -677,7 +675,7 @@ public void testCreate_OpenShiftAiRerankModel_WithOverriddenTaskSettings() throw """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) @@ -689,7 +687,7 @@ public void testCreate_OpenShiftAiRerankModel_WithOverriddenTaskSettings() throw PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new QueryAndDocsInputs(QUERY_TEST_VALUE, documents, null, null, false), + new QueryAndDocsInputs(QUERY_VALUE, documents, null, null, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -703,7 +701,7 @@ public void testCreate_OpenShiftAiRerankModel_WithOverriddenTaskSettings() throw public void testCreate_OpenShiftAiRerankModel_NoTaskSettings() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); - List documents = DOCUMENTS_TEST_VALUE; + List documents = DOCUMENTS_VALUE; try (var sender = createSender(senderFactory)) { sender.startSynchronously(); @@ -734,7 +732,7 @@ public void testCreate_OpenShiftAiRerankModel_NoTaskSettings() throws IOExceptio """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE, null, null); + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE, null, null); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) @@ -743,7 +741,7 @@ public void testCreate_OpenShiftAiRerankModel_NoTaskSettings() throws IOExceptio PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new QueryAndDocsInputs(QUERY_TEST_VALUE, documents, null, null, false), + new QueryAndDocsInputs(QUERY_VALUE, documents, null, null, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -787,7 +785,7 @@ public void testCreate_OpenShiftAiRerankModel_NoTaskSettings_WithRequestParamete """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE, null, null); + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE, null, null); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) @@ -796,7 +794,7 @@ public void testCreate_OpenShiftAiRerankModel_NoTaskSettings_WithRequestParamete PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new QueryAndDocsInputs(QUERY_TEST_VALUE, DOCUMENTS_TEST_VALUE, true, 2, false), + new QueryAndDocsInputs(QUERY_VALUE, DOCUMENTS_VALUE, true, 2, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -804,13 +802,13 @@ public void testCreate_OpenShiftAiRerankModel_NoTaskSettings_WithRequestParamete var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS))); } - assertRerankActionCreator(DOCUMENTS_TEST_VALUE, 2, true); + assertRerankActionCreator(DOCUMENTS_VALUE, 2, true); } public void testCreate_OpenShiftAiRerankModel_WithTaskSettings_WithRequestParametersPrioritized() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); - List documents = DOCUMENTS_TEST_VALUE; + List documents = DOCUMENTS_VALUE; try (var sender = createSender(senderFactory)) { sender.startSynchronously(); @@ -831,7 +829,7 @@ public void testCreate_OpenShiftAiRerankModel_WithTaskSettings_WithRequestParame """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) @@ -840,7 +838,7 @@ public void testCreate_OpenShiftAiRerankModel_WithTaskSettings_WithRequestParame PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new QueryAndDocsInputs(QUERY_TEST_VALUE, documents, false, 1, false), + new QueryAndDocsInputs(QUERY_VALUE, documents, false, 1, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -884,7 +882,7 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY, MODEL_TEST_VALUE); + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) @@ -893,7 +891,7 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new QueryAndDocsInputs(QUERY_TEST_VALUE, DOCUMENTS_TEST_VALUE, null, null, false), + new QueryAndDocsInputs(QUERY_VALUE, DOCUMENTS_VALUE, null, null, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -918,8 +916,8 @@ private void assertRerankActionCreator( var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); int fieldCount = 3; assertThat(requestMap.get(DOCUMENTS_FIELD_NAME), is(documents)); - assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_TEST_VALUE)); - assertThat(requestMap.get(QUERY_FIELD_NAME), is(QUERY_TEST_VALUE)); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); + assertThat(requestMap.get(QUERY_FIELD_NAME), is(QUERY_VALUE)); if (expectedTopN != null) { assertThat(requestMap.get(TOP_N_FIELD_NAME), is(expectedTopN)); fieldCount++; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java index 8eccbb0b2b080..0b39073ab7544 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java @@ -17,9 +17,10 @@ public class OpenShiftAiChatCompletionModelTests extends ESTestCase { - private static final String MODEL_ID = "model_name"; - private static final String API_KEY = "api_key"; - private static final String URL = "some_url"; + private static final String MODEL_VALUE = "model_name"; + private static final String API_KEY_VALUE = "api_key"; + private static final String URL_VALUE = "http://www.abc.com"; + private static final String ALTERNATE_MODEL_VALUE = "different_model"; public static OpenShiftAiChatCompletionModel createCompletionModel(String url, String apiKey, String modelName) { return createModelWithTaskType(url, apiKey, modelName, TaskType.COMPLETION); @@ -40,28 +41,28 @@ public static OpenShiftAiChatCompletionModel createModelWithTaskType(String url, } public void testOverrideWith_UnifiedCompletionRequest_KeepsSameModelId() { - var model = createCompletionModel(URL, API_KEY, MODEL_ID); - var overriddenModel = OpenShiftAiChatCompletionModel.of(model, MODEL_ID); + var model = createCompletionModel(URL_VALUE, API_KEY_VALUE, MODEL_VALUE); + var overriddenModel = OpenShiftAiChatCompletionModel.of(model, MODEL_VALUE); assertThat(overriddenModel, is(sameInstance(model))); } public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() { - var model = createCompletionModel(URL, API_KEY, MODEL_ID); - var overriddenModel = OpenShiftAiChatCompletionModel.of(model, "different_model"); + var model = createCompletionModel(URL_VALUE, API_KEY_VALUE, MODEL_VALUE); + var overriddenModel = OpenShiftAiChatCompletionModel.of(model, ALTERNATE_MODEL_VALUE); - assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + assertThat(overriddenModel.getServiceSettings().modelId(), is(ALTERNATE_MODEL_VALUE)); } public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { - var model = createCompletionModel(URL, API_KEY, MODEL_ID); + var model = createCompletionModel(URL_VALUE, API_KEY_VALUE, MODEL_VALUE); var overriddenModel = OpenShiftAiChatCompletionModel.of(model, null); assertThat(overriddenModel, is(sameInstance(model))); } public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() { - var model = createCompletionModel(URL, API_KEY, null); + var model = createCompletionModel(URL_VALUE, API_KEY_VALUE, null); var overriddenModel = OpenShiftAiChatCompletionModel.of(model, null); assertThat(overriddenModel, is(sameInstance(model))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java index 972d7297a7a1f..046d9feecf420 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java @@ -31,7 +31,7 @@ import static org.mockito.Mockito.when; public class OpenShiftAiChatCompletionResponseHandlerTests extends ESTestCase { - private static final String URL = "https://api.openshift.ai/v1/chat/completions"; + private static final String URL_VALUE = "http://www.abc.com"; private static final String INFERENCE_ID = "id"; private final OpenShiftAiChatCompletionResponseHandler responseHandler = new OpenShiftAiChatCompletionResponseHandler( "chat completions", @@ -55,7 +55,7 @@ public void testFailNotFound() throws IOException { status [404]. Error message: [{\\"detail\\":\\"Not Found\\"}]", "type" : "openshift_ai_error" } - }""".formatted(URL, INFERENCE_ID)))); + }""".formatted(URL_VALUE, INFERENCE_ID)))); } public void testFailBadRequest() throws IOException { @@ -129,7 +129,7 @@ private static Request mockRequest() throws URISyntaxException { var request = mock(Request.class); when(request.getInferenceEntityId()).thenReturn(INFERENCE_ID); when(request.isStreaming()).thenReturn(true); - when(request.getURI()).thenReturn(new URI(URL)); + when(request.getURI()).thenReturn(new URI(URL_VALUE)); return request; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java index efd78b1bbe047..111181cf59c25 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java @@ -33,8 +33,8 @@ public class OpenShiftAiChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase< OpenShiftAiChatCompletionServiceSettings> { - private static final String MODEL_ID = "some model"; - private static final String CORRECT_URL = "https://www.elastic.co"; + private static final String MODEL_VALUE = "some_model"; + private static final String URL_VALUE = "http://www.abc.com"; private static final int RATE_LIMIT = 2; public void testFromMap_AllFields_Success() { @@ -42,9 +42,9 @@ public void testFromMap_AllFields_Success() { new HashMap<>( Map.of( ServiceFields.MODEL_ID, - MODEL_ID, + MODEL_VALUE, ServiceFields.URL, - CORRECT_URL, + URL_VALUE, RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) ) @@ -54,7 +54,7 @@ public void testFromMap_AllFields_Success() { assertThat( serviceSettings, - is(new OpenShiftAiChatCompletionServiceSettings(MODEL_ID, CORRECT_URL, new RateLimitSettings(RATE_LIMIT))) + is(new OpenShiftAiChatCompletionServiceSettings(MODEL_VALUE, URL_VALUE, new RateLimitSettings(RATE_LIMIT))) ); } @@ -63,7 +63,7 @@ public void testFromMap_MissingModelId_Success() { new HashMap<>( Map.of( ServiceFields.URL, - CORRECT_URL, + URL_VALUE, RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) ) @@ -71,7 +71,7 @@ public void testFromMap_MissingModelId_Success() { ConfigurationParseContext.PERSISTENT ); - assertThat(serviceSettings, is(new OpenShiftAiChatCompletionServiceSettings(null, CORRECT_URL, new RateLimitSettings(RATE_LIMIT)))); + assertThat(serviceSettings, is(new OpenShiftAiChatCompletionServiceSettings(null, URL_VALUE, new RateLimitSettings(RATE_LIMIT)))); } public void testFromMap_MissingUrl_ThrowsException() { @@ -81,7 +81,7 @@ public void testFromMap_MissingUrl_ThrowsException() { new HashMap<>( Map.of( ServiceFields.MODEL_ID, - MODEL_ID, + MODEL_VALUE, RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) ) @@ -98,11 +98,11 @@ public void testFromMap_MissingUrl_ThrowsException() { public void testFromMap_MissingRateLimit_Success() { var serviceSettings = OpenShiftAiChatCompletionServiceSettings.fromMap( - new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)), + new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_VALUE, ServiceFields.URL, URL_VALUE)), ConfigurationParseContext.PERSISTENT ); - assertThat(serviceSettings, is(new OpenShiftAiChatCompletionServiceSettings(MODEL_ID, CORRECT_URL, null))); + assertThat(serviceSettings, is(new OpenShiftAiChatCompletionServiceSettings(MODEL_VALUE, URL_VALUE, null))); } public void testToXContent_WritesAllValues() throws IOException { @@ -110,9 +110,9 @@ public void testToXContent_WritesAllValues() throws IOException { new HashMap<>( Map.of( ServiceFields.MODEL_ID, - MODEL_ID, + MODEL_VALUE, ServiceFields.URL, - CORRECT_URL, + URL_VALUE, RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) ) @@ -131,14 +131,14 @@ public void testToXContent_WritesAllValues() throws IOException { "requests_per_minute": 2 } } - """.formatted(MODEL_ID, CORRECT_URL)); + """.formatted(MODEL_VALUE, URL_VALUE)); assertThat(xContentResult, is(expected)); } public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException { var serviceSettings = OpenShiftAiChatCompletionServiceSettings.fromMap( - new HashMap<>(Map.of(ServiceFields.URL, CORRECT_URL)), + new HashMap<>(Map.of(ServiceFields.URL, URL_VALUE)), ConfigurationParseContext.PERSISTENT ); @@ -152,7 +152,7 @@ public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws "requests_per_minute": 3000 } } - """.formatted(CORRECT_URL)); + """.formatted(URL_VALUE)); assertThat(xContentResult, is(expected)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java index 26299fbbcbd39..231445b5e70d4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java @@ -34,23 +34,23 @@ import static org.hamcrest.Matchers.is; public class OpenShiftAiEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase { - private static final String MODEL_ID = "some model"; - private static final String CORRECT_URL = "https://www.elastic.co"; - private static final String INVALID_URL = "^^^"; - private static final int DIMENSIONS = 384; - private static final SimilarityMeasure SIMILARITY_MEASURE = SimilarityMeasure.DOT_PRODUCT; - private static final int MAX_INPUT_TOKENS = 128; - private static final int RATE_LIMIT = 2; + private static final String MODEL_VALUE = "some_model"; + private static final String CORRECT_URL_VALUE = "http://www.abc.com"; + private static final String INVALID_URL_VALUE = "^^^"; + private static final int DIMENSIONS_VALUE = 384; + private static final SimilarityMeasure SIMILARITY_MEASURE_VALUE = SimilarityMeasure.DOT_PRODUCT; + private static final int MAX_INPUT_TOKENS_VALUE = 128; + private static final int RATE_LIMIT_VALUE = 2; public void testFromMap_AllFields_Success() { var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, - CORRECT_URL, - SIMILARITY_MEASURE.toString(), - DIMENSIONS, - MAX_INPUT_TOKENS, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), true ), ConfigurationParseContext.PERSISTENT @@ -60,12 +60,12 @@ public void testFromMap_AllFields_Success() { serviceSettings, is( new OpenShiftAiEmbeddingsServiceSettings( - MODEL_ID, - CORRECT_URL, - DIMENSIONS, - SIMILARITY_MEASURE, - MAX_INPUT_TOKENS, - new RateLimitSettings(RATE_LIMIT), + MODEL_VALUE, + CORRECT_URL_VALUE, + DIMENSIONS_VALUE, + SIMILARITY_MEASURE_VALUE, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(RATE_LIMIT_VALUE), true ) ) @@ -76,11 +76,11 @@ public void testFromMap_NoModelId_Success() { var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( null, - CORRECT_URL, - SIMILARITY_MEASURE.toString(), - DIMENSIONS, - MAX_INPUT_TOKENS, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), false ), ConfigurationParseContext.PERSISTENT @@ -90,11 +90,11 @@ public void testFromMap_NoModelId_Success() { is( new OpenShiftAiEmbeddingsServiceSettings( null, - CORRECT_URL, - DIMENSIONS, - SIMILARITY_MEASURE, - MAX_INPUT_TOKENS, - new RateLimitSettings(RATE_LIMIT), + CORRECT_URL_VALUE, + DIMENSIONS_VALUE, + SIMILARITY_MEASURE_VALUE, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(RATE_LIMIT_VALUE), false ) ) @@ -106,12 +106,12 @@ public void testFromMap_NoUrl_ThrowsException() { ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, + MODEL_VALUE, null, - SIMILARITY_MEASURE.toString(), - DIMENSIONS, - MAX_INPUT_TOKENS, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), false ), ConfigurationParseContext.PERSISTENT @@ -128,12 +128,12 @@ public void testFromMap_EmptyUrl_ThrowsException() { ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, + MODEL_VALUE, "", - SIMILARITY_MEASURE.toString(), - DIMENSIONS, - MAX_INPUT_TOKENS, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), false ), ConfigurationParseContext.PERSISTENT @@ -150,12 +150,12 @@ public void testFromMap_InvalidUrl_ThrowsException() { ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, - INVALID_URL, - SIMILARITY_MEASURE.toString(), - DIMENSIONS, - MAX_INPUT_TOKENS, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + MODEL_VALUE, + INVALID_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), false ), ConfigurationParseContext.PERSISTENT @@ -163,18 +163,18 @@ public void testFromMap_InvalidUrl_ThrowsException() { ); assertThat(thrownException.getMessage(), containsString(""" Validation Failed: 1: [service_settings] Invalid url [%s] received for field [url]. \ - Error: unable to parse url [%s]. Reason: Illegal character in path;""".formatted(INVALID_URL, INVALID_URL))); + Error: unable to parse url [%s]. Reason: Illegal character in path;""".formatted(INVALID_URL_VALUE, INVALID_URL_VALUE))); } public void testFromMap_NoSimilarity_Success() { var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, - CORRECT_URL, + MODEL_VALUE, + CORRECT_URL_VALUE, null, - DIMENSIONS, - MAX_INPUT_TOKENS, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), false ), ConfigurationParseContext.PERSISTENT @@ -184,12 +184,12 @@ public void testFromMap_NoSimilarity_Success() { serviceSettings, is( new OpenShiftAiEmbeddingsServiceSettings( - MODEL_ID, - CORRECT_URL, - DIMENSIONS, + MODEL_VALUE, + CORRECT_URL_VALUE, + DIMENSIONS_VALUE, null, - MAX_INPUT_TOKENS, - new RateLimitSettings(RATE_LIMIT), + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(RATE_LIMIT_VALUE), false ) ) @@ -201,12 +201,12 @@ public void testFromMap_InvalidSimilarity_ThrowsException() { ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, - CORRECT_URL, + MODEL_VALUE, + CORRECT_URL_VALUE, "by_size", - DIMENSIONS, - MAX_INPUT_TOKENS, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), false ), ConfigurationParseContext.PERSISTENT @@ -220,12 +220,12 @@ public void testFromMap_InvalidSimilarity_ThrowsException() { public void testFromMap_NoDimensions_SetByUserFalse_Persistent_Success() { var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, - CORRECT_URL, - SIMILARITY_MEASURE.toString(), + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), null, - MAX_INPUT_TOKENS, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), false ), ConfigurationParseContext.PERSISTENT @@ -235,12 +235,12 @@ public void testFromMap_NoDimensions_SetByUserFalse_Persistent_Success() { serviceSettings, is( new OpenShiftAiEmbeddingsServiceSettings( - MODEL_ID, - CORRECT_URL, + MODEL_VALUE, + CORRECT_URL_VALUE, null, - SIMILARITY_MEASURE, - MAX_INPUT_TOKENS, - new RateLimitSettings(RATE_LIMIT), + SIMILARITY_MEASURE_VALUE, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(RATE_LIMIT_VALUE), false ) ) @@ -250,12 +250,12 @@ public void testFromMap_NoDimensions_SetByUserFalse_Persistent_Success() { public void testFromMap_Persistent_WithDimensions_SetByUserFalse_Persistent_Success() { var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, - CORRECT_URL, - SIMILARITY_MEASURE.toString(), - DIMENSIONS, - MAX_INPUT_TOKENS, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), false ), ConfigurationParseContext.PERSISTENT @@ -265,12 +265,12 @@ public void testFromMap_Persistent_WithDimensions_SetByUserFalse_Persistent_Succ serviceSettings, is( new OpenShiftAiEmbeddingsServiceSettings( - MODEL_ID, - CORRECT_URL, - DIMENSIONS, - SIMILARITY_MEASURE, - MAX_INPUT_TOKENS, - new RateLimitSettings(RATE_LIMIT), + MODEL_VALUE, + CORRECT_URL_VALUE, + DIMENSIONS_VALUE, + SIMILARITY_MEASURE_VALUE, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(RATE_LIMIT_VALUE), false ) ) @@ -282,12 +282,12 @@ public void testFromMap_WithDimensions_SetByUserNull_Persistent_ThrowsException( ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, - CORRECT_URL, - SIMILARITY_MEASURE.toString(), - DIMENSIONS, - MAX_INPUT_TOKENS, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), null ), ConfigurationParseContext.PERSISTENT @@ -303,12 +303,12 @@ public void testFromMap_WithDimensions_SetByUserNull_Persistent_ThrowsException( public void testFromMap_NoDimensions_SetByUserNull_Request_Success() { var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, - CORRECT_URL, - SIMILARITY_MEASURE.toString(), + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), null, - MAX_INPUT_TOKENS, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), null ), ConfigurationParseContext.REQUEST @@ -318,12 +318,12 @@ public void testFromMap_NoDimensions_SetByUserNull_Request_Success() { serviceSettings, is( new OpenShiftAiEmbeddingsServiceSettings( - MODEL_ID, - CORRECT_URL, + MODEL_VALUE, + CORRECT_URL_VALUE, null, - SIMILARITY_MEASURE, - MAX_INPUT_TOKENS, - new RateLimitSettings(RATE_LIMIT), + SIMILARITY_MEASURE_VALUE, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(RATE_LIMIT_VALUE), false ) ) @@ -333,12 +333,12 @@ public void testFromMap_NoDimensions_SetByUserNull_Request_Success() { public void testFromMap_WithDimensions_SetByUserNull_Request_Success() { var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, - CORRECT_URL, - SIMILARITY_MEASURE.toString(), - DIMENSIONS, - MAX_INPUT_TOKENS, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), null ), ConfigurationParseContext.REQUEST @@ -348,12 +348,12 @@ public void testFromMap_WithDimensions_SetByUserNull_Request_Success() { serviceSettings, is( new OpenShiftAiEmbeddingsServiceSettings( - MODEL_ID, - CORRECT_URL, - DIMENSIONS, - SIMILARITY_MEASURE, - MAX_INPUT_TOKENS, - new RateLimitSettings(RATE_LIMIT), + MODEL_VALUE, + CORRECT_URL_VALUE, + DIMENSIONS_VALUE, + SIMILARITY_MEASURE_VALUE, + MAX_INPUT_TOKENS_VALUE, + new RateLimitSettings(RATE_LIMIT_VALUE), true ) ) @@ -365,12 +365,12 @@ public void testFromMap_WithDimensions_SetByUserTrue_Request_ThrowsException() { ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, - CORRECT_URL, - SIMILARITY_MEASURE.toString(), - DIMENSIONS, - MAX_INPUT_TOKENS, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), true ), ConfigurationParseContext.REQUEST @@ -388,12 +388,12 @@ public void testFromMap_ZeroDimensions_ThrowsException() { ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, - CORRECT_URL, - SIMILARITY_MEASURE.toString(), + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), 0, - MAX_INPUT_TOKENS, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), false ), ConfigurationParseContext.PERSISTENT @@ -410,12 +410,12 @@ public void testFromMap_NegativeDimensions_ThrowsException() { ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, - CORRECT_URL, - SIMILARITY_MEASURE.toString(), + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), -10, - MAX_INPUT_TOKENS, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + MAX_INPUT_TOKENS_VALUE, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), false ), ConfigurationParseContext.PERSISTENT @@ -430,12 +430,12 @@ public void testFromMap_NegativeDimensions_ThrowsException() { public void testFromMap_NoInputTokens_Success() { var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, - CORRECT_URL, - SIMILARITY_MEASURE.toString(), - DIMENSIONS, + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, null, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), false ), ConfigurationParseContext.PERSISTENT @@ -445,12 +445,12 @@ public void testFromMap_NoInputTokens_Success() { serviceSettings, is( new OpenShiftAiEmbeddingsServiceSettings( - MODEL_ID, - CORRECT_URL, - DIMENSIONS, - SIMILARITY_MEASURE, + MODEL_VALUE, + CORRECT_URL_VALUE, + DIMENSIONS_VALUE, + SIMILARITY_MEASURE_VALUE, null, - new RateLimitSettings(RATE_LIMIT), + new RateLimitSettings(RATE_LIMIT_VALUE), false ) ) @@ -462,12 +462,12 @@ public void testFromMap_ZeroInputTokens_ThrowsException() { ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, - CORRECT_URL, - SIMILARITY_MEASURE.toString(), - DIMENSIONS, + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, 0, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), false ), ConfigurationParseContext.PERSISTENT @@ -484,12 +484,12 @@ public void testFromMap_NegativeInputTokens_ThrowsException() { ValidationException.class, () -> OpenShiftAiEmbeddingsServiceSettings.fromMap( buildServiceSettingsMap( - MODEL_ID, - CORRECT_URL, - SIMILARITY_MEASURE.toString(), - DIMENSIONS, + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, -10, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)), + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT_VALUE)), false ), ConfigurationParseContext.PERSISTENT @@ -503,7 +503,15 @@ public void testFromMap_NegativeInputTokens_ThrowsException() { public void testFromMap_NoRateLimit_Success() { var serviceSettings = OpenShiftAiEmbeddingsServiceSettings.fromMap( - buildServiceSettingsMap(MODEL_ID, CORRECT_URL, SIMILARITY_MEASURE.toString(), DIMENSIONS, MAX_INPUT_TOKENS, null, false), + buildServiceSettingsMap( + MODEL_VALUE, + CORRECT_URL_VALUE, + SIMILARITY_MEASURE_VALUE.toString(), + DIMENSIONS_VALUE, + MAX_INPUT_TOKENS_VALUE, + null, + false + ), ConfigurationParseContext.PERSISTENT ); @@ -511,11 +519,11 @@ public void testFromMap_NoRateLimit_Success() { serviceSettings, is( new OpenShiftAiEmbeddingsServiceSettings( - MODEL_ID, - CORRECT_URL, - DIMENSIONS, - SIMILARITY_MEASURE, - MAX_INPUT_TOKENS, + MODEL_VALUE, + CORRECT_URL_VALUE, + DIMENSIONS_VALUE, + SIMILARITY_MEASURE_VALUE, + MAX_INPUT_TOKENS_VALUE, new RateLimitSettings(3000), false ) @@ -525,11 +533,11 @@ public void testFromMap_NoRateLimit_Success() { public void testToXContent_WritesAllValues() throws IOException { var entity = new OpenShiftAiEmbeddingsServiceSettings( - MODEL_ID, - CORRECT_URL, - DIMENSIONS, - SIMILARITY_MEASURE, - MAX_INPUT_TOKENS, + MODEL_VALUE, + CORRECT_URL_VALUE, + DIMENSIONS_VALUE, + SIMILARITY_MEASURE_VALUE, + MAX_INPUT_TOKENS_VALUE, new RateLimitSettings(3), false ); @@ -550,7 +558,7 @@ public void testToXContent_WritesAllValues() throws IOException { "max_input_tokens": 128, "dimensions_set_by_user": false } - """.formatted(MODEL_ID, CORRECT_URL)))); + """.formatted(MODEL_VALUE, CORRECT_URL_VALUE)))); } @Override @@ -641,7 +649,7 @@ public static HashMap buildServiceSettingsMap( result.put(RateLimitSettings.FIELD_NAME, rateLimitSettings); } if (dimensionsSetByUser != null) { - result.put("dimensions_set_by_user", dimensionsSetByUser); + result.put(ServiceFields.DIMENSIONS_SET_BY_USER, dimensionsSetByUser); } return result; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java index 000d9876b05c7..264c3bababb7c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestEntityTests.java @@ -22,7 +22,7 @@ import static org.hamcrest.Matchers.is; public class OpenShiftAiChatCompletionRequestEntityTests extends ESTestCase { - private static final String USER_ROLE_VALUE = "user"; + private static final String ROLE_VALUE = "user"; public void testSerializationWithModelIdStreaming() throws IOException { testSerialization("modelId", true, """ @@ -85,7 +85,7 @@ public void testSerializationWithoutModelIdNonStreaming() throws IOException { private static void testSerialization(String modelId, boolean isStreaming, String expectedJson) throws IOException { var message = new UnifiedCompletionRequest.Message( new UnifiedCompletionRequest.ContentString("Hello, world!"), - USER_ROLE_VALUE, + ROLE_VALUE, null, null ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java index 76e826f873d8d..7b408e8c132f9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java @@ -26,14 +26,23 @@ public class OpenShiftAiChatCompletionRequestTests extends ESTestCase { - private static final String URL_VALUE = "some_url"; - private static final String MODEL_VALUE = "some model"; - private static final String USER_ROLE = "user"; - private static final String API_KEY = "secret"; + // Completion field names + private static final String N_FIELD_NAME = "n"; + private static final String STREAM_FIELD_NAME = "stream"; + private static final String MESSAGES_FIELD_NAME = "messages"; + private static final String ROLE_FIELD_NAME = "role"; + private static final String CONTENT_FIELD_NAME = "content"; + private static final String MODEL_FIELD_NAME = "model"; + + // Test values + private static final String URL_VALUE = "http://www.abc.com"; + private static final String MODEL_VALUE = "some_model"; + private static final String ROLE_VALUE = "user"; + private static final String API_KEY_VALUE = "test_api_key"; public void testCreateRequest_WithStreaming() throws IOException { String input = randomAlphaOfLength(15); - var request = createRequest(MODEL_VALUE, URL_VALUE, API_KEY, input, true); + var request = createRequest(MODEL_VALUE, URL_VALUE, API_KEY_VALUE, input, true); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -41,29 +50,28 @@ public void testCreateRequest_WithStreaming() throws IOException { var requestMap = entityAsMap(httpPost.getEntity().getContent()); assertThat(request.getURI().toString(), is(URL_VALUE)); - assertThat(requestMap.get("stream"), is(true)); - assertThat(requestMap.get("model"), is(MODEL_VALUE)); - assertThat(requestMap.get("n"), is(1)); - assertThat(requestMap.get("stream_options"), is(nullValue())); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", USER_ROLE, "content", input)))); + assertThat(requestMap.get(STREAM_FIELD_NAME), is(true)); + assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); + assertThat(requestMap.get(N_FIELD_NAME), is(1)); + assertThat(requestMap.get(MESSAGES_FIELD_NAME), is(List.of(Map.of(ROLE_FIELD_NAME, ROLE_VALUE, CONTENT_FIELD_NAME, input)))); assertThat(requestMap, aMapWithSize(4)); - assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY))); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY_VALUE))); } public void testTruncate_DoesNotReduceInputTextSize() { String input = randomAlphaOfLength(5); - var request = createRequest(MODEL_VALUE, URL_VALUE, API_KEY, input, true); + var request = createRequest(MODEL_VALUE, URL_VALUE, API_KEY_VALUE, input, true); assertThat(request.truncate(), is(sameInstance(request))); } public void testTruncationInfo_ReturnsNull() { - var request = createRequest(MODEL_VALUE, URL_VALUE, API_KEY, randomAlphaOfLength(5), true); + var request = createRequest(MODEL_VALUE, URL_VALUE, API_KEY_VALUE, randomAlphaOfLength(5), true); assertThat(request.getTruncationInfo(), is(nullValue())); } public static OpenShiftAiChatCompletionRequest createRequest(String modelId, String url, String apiKey, String input, boolean stream) { var chatCompletionModel = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(url, apiKey, modelId); - return new OpenShiftAiChatCompletionRequest(new UnifiedChatInput(List.of(input), USER_ROLE, stream), chatCompletionModel); + return new OpenShiftAiChatCompletionRequest(new UnifiedChatInput(List.of(input), ROLE_VALUE, stream), chatCompletionModel); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntityTests.java index afa06a834fa6b..506dc072937c8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestEntityTests.java @@ -21,21 +21,22 @@ public class OpenShiftAiEmbeddingsRequestEntityTests extends ESTestCase { - private static final String MODEL = "some model"; - private static final String INPUT = "some input"; + private static final String MODEL_VALUE = "some_model"; + private static final List INPUT_VALUE = List.of("some input"); + private static final int DIMENSIONS_VALUE = 100; public void testXContent_DoesNotWriteDimensionsWhenNullAndSetByUserIsFalse() throws IOException { - var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of(INPUT), MODEL, null, false); + var entity = new OpenShiftAiEmbeddingsRequestEntity(INPUT_VALUE, MODEL_VALUE, null, false); testXContent_DoesNotWriteDimensions(entity); } public void testXContent_DoesNotWriteDimensionsWhenNotSetByUser() throws IOException { - var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of(INPUT), MODEL, 100, false); + var entity = new OpenShiftAiEmbeddingsRequestEntity(INPUT_VALUE, MODEL_VALUE, DIMENSIONS_VALUE, false); testXContent_DoesNotWriteDimensions(entity); } public void testXContent_DoesNotWriteDimensionsWhenNull_EvenIfSetByUserIsTrue() throws IOException { - var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of(INPUT), MODEL, null, true); + var entity = new OpenShiftAiEmbeddingsRequestEntity(INPUT_VALUE, MODEL_VALUE, null, true); testXContent_DoesNotWriteDimensions(entity); } @@ -47,13 +48,13 @@ private static void testXContent_DoesNotWriteDimensions(OpenShiftAiEmbeddingsReq assertThat(xContentResult, is(XContentHelper.stripWhitespace(""" { "input": ["some input"], - "model": "some model" + "model": "some_model" } """))); } public void testXContent_DoesNotWriteModelWhenItIsNull() throws IOException { - var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of(INPUT), null, null, false); + var entity = new OpenShiftAiEmbeddingsRequestEntity(INPUT_VALUE, null, null, false); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -67,7 +68,7 @@ public void testXContent_DoesNotWriteModelWhenItIsNull() throws IOException { } public void testXContent_WritesDimensionsWhenNonNull_AndSetByUserIsTrue() throws IOException { - var entity = new OpenShiftAiEmbeddingsRequestEntity(List.of(INPUT), MODEL, 100, true); + var entity = new OpenShiftAiEmbeddingsRequestEntity(INPUT_VALUE, MODEL_VALUE, DIMENSIONS_VALUE, true); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -76,7 +77,7 @@ public void testXContent_WritesDimensionsWhenNonNull_AndSetByUserIsTrue() throws assertThat(xContentResult, is(XContentHelper.stripWhitespace(""" { "input": ["some input"], - "model": "some model", + "model": "some_model", "dimensions": 100 } """))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java index 704f8222ae014..efa411cc9cdfb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsModelTests; import java.io.IOException; @@ -29,12 +30,12 @@ public class OpenShiftAiEmbeddingsRequestTests extends ESTestCase { private static final String INPUT_FIELD_NAME = "input"; private static final String MODEL_FIELD_NAME = "model"; - private static final String DIMENSIONS_FIELD_NAME = "dimensions"; - private static final String MODEL_VALUE = "some model"; + private static final String MODEL_VALUE = "some_model"; private static final String INPUT_VALUE = "ABCD"; - private static final String URL_VALUE = "some_url"; - private static final String API_KEY = "some api key"; + private static final String URL_VALUE = "http://www.abc.com"; + private static final String API_KEY_VALUE = "test_api_key"; + private static final int DIMENSIONS_VALUE = 384; public void testCreateRequest_NoDimensions_DimensionsSetByUserFalse_Success() throws IOException { testCreateRequest_Success(null, false, null); @@ -45,11 +46,11 @@ public void testCreateRequest_NoDimensions_DimensionsSetByUserTrue_Success() thr } public void testCreateRequest_WithDimensions_DimensionsSetByUserFalse_Success() throws IOException { - testCreateRequest_Success(384, false, null); + testCreateRequest_Success(DIMENSIONS_VALUE, false, null); } public void testCreateRequest_WithDimensions_DimensionsSetByUserTrue_Success() throws IOException { - testCreateRequest_Success(384, true, 384); + testCreateRequest_Success(DIMENSIONS_VALUE, true, DIMENSIONS_VALUE); } private void testCreateRequest_Success(Integer dimensions, boolean dimensionsSetByUser, Integer expectedDimensions) throws IOException { @@ -60,8 +61,8 @@ private void testCreateRequest_Success(Integer dimensions, boolean dimensionsSet var requestMap = entityAsMap(httpPost.getEntity().getContent()); assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_VALUE))); assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); - assertThat(requestMap.get(DIMENSIONS_FIELD_NAME), is(expectedDimensions)); - assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY))); + assertThat(requestMap.get(ServiceFields.DIMENSIONS), is(expectedDimensions)); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY_VALUE))); } public void testCreateRequest_NoModel_Success() throws IOException { @@ -72,8 +73,8 @@ public void testCreateRequest_NoModel_Success() throws IOException { var requestMap = entityAsMap(httpPost.getEntity().getContent()); assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_VALUE))); assertThat(requestMap.get(MODEL_FIELD_NAME), is(nullValue())); - assertThat(requestMap.get(DIMENSIONS_FIELD_NAME), is(nullValue())); - assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY))); + assertThat(requestMap.get(ServiceFields.DIMENSIONS), is(nullValue())); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY_VALUE))); } @@ -115,7 +116,7 @@ private static OpenShiftAiEmbeddingsRequest createRequest(Integer dimensions, Bo private static OpenShiftAiEmbeddingsRequest createRequest(Integer dimensions, Boolean dimensionsSetByUser, String modelId) { var embeddingsModel = OpenShiftAiEmbeddingsModelTests.createModel( URL_VALUE, - API_KEY, + API_KEY_VALUE, modelId, dimensions, dimensionsSetByUser, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java index f3deea999e5b7..920ceba6696da 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAIRerankRequestEntityTests.java @@ -21,21 +21,21 @@ import static org.hamcrest.Matchers.is; public class OpenShiftAIRerankRequestEntityTests extends ESTestCase { - private static final String DOCUMENT = "some document"; - private static final String QUERY = "some query"; - private static final String MODEL = "some model"; - private static final Integer TOP_N = 8; - private static final Boolean RETURN_DOCUMENTS = true; + private static final List DOCUMENT_VALUE = List.of("some document"); + private static final String QUERY_VALUE = "some query"; + private static final String MODEL_VALUE = "some_model"; + private static final Integer TOP_N_VALUE = 8; + private static final Boolean RETURN_DOCUMENTS_VALUE = true; public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { - var entity = new OpenShiftAIRerankRequestEntity(MODEL, QUERY, List.of(DOCUMENT), RETURN_DOCUMENTS, TOP_N); + var entity = new OpenShiftAIRerankRequestEntity(MODEL_VALUE, QUERY_VALUE, DOCUMENT_VALUE, RETURN_DOCUMENTS_VALUE, TOP_N_VALUE); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); String result = Strings.toString(builder); String expected = """ { - "model": "some model", + "model": "some_model", "query": "some query", "documents": ["some document"], "top_n": 8, @@ -46,7 +46,7 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException } public void testXContent_WritesMinimalFields() throws IOException { - var entity = new OpenShiftAIRerankRequestEntity(null, QUERY, List.of(DOCUMENT), null, null); + var entity = new OpenShiftAIRerankRequestEntity(null, QUERY_VALUE, DOCUMENT_VALUE, null, null); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java index 2922f8e6bc2cd..ba4d4c4aa0b34 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java @@ -30,56 +30,80 @@ public class OpenShiftAiRerankRequestTests extends ESTestCase { private static final String DOCUMENTS_FIELD_NAME = "documents"; private static final String QUERY_FIELD_NAME = "query"; + private static final String URL_VALUE = "http://www.abc.com"; private static final String DOCUMENT_VALUE = "some document"; private static final String QUERY_VALUE = "some query"; - private static final String MODEL_VALUE = "some model"; + private static final String MODEL_VALUE = "some_model"; private static final Integer TOP_N_VALUE = 8; private static final Boolean RETURN_DOCUMENTS_VALUE = false; - private static final String API_KEY_VALUE = "some api key"; + private static final String API_KEY_VALUE = "test_api_key"; public void testCreateRequest_WithMinimalFieldsSet() throws IOException { - testCreateRequest(null, null, null, createRequest(null, null, null)); + testCreateRequest(createRequest(null, null, null, null, null), null, null, null); } - public void testCreateRequest_WithTopN() throws IOException { - testCreateRequest(TOP_N_VALUE, null, null, createRequest(TOP_N_VALUE, null, null)); + public void testCreateRequest_TaskSettingsWithTopN() throws IOException { + testCreateRequest(createRequest(TOP_N_VALUE, null, null, null, null), TOP_N_VALUE, null, null); } - public void testCreateRequest_WithReturnDocuments() throws IOException { - testCreateRequest(null, RETURN_DOCUMENTS_VALUE, null, createRequest(null, RETURN_DOCUMENTS_VALUE, null)); + public void testCreateRequest_TaskSettingsWithReturnDocuments() throws IOException { + testCreateRequest(createRequest(null, RETURN_DOCUMENTS_VALUE, null, null, null), null, RETURN_DOCUMENTS_VALUE, null); } - public void testCreateRequest_WithModelId() throws IOException { - testCreateRequest(null, null, MODEL_VALUE, createRequest(null, null, MODEL_VALUE)); + public void testCreateRequest_TaskSettingsWithModelId() throws IOException { + testCreateRequest(createRequest(null, null, MODEL_VALUE, null, null), null, null, MODEL_VALUE); } - public void testCreateRequest_AllFields() throws IOException { + public void testCreateRequest_TaskSettingsWithAllFields() throws IOException { testCreateRequest( + createRequest(TOP_N_VALUE, RETURN_DOCUMENTS_VALUE, MODEL_VALUE, null, null), TOP_N_VALUE, RETURN_DOCUMENTS_VALUE, - MODEL_VALUE, - createRequest(TOP_N_VALUE, RETURN_DOCUMENTS_VALUE, MODEL_VALUE) + MODEL_VALUE ); } - public void testCreateRequest_AllFields_OverridesTaskSettings() throws IOException { + public void testCreateRequest_RequestSettingsOverrideTaskSettings() throws IOException { testCreateRequest( + createRequest(1, true, MODEL_VALUE, TOP_N_VALUE, RETURN_DOCUMENTS_VALUE), TOP_N_VALUE, RETURN_DOCUMENTS_VALUE, - MODEL_VALUE, - createRequestWithDifferentTaskSettings(TOP_N_VALUE, RETURN_DOCUMENTS_VALUE) + MODEL_VALUE ); } - public void testCreateRequest_AllFields_KeepsTaskSettings() throws IOException { - testCreateRequest(1, true, MODEL_VALUE, createRequestWithDifferentTaskSettings(null, null)); + public void testCreateRequest_RequestSettingsOverrideNullTaskSettings() throws IOException { + testCreateRequest( + createRequest(null, null, MODEL_VALUE, TOP_N_VALUE, RETURN_DOCUMENTS_VALUE), + TOP_N_VALUE, + RETURN_DOCUMENTS_VALUE, + MODEL_VALUE + ); + } + + public void testCreateRequest_ReturnDocumentsFromTaskSettings_TopNFromRequest() throws IOException { + testCreateRequest( + createRequest(null, RETURN_DOCUMENTS_VALUE, MODEL_VALUE, TOP_N_VALUE, null), + TOP_N_VALUE, + RETURN_DOCUMENTS_VALUE, + MODEL_VALUE + ); + } + + public void testCreateRequest_TopNFromTaskSettings_ReturnDocumentsFromRequest() throws IOException { + testCreateRequest( + createRequest(TOP_N_VALUE, null, MODEL_VALUE, null, RETURN_DOCUMENTS_VALUE), + TOP_N_VALUE, + RETURN_DOCUMENTS_VALUE, + MODEL_VALUE + ); } private void testCreateRequest( + OpenShiftAiRerankRequest request, Integer expectedTopN, Boolean expectedReturnDocuments, - String expectedModelId, - OpenShiftAiRerankRequest request + String expectedModelId ) throws IOException { var httpRequest = request.createHttpRequest(); @@ -110,19 +134,19 @@ private void testCreateRequest( } private static OpenShiftAiRerankRequest createRequest( - @Nullable Integer topN, - @Nullable Boolean returnDocuments, - @Nullable String modelId + @Nullable Integer taskSettingsTopN, + @Nullable Boolean taskSettingsReturnDocuments, + @Nullable String modelId, + @Nullable Integer requestTopN, + @Nullable Boolean requestReturnDocuments ) { - var rerankModel = OpenShiftAiRerankModelTests.createModel(randomAlphaOfLength(10), API_KEY_VALUE, modelId, topN, returnDocuments); - return new OpenShiftAiRerankRequest(QUERY_VALUE, List.of(DOCUMENT_VALUE), returnDocuments, topN, rerankModel); - } - - private static OpenShiftAiRerankRequest createRequestWithDifferentTaskSettings( - @Nullable Integer topN, - @Nullable Boolean returnDocuments - ) { - var rerankModel = OpenShiftAiRerankModelTests.createModel(randomAlphaOfLength(10), API_KEY_VALUE, MODEL_VALUE, 1, true); - return new OpenShiftAiRerankRequest(QUERY_VALUE, List.of(DOCUMENT_VALUE), returnDocuments, topN, rerankModel); + var rerankModel = OpenShiftAiRerankModelTests.createModel( + URL_VALUE, + API_KEY_VALUE, + modelId, + taskSettingsTopN, + taskSettingsReturnDocuments + ); + return new OpenShiftAiRerankRequest(QUERY_VALUE, List.of(DOCUMENT_VALUE), requestReturnDocuments, requestTopN, rerankModel); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java index c96361d336557..0d210428a46a6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java @@ -23,6 +23,12 @@ public class OpenShiftAiRerankModelTests extends ESTestCase { + private static final String URL_VALUE = "http://www.abc.com"; + private static final String API_KEY_VALUE = "test_api_key"; + private static final String MODEL_VALUE = "some_model"; + private static final int TOP_N_VALUE = 4; + private static final boolean RETURN_DOCUMENTS_VALUE = false; + public static OpenShiftAiRerankModel createModel(String url, String apiKey, @Nullable String modelId) { return createModel(url, apiKey, modelId, 2, true); } @@ -53,29 +59,29 @@ public void testOverrideWith_EmptyParams_KeepsSameModel() { } private static void testOverrideWith_KeepsSameModel(Map taskSettings) { - var model = createModel("url", "api_key", "model_name", 2, true); + var model = createModel(URL_VALUE, API_KEY_VALUE, MODEL_VALUE, 2, true); var overriddenModel = OpenShiftAiRerankModel.of(model, taskSettings); assertThat(overriddenModel, is(sameInstance(model))); } public void testOverrideWith_DifferentParams_OverridesAllTaskSettings() { - testOverrideWith_DifferentParams(buildTaskSettingsMap(4, false), 4, false); + testOverrideWith_DifferentParams(buildTaskSettingsMap(TOP_N_VALUE, RETURN_DOCUMENTS_VALUE), TOP_N_VALUE, RETURN_DOCUMENTS_VALUE); } public void testOverrideWith_DifferentParams_OverridesOnlyReturnDocuments() { - testOverrideWith_DifferentParams(buildTaskSettingsMap(null, false), 2, false); + testOverrideWith_DifferentParams(buildTaskSettingsMap(null, RETURN_DOCUMENTS_VALUE), 2, RETURN_DOCUMENTS_VALUE); } public void testOverrideWith_DifferentParams_OverridesOnlyTopN() { - testOverrideWith_DifferentParams(buildTaskSettingsMap(4, null), 4, true); + testOverrideWith_DifferentParams(buildTaskSettingsMap(TOP_N_VALUE, null), TOP_N_VALUE, true); } public void testOverrideWith_DifferentParams_OverridesNullValues() { - var model = createModel("url", "api_key", "model_name", null, null); - var overriddenModel = OpenShiftAiRerankModel.of(model, buildTaskSettingsMap(4, false)); + var model = createModel(URL_VALUE, API_KEY_VALUE, MODEL_VALUE, null, null); + var overriddenModel = OpenShiftAiRerankModel.of(model, buildTaskSettingsMap(TOP_N_VALUE, RETURN_DOCUMENTS_VALUE)); - assertThat(overriddenModel.getTaskSettings().getTopN(), is(4)); - assertThat(overriddenModel.getTaskSettings().getReturnDocuments(), is(false)); + assertThat(overriddenModel.getTaskSettings().getTopN(), is(TOP_N_VALUE)); + assertThat(overriddenModel.getTaskSettings().getReturnDocuments(), is(RETURN_DOCUMENTS_VALUE)); } private static void testOverrideWith_DifferentParams( @@ -83,7 +89,7 @@ private static void testOverrideWith_DifferentParams( int expectedTopN, boolean expectedReturnDocuments ) { - var model = createModel("url", "api_key", "model_name", 2, true); + var model = createModel(URL_VALUE, API_KEY_VALUE, MODEL_VALUE, 2, true); var overriddenModel = OpenShiftAiRerankModel.of(model, taskSettings); assertThat(overriddenModel.getTaskSettings().getTopN(), is(expectedTopN)); From bb54f4b1b7cb0e69bdf4fa02657b28d3a7e9efd9 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Mon, 10 Nov 2025 23:43:31 +0200 Subject: [PATCH 64/70] Update OpenShift AI embeddings request tests to pass null for dimensions --- .../request/embeddings/OpenShiftAiEmbeddingsRequestTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java index efa411cc9cdfb..e03431d97b836 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java @@ -118,7 +118,7 @@ private static OpenShiftAiEmbeddingsRequest createRequest(Integer dimensions, Bo URL_VALUE, API_KEY_VALUE, modelId, - dimensions, + null, dimensionsSetByUser, dimensions, null From 5fd79a7ec0419523fdc7b0048f26852b4d769c00 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Tue, 11 Nov 2025 00:00:18 +0200 Subject: [PATCH 65/70] Refactor OpenShift AI test constants for improved clarity and consistency --- .../openshiftai/OpenShiftAiServiceTests.java | 14 ++++++++------ .../OpenShiftAiChatCompletionModelTests.java | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index ad0db05493aa5..12e6e453243ab 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -103,16 +103,18 @@ import static org.mockito.Mockito.mock; public class OpenShiftAiServiceTests extends AbstractInferenceServiceTests { + private static final String API_KEY_FIELD_NAME = "api_key"; + private static final String INPUT_FIELD_NAME = "some input"; + private static final String MODEL_FIELD_NAME = "some model"; private static final String URL_VALUE = "http://www.abc.com"; private static final String MODEL_VALUE = "some_model"; private static final String ROLE_VALUE = "user"; private static final String API_KEY_VALUE = "test_api_key"; private static final String INFERENCE_ID_VALUE = "id"; - private static final String API_KEY_FIELD_NAME = "api_key"; private static final int DIMENSIONS_VALUE = 1536; private static final int MAX_INPUT_TOKENS_VALUE = 512; - private static final String INPUT_FIELD_NAME = "input"; - private static final String MODEL_FIELD_NAME = "model"; + private static final String FIRST_PART_OF_INPUT_VALUE = "abc"; + public static final String SECOND_PART_OF_INPUT_VALUE = "def"; private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -703,7 +705,7 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("abc"), new ChunkInferenceInput("def")), + List.of(new ChunkInferenceInput(FIRST_PART_OF_INPUT_VALUE), new ChunkInferenceInput(SECOND_PART_OF_INPUT_VALUE)), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -750,7 +752,7 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap, aMapWithSize(2)); - assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of("abc", "def"))); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(FIRST_PART_OF_INPUT_VALUE, SECOND_PART_OF_INPUT_VALUE))); assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); } } @@ -827,7 +829,7 @@ private InferenceEventsAssertion streamCompletion() throws Exception { null, null, null, - List.of("abc"), + List.of(FIRST_PART_OF_INPUT_VALUE), true, new HashMap<>(), InputType.INGEST, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java index 0b39073ab7544..927d01cc2f6c0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java @@ -18,7 +18,7 @@ public class OpenShiftAiChatCompletionModelTests extends ESTestCase { private static final String MODEL_VALUE = "model_name"; - private static final String API_KEY_VALUE = "api_key"; + private static final String API_KEY_VALUE = "test_api_key"; private static final String URL_VALUE = "http://www.abc.com"; private static final String ALTERNATE_MODEL_VALUE = "different_model"; From 6aa70a6ed27acd620b488266e01cdcafe0dc40e8 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Tue, 11 Nov 2025 00:38:28 +0200 Subject: [PATCH 66/70] Refactor OpenShift AI test field names for clarity and consistency --- .../services/openshiftai/OpenShiftAiServiceTests.java | 6 +++--- .../embeddings/OpenShiftAiEmbeddingsRequestTests.java | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index 12e6e453243ab..0fed0d3792e05 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -104,8 +104,8 @@ public class OpenShiftAiServiceTests extends AbstractInferenceServiceTests { private static final String API_KEY_FIELD_NAME = "api_key"; - private static final String INPUT_FIELD_NAME = "some input"; - private static final String MODEL_FIELD_NAME = "some model"; + private static final String INPUT_FIELD_NAME = "input"; + private static final String MODEL_FIELD_NAME = "model"; private static final String URL_VALUE = "http://www.abc.com"; private static final String MODEL_VALUE = "some_model"; private static final String ROLE_VALUE = "user"; @@ -114,7 +114,7 @@ public class OpenShiftAiServiceTests extends AbstractInferenceServiceTests { private static final int DIMENSIONS_VALUE = 1536; private static final int MAX_INPUT_TOKENS_VALUE = 512; private static final String FIRST_PART_OF_INPUT_VALUE = "abc"; - public static final String SECOND_PART_OF_INPUT_VALUE = "def"; + private static final String SECOND_PART_OF_INPUT_VALUE = "def"; private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java index e03431d97b836..6e3a5d14a3490 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java @@ -88,7 +88,7 @@ public void testTruncate_ReducesInputTextSizeByHalf() throws IOException { var httpPost = (HttpPost) httpRequest.httpRequestBase(); var requestMap = entityAsMap(httpPost.getEntity().getContent()); assertThat(requestMap, aMapWithSize(2)); - assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_VALUE.substring(0, 2)))); + assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_VALUE.substring(0, INPUT_VALUE.length() / 2)))); assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); } From e5c58b4fb7f6eab8a486594837269826b8bb5646 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Tue, 11 Nov 2025 00:59:38 +0200 Subject: [PATCH 67/70] Add validation tests for invalid and empty URL in OpenShift AI settings --- ...tAiChatCompletionServiceSettingsTests.java | 60 ++++++++++++++----- 1 file changed, 45 insertions(+), 15 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java index 111181cf59c25..2e90e750e2460 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java @@ -35,6 +35,7 @@ public class OpenShiftAiChatCompletionServiceSettingsTests extends AbstractBWCWi private static final String MODEL_VALUE = "some_model"; private static final String URL_VALUE = "http://www.abc.com"; + private static final String INVALID_URL_VALUE = "^^^"; private static final int RATE_LIMIT = 2; public void testFromMap_AllFields_Success() { @@ -75,25 +76,54 @@ public void testFromMap_MissingModelId_Success() { } public void testFromMap_MissingUrl_ThrowsException() { + testFromMap_InvalidUrl( + Map.of( + ServiceFields.MODEL_ID, + MODEL_VALUE, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + "Validation Failed: 1: [service_settings] does not contain the required setting [url];" + ); + } + + public void testFromMap_InvalidUrl_ThrowsException() { + testFromMap_InvalidUrl( + Map.of( + ServiceFields.URL, + INVALID_URL_VALUE, + ServiceFields.MODEL_ID, + MODEL_VALUE, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + """ + Validation Failed: 1: [service_settings] Invalid url [^^^] received for field [url]. \ + Error: unable to parse url [^^^]. Reason: Illegal character in path;""" + ); + } + + public void testFromMap_EmptyUrl_ThrowsException() { + testFromMap_InvalidUrl( + Map.of( + ServiceFields.URL, + "", + ServiceFields.MODEL_ID, + MODEL_VALUE, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + "Validation Failed: 1: [service_settings] Invalid value empty string. [url] must be a non-empty string;" + ); + } + + private static void testFromMap_InvalidUrl(Map serviceSettingsMap, String expectedErrorMessage) { var thrownException = expectThrows( ValidationException.class, - () -> OpenShiftAiChatCompletionServiceSettings.fromMap( - new HashMap<>( - Map.of( - ServiceFields.MODEL_ID, - MODEL_VALUE, - RateLimitSettings.FIELD_NAME, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) - ) - ), - ConfigurationParseContext.PERSISTENT - ) + () -> OpenShiftAiChatCompletionServiceSettings.fromMap(new HashMap<>(serviceSettingsMap), ConfigurationParseContext.PERSISTENT) ); - assertThat( - thrownException.getMessage(), - containsString("Validation Failed: 1: [service_settings] does not contain the required setting [url];") - ); + assertThat(thrownException.getMessage(), containsString(expectedErrorMessage)); } public void testFromMap_MissingRateLimit_Success() { From 0e1b14b623a2647c9efcd314dd1c5a08163502ce Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Wed, 12 Nov 2025 02:36:42 +0200 Subject: [PATCH 68/70] Refactor OpenShift AI test constants for improved clarity and consistency --- .../openshiftai/OpenShiftAiServiceTests.java | 19 ++- .../action/OpenShiftAiActionCreatorTests.java | 145 ++++++++++++++---- ...tAiChatCompletionResponseHandlerTests.java | 13 +- ...tAiChatCompletionServiceSettingsTests.java | 8 +- ...ShiftAiEmbeddingsServiceSettingsTests.java | 8 +- ...OpenShiftAiChatCompletionRequestTests.java | 3 +- .../OpenShiftAiEmbeddingsRequestTests.java | 5 +- .../rarank/OpenShiftAiRerankRequestTests.java | 3 +- .../rerank/OpenShiftAiRerankModelTests.java | 4 - 9 files changed, 151 insertions(+), 57 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index 0fed0d3792e05..131fa98f97563 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -70,7 +71,6 @@ import java.util.EnumSet; import java.util.HashMap; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -175,7 +175,7 @@ private static void assertModel(Model model, TaskType taskType, boolean modelInc case TEXT_EMBEDDING -> assertTextEmbeddingModel(model, modelIncludesSecrets); case COMPLETION -> assertCompletionModel(model, modelIncludesSecrets); case CHAT_COMPLETION -> assertChatCompletionModel(model, modelIncludesSecrets); - default -> fail("unexpected task type [%s]".formatted(taskType)); + default -> fail(Strings.format("unexpected task type [%s]", taskType)); } } @@ -277,7 +277,7 @@ public void shutdown() throws IOException { } public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - var chunkingSettingsMap = createRandomChunkingSettings(); + var chunkingSettings = createRandomChunkingSettings(); try (var service = createService()) { ActionListener modelVerificationActionListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(OpenShiftAiEmbeddingsModel.class)); @@ -285,7 +285,7 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP var embeddingsModel = (OpenShiftAiEmbeddingsModel) model; assertThat(embeddingsModel.getServiceSettings().uri().toString(), is(URL_VALUE)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings().asMap(), is(chunkingSettingsMap.asMap())); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings().asMap(), is(chunkingSettings.asMap())); assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); }, e -> fail("parse request should not fail " + e.getMessage())); @@ -294,7 +294,7 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP TaskType.TEXT_EMBEDDING, getRequestConfigMap( getServiceSettingsMap(MODEL_VALUE, URL_VALUE), - chunkingSettingsMap.asMap(), + chunkingSettings.asMap(), getSecretSettingsMap(API_KEY_VALUE) ), modelVerificationActionListener @@ -461,7 +461,7 @@ public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception { } }); var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); - assertThat(json, is(String.format(Locale.ROOT, XContentHelper.stripWhitespace(""" + assertThat(json, is(Strings.format(XContentHelper.stripWhitespace(""" { "error" : { "code" : "not_found", @@ -568,7 +568,7 @@ public void testInfer_StreamRequest_ErrorResponse() { var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion); assertThat(e.status(), equalTo(RestStatus.NOT_FOUND)); - assertThat(e.getMessage(), equalTo(String.format(Locale.ROOT, """ + assertThat(e.getMessage(), equalTo(Strings.format(""" Resource not found at [%s] for request from inference entity id [inferenceEntityId] status [404]. Error message: [{ "detail": "Not Found" }]""", getUrl(webServer)))); @@ -748,7 +748,10 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), is("Bearer %s".formatted(API_KEY_VALUE))); + assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), + is(Strings.format("Bearer %s", API_KEY_VALUE)) + ); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap, aMapWithSize(2)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java index 9a8cea882b9de..de8650ca1b9a6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java @@ -90,6 +90,11 @@ public class OpenShiftAiActionCreatorTests extends ESTestCase { private static final String INPUT_TO_TRUNCATE = "super long input"; private static final List EMBEDDINGS_VALUE = List.of(new float[] { 0.0123F, -0.0123F }); private static final List DOCUMENTS_VALUE = List.of("Luke"); + private static final int N_VALUE = 1; + private static final boolean RETURN_DOCUMENTS_DEFAULT_VALUE = true; + private static final boolean RETURN_DOCUMENTS_OVERRIDDEN_VALUE = false; + private static final int TOP_N_DEFAULT_VALUE = 2; + private static final int TOP_N_OVERRIDDEN_VALUE = 1; private static final List RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS = List.of( new RankedDocsResultsTests.RerankExpectation(Map.of("text", "awgawgawgawg", "index", 1, "relevance_score", 0.9921875f)), new RankedDocsResultsTests.RerankExpectation(Map.of("text", "awdawdawda", "index", 0, "relevance_score", 0.4921875f)) @@ -185,7 +190,7 @@ public void testCreate_OpenShiftAiEmbeddingsModel() throws IOException { private static void assertContentTypeAndAuthorization(MockRequest request) { assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters())); - assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer %s".formatted(API_KEY_VALUE))); + assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo(Strings.format("Bearer %s", API_KEY_VALUE))); } public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat() throws IOException { @@ -239,7 +244,8 @@ public void testCreate_OpenShiftAiEmbeddingsModel_FailsFromInvalidResponseFormat assertThat( thrownException.getMessage(), is( - "Failed to send OpenShift AI text_embedding request from inference entity id [inferenceEntityId]. Cause: %s".formatted( + Strings.format( + "Failed to send OpenShift AI text_embedding request from inference entity id [inferenceEntityId]. Cause: %s", failureCauseMessage ) ) @@ -307,7 +313,7 @@ public void testCreate_OpenShiftAiChatCompletionModel() throws IOException { is(List.of(Map.of(ROLE_FIELD_NAME, ROLE_VALUE, CONTENT_FIELD_NAME, INPUT_ENTRY_VALUE))) ); assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); - assertThat(requestMap.get(N_FIELD_NAME), is(1)); + assertThat(requestMap.get(N_FIELD_NAME), is(N_VALUE)); assertThat(requestMap.get(STREAM_FIELD_NAME), is(false)); } } @@ -364,7 +370,8 @@ public void testCreate_OpenShiftAiChatCompletionModel_FailsFromInvalidResponseFo assertThat( thrownException.getMessage(), is( - "Failed to send OpenShift AI completion request from inference entity id [inferenceEntityId]. Cause: %s".formatted( + Strings.format( + "Failed to send OpenShift AI completion request from inference entity id [inferenceEntityId]. Cause: %s", failureCauseMessage ) ) @@ -631,7 +638,13 @@ public void testCreate_OpenShiftAiRerankModel_WithTaskSettings() throws IOExcept """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); + var model = OpenShiftAiRerankModelTests.createModel( + getUrl(webServer), + API_KEY_VALUE, + MODEL_VALUE, + TOP_N_DEFAULT_VALUE, + RETURN_DOCUMENTS_DEFAULT_VALUE + ); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) @@ -648,13 +661,12 @@ public void testCreate_OpenShiftAiRerankModel_WithTaskSettings() throws IOExcept var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS))); } - assertRerankActionCreator(DOCUMENTS_VALUE, 2, true); + assertRerankActionCreator(TOP_N_DEFAULT_VALUE, RETURN_DOCUMENTS_DEFAULT_VALUE, MODEL_VALUE); } public void testCreate_OpenShiftAiRerankModel_WithOverriddenTaskSettings() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); - List documents = DOCUMENTS_VALUE; try (var sender = createSender(senderFactory)) { sender.startSynchronously(); @@ -675,19 +687,32 @@ public void testCreate_OpenShiftAiRerankModel_WithOverriddenTaskSettings() throw """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); + var model = OpenShiftAiRerankModelTests.createModel( + getUrl(webServer), + API_KEY_VALUE, + MODEL_VALUE, + TOP_N_DEFAULT_VALUE, + RETURN_DOCUMENTS_DEFAULT_VALUE + ); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) ); var action = actionCreator.create( model, - new HashMap<>(Map.of(OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, false, OpenShiftAiRerankTaskSettings.TOP_N, 1)) + new HashMap<>( + Map.of( + OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS, + RETURN_DOCUMENTS_OVERRIDDEN_VALUE, + OpenShiftAiRerankTaskSettings.TOP_N, + TOP_N_OVERRIDDEN_VALUE + ) + ) ); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new QueryAndDocsInputs(QUERY_VALUE, documents, null, null, false), + new QueryAndDocsInputs(QUERY_VALUE, DOCUMENTS_VALUE, null, null, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -695,13 +720,12 @@ public void testCreate_OpenShiftAiRerankModel_WithOverriddenTaskSettings() throw var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_NO_TEXT_SINGLE_RESULT))); } - assertRerankActionCreator(documents, 1, false); + assertRerankActionCreator(TOP_N_OVERRIDDEN_VALUE, RETURN_DOCUMENTS_OVERRIDDEN_VALUE, MODEL_VALUE); } - public void testCreate_OpenShiftAiRerankModel_NoTaskSettings() throws IOException { + public void testCreate_OpenShiftAiRerankModel_NoTaskSettingsWithModelId() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); - List documents = DOCUMENTS_VALUE; try (var sender = createSender(senderFactory)) { sender.startSynchronously(); @@ -741,7 +765,60 @@ public void testCreate_OpenShiftAiRerankModel_NoTaskSettings() throws IOExceptio PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new QueryAndDocsInputs(QUERY_VALUE, documents, null, null, false), + new QueryAndDocsInputs(QUERY_VALUE, DOCUMENTS_VALUE, null, null, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS))); + } + assertRerankActionCreator(null, null, MODEL_VALUE); + } + + public void testCreate_OpenShiftAiRerankModel_NoTaskSettings_NoModelId() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "id": "rerank-d300256dd02b4c63b8a2bc34dcdad845", + "model": "bge-reranker-v2-m3", + "usage": { + "total_tokens": 30 + }, + "results": [ + { + "index": 1, + "document": { + "text": "awgawgawgawg" + }, + "relevance_score": 0.9921875 + }, + { + "index": 0, + "document": { + "text": "awdawdawda" + }, + "relevance_score": 0.4921875 + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY_VALUE, null, null, null); + var actionCreator = new OpenShiftAiActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create(model, null); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new QueryAndDocsInputs(QUERY_VALUE, DOCUMENTS_VALUE, null, null, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -749,7 +826,7 @@ public void testCreate_OpenShiftAiRerankModel_NoTaskSettings() throws IOExceptio var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS))); } - assertRerankActionCreator(documents, null, null); + assertRerankActionCreator(null, null, null); } public void testCreate_OpenShiftAiRerankModel_NoTaskSettings_WithRequestParameters() throws IOException { @@ -794,7 +871,7 @@ public void testCreate_OpenShiftAiRerankModel_NoTaskSettings_WithRequestParamete PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new QueryAndDocsInputs(QUERY_VALUE, DOCUMENTS_VALUE, true, 2, false), + new QueryAndDocsInputs(QUERY_VALUE, DOCUMENTS_VALUE, RETURN_DOCUMENTS_DEFAULT_VALUE, TOP_N_DEFAULT_VALUE, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -802,13 +879,12 @@ public void testCreate_OpenShiftAiRerankModel_NoTaskSettings_WithRequestParamete var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_WITH_TEXT_TWO_RESULTS))); } - assertRerankActionCreator(DOCUMENTS_VALUE, 2, true); + assertRerankActionCreator(TOP_N_DEFAULT_VALUE, RETURN_DOCUMENTS_DEFAULT_VALUE, MODEL_VALUE); } public void testCreate_OpenShiftAiRerankModel_WithTaskSettings_WithRequestParametersPrioritized() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, NO_RETRY_SETTINGS); - List documents = DOCUMENTS_VALUE; try (var sender = createSender(senderFactory)) { sender.startSynchronously(); @@ -829,7 +905,13 @@ public void testCreate_OpenShiftAiRerankModel_WithTaskSettings_WithRequestParame """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); + var model = OpenShiftAiRerankModelTests.createModel( + getUrl(webServer), + API_KEY_VALUE, + MODEL_VALUE, + TOP_N_DEFAULT_VALUE, + RETURN_DOCUMENTS_DEFAULT_VALUE + ); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) @@ -838,7 +920,7 @@ public void testCreate_OpenShiftAiRerankModel_WithTaskSettings_WithRequestParame PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new QueryAndDocsInputs(QUERY_VALUE, documents, false, 1, false), + new QueryAndDocsInputs(QUERY_VALUE, DOCUMENTS_VALUE, RETURN_DOCUMENTS_OVERRIDDEN_VALUE, TOP_N_OVERRIDDEN_VALUE, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -846,7 +928,7 @@ public void testCreate_OpenShiftAiRerankModel_WithTaskSettings_WithRequestParame var result = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat(result.asMap(), is(buildExpectationRerank(RERANK_EXPECTATIONS_NO_TEXT_SINGLE_RESULT))); } - assertRerankActionCreator(documents, 1, false); + assertRerankActionCreator(TOP_N_OVERRIDDEN_VALUE, RETURN_DOCUMENTS_OVERRIDDEN_VALUE, MODEL_VALUE); } public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() throws IOException { @@ -882,7 +964,13 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenShiftAiRerankModelTests.createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); + var model = OpenShiftAiRerankModelTests.createModel( + getUrl(webServer), + API_KEY_VALUE, + MODEL_VALUE, + TOP_N_DEFAULT_VALUE, + RETURN_DOCUMENTS_DEFAULT_VALUE + ); var actionCreator = new OpenShiftAiActionCreator( sender, new ServiceComponents(threadPool, mockThrottlerManager(), NO_RETRY_SETTINGS, TruncatorTests.createTruncator()) @@ -905,18 +993,17 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t } private void assertRerankActionCreator( - List documents, @Nullable Integer expectedTopN, - @Nullable Boolean expectedReturnDocuments + @Nullable Boolean expectedReturnDocuments, + @Nullable String expectedModel ) throws IOException { assertThat(webServer.requests(), hasSize(1)); assertThat(webServer.requests().getFirst().getUri().getQuery(), is(nullValue())); assertContentTypeAndAuthorization(webServer.requests().getFirst()); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - int fieldCount = 3; - assertThat(requestMap.get(DOCUMENTS_FIELD_NAME), is(documents)); - assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); + int fieldCount = 2; + assertThat(requestMap.get(DOCUMENTS_FIELD_NAME), is(DOCUMENTS_VALUE)); assertThat(requestMap.get(QUERY_FIELD_NAME), is(QUERY_VALUE)); if (expectedTopN != null) { assertThat(requestMap.get(TOP_N_FIELD_NAME), is(expectedTopN)); @@ -926,6 +1013,10 @@ private void assertRerankActionCreator( assertThat(requestMap.get(RETURN_DOCUMENTS_FIELD_NAME), is(expectedReturnDocuments)); fieldCount++; } + if (expectedModel != null) { + assertThat(requestMap.get(MODEL_FIELD_NAME), is(expectedModel)); + fieldCount++; + } assertThat(requestMap, aMapWithSize(fieldCount)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java index 046d9feecf420..fe850ae98b945 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionResponseHandlerTests.java @@ -11,6 +11,7 @@ import org.apache.http.StatusLine; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Strings; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; @@ -47,7 +48,7 @@ public void testFailNotFound() throws IOException { var errorJson = invalidResponseJson(responseJson, 404); - assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + assertThat(errorJson, is(XContentHelper.stripWhitespace(Strings.format(""" { "error" : { "code" : "not_found", @@ -55,7 +56,7 @@ public void testFailNotFound() throws IOException { status [404]. Error message: [{\\"detail\\":\\"Not Found\\"}]", "type" : "openshift_ai_error" } - }""".formatted(URL_VALUE, INFERENCE_ID)))); + }""", URL_VALUE, INFERENCE_ID)))); } public void testFailBadRequest() throws IOException { @@ -73,7 +74,7 @@ public void testFailBadRequest() throws IOException { var errorJson = invalidResponseJson(responseJson, 400); - assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + assertThat(errorJson, is(XContentHelper.stripWhitespace(Strings.format(""" { "error": { "code": "bad_request", @@ -84,7 +85,7 @@ public void testFailBadRequest() throws IOException { "type": "openshift_ai_error" } } - """.formatted(INFERENCE_ID)))); + """, INFERENCE_ID)))); } public void testFailValidationWithInvalidJson() throws IOException { @@ -94,7 +95,7 @@ public void testFailValidationWithInvalidJson() throws IOException { var errorJson = invalidResponseJson(responseJson, 500); - assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + assertThat(errorJson, is(XContentHelper.stripWhitespace(Strings.format(""" { "error": { "code": "bad_request", @@ -103,7 +104,7 @@ public void testFailValidationWithInvalidJson() throws IOException { "type": "openshift_ai_error" } } - """.formatted(INFERENCE_ID)))); + """, INFERENCE_ID)))); } private String invalidResponseJson(String responseJson, int statusCode) throws IOException { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java index 2e90e750e2460..83d4fb6cb2538 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionServiceSettingsTests.java @@ -153,7 +153,7 @@ public void testToXContent_WritesAllValues() throws IOException { XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); serviceSettings.toXContent(builder, null); String xContentResult = Strings.toString(builder); - var expected = XContentHelper.stripWhitespace(""" + var expected = XContentHelper.stripWhitespace(Strings.format(""" { "model_id": "%s", "url": "%s", @@ -161,7 +161,7 @@ public void testToXContent_WritesAllValues() throws IOException { "requests_per_minute": 2 } } - """.formatted(MODEL_VALUE, URL_VALUE)); + """, MODEL_VALUE, URL_VALUE)); assertThat(xContentResult, is(expected)); } @@ -175,14 +175,14 @@ public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); serviceSettings.toXContent(builder, null); String xContentResult = Strings.toString(builder); - var expected = XContentHelper.stripWhitespace(""" + var expected = XContentHelper.stripWhitespace(Strings.format(""" { "url": "%s", "rate_limit": { "requests_per_minute": 3000 } } - """.formatted(URL_VALUE)); + """, URL_VALUE)); assertThat(xContentResult, is(expected)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java index 231445b5e70d4..a56e6bb135fc8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsServiceSettingsTests.java @@ -161,9 +161,9 @@ public void testFromMap_InvalidUrl_ThrowsException() { ConfigurationParseContext.PERSISTENT ) ); - assertThat(thrownException.getMessage(), containsString(""" + assertThat(thrownException.getMessage(), containsString(Strings.format(""" Validation Failed: 1: [service_settings] Invalid url [%s] received for field [url]. \ - Error: unable to parse url [%s]. Reason: Illegal character in path;""".formatted(INVALID_URL_VALUE, INVALID_URL_VALUE))); + Error: unable to parse url [%s]. Reason: Illegal character in path;""", INVALID_URL_VALUE, INVALID_URL_VALUE))); } public void testFromMap_NoSimilarity_Success() { @@ -546,7 +546,7 @@ public void testToXContent_WritesAllValues() throws IOException { entity.toXContent(builder, null); String xContentResult = Strings.toString(builder); - assertThat(xContentResult, CoreMatchers.is(XContentHelper.stripWhitespace(""" + assertThat(xContentResult, CoreMatchers.is(XContentHelper.stripWhitespace(Strings.format(""" { "model_id": "%s", "url": "%s", @@ -558,7 +558,7 @@ public void testToXContent_WritesAllValues() throws IOException { "max_input_tokens": 128, "dimensions_set_by_user": false } - """.formatted(MODEL_VALUE, CORRECT_URL_VALUE)))); + """, MODEL_VALUE, CORRECT_URL_VALUE)))); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java index 7b408e8c132f9..2f6a78076e21c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/completion/OpenShiftAiChatCompletionRequestTests.java @@ -9,6 +9,7 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModelTests; @@ -55,7 +56,7 @@ public void testCreateRequest_WithStreaming() throws IOException { assertThat(requestMap.get(N_FIELD_NAME), is(1)); assertThat(requestMap.get(MESSAGES_FIELD_NAME), is(List.of(Map.of(ROLE_FIELD_NAME, ROLE_VALUE, CONTENT_FIELD_NAME, input)))); assertThat(requestMap, aMapWithSize(4)); - assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY_VALUE))); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is(Strings.format("Bearer %s", API_KEY_VALUE))); } public void testTruncate_DoesNotReduceInputTextSize() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java index 6e3a5d14a3490..87defb2301f6e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/embeddings/OpenShiftAiEmbeddingsRequestTests.java @@ -9,6 +9,7 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.common.Truncator; @@ -62,7 +63,7 @@ private void testCreateRequest_Success(Integer dimensions, boolean dimensionsSet assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_VALUE))); assertThat(requestMap.get(MODEL_FIELD_NAME), is(MODEL_VALUE)); assertThat(requestMap.get(ServiceFields.DIMENSIONS), is(expectedDimensions)); - assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY_VALUE))); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is(Strings.format("Bearer %s", API_KEY_VALUE))); } public void testCreateRequest_NoModel_Success() throws IOException { @@ -74,7 +75,7 @@ public void testCreateRequest_NoModel_Success() throws IOException { assertThat(requestMap.get(INPUT_FIELD_NAME), is(List.of(INPUT_VALUE))); assertThat(requestMap.get(MODEL_FIELD_NAME), is(nullValue())); assertThat(requestMap.get(ServiceFields.DIMENSIONS), is(nullValue())); - assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY_VALUE))); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is(Strings.format("Bearer %s", API_KEY_VALUE))); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java index ba4d4c4aa0b34..dba3c26eab97f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/request/rarank/OpenShiftAiRerankRequestTests.java @@ -9,6 +9,7 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; import org.elasticsearch.core.Nullable; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; @@ -111,7 +112,7 @@ private void testCreateRequest( var httpPost = (HttpPost) httpRequest.httpRequestBase(); assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters())); - assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer %s".formatted(API_KEY_VALUE))); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(Strings.format("Bearer %s", API_KEY_VALUE))); var requestMap = entityAsMap(httpPost.getEntity().getContent()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java index 0d210428a46a6..f553399d90c98 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java @@ -29,10 +29,6 @@ public class OpenShiftAiRerankModelTests extends ESTestCase { private static final int TOP_N_VALUE = 4; private static final boolean RETURN_DOCUMENTS_VALUE = false; - public static OpenShiftAiRerankModel createModel(String url, String apiKey, @Nullable String modelId) { - return createModel(url, apiKey, modelId, 2, true); - } - public static OpenShiftAiRerankModel createModel( String url, String apiKey, From 70bebeb363d3aca086d966f56fcc45dc531dba90 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Wed, 12 Nov 2025 20:17:29 +0200 Subject: [PATCH 69/70] Add "openshift_ai" to various service lists in InferenceGetServicesIT --- .../xpack/inference/InferenceGetServicesIT.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index dd1012daffc03..2194c5b08122a 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -70,6 +70,7 @@ public void testGetServicesWithoutTaskType() throws IOException { "llama", "mistral", "openai", + "openshift_ai", "streaming_completion_test_service", "completion_test_service", "test_reranking_service", @@ -116,6 +117,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { "llama", "mistral", "openai", + "openshift_ai", "text_embedding_test_service", "voyageai", "watsonxai" @@ -140,6 +142,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { "elasticsearch", "googlevertexai", "jinaai", + "openshift_ai", "test_reranking_service", "voyageai", "hugging_face", @@ -167,6 +170,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException { "googleaistudio", "googlevertexai", "openai", + "openshift_ai", "streaming_completion_test_service", "completion_test_service", "hugging_face", @@ -188,6 +192,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { "deepseek", "elastic", "openai", + "openshift_ai", "streaming_completion_test_service", "hugging_face", "amazon_sagemaker", From b30282da546de7d75336602d74857c8e2a895232 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Wed, 12 Nov 2025 20:27:50 +0200 Subject: [PATCH 70/70] Fix embeddings input handling in OpenShiftAiActionCreator --- .../services/openshiftai/action/OpenShiftAiActionCreator.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreator.java index 23a6c56efacf9..52319d12f1112 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreator.java @@ -85,7 +85,7 @@ public ExecutableAction create(OpenShiftAiEmbeddingsModel model) { EMBEDDINGS_HANDLER, embeddingsInput -> new OpenShiftAiEmbeddingsRequest( serviceComponents.truncator(), - truncate(embeddingsInput.getInputs(), model.getServiceSettings().maxInputTokens()), + truncate(embeddingsInput.getTextInputs(), model.getServiceSettings().maxInputTokens()), model ), EmbeddingsInput.class