Skip to content

Commit a7205ae

Browse files
Add AI21 support to Inference Plugin (#131238)
* Add Ai21ActionVisitor interface * Implement AI21 chat completion service and related components * Make MODEL_FIELD and OBJECT_FIELD optional in OpenAiUnifiedStreamingProcessor * Add changelog * Add unit tests for Ai21 chat completion service and settings * Add unit tests for Ai21ChatCompletionResponseHandler * Make API_COMPLETIONS_PATH public and add unit tests for Ai21ChatCompletionRequestEntity and Ai21ChatCompletionRequest * Refactor error handling and add unit tests for Ai21 service * Refactor Ai21 service constructors and error handling for improved clarity and functionality * Add Ai21ApiConstants class for API endpoint definitions * Refactor Ai21ChatCompletionModel to use URIBuilder for endpoint URL construction * Make fromString method public for improved accessibility * Use Strings utility for null and blank check in ErrorResponse.fromString method * Fix typos --------- Co-authored-by: Jonathan Buttner <[email protected]>
1 parent 419eeb9 commit a7205ae

34 files changed

+2361
-188
lines changed

docs/changelog/131238.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 113757
2+
summary: Added AI21 Completion and Chat Completion support to the Inference Plugin
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ static TransportVersion def(int id) {
356356
public static final TransportVersion PIPELINE_TRACKING_INFO = def(9_131_0_00);
357357
public static final TransportVersion COMPONENT_TEMPLATE_TRACKING_INFO = def(9_132_0_00);
358358
public static final TransportVersion TO_CHILD_BLOCK_JOIN_QUERY = def(9_133_0_00);
359+
public static final TransportVersion ML_INFERENCE_AI21_COMPLETION_ADDED = def(9_134_0_00);
359360

360361
/*
361362
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;
3232
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
3333
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
34+
import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionServiceSettings;
3435
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings;
3536
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettings;
3637
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettings;
@@ -178,6 +179,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
178179
addVoyageAINamedWriteables(namedWriteables);
179180
addCustomNamedWriteables(namedWriteables);
180181
addLlamaNamedWriteables(namedWriteables);
182+
addAi21NamedWriteables(namedWriteables);
181183

182184
addUnifiedNamedWriteables(namedWriteables);
183185

@@ -298,6 +300,17 @@ private static void addLlamaNamedWriteables(List<NamedWriteableRegistry.Entry> n
298300
// no task settings for Llama
299301
}
300302

303+
private static void addAi21NamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
304+
namedWriteables.add(
305+
new NamedWriteableRegistry.Entry(
306+
ServiceSettings.class,
307+
Ai21ChatCompletionServiceSettings.NAME,
308+
Ai21ChatCompletionServiceSettings::new
309+
)
310+
);
311+
// no task settings for AI21
312+
}
313+
301314
private static void addAzureAiStudioNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
302315
namedWriteables.add(
303316
new NamedWriteableRegistry.Entry(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
import org.elasticsearch.xpack.inference.rest.RestStreamInferenceAction;
114114
import org.elasticsearch.xpack.inference.rest.RestUpdateInferenceModelAction;
115115
import org.elasticsearch.xpack.inference.services.ServiceComponents;
116+
import org.elasticsearch.xpack.inference.services.ai21.Ai21Service;
116117
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService;
117118
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockService;
118119
import org.elasticsearch.xpack.inference.services.amazonbedrock.client.AmazonBedrockRequestSender;
@@ -413,6 +414,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
413414
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get(), context),
414415
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context),
415416
context -> new LlamaService(httpFactory.get(), serviceComponents.get(), context),
417+
context -> new Ai21Service(httpFactory.get(), serviceComponents.get(), context),
416418
ElasticsearchInternalService::new,
417419
context -> new CustomService(httpFactory.get(), serviceComponents.get(), context)
418420
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ErrorResponse.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77

88
package org.elasticsearch.xpack.inference.external.http.retry;
99

10+
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
12+
13+
import java.nio.charset.StandardCharsets;
1014
import java.util.Objects;
1115

1216
public class ErrorResponse {
@@ -46,4 +50,39 @@ public boolean equals(Object o) {
4650
public int hashCode() {
4751
return Objects.hash(errorMessage, errorStructureFound);
4852
}
53+
54+
/**
55+
* Creates an ErrorResponse from the given HttpResult.
56+
* Attempts to read the body as a UTF-8 string and constructs an ErrorResponse.
57+
* If reading fails, returns a generic UNDEFINED_ERROR.
58+
*
59+
* @param response the HttpResult containing the error response
60+
* @return an ErrorResponse instance
61+
*/
62+
public static ErrorResponse fromResponse(HttpResult response) {
63+
try {
64+
String errorMessage = new String(response.body(), StandardCharsets.UTF_8);
65+
return new ErrorResponse(errorMessage);
66+
} catch (Exception e) {
67+
// swallow the error
68+
}
69+
70+
return ErrorResponse.UNDEFINED_ERROR;
71+
}
72+
73+
/**
74+
* Parses a string response into an ErrorResponse.
75+
* If the string is not blank, creates a new ErrorResponse with the string as the error message.
76+
* If the string is blank, returns UNDEFINED_ERROR.
77+
*
78+
* @param response the error response as a string
79+
* @return an ErrorResponse instance
80+
*/
81+
public static ErrorResponse fromString(String response) {
82+
if (Strings.isNullOrBlank(response) == false) {
83+
return new ErrorResponse(response);
84+
} else {
85+
return ErrorResponse.UNDEFINED_ERROR;
86+
}
87+
}
4988
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.ai21;
9+
10+
import org.elasticsearch.inference.ModelConfigurations;
11+
import org.elasticsearch.inference.ModelSecrets;
12+
import org.elasticsearch.inference.ServiceSettings;
13+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
14+
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
15+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
16+
17+
import java.net.URI;
18+
import java.net.URISyntaxException;
19+
import java.util.Objects;
20+
21+
/**
22+
* Represents a AI21 model that can be used for inference tasks.
23+
* This class extends RateLimitGroupingModel to handle rate limiting based on model and API key.
24+
*/
25+
public abstract class Ai21Model extends RateLimitGroupingModel {
26+
protected URI uri;
27+
protected RateLimitSettings rateLimitSettings;
28+
29+
protected Ai21Model(ModelConfigurations configurations, ModelSecrets secrets) {
30+
super(configurations, secrets);
31+
}
32+
33+
protected Ai21Model(RateLimitGroupingModel model, ServiceSettings serviceSettings) {
34+
super(model, serviceSettings);
35+
}
36+
37+
public URI uri() {
38+
return this.uri;
39+
}
40+
41+
@Override
42+
public RateLimitSettings rateLimitSettings() {
43+
return this.rateLimitSettings;
44+
}
45+
46+
@Override
47+
public int rateLimitGroupingHash() {
48+
return Objects.hash(getServiceSettings().modelId(), getSecretSettings().apiKey());
49+
}
50+
51+
// Needed for testing only
52+
public void setURI(String newUri) {
53+
try {
54+
this.uri = new URI(newUri);
55+
} catch (URISyntaxException e) {
56+
// swallow any error
57+
}
58+
}
59+
60+
@Override
61+
public DefaultSecretSettings getSecretSettings() {
62+
return (DefaultSecretSettings) super.getSecretSettings();
63+
}
64+
}

0 commit comments

Comments
 (0)