Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/117176.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 117176
summary: Integrate IBM watsonx to Inference API for re-ranking task
area: Experiences
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ static TransportVersion def(int id) {
public static final TransportVersion LINEAR_RETRIEVER_SUPPORT = def(8_837_00_0);
public static final TransportVersion TIMEOUT_GET_PARAM_FOR_RESOLVE_CLUSTER = def(8_838_00_0);
public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_00_0);
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_RERANK_ADDED = def(8_840_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
Expand Down Expand Up @@ -364,6 +366,17 @@ private static void addIbmWatsonxNamedWritables(List<NamedWriteableRegistry.Entr
IbmWatsonxEmbeddingsServiceSettings::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
IbmWatsonxRerankServiceSettings.NAME,
IbmWatsonxRerankServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, IbmWatsonxRerankTaskSettings.NAME, IbmWatsonxRerankTaskSettings::new)
);
}

private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.IbmWatsonxEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.IbmWatsonxRerankRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;

import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;

public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {

private final Sender sender;
private final ServiceComponents serviceComponents;

Expand All @@ -41,6 +42,17 @@ public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Obje
);
}

@Override
public ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = IbmWatsonxRerankModel.of(model, taskSettings);
var requestCreator = IbmWatsonxRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
overriddenModel.getServiceSettings().uri(),
"Ibm Watsonx rerank"
);
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}

protected IbmWatsonxEmbeddingsRequestManager getEmbeddingsRequestManager(
IbmWatsonxEmbeddingsModel model,
Truncator truncator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;

import java.util.Map;

public interface IbmWatsonxActionVisitor {
ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings);

ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.ibmwatsonx.IbmWatsonxResponseHandler;
import org.elasticsearch.xpack.inference.external.request.ibmwatsonx.IbmWatsonxRerankRequest;
import org.elasticsearch.xpack.inference.external.response.ibmwatsonx.IbmWatsonxRankedResponseEntity;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

public class IbmWatsonxRerankRequestManager extends IbmWatsonxRequestManager {
private static final Logger logger = LogManager.getLogger(IbmWatsonxRerankRequestManager.class);
private static final ResponseHandler HANDLER = createIbmWatsonxResponseHandler();

private static ResponseHandler createIbmWatsonxResponseHandler() {
return new IbmWatsonxResponseHandler(
"ibm watsonx rerank",
(request, response) -> IbmWatsonxRankedResponseEntity.fromResponse(response)
);
}

public static IbmWatsonxRerankRequestManager of(IbmWatsonxRerankModel model, ThreadPool threadPool) {
return new IbmWatsonxRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
}

private final IbmWatsonxRerankModel model;

public IbmWatsonxRerankRequestManager(IbmWatsonxRerankModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = model;
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);

execute(
new ExecutableInferenceRequest(
requestSender,
logger,
getRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model),
HANDLER,
hasRequestCompletedFunction,
listener
)
);
}

protected IbmWatsonxRerankRequest getRerankRequest(String query, List<String> chunks, IbmWatsonxRerankModel model) {
return new IbmWatsonxRerankRequest(query, chunks, model);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import static org.elasticsearch.core.Strings.format;

public class IbmWatsonxResponseHandler extends BaseResponseHandler {

public IbmWatsonxResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.ibmwatsonx;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;

import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;

public class IbmWatsonxRerankRequest implements IbmWatsonxRequest {

private final String query;
private final List<String> input;
private final IbmWatsonxRerankTaskSettings taskSettings;
private final IbmWatsonxRerankModel model;

public IbmWatsonxRerankRequest(String query, List<String> input, IbmWatsonxRerankModel model) {
Objects.requireNonNull(model);

this.input = Objects.requireNonNull(input);
this.query = Objects.requireNonNull(query);
taskSettings = model.getTaskSettings();
this.model = model;
}

@Override
public HttpRequest createHttpRequest() {
URI uri;

try {
uri = new URI(model.uri().toString());
} catch (URISyntaxException ex) {
throw new IllegalArgumentException("cannot parse URI patter");
}

HttpPost httpPost = new HttpPost(uri);

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(
new IbmWatsonxRerankRequestEntity(
query,
input,
taskSettings,
model.getServiceSettings().modelId(),
model.getServiceSettings().projectId()
)
).getBytes(StandardCharsets.UTF_8)
);

httpPost.setEntity(byteEntity);
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());

decorateWithAuth(httpPost);

return new HttpRequest(httpPost, getInferenceEntityId());
}

public void decorateWithAuth(HttpPost httpPost) {
IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId());
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public URI getURI() {
return model.uri();
}

@Override
public Request truncate() {
return this;
}

public String getQuery() {
return query;
}

public List<String> getInput() {
return input;
}

public IbmWatsonxRerankModel getModel() {
return model;
}

@Override
public boolean[] getTruncationInfo() {
return null;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.ibmwatsonx;

import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

public record IbmWatsonxRerankRequestEntity(
String query,
List<String> inputs,
IbmWatsonxRerankTaskSettings taskSettings,
String modelId,
String projectId
) implements ToXContentObject {

private static final String INPUTS_FIELD = "inputs";
private static final String QUERY_FIELD = "query";
private static final String MODEL_ID_FIELD = "model_id";
private static final String PROJECT_ID_FIELD = "project_id";

public IbmWatsonxRerankRequestEntity {
Objects.requireNonNull(query);
Objects.requireNonNull(inputs);
Objects.requireNonNull(modelId);
Objects.requireNonNull(projectId);
Objects.requireNonNull(taskSettings);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.field(MODEL_ID_FIELD, modelId);
builder.field(QUERY_FIELD, query);
builder.startArray(INPUTS_FIELD);
for (String input : inputs) {
builder.startObject();
builder.field("text", input);
builder.endObject();
}
builder.endArray();
builder.field(PROJECT_ID_FIELD, projectId);

builder.startObject("parameters");
{
if (taskSettings.getTruncateInputTokens() != null) {
builder.field("truncate_input_tokens", taskSettings.getTruncateInputTokens());
}

builder.startObject("return_options");
{
if (taskSettings.getDoesReturnDocuments() != null) {
builder.field("inputs", taskSettings.getDoesReturnDocuments());
}
if (taskSettings.getTopNDocumentsOnly() != null) {
builder.field("top_n", taskSettings.getTopNDocumentsOnly());
}
}
builder.endObject();
}
builder.endObject();

builder.endObject();

return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public class IbmWatsonxUtils {
public static final String V1 = "v1";
public static final String TEXT = "text";
public static final String EMBEDDINGS = "embeddings";
public static final String RERANKS = "reranks";

private IbmWatsonxUtils() {}

Expand Down
Loading