Skip to content

Commit 78ab1da

Browse files
Add Ibm Granite Completion and Chat Completion support
1 parent 6e67fac commit 78ab1da

22 files changed

+1210
-15
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ static TransportVersion def(int id) {
192192
public static final TransportVersion ESQL_REGEX_MATCH_WITH_CASE_INSENSITIVITY_8_19 = def(8_841_0_44);
193193
public static final TransportVersion ESQL_QUERY_PLANNING_DURATION_8_19 = def(8_841_0_45);
194194
public static final TransportVersion SEARCH_SOURCE_EXCLUDE_VECTORS_PARAM_8_19 = def(8_841_0_46);
195+
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED_8_19 = def(8_841_0_47);
195196
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
196197
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
197198
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -288,6 +289,8 @@ static TransportVersion def(int id) {
288289
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED = def(9_090_0_00);
289290
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST = def(9_091_0_00);
290291
public static final TransportVersion SEARCH_SOURCE_EXCLUDE_VECTORS_PARAM = def(9_092_0_00);
292+
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED = def(9_093_0_00);
293+
291294
/*
292295
* STOP! READ THIS FIRST! No, really,
293296
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
135135

136136
public void testGetServicesWithCompletionTaskType() throws IOException {
137137
List<Object> services = getServices(TaskType.COMPLETION);
138-
assertThat(services.size(), equalTo(14));
138+
assertThat(services.size(), equalTo(15));
139139

140140
var providers = providers(services);
141141

@@ -157,15 +157,16 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
157157
"completion_test_service",
158158
"hugging_face",
159159
"amazon_sagemaker",
160-
"mistral"
160+
"mistral",
161+
"watsonxai"
161162
).toArray()
162163
)
163164
);
164165
}
165166

166167
public void testGetServicesWithChatCompletionTaskType() throws IOException {
167168
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
168-
assertThat(services.size(), equalTo(8));
169+
assertThat(services.size(), equalTo(9));
169170

170171
var providers = providers(services);
171172

@@ -180,7 +181,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
180181
"hugging_face",
181182
"amazon_sagemaker",
182183
"googlevertexai",
183-
"mistral"
184+
"mistral",
185+
"watsonxai"
184186
).toArray()
185187
)
186188
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
9393
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
9494
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
95+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionServiceSettings;
9596
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
9697
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
9798
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
@@ -472,6 +473,13 @@ private static void addIbmWatsonxNamedWritables(List<NamedWriteableRegistry.Entr
472473
namedWriteables.add(
473474
new NamedWriteableRegistry.Entry(TaskSettings.class, IbmWatsonxRerankTaskSettings.NAME, IbmWatsonxRerankTaskSettings::new)
474475
);
476+
namedWriteables.add(
477+
new NamedWriteableRegistry.Entry(
478+
ServiceSettings.class,
479+
IbmWatsonxChatCompletionServiceSettings.NAME,
480+
IbmWatsonxChatCompletionServiceSettings::new
481+
)
482+
);
475483
}
476484

477485
private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.ibmwatsonx;
9+
10+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
11+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
12+
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
13+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
14+
import org.elasticsearch.xpack.inference.external.request.Request;
15+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity;
16+
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
17+
18+
import java.util.Locale;
19+
20+
/**
21+
* Handles streaming chat completion responses and error parsing for Watsonx inference endpoints.
22+
* Adapts the OpenAI handler to support Watsonx's error schema.
23+
*/
24+
public class IbmWatsonUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
25+
26+
private static final String WATSONX_ERROR = "watsonx_error";
27+
28+
public IbmWatsonUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
29+
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
30+
}
31+
32+
@Override
33+
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
34+
assert request.isStreaming() : "Only streaming requests support this format";
35+
var responseStatusCode = result.response().getStatusLine().getStatusCode();
36+
if (request.isStreaming()) {
37+
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
38+
var restStatus = toRestStatus(responseStatusCode);
39+
return errorResponse instanceof IbmWatsonxErrorResponseEntity
40+
? new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT))
41+
: new UnifiedChatCompletionException(
42+
restStatus,
43+
errorMessage,
44+
createErrorType(errorResponse),
45+
restStatus.name().toLowerCase(Locale.ROOT)
46+
);
47+
} else {
48+
return super.buildError(message, request, result, errorResponse);
49+
}
50+
}
51+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.ibmwatsonx;
9+
10+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
11+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity;
12+
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
13+
14+
public class IbmWatsonxCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
15+
16+
/**
17+
* Constructs a IbmWatsonxCompletionResponseHandler with the specified request type and response parser.
18+
*
19+
* @param requestType The type of request being handled (e.g., "Ibm WatsonX completions").
20+
* @param parseFunction The function to parse the response.
21+
*/
22+
public IbmWatsonxCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
23+
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
24+
}
25+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxModel.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,25 @@
77

88
package org.elasticsearch.xpack.inference.services.ibmwatsonx;
99

10-
import org.elasticsearch.inference.Model;
1110
import org.elasticsearch.inference.ModelConfigurations;
1211
import org.elasticsearch.inference.ModelSecrets;
1312
import org.elasticsearch.inference.ServiceSettings;
1413
import org.elasticsearch.inference.TaskSettings;
1514
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
15+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
1616
import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionVisitor;
17+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
1718

19+
import java.net.URI;
1820
import java.util.Map;
1921
import java.util.Objects;
2022

21-
public abstract class IbmWatsonxModel extends Model {
23+
public abstract class IbmWatsonxModel extends RateLimitGroupingModel {
2224

2325
private final IbmWatsonxRateLimitServiceSettings rateLimitServiceSettings;
2426

27+
protected URI uri;
28+
2529
public IbmWatsonxModel(
2630
ModelConfigurations configurations,
2731
ModelSecrets secrets,
@@ -49,4 +53,14 @@ public IbmWatsonxModel(IbmWatsonxModel model, TaskSettings taskSettings) {
4953
public IbmWatsonxRateLimitServiceSettings rateLimitServiceSettings() {
5054
return rateLimitServiceSettings;
5155
}
56+
57+
@Override
58+
public int rateLimitGroupingHash() {
59+
return Objects.hash(uri);
60+
}
61+
62+
@Override
63+
public RateLimitSettings rateLimitSettings() {
64+
return this.rateLimitServiceSettings().rateLimitSettings();
65+
}
5266
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030
import org.elasticsearch.rest.RestStatus;
3131
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
3232
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
33+
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
34+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
3335
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
36+
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
3437
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
3538
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
3639
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
@@ -40,14 +43,18 @@
4043
import org.elasticsearch.xpack.inference.services.ServiceComponents;
4144
import org.elasticsearch.xpack.inference.services.ServiceUtils;
4245
import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator;
46+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
4347
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
4448
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
49+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxChatCompletionRequest;
4550
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
51+
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
4652

4753
import java.util.EnumSet;
4854
import java.util.HashMap;
4955
import java.util.List;
5056
import java.util.Map;
57+
import java.util.Set;
5158

5259
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
5360
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
@@ -56,7 +63,6 @@
5663
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
5764
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
5865
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
59-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
6066
import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL;
6167
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.API_VERSION;
6268
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.EMBEDDING_MAX_BATCH_SIZE;
@@ -67,7 +73,15 @@ public class IbmWatsonxService extends SenderService {
6773
public static final String NAME = "watsonxai";
6874

6975
private static final String SERVICE_NAME = "IBM Watsonx";
70-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING);
76+
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
77+
TaskType.TEXT_EMBEDDING,
78+
TaskType.COMPLETION,
79+
TaskType.CHAT_COMPLETION
80+
);
81+
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new IbmWatsonUnifiedChatCompletionResponseHandler(
82+
"ibm watsonx chat completions",
83+
OpenAiChatCompletionResponseEntity::fromResponse
84+
);
7185

7286
public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
7387
super(factory, serviceComponents);
@@ -148,6 +162,14 @@ private static IbmWatsonxModel createModel(
148162
secretSettings,
149163
context
150164
);
165+
case CHAT_COMPLETION, COMPLETION -> new IbmWatsonxChatCompletionModel(
166+
inferenceEntityId,
167+
taskType,
168+
NAME,
169+
serviceSettings,
170+
secretSettings,
171+
context
172+
);
151173
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
152174
};
153175
}
@@ -236,6 +258,11 @@ public TransportVersion getMinimalSupportedVersion() {
236258
return TransportVersions.V_8_16_0;
237259
}
238260

261+
@Override
262+
public Set<TaskType> supportedStreamingTasks() {
263+
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
264+
}
265+
239266
@Override
240267
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
241268
if (model instanceof IbmWatsonxEmbeddingsModel embeddingsModel) {
@@ -291,7 +318,24 @@ protected void doUnifiedCompletionInfer(
291318
TimeValue timeout,
292319
ActionListener<InferenceServiceResults> listener
293320
) {
294-
throwUnsupportedUnifiedCompletionOperation(NAME);
321+
if (model instanceof IbmWatsonxChatCompletionModel == false) {
322+
listener.onFailure(createInvalidModelException(model));
323+
return;
324+
}
325+
326+
IbmWatsonxChatCompletionModel ibmWatsonxChatCompletionModel = (IbmWatsonxChatCompletionModel) model;
327+
var overriddenModel = IbmWatsonxChatCompletionModel.of(ibmWatsonxChatCompletionModel, inputs.getRequest());
328+
var manager = new GenericRequestManager<>(
329+
getServiceComponents().threadPool(),
330+
overriddenModel,
331+
UNIFIED_CHAT_COMPLETION_HANDLER,
332+
unifiedChatInput -> new IbmWatsonxChatCompletionRequest(unifiedChatInput, overriddenModel),
333+
UnifiedChatInput.class
334+
);
335+
var errorMessage = IbmWatsonxActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId());
336+
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
337+
338+
action.execute(inputs, timeout, listener);
295339
}
296340

297341
@Override

0 commit comments

Comments
 (0)