Skip to content

Commit 734a6b7

Browse files
[ML] Add ContextualAI inference service (#134933)
1 parent aaeca25 commit 734a6b7

File tree

21 files changed

+1468
-1
lines changed

21 files changed

+1468
-1
lines changed

docs/changelog/134933.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 134933
2+
summary: Add ContextualAI Rerank Service Implementation to the Inference API
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
9175000
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
semantic_search_ccs_support,9174000
1+
contextual_ai_service,9175000

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
5959
"azureaistudio",
6060
"azureopenai",
6161
"cohere",
62+
"contextualai",
6263
"deepseek",
6364
"elastic",
6465
"elasticsearch",
@@ -134,6 +135,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
134135
"alibabacloud-ai-search",
135136
"azureaistudio",
136137
"cohere",
138+
"contextualai",
137139
"elasticsearch",
138140
"googlevertexai",
139141
"jinaai",

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
128128
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
129129
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
130+
import org.elasticsearch.xpack.inference.services.contextualai.ContextualAiService;
130131
import org.elasticsearch.xpack.inference.services.custom.CustomService;
131132
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
132133
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
@@ -410,6 +411,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
410411
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get(), context),
411412
context -> new OpenAiService(httpFactory.get(), serviceComponents.get(), context),
412413
context -> new CohereService(httpFactory.get(), serviceComponents.get(), context),
414+
context -> new ContextualAiService(httpFactory.get(), serviceComponents.get(), context),
413415
context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get(), context),
414416
context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get(), context),
415417
context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get(), context),
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.contextualai;
9+
10+
import org.elasticsearch.common.settings.SecureString;
11+
import org.elasticsearch.core.Nullable;
12+
import org.elasticsearch.inference.ModelConfigurations;
13+
import org.elasticsearch.inference.ModelSecrets;
14+
import org.elasticsearch.inference.ServiceSettings;
15+
import org.elasticsearch.inference.TaskSettings;
16+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
17+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
18+
import org.elasticsearch.xpack.inference.services.ServiceUtils;
19+
import org.elasticsearch.xpack.inference.services.contextualai.action.ContextualAiActionVisitor;
20+
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
21+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
22+
23+
import java.net.URI;
24+
import java.util.Map;
25+
import java.util.Objects;
26+
27+
public abstract class ContextualAiModel extends RateLimitGroupingModel {
28+
29+
private final SecureString apiKey;
30+
private final ContextualAiRateLimitServiceSettings rateLimitServiceSettings;
31+
32+
public ContextualAiModel(
33+
ModelConfigurations configurations,
34+
ModelSecrets secrets,
35+
@Nullable ApiKeySecrets apiKeySecrets,
36+
ContextualAiRateLimitServiceSettings rateLimitServiceSettings
37+
) {
38+
super(configurations, secrets);
39+
40+
this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings);
41+
apiKey = ServiceUtils.apiKey(apiKeySecrets);
42+
}
43+
44+
protected ContextualAiModel(ContextualAiModel model, TaskSettings taskSettings) {
45+
super(model, taskSettings);
46+
47+
rateLimitServiceSettings = model.rateLimitServiceSettings();
48+
apiKey = model.apiKey();
49+
}
50+
51+
protected ContextualAiModel(ContextualAiModel model, ServiceSettings serviceSettings) {
52+
super(model, serviceSettings);
53+
54+
rateLimitServiceSettings = model.rateLimitServiceSettings();
55+
apiKey = model.apiKey();
56+
}
57+
58+
public SecureString apiKey() {
59+
return apiKey;
60+
}
61+
62+
public ContextualAiRateLimitServiceSettings rateLimitServiceSettings() {
63+
return rateLimitServiceSettings;
64+
}
65+
66+
public abstract ExecutableAction accept(ContextualAiActionVisitor creator, Map<String, Object> taskSettings);
67+
68+
public RateLimitSettings rateLimitSettings() {
69+
return rateLimitServiceSettings.rateLimitSettings();
70+
}
71+
72+
public int rateLimitGroupingHash() {
73+
return apiKey().hashCode();
74+
}
75+
76+
public URI baseUri() {
77+
return rateLimitServiceSettings.uri();
78+
}
79+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.contextualai;
9+
10+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
11+
12+
import java.net.URI;
13+
14+
public interface ContextualAiRateLimitServiceSettings {
15+
RateLimitSettings rateLimitSettings();
16+
17+
URI uri();
18+
19+
String modelId();
20+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.contextualai;
9+
10+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
11+
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
12+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
13+
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
14+
import org.elasticsearch.xpack.inference.external.request.Request;
15+
import org.elasticsearch.xpack.inference.services.contextualai.response.ContextualAiErrorResponseEntity;
16+
17+
/**
18+
* Response handler for ContextualAI API calls.
19+
*/
20+
public class ContextualAiResponseHandler extends BaseResponseHandler {
21+
22+
public ContextualAiResponseHandler(String requestType, ResponseParser parseFunction, boolean supportsStreaming) {
23+
super(requestType, parseFunction, ContextualAiErrorResponseEntity::fromResponse, supportsStreaming);
24+
}
25+
26+
@Override
27+
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
28+
if (result.isSuccessfulResponse()) {
29+
return;
30+
}
31+
32+
// handle error codes
33+
int statusCode = result.response().getStatusLine().getStatusCode();
34+
if (statusCode == 500) {
35+
throw new RetryException(true, buildError(SERVER_ERROR, request, result));
36+
} else if (statusCode > 500) {
37+
throw new RetryException(false, buildError(SERVER_ERROR, request, result));
38+
} else if (statusCode == 429) {
39+
throw new RetryException(true, buildError(RATE_LIMIT, request, result));
40+
} else if (statusCode == 401) {
41+
throw new RetryException(false, buildError(AUTHENTICATION, request, result));
42+
} else if (statusCode >= 300 && statusCode < 400) {
43+
throw new RetryException(false, buildError(REDIRECTION, request, result));
44+
} else {
45+
throw new RetryException(false, buildError(UNSUCCESSFUL, request, result));
46+
}
47+
}
48+
}

0 commit comments

Comments
 (0)