Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/134933.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 134933
summary: Add ContextualAI Rerank Service Implementation to the Inference API
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
9175000
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.2.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
semantic_search_ccs_support,9174000
contextual_ai_service,9175000
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"contextualai",
"deepseek",
"elastic",
"elasticsearch",
Expand Down Expand Up @@ -134,6 +135,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
"alibabacloud-ai-search",
"azureaistudio",
"cohere",
"contextualai",
"elasticsearch",
"googlevertexai",
"jinaai",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.contextualai.ContextualAiService;
import org.elasticsearch.xpack.inference.services.custom.CustomService;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
Expand Down Expand Up @@ -410,6 +411,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get(), context),
context -> new OpenAiService(httpFactory.get(), serviceComponents.get(), context),
context -> new CohereService(httpFactory.get(), serviceComponents.get(), context),
context -> new ContextualAiService(httpFactory.get(), serviceComponents.get(), context),
context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get(), context),
context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get(), context),
context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get(), context),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.contextualai;

import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.contextualai.action.ContextualAiActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.net.URI;
import java.util.Map;
import java.util.Objects;

public abstract class ContextualAiModel extends RateLimitGroupingModel {

private final SecureString apiKey;
private final ContextualAiRateLimitServiceSettings rateLimitServiceSettings;

public ContextualAiModel(
ModelConfigurations configurations,
ModelSecrets secrets,
@Nullable ApiKeySecrets apiKeySecrets,
ContextualAiRateLimitServiceSettings rateLimitServiceSettings
) {
super(configurations, secrets);

this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings);
apiKey = ServiceUtils.apiKey(apiKeySecrets);
}

protected ContextualAiModel(ContextualAiModel model, TaskSettings taskSettings) {
super(model, taskSettings);

rateLimitServiceSettings = model.rateLimitServiceSettings();
apiKey = model.apiKey();
}

protected ContextualAiModel(ContextualAiModel model, ServiceSettings serviceSettings) {
super(model, serviceSettings);

rateLimitServiceSettings = model.rateLimitServiceSettings();
apiKey = model.apiKey();
}

public SecureString apiKey() {
return apiKey;
}

public ContextualAiRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}

public abstract ExecutableAction accept(ContextualAiActionVisitor creator, Map<String, Object> taskSettings);

public RateLimitSettings rateLimitSettings() {
return rateLimitServiceSettings.rateLimitSettings();
}

public int rateLimitGroupingHash() {
return apiKey().hashCode();
}

public URI baseUri() {
return rateLimitServiceSettings.uri();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.contextualai;

import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.net.URI;

public interface ContextualAiRateLimitServiceSettings {
RateLimitSettings rateLimitSettings();

URI uri();

String modelId();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.contextualai;

import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.contextualai.response.ContextualAiErrorResponseEntity;

/**
* Response handler for ContextualAI API calls.
*/
public class ContextualAiResponseHandler extends BaseResponseHandler {

public ContextualAiResponseHandler(String requestType, ResponseParser parseFunction, boolean supportsStreaming) {
super(requestType, parseFunction, ContextualAiErrorResponseEntity::fromResponse, supportsStreaming);
}

@Override
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
if (result.isSuccessfulResponse()) {
return;
}

// handle error codes
int statusCode = result.response().getStatusLine().getStatusCode();
if (statusCode == 500) {
throw new RetryException(true, buildError(SERVER_ERROR, request, result));
} else if (statusCode > 500) {
throw new RetryException(false, buildError(SERVER_ERROR, request, result));
} else if (statusCode == 429) {
throw new RetryException(true, buildError(RATE_LIMIT, request, result));
} else if (statusCode == 401) {
throw new RetryException(false, buildError(AUTHENTICATION, request, result));
} else if (statusCode >= 300 && statusCode < 400) {
throw new RetryException(false, buildError(REDIRECTION, request, result));
} else {
throw new RetryException(false, buildError(UNSUCCESSFUL, request, result));
}
}
}
Loading