diff --git a/docs/changelog/134933.yaml b/docs/changelog/134933.yaml new file mode 100644 index 0000000000000..ff884b2f888e4 --- /dev/null +++ b/docs/changelog/134933.yaml @@ -0,0 +1,5 @@ +pr: 134933 +summary: Add ContextualAI Rerank Service Implementation to the Inference API +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/resources/transport/definitions/referable/contextual_ai_service.csv b/server/src/main/resources/transport/definitions/referable/contextual_ai_service.csv new file mode 100644 index 0000000000000..a4676dc2fe444 --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/contextual_ai_service.csv @@ -0,0 +1 @@ +9175000 diff --git a/server/src/main/resources/transport/upper_bounds/9.2.csv b/server/src/main/resources/transport/upper_bounds/9.2.csv index 57900e0428e01..455c9526b5cfd 100644 --- a/server/src/main/resources/transport/upper_bounds/9.2.csv +++ b/server/src/main/resources/transport/upper_bounds/9.2.csv @@ -1 +1 @@ -semantic_search_ccs_support,9174000 +contextual_ai_service,9175000 diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 4bf874674df6d..f86c92c02db48 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -59,6 +59,7 @@ public void testGetServicesWithoutTaskType() throws IOException { "azureaistudio", "azureopenai", "cohere", + "contextualai", "deepseek", "elastic", "elasticsearch", @@ -134,6 +135,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { "alibabacloud-ai-search", "azureaistudio", "cohere", + "contextualai", "elasticsearch", "googlevertexai", "jinaai", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 415b93443db5f..8d0b8486ba437 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -127,6 +127,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService; import org.elasticsearch.xpack.inference.services.cohere.CohereService; +import org.elasticsearch.xpack.inference.services.contextualai.ContextualAiService; import org.elasticsearch.xpack.inference.services.custom.CustomService; import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; @@ -410,6 +411,7 @@ public List getInferenceServiceFactories() { context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get(), context), context -> new OpenAiService(httpFactory.get(), serviceComponents.get(), context), context -> new CohereService(httpFactory.get(), serviceComponents.get(), context), + context -> new ContextualAiService(httpFactory.get(), serviceComponents.get(), context), context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get(), context), context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get(), context), context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get(), context), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiModel.java new file mode 100644 index 0000000000000..eaed55f1b17f7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiModel.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.contextualai; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.contextualai.action.ContextualAiActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +public abstract class ContextualAiModel extends RateLimitGroupingModel { + + private final SecureString apiKey; + private final ContextualAiRateLimitServiceSettings rateLimitServiceSettings; + + public ContextualAiModel( + ModelConfigurations configurations, + ModelSecrets secrets, + @Nullable ApiKeySecrets apiKeySecrets, + ContextualAiRateLimitServiceSettings rateLimitServiceSettings + ) { + super(configurations, secrets); + + this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings); + apiKey = ServiceUtils.apiKey(apiKeySecrets); + } + + protected ContextualAiModel(ContextualAiModel model, TaskSettings taskSettings) { + super(model, taskSettings); + + rateLimitServiceSettings = model.rateLimitServiceSettings(); + apiKey = model.apiKey(); + } + + protected ContextualAiModel(ContextualAiModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + + rateLimitServiceSettings = model.rateLimitServiceSettings(); + apiKey = model.apiKey(); + } + + public SecureString apiKey() { + return apiKey; + } + + public ContextualAiRateLimitServiceSettings rateLimitServiceSettings() { + return rateLimitServiceSettings; + } + + public abstract ExecutableAction accept(ContextualAiActionVisitor creator, Map taskSettings); + + public RateLimitSettings rateLimitSettings() { + return rateLimitServiceSettings.rateLimitSettings(); + } + + public int rateLimitGroupingHash() { + return apiKey().hashCode(); + } + + public URI baseUri() { + return rateLimitServiceSettings.uri(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiRateLimitServiceSettings.java new file mode 100644 index 0000000000000..1e93e12dc0fab --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiRateLimitServiceSettings.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.contextualai; + +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.net.URI; + +public interface ContextualAiRateLimitServiceSettings { + RateLimitSettings rateLimitSettings(); + + URI uri(); + + String modelId(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiResponseHandler.java new file mode 100644 index 0000000000000..07f5762c5f599 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiResponseHandler.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.contextualai; + +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.contextualai.response.ContextualAiErrorResponseEntity; + +/** + * Response handler for ContextualAI API calls. + */ +public class ContextualAiResponseHandler extends BaseResponseHandler { + + public ContextualAiResponseHandler(String requestType, ResponseParser parseFunction, boolean supportsStreaming) { + super(requestType, parseFunction, ContextualAiErrorResponseEntity::fromResponse, supportsStreaming); + } + + @Override + protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { + if (result.isSuccessfulResponse()) { + return; + } + + // handle error codes + int statusCode = result.response().getStatusLine().getStatusCode(); + if (statusCode == 500) { + throw new RetryException(true, buildError(SERVER_ERROR, request, result)); + } else if (statusCode > 500) { + throw new RetryException(false, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 429) { + throw new RetryException(true, buildError(RATE_LIMIT, request, result)); + } else if (statusCode == 401) { + throw new RetryException(false, buildError(AUTHENTICATION, request, result)); + } else if (statusCode >= 300 && statusCode < 400) { + throw new RetryException(false, buildError(REDIRECTION, request, result)); + } else { + throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java new file mode 100644 index 0000000000000..c67fc328acdd2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java @@ -0,0 +1,268 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.contextualai; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.RerankingInferenceService; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.contextualai.action.ContextualAiActionCreator; +import org.elasticsearch.xpack.inference.services.contextualai.rerank.ContextualAiRerankModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Contextual AI inference service for reranking tasks. + * This service uses the Contextual AI REST API to perform document reranking. + */ +public class ContextualAiService extends SenderService implements RerankingInferenceService { + public static final String NAME = "contextualai"; + private static final String SERVICE_NAME = "Contextual AI"; + + private static final TransportVersion CONTEXTUAL_AI_SERVICE = TransportVersion.fromName("contextual_ai_service"); + + private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of(TaskType.RERANK); + + public ContextualAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public ContextualAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); + } + + @Override + public String name() { + return NAME; + } + + @Override + public EnumSet supportedTaskTypes() { + return SUPPORTED_TASK_TYPES; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ContextualAiRerankModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + ServiceUtils.throwIfNotEmptyMap(config, NAME); + ServiceUtils.throwIfNotEmptyMap(serviceSettingsMap, NAME); + ServiceUtils.throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + private static ContextualAiRerankModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + if (taskType != TaskType.RERANK) { + throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + } + + return new ContextualAiRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context); + } + + @Override + public ContextualAiRerankModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + ServiceUtils.parsePersistedConfigErrorMsg(inferenceEntityId, NAME), + ConfigurationParseContext.PERSISTENT + ); + } + + @Override + public ContextualAiRerankModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + return createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + ServiceUtils.parsePersistedConfigErrorMsg(inferenceEntityId, NAME), + ConfigurationParseContext.PERSISTENT + ); + } + + @Override + public void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof ContextualAiRerankModel == false) { + listener.onFailure(ServiceUtils.createInvalidModelException(model)); + return; + } + + ContextualAiRerankModel contextualAiModel = (ContextualAiRerankModel) model; + var actionCreator = new ContextualAiActionCreator(getSender(), getServiceComponents()); + + var action = contextualAiModel.accept(actionCreator, taskSettings); + action.execute(inputs, timeout, listener); + } + + @Override + protected void doChunkedInfer( + Model model, + List inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + listener.onFailure(new ElasticsearchStatusException("Chunked inference is not supported for rerank task", RestStatus.BAD_REQUEST)); + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + listener.onFailure(new ElasticsearchStatusException("Unified completion is not supported for rerank task", RestStatus.BAD_REQUEST)); + } + + @Override + protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { + // Rerank accepts any input type + } + + @Override + public int rerankerWindowSize(String modelId) { + // Contextual AI rerank models have an 8000 token limit per document + // https://docs.contextual.ai/docs/rerank-api-reference + // Using 1 token = 0.75 words as a rough estimate, we get 6000 words + // allowing for some headroom, we set the window size below 6000 words + // https://github.com/elastic/elasticsearch/pull/134933#discussion_r2368608515 + return 5500; + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return CONTEXTUAL_AI_SERVICE; + } + + public static class Configuration { + public static InferenceServiceConfiguration get() { + return configuration.getOrCompute(); + } + + private static final LazyInitializable configuration = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.put( + "model_id", + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription( + "The model ID to use for Contextual AI requests." + ) + .setLabel("Model ID") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(SUPPORTED_TASK_TYPES) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/action/ContextualAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/action/ContextualAiActionCreator.java new file mode 100644 index 0000000000000..9964b36ab99a2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/action/ContextualAiActionCreator.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.contextualai.action; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.contextualai.ContextualAiResponseHandler; +import org.elasticsearch.xpack.inference.services.contextualai.request.ContextualAiRerankRequest; +import org.elasticsearch.xpack.inference.services.contextualai.rerank.ContextualAiRerankModel; +import org.elasticsearch.xpack.inference.services.contextualai.response.ContextualAiRerankResponseEntity; + +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; + +/** + * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the ContextualAI model type. + */ +public class ContextualAiActionCreator implements ContextualAiActionVisitor { + + private static final ResponseHandler RERANK_HANDLER = new ContextualAiResponseHandler( + "contextualai rerank", + (request, response) -> ContextualAiRerankResponseEntity.fromResponse((ContextualAiRerankRequest) request, response), + false + ); + + private final Sender sender; + private final ServiceComponents serviceComponents; + + public ContextualAiActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(ContextualAiRerankModel model, Map taskSettings) { + var overriddenModel = ContextualAiRerankModel.of(model, taskSettings); + + Function requestCreator = rerankInput -> new ContextualAiRerankRequest( + rerankInput.getQuery(), + rerankInput.getChunks(), + rerankInput.getTopN(), + overriddenModel.getTaskSettings().getInstruction(), + overriddenModel + ); + + var requestManager = new GenericRequestManager<>( + serviceComponents.threadPool(), + overriddenModel, + RERANK_HANDLER, + requestCreator, + QueryAndDocsInputs.class + ); + + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("ContextualAI rerank"); + return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/action/ContextualAiActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/action/ContextualAiActionVisitor.java new file mode 100644 index 0000000000000..c02e00ebbf6cd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/action/ContextualAiActionVisitor.java @@ -0,0 +1,17 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.contextualai.action; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.contextualai.rerank.ContextualAiRerankModel; + +import java.util.Map; + +public interface ContextualAiActionVisitor { + ExecutableAction create(ContextualAiRerankModel model, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/request/ContextualAiRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/request/ContextualAiRerankRequest.java new file mode 100644 index 0000000000000..307050e105dc6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/request/ContextualAiRerankRequest.java @@ -0,0 +1,108 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.contextualai.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.contextualai.rerank.ContextualAiRerankModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public class ContextualAiRerankRequest implements Request { + + private static final Logger logger = LogManager.getLogger(ContextualAiRerankRequest.class); + + private final String query; + private final List documents; + private final Integer topN; + private final String instruction; + private final ContextualAiRerankModel model; + + public ContextualAiRerankRequest( + String query, + List documents, + @Nullable Integer topN, + @Nullable String instruction, + ContextualAiRerankModel model + ) { + this.query = Objects.requireNonNull(query); + this.documents = Objects.requireNonNull(documents); + this.topN = topN; + this.instruction = instruction; + this.model = Objects.requireNonNull(model); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.uri()); + + var requestEntity = new ContextualAiRerankRequestEntity(query, documents, getTopN(), instruction, model); + String requestJson; + try { + requestJson = Strings.toString(requestEntity); + logger.debug("ContextualAI JSON Request: {}", requestJson); + } catch (Exception e) { + throw new RuntimeException("Failed to serialize ContextualAI request entity", e); + } + + ByteArrayEntity byteEntity = new ByteArrayEntity(requestJson.getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + + decorateWithAuth(httpPost); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + void decorateWithAuth(HttpPost httpPost) { + SecureString apiKey = model.apiKey(); + if (apiKey != null) { + httpPost.setHeader(createAuthBearerHeader(apiKey)); + } + } + + @Override + public String getInferenceEntityId() { + return model != null ? model.getInferenceEntityId() : "unknown"; + } + + @Override + public URI getURI() { + return model != null ? model.uri() : null; + } + + public Integer getTopN() { + return topN != null ? topN : (model.getTaskSettings() != null ? model.getTaskSettings().getTopN() : null); + } + + @Override + public Request truncate() { + // Not applicable for rerank, only used in text embedding requests + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // Not applicable for rerank, only used in text embedding requests + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/request/ContextualAiRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/request/ContextualAiRerankRequestEntity.java new file mode 100644 index 0000000000000..c41d056733f20 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/request/ContextualAiRerankRequestEntity.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.contextualai.request; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.contextualai.rerank.ContextualAiRerankModel; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +/** + * Request entity for Contextual AI rerank API. + * Based on the API documentation at https://docs.contextual.ai/api-reference/rerank/rerank + */ +public class ContextualAiRerankRequestEntity implements ToXContentObject { + + private static final String MODEL_FIELD = "model"; + private static final String QUERY_FIELD = "query"; + private static final String DOCUMENTS_FIELD = "documents"; + private static final String TOP_N_FIELD = "top_n"; + private static final String INSTRUCTION_FIELD = "instruction"; + + private final String query; + private final List documents; + private final Integer topN; + private final String instruction; + private final ContextualAiRerankModel model; + + public ContextualAiRerankRequestEntity( + String query, + List documents, + @Nullable Integer topN, + @Nullable String instruction, + ContextualAiRerankModel model + ) { + this.query = Objects.requireNonNull(query); + this.documents = Objects.requireNonNull(documents); + this.topN = topN; + this.instruction = instruction; + this.model = Objects.requireNonNull(model); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + // Order fields to match ContextualAI API expectation: query, model, top_n, instruction, documents + builder.field(QUERY_FIELD, query); + builder.field(MODEL_FIELD, model.modelId()); + + // Add top_n field if specified + if (topN != null) { + builder.field(TOP_N_FIELD, topN); + } else if (model.getTaskSettings() != null && model.getTaskSettings().getTopN() != null) { + builder.field(TOP_N_FIELD, model.getTaskSettings().getTopN()); + } + + if (instruction != null) { + builder.field(INSTRUCTION_FIELD, instruction); + } + + builder.field(DOCUMENTS_FIELD, documents); + + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/rerank/ContextualAiRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/rerank/ContextualAiRerankModel.java new file mode 100644 index 0000000000000..2ec2b9e1f3d3d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/rerank/ContextualAiRerankModel.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.contextualai.rerank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.contextualai.ContextualAiModel; +import org.elasticsearch.xpack.inference.services.contextualai.ContextualAiService; +import org.elasticsearch.xpack.inference.services.contextualai.action.ContextualAiActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.net.URI; +import java.util.Map; + +public class ContextualAiRerankModel extends ContextualAiModel { + public static ContextualAiRerankModel of(ContextualAiRerankModel model, Map taskSettings) { + var requestTaskSettings = ContextualAiRerankTaskSettings.fromMap(taskSettings); + return new ContextualAiRerankModel(model, ContextualAiRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + } + + public ContextualAiRerankModel( + String modelId, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + modelId, + ContextualAiRerankServiceSettings.fromMap(serviceSettings, context), + ContextualAiRerankTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets) + ); + } + + public ContextualAiRerankModel( + String modelId, + ContextualAiRerankServiceSettings serviceSettings, + ContextualAiRerankTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(modelId, TaskType.RERANK, ContextualAiService.NAME, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + secretSettings, + serviceSettings + ); + } + + private ContextualAiRerankModel(ContextualAiRerankModel model, ContextualAiRerankTaskSettings taskSettings) { + super(model, taskSettings); + } + + public ContextualAiRerankModel(ContextualAiRerankModel model, ContextualAiRerankServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + @Override + public ContextualAiRerankServiceSettings getServiceSettings() { + return (ContextualAiRerankServiceSettings) super.getServiceSettings(); + } + + @Override + public ContextualAiRerankTaskSettings getTaskSettings() { + return (ContextualAiRerankTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + public URI uri() { + return getServiceSettings().uri(); + } + + public String modelId() { + return getServiceSettings().modelId(); + } + + /** + * Accepts a visitor to create an executable action. The returned action will not return documents in the response. + * @param visitor Interface for creating {@link ExecutableAction} instances for ContextualAI models. + * @param taskSettings Settings in the request to override the model's defaults + * @return the rerank action + */ + @Override + public ExecutableAction accept(ContextualAiActionVisitor visitor, Map taskSettings) { + return visitor.create(this, taskSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/rerank/ContextualAiRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/rerank/ContextualAiRerankServiceSettings.java new file mode 100644 index 0000000000000..212a3c4e03323 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/rerank/ContextualAiRerankServiceSettings.java @@ -0,0 +1,154 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.contextualai.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.contextualai.ContextualAiRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.contextualai.ContextualAiService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; + +public class ContextualAiRerankServiceSettings extends FilteredXContentObject + implements + ContextualAiRateLimitServiceSettings, + ServiceSettings { + + public static final String NAME = "contextualai_rerank_service_settings"; + private static final String API_KEY = "api_key"; + private static final String MODEL_ID = "model_id"; + + // TODO: Make this configurable instead of hardcoded. Should support custom endpoints or different ContextualAI regions. + private static final String DEFAULT_URL = "https://api.contextual.ai/v1/rerank"; + + // Default rate limit settings - can be adjusted based on ContextualAI's actual limits + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(1000); // 1000 requests per minute + + private final URI uri; + private final String modelId; + private final RateLimitSettings rateLimitSettings; + + public static ContextualAiRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + ContextualAiService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + URI uri = url != null ? ServiceUtils.createUri(url) : ServiceUtils.createUri(DEFAULT_URL); + return new ContextualAiRerankServiceSettings(uri, modelId, rateLimitSettings); + } + + public ContextualAiRerankServiceSettings(URI uri, @Nullable String modelId, @Nullable RateLimitSettings rateLimitSettings) { + this.uri = Objects.requireNonNull(uri); + this.modelId = modelId; // Can be null for REQUEST context + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public ContextualAiRerankServiceSettings(StreamInput in) throws IOException { + this.uri = ServiceUtils.createUri(in.readString()); + this.modelId = in.readString(); + this.rateLimitSettings = Objects.requireNonNullElse(in.readOptionalWriteable(RateLimitSettings::new), DEFAULT_RATE_LIMIT_SETTINGS); + } + + public URI uri() { + return uri; + } + + public String modelId() { + return modelId; + } + + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(URL, uri.toString()); + builder.field(MODEL_ID, modelId); + + rateLimitSettings.toXContent(builder, params); + + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(URL, uri.toString()); + builder.field(MODEL_ID, modelId); + + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(uri.toString()); + out.writeString(modelId); + out.writeOptionalWriteable(rateLimitSettings); + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + ContextualAiRerankServiceSettings that = (ContextualAiRerankServiceSettings) object; + return Objects.equals(uri, that.uri) + && Objects.equals(modelId, that.modelId) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(uri, modelId, rateLimitSettings); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.V_8_15_0; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/rerank/ContextualAiRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/rerank/ContextualAiRerankTaskSettings.java new file mode 100644 index 0000000000000..f105df262392d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/rerank/ContextualAiRerankTaskSettings.java @@ -0,0 +1,159 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.contextualai.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ServiceUtils; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; + +public class ContextualAiRerankTaskSettings implements TaskSettings { + + public static final String NAME = "contextualai_rerank_task_settings"; + public static final String RETURN_DOCUMENTS = "return_documents"; + public static final String TOP_N_DOCS_ONLY = "top_n"; + public static final String INSTRUCTION = "instruction"; + + // Default hardcoded instruction for reranking + private static final String DEFAULT_INSTRUCTION = "Rerank the given documents based on their relevance to the query."; + + public static final ContextualAiRerankTaskSettings EMPTY_SETTINGS = new ContextualAiRerankTaskSettings(null, null, null); + + public static ContextualAiRerankTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException); + Integer topN = extractOptionalPositiveInteger(map, TOP_N_DOCS_ONLY, ModelConfigurations.TASK_SETTINGS, validationException); + String instruction = ServiceUtils.extractOptionalString(map, INSTRUCTION, ModelConfigurations.TASK_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new ContextualAiRerankTaskSettings(returnDocuments, topN, instruction); + } + + public static ContextualAiRerankTaskSettings of( + ContextualAiRerankTaskSettings originalSettings, + ContextualAiRerankTaskSettings requestSettings + ) { + var returnDocuments = requestSettings.getReturnDocuments() != null + ? requestSettings.getReturnDocuments() + : originalSettings.getReturnDocuments(); + var topN = requestSettings.getTopN() != null ? requestSettings.getTopN() : originalSettings.getTopN(); + var instruction = requestSettings.getInstruction() != null ? requestSettings.getInstruction() : originalSettings.getInstruction(); + + return new ContextualAiRerankTaskSettings(returnDocuments, topN, instruction); + } + + private final Boolean returnDocuments; + private final Integer topN; + private final String instruction; + + public ContextualAiRerankTaskSettings(@Nullable Boolean returnDocuments, @Nullable Integer topN, @Nullable String instruction) { + this.returnDocuments = returnDocuments; + this.topN = topN; + this.instruction = instruction; + } + + public ContextualAiRerankTaskSettings(StreamInput in) throws IOException { + this.returnDocuments = in.readOptionalBoolean(); + this.topN = in.readOptionalVInt(); + this.instruction = in.readOptionalString(); + } + + @Nullable + public Boolean getReturnDocuments() { + return returnDocuments; + } + + @Nullable + public Integer getTopN() { + return topN; + } + + // Return custom instruction if provided, otherwise use default + public String getInstruction() { + return instruction != null ? instruction : DEFAULT_INSTRUCTION; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalBoolean(returnDocuments); + out.writeOptionalVInt(topN); + out.writeOptionalString(instruction); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (returnDocuments != null) { + builder.field(RETURN_DOCUMENTS, returnDocuments); + } + if (topN != null) { + builder.field(TOP_N_DOCS_ONLY, topN); + } + if (instruction != null) { + builder.field(INSTRUCTION, instruction); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + ContextualAiRerankTaskSettings that = (ContextualAiRerankTaskSettings) object; + return Objects.equals(returnDocuments, that.returnDocuments) + && Objects.equals(topN, that.topN) + && Objects.equals(instruction, that.instruction); + } + + @Override + public int hashCode() { + return Objects.hash(returnDocuments, topN, instruction); + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + return fromMap(newSettings); + } + + @Override + public boolean isEmpty() { + return returnDocuments == null && topN == null && instruction == null; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.V_8_15_0; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/response/ContextualAiErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/response/ContextualAiErrorResponseEntity.java new file mode 100644 index 0000000000000..479a77ebe1409 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/response/ContextualAiErrorResponseEntity.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.contextualai.response; + +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; + +/** + * Error response entity for ContextualAI API calls. + */ +public class ContextualAiErrorResponseEntity extends ErrorResponse { + + private ContextualAiErrorResponseEntity(String errorMessage) { + super(errorMessage); + } + + public static ErrorResponse fromResponse(HttpResult response) { + // Simple error handling - just return the status line as the error message + return new ContextualAiErrorResponseEntity(response.response().getStatusLine().toString()); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/response/ContextualAiRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/response/ContextualAiRerankResponseEntity.java new file mode 100644 index 0000000000000..8cb5a46d14950 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/response/ContextualAiRerankResponseEntity.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.contextualai.response; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.services.contextualai.request.ContextualAiRerankRequest; + +import java.io.IOException; +import java.util.Comparator; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; + +/** + * Parses the Contextual AI rerank response. + * + * Based on the API documentation, the response should look like: + * + *
+ * {
+ *   "results": [
+ *     {
+ *       "index": 0,
+ *       "relevance_score": 0.95,
+ *       "document": "original document text if return_documents=true"
+ *     },
+ *     {
+ *       "index": 1,
+ *       "relevance_score": 0.85,
+ *       "document": "original document text if return_documents=true"
+ *     }
+ *   ]
+ * }
+ * 
+ */ +public class ContextualAiRerankResponseEntity { + + public static RankedDocsResults fromResponse(ContextualAiRerankRequest request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + var rankedDocs = doParse(jsonParser); + var rankedDocsByRelevanceStream = rankedDocs.stream() + .sorted(Comparator.comparingDouble(RankedDocsResults.RankedDoc::relevanceScore).reversed()); + var rankedDocStreamTopN = request.getTopN() == null + ? rankedDocsByRelevanceStream + : rankedDocsByRelevanceStream.limit(request.getTopN()); + return new RankedDocsResults(rankedDocStreamTopN.toList()); + } + } + + private static List doParse(XContentParser parser) throws IOException { + var responseParser = ResponseParser.PARSER; + var responseObject = responseParser.apply(parser, null); + return responseObject.results.stream() + .map(result -> new RankedDocsResults.RankedDoc(result.index, result.relevanceScore, result.document)) + .toList(); + } + + private record ResponseObject(List results) { + private static final ParseField RESULTS = new ParseField("results"); + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "contextualai_rerank_response", + true, + args -> { + @SuppressWarnings("unchecked") + List results = (List) args[0]; + return new ResponseObject(results); + } + ); + + static { + PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), RankedDocEntry.PARSER, RESULTS); + } + } + + private record RankedDocEntry(Integer index, Float relevanceScore, @Nullable String document) { + + private static final ParseField INDEX = new ParseField("index"); + private static final ParseField RELEVANCE_SCORE = new ParseField("relevance_score"); + private static final ParseField DOCUMENT = new ParseField("document"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "contextualai_ranked_doc", + true, + args -> new RankedDocEntry((Integer) args[0], (Float) args[1], (String) args[2]) + ); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), INDEX); + PARSER.declareFloat(ConstructingObjectParser.constructorArg(), RELEVANCE_SCORE); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), DOCUMENT); + } + } + + private static class ResponseParser { + private static final ConstructingObjectParser PARSER = ResponseObject.PARSER; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiResponseHandlerTests.java new file mode 100644 index 0000000000000..be1c5848286b7 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiResponseHandlerTests.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.contextualai; + +import org.apache.http.Header; +import org.apache.http.HeaderElement; +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.hamcrest.MatcherAssert; + +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.core.Is.is; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ContextualAiResponseHandlerTests extends ESTestCase { + + public void testCheckForFailureStatusCode_DoesNotThrowFor200() { + callCheckForFailureStatusCode(200, "id"); + } + + public void testCheckForFailureStatusCode_ThrowsFor500_WithShouldRetryTrue() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(500, "id")); + assertTrue(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [500]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor503() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(503, "id")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [503]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor429() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(429, "id")); + assertTrue(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a rate limit status code for request from inference entity id [id] status [429]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.TOO_MANY_REQUESTS)); + } + + public void testCheckForFailureStatusCode_ThrowsFor401() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(401, "inferenceEntityId")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString( + "Received an authentication error status code for request from inference entity id [inferenceEntityId] status [401]" + ) + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.UNAUTHORIZED)); + } + + public void testCheckForFailureStatusCode_ThrowsFor400() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(400, "id")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received an unsuccessful status code for request from inference entity id [id] status [400]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor300() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(300, "id")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Unhandled redirection for request from inference entity id [id] status [300]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.MULTIPLE_CHOICES)); + } + + private static void callCheckForFailureStatusCode(int statusCode, String modelId) { + callCheckForFailureStatusCode(statusCode, null, modelId); + } + + private static void callCheckForFailureStatusCode(int statusCode, @Nullable String errorMessage, String modelId) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + when(statusLine.toString()).thenReturn("HTTP/1.1 " + statusCode + " Error"); + + var httpResponse = mock(HttpResponse.class); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + var header = mock(Header.class); + when(header.getElements()).thenReturn(new HeaderElement[] {}); + when(httpResponse.getFirstHeader(anyString())).thenReturn(header); + + var mockRequest = mock(Request.class); + when(mockRequest.getInferenceEntityId()).thenReturn(modelId); + var httpResult = new HttpResult(httpResponse, errorMessage == null ? new byte[] {} : errorMessage.getBytes(StandardCharsets.UTF_8)); + var handler = new ContextualAiResponseHandler("", (request, result) -> null, false); + + handler.checkForFailureStatusCode(mockRequest, httpResult); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiServiceTests.java new file mode 100644 index 0000000000000..ed663c52489f9 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiServiceTests.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.contextualai; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; + +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.core.Is.is; +import static org.mockito.Mockito.mock; + +public class ContextualAiServiceTests extends ESTestCase { + + private ThreadPool threadPool; + + @Override + public void setUp() throws Exception { + super.setUp(); + threadPool = new TestThreadPool(getTestName()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + terminate(threadPool); + } + + public void testRerankerWindowSize() { + var service = createContextualAiService(); + assertThat(service.rerankerWindowSize("any-model"), is(5500)); + } + + private ContextualAiService createContextualAiService() { + return new ContextualAiService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/contextualai/response/ContextualAiErrorResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/contextualai/response/ContextualAiErrorResponseEntityTests.java new file mode 100644 index 0000000000000..6f11675b883fd --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/contextualai/response/ContextualAiErrorResponseEntityTests.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.contextualai.response; + +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ContextualAiErrorResponseEntityTests extends ESTestCase { + + public void testFromResponse() { + var statusLine = mock(StatusLine.class); + when(statusLine.toString()).thenReturn("HTTP/1.1 400 Bad Request"); + + var httpResponse = mock(HttpResponse.class); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + + var httpResult = new HttpResult(httpResponse, new byte[0]); + + ErrorResponse errorResponse = ContextualAiErrorResponseEntity.fromResponse(httpResult); + + assertThat(errorResponse.getErrorMessage(), is("HTTP/1.1 400 Bad Request")); + } + + public void testFromResponse_ServerError() { + var statusLine = mock(StatusLine.class); + when(statusLine.toString()).thenReturn("HTTP/1.1 500 Internal Server Error"); + + var httpResponse = mock(HttpResponse.class); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + + var httpResult = new HttpResult(httpResponse, new byte[0]); + + ErrorResponse errorResponse = ContextualAiErrorResponseEntity.fromResponse(httpResult); + + assertThat(errorResponse.getErrorMessage(), is("HTTP/1.1 500 Internal Server Error")); + } +}