Skip to content

Commit 5d0c5e0

Browse files
Add Ibm Granite Completion and Chat Completion support (#129146)
* Add Ibm Granite Completion and Chat Completion support * Apply suggestions * remove ibm watsonx transport version constant * update transport version
1 parent 82b6e45 commit 5d0c5e0

File tree

36 files changed

+1395
-189
lines changed

36 files changed

+1395
-189
lines changed

docs/changelog/129146.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 129146
2+
summary: "[ML] Add IBM watsonx Completion and Chat Completion support to the Inference Plugin"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ static TransportVersion def(int id) {
328328
public static final TransportVersion MAPPINGS_IN_DATA_STREAMS = def(9_112_0_00);
329329
public static final TransportVersion PROJECT_STATE_REGISTRY_RECORDS_DELETIONS = def(9_113_0_00);
330330
public static final TransportVersion ESQL_SERIALIZE_TIMESERIES_FIELD_TYPE = def(9_114_0_00);
331-
331+
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED = def(9_115_0_00);
332332
/*
333333
* STOP! READ THIS FIRST! No, really,
334334
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
151151
"completion_test_service",
152152
"hugging_face",
153153
"amazon_sagemaker",
154-
"mistral"
154+
"mistral",
155+
"watsonxai"
155156
).toArray()
156157
)
157158
);
@@ -169,7 +170,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
169170
"hugging_face",
170171
"amazon_sagemaker",
171172
"googlevertexai",
172-
"mistral"
173+
"mistral",
174+
"watsonxai"
173175
).toArray()
174176
)
175177
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
9696
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
9797
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
98+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionServiceSettings;
9899
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
99100
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
100101
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
@@ -469,6 +470,13 @@ private static void addIbmWatsonxNamedWritables(List<NamedWriteableRegistry.Entr
469470
namedWriteables.add(
470471
new NamedWriteableRegistry.Entry(TaskSettings.class, IbmWatsonxRerankTaskSettings.NAME, IbmWatsonxRerankTaskSettings::new)
471472
);
473+
namedWriteables.add(
474+
new NamedWriteableRegistry.Entry(
475+
ServiceSettings.class,
476+
IbmWatsonxChatCompletionServiceSettings.NAME,
477+
IbmWatsonxChatCompletionServiceSettings::new
478+
)
479+
);
472480
}
473481

474482
private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ public DefaultSecretSettings getSecretSettings() {
8080

8181
/**
8282
* Accepts a visitor to create an executable action. The returned action will not return documents in the response.
83-
* @param visitor _
84-
* @param taskSettings _
83+
* @param visitor Interface for creating {@link ExecutableAction} instances for Cohere models.
84+
* @param taskSettings Settings in the request to override the model's defaults
8585
* @return the rerank action
8686
*/
8787
@Override
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.ibmwatsonx;
9+
10+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
11+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
12+
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
13+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
14+
import org.elasticsearch.xpack.inference.external.request.Request;
15+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity;
16+
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
17+
18+
import java.util.Locale;
19+
20+
/**
21+
* Handles streaming chat completion responses and error parsing for Watsonx inference endpoints.
22+
* Adapts the OpenAI handler to support Watsonx's error schema.
23+
*/
24+
public class IbmWatsonUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
25+
26+
private static final String WATSONX_ERROR = "watsonx_error";
27+
28+
public IbmWatsonUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
29+
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
30+
}
31+
32+
@Override
33+
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
34+
assert request.isStreaming() : "Only streaming requests support this format";
35+
var responseStatusCode = result.response().getStatusLine().getStatusCode();
36+
if (request.isStreaming()) {
37+
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
38+
var restStatus = toRestStatus(responseStatusCode);
39+
return errorResponse instanceof IbmWatsonxErrorResponseEntity
40+
? new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT))
41+
: new UnifiedChatCompletionException(
42+
restStatus,
43+
errorMessage,
44+
createErrorType(errorResponse),
45+
restStatus.name().toLowerCase(Locale.ROOT)
46+
);
47+
} else {
48+
return super.buildError(message, request, result, errorResponse);
49+
}
50+
}
51+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.ibmwatsonx;
9+
10+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
11+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity;
12+
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
13+
14+
public class IbmWatsonxCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
15+
16+
/**
17+
* Constructs a IbmWatsonxCompletionResponseHandler with the specified request type and response parser.
18+
*
19+
* @param requestType The type of request being handled (e.g., "IBM watsonx completions").
20+
* @param parseFunction The function to parse the response.
21+
*/
22+
public IbmWatsonxCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
23+
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
24+
}
25+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public class IbmWatsonxEmbeddingsRequestManager extends IbmWatsonxRequestManager
3535
private static final ResponseHandler HANDLER = createEmbeddingsHandler();
3636

3737
private static ResponseHandler createEmbeddingsHandler() {
38-
return new IbmWatsonxResponseHandler("ibm watsonx embeddings", IbmWatsonxEmbeddingsResponseEntity::fromResponse);
38+
return new IbmWatsonxResponseHandler("IBM watsonx embeddings", IbmWatsonxEmbeddingsResponseEntity::fromResponse);
3939
}
4040

4141
private final IbmWatsonxEmbeddingsModel model;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxModel.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@
77

88
package org.elasticsearch.xpack.inference.services.ibmwatsonx;
99

10-
import org.elasticsearch.inference.Model;
1110
import org.elasticsearch.inference.ModelConfigurations;
1211
import org.elasticsearch.inference.ModelSecrets;
1312
import org.elasticsearch.inference.ServiceSettings;
1413
import org.elasticsearch.inference.TaskSettings;
1514
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
15+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
1616
import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionVisitor;
17+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
1718

1819
import java.util.Map;
1920
import java.util.Objects;
2021

21-
public abstract class IbmWatsonxModel extends Model {
22+
public abstract class IbmWatsonxModel extends RateLimitGroupingModel {
2223

2324
private final IbmWatsonxRateLimitServiceSettings rateLimitServiceSettings;
2425

@@ -49,4 +50,14 @@ public IbmWatsonxModel(IbmWatsonxModel model, TaskSettings taskSettings) {
4950
public IbmWatsonxRateLimitServiceSettings rateLimitServiceSettings() {
5051
return rateLimitServiceSettings;
5152
}
53+
54+
@Override
55+
public int rateLimitGroupingHash() {
56+
return Objects.hash(this.rateLimitServiceSettings);
57+
}
58+
59+
@Override
60+
public RateLimitSettings rateLimitSettings() {
61+
return this.rateLimitServiceSettings().rateLimitSettings();
62+
}
5263
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxRerankRequestManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public class IbmWatsonxRerankRequestManager extends IbmWatsonxRequestManager {
3131

3232
private static ResponseHandler createIbmWatsonxResponseHandler() {
3333
return new IbmWatsonxResponseHandler(
34-
"ibm watsonx rerank",
34+
"IBM watsonx rerank",
3535
(request, response) -> IbmWatsonxRankedResponseEntity.fromResponse(response)
3636
);
3737
}

0 commit comments

Comments
 (0)