Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
78ab1da
Add Ibm Granite Completion and Chat Completion support
Evgenii-Kazannik May 28, 2025
f92f348
Apply suggestions
Evgenii-Kazannik Jun 10, 2025
510e3c5
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jun 13, 2025
d6d19be
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jun 17, 2025
a6eaec6
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jun 23, 2025
9faf6f6
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jun 30, 2025
b23bdfb
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
136416d
remove ibm watsonx transport version constant
Evgenii-Kazannik Jul 2, 2025
ff6ccf5
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
80537a4
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
b44bab6
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
b1a76c3
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
1bf81ed
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
e70752f
Merge remote-tracking branch 'origin/Add-IBM-Granite-support-for-comp…
Evgenii-Kazannik Jul 2, 2025
b219e72
update transport version
Evgenii-Kazannik Jul 2, 2025
c950380
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
bf882a0
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
f9b086f
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
8e08b9e
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
08ab2f6
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
4ed865c
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_REGEX_MATCH_WITH_CASE_INSENSITIVITY_8_19 = def(8_841_0_44);
public static final TransportVersion ESQL_QUERY_PLANNING_DURATION_8_19 = def(8_841_0_45);
public static final TransportVersion SEARCH_SOURCE_EXCLUDE_VECTORS_PARAM_8_19 = def(8_841_0_46);
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED_8_19 = def(8_841_0_47);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -288,6 +289,8 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED = def(9_090_0_00);
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST = def(9_091_0_00);
public static final TransportVersion SEARCH_SOURCE_EXCLUDE_VECTORS_PARAM = def(9_092_0_00);
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED = def(9_093_0_00);

/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {

public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(14));
assertThat(services.size(), equalTo(15));

var providers = providers(services);

Expand All @@ -157,15 +157,16 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"completion_test_service",
"hugging_face",
"amazon_sagemaker",
"mistral"
"mistral",
"watsonxai"
).toArray()
)
);
}

public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(8));
assertThat(services.size(), equalTo(9));

var providers = providers(services);

Expand All @@ -180,7 +181,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
"hugging_face",
"amazon_sagemaker",
"googlevertexai",
"mistral"
"mistral",
"watsonxai"
).toArray()
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
Expand Down Expand Up @@ -472,6 +473,13 @@ private static void addIbmWatsonxNamedWritables(List<NamedWriteableRegistry.Entr
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, IbmWatsonxRerankTaskSettings.NAME, IbmWatsonxRerankTaskSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
IbmWatsonxChatCompletionServiceSettings.NAME,
IbmWatsonxChatCompletionServiceSettings::new
)
);
}

private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.ibmwatsonx;

import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity;
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;

import java.util.Locale;

/**
* Handles streaming chat completion responses and error parsing for Watsonx inference endpoints.
* Adapts the OpenAI handler to support Watsonx's error schema.
*/
public class IbmWatsonUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {

private static final String WATSONX_ERROR = "watsonx_error";

public IbmWatsonUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
}

@Override
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
assert request.isStreaming() : "Only streaming requests support this format";
var responseStatusCode = result.response().getStatusLine().getStatusCode();
if (request.isStreaming()) {
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
var restStatus = toRestStatus(responseStatusCode);
return errorResponse instanceof IbmWatsonxErrorResponseEntity
? new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT))
: new UnifiedChatCompletionException(
restStatus,
errorMessage,
createErrorType(errorResponse),
restStatus.name().toLowerCase(Locale.ROOT)
);
} else {
return super.buildError(message, request, result, errorResponse);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.ibmwatsonx;

import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity;
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;

public class IbmWatsonxCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {

/**
* Constructs a IbmWatsonxCompletionResponseHandler with the specified request type and response parser.
*
* @param requestType The type of request being handled (e.g., "Ibm WatsonX completions").
* @param parseFunction The function to parse the response.
*/
public IbmWatsonxCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,25 @@

package org.elasticsearch.xpack.inference.services.ibmwatsonx;

import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.net.URI;
import java.util.Map;
import java.util.Objects;

public abstract class IbmWatsonxModel extends Model {
public abstract class IbmWatsonxModel extends RateLimitGroupingModel {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify why this needs to be a RateLimitGroupingModel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type needs to be used in GenericRequestManager
which I believe is also going to handle the requests for other tasks in the future


private final IbmWatsonxRateLimitServiceSettings rateLimitServiceSettings;

protected URI uri;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is URI only going to be used for completion/chat completion use cases? If yes, can it be in the completion model implementation instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed uri from the model.
Uri is to be set during an inference endpoint creation as part of the service settings.
It's not part of the IBM watsonx.ai API for completions.
Thanks


public IbmWatsonxModel(
ModelConfigurations configurations,
ModelSecrets secrets,
Expand Down Expand Up @@ -49,4 +53,14 @@ public IbmWatsonxModel(IbmWatsonxModel model, TaskSettings taskSettings) {
public IbmWatsonxRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}

@Override
public int rateLimitGroupingHash() {
return Objects.hash(uri);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this not need to include the rateLimitServiceSettings?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Including both back then would have made hashing more accurate. Thank you.
Eventually I removed uri from the model as it's not needed there.
Therefore, only rateLimitServiceSettings is hashed now.
I made an update

}

@Override
public RateLimitSettings rateLimitSettings() {
return this.rateLimitServiceSettings().rateLimitSettings();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
Expand All @@ -40,14 +43,18 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;

import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
Expand All @@ -56,7 +63,6 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.API_VERSION;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.EMBEDDING_MAX_BATCH_SIZE;
Expand All @@ -67,7 +73,15 @@ public class IbmWatsonxService extends SenderService {
public static final String NAME = "watsonxai";

private static final String SERVICE_NAME = "IBM Watsonx";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING);
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.COMPLETION,
TaskType.CHAT_COMPLETION
);
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new IbmWatsonUnifiedChatCompletionResponseHandler(
"ibm watsonx chat completions",
OpenAiChatCompletionResponseEntity::fromResponse
);

public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
Expand Down Expand Up @@ -148,6 +162,14 @@ private static IbmWatsonxModel createModel(
secretSettings,
context
);
case CHAT_COMPLETION, COMPLETION -> new IbmWatsonxChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
secretSettings,
context
);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}
Expand Down Expand Up @@ -236,6 +258,11 @@ public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.V_8_16_0;
}

@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
}

@Override
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
if (model instanceof IbmWatsonxEmbeddingsModel embeddingsModel) {
Expand Down Expand Up @@ -291,7 +318,24 @@ protected void doUnifiedCompletionInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throwUnsupportedUnifiedCompletionOperation(NAME);
if (model instanceof IbmWatsonxChatCompletionModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}

IbmWatsonxChatCompletionModel ibmWatsonxChatCompletionModel = (IbmWatsonxChatCompletionModel) model;
var overriddenModel = IbmWatsonxChatCompletionModel.of(ibmWatsonxChatCompletionModel, inputs.getRequest());
var manager = new GenericRequestManager<>(
getServiceComponents().threadPool(),
overriddenModel,
UNIFIED_CHAT_COMPLETION_HANDLER,
unifiedChatInput -> new IbmWatsonxChatCompletionRequest(unifiedChatInput, overriddenModel),
UnifiedChatInput.class
);
var errorMessage = IbmWatsonxActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId());
var action = new SenderExecutableAction(getSender(), manager, errorMessage);

action.execute(inputs, timeout, listener);
}

@Override
Expand Down
Loading