Skip to content

Commit 09b1c6d

Browse files
saikatsarkar056elasticsearchmachine
andauthored
Integrate watsonx for re-ranking task (#117176)
* Integrate watsonx reranking to inference api * Add api_version to the watsonx api call * Fix the return_doc option * Add top_n parameter to task_settings * Add truncate_input_tokens parameter to task_settings * Add test for IbmWatonxRankedResponseEntity * Add test for IbmWatonxRankedRequestEntity * Add test for IbmWatonxRankedRequest * [CI] Auto commit changes from spotless * Add changelog * Fix transport version * Add test for IbmWatsonxService * Remove canHandleStreamingResponses * Add requireNonNull for modelId and projectId * Remove maxInputToken method * Convert all optionals to required * [CI] Auto commit changes from spotless * Set minimal_supported version to be ML_INFERENCE_IBM_WATSONX_RERANK_ADDED * Remove extraction of unused fields from IbmWatsonxRerankServiceSettings * Add space * Add space --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 1225b07 commit 09b1c6d

File tree

21 files changed

+1370
-2
lines changed

21 files changed

+1370
-2
lines changed

docs/changelog/117176.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 117176
2+
summary: Integrate IBM watsonx to Inference API for re-ranking task
3+
area: Experiences
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ static TransportVersion def(int id) {
171171
public static final TransportVersion LINEAR_RETRIEVER_SUPPORT = def(8_837_00_0);
172172
public static final TransportVersion TIMEOUT_GET_PARAM_FOR_RESOLVE_CLUSTER = def(8_838_00_0);
173173
public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_00_0);
174+
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_RERANK_ADDED = def(8_840_00_0);
174175

175176
/*
176177
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@
7575
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
7676
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
7777
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
78+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
79+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
7880
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings;
7981
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings;
8082
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
@@ -364,6 +366,17 @@ private static void addIbmWatsonxNamedWritables(List<NamedWriteableRegistry.Entr
364366
IbmWatsonxEmbeddingsServiceSettings::new
365367
)
366368
);
369+
370+
namedWriteables.add(
371+
new NamedWriteableRegistry.Entry(
372+
ServiceSettings.class,
373+
IbmWatsonxRerankServiceSettings.NAME,
374+
IbmWatsonxRerankServiceSettings::new
375+
)
376+
);
377+
namedWriteables.add(
378+
new NamedWriteableRegistry.Entry(TaskSettings.class, IbmWatsonxRerankTaskSettings.NAME, IbmWatsonxRerankTaskSettings::new)
379+
);
367380
}
368381

369382
private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxActionCreator.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,18 @@
1212
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1313
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1414
import org.elasticsearch.xpack.inference.external.http.sender.IbmWatsonxEmbeddingsRequestManager;
15+
import org.elasticsearch.xpack.inference.external.http.sender.IbmWatsonxRerankRequestManager;
1516
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1617
import org.elasticsearch.xpack.inference.services.ServiceComponents;
1718
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
19+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
1820

1921
import java.util.Map;
2022
import java.util.Objects;
2123

2224
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
2325

2426
public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {
25-
2627
private final Sender sender;
2728
private final ServiceComponents serviceComponents;
2829

@@ -41,6 +42,17 @@ public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Obje
4142
);
4243
}
4344

45+
@Override
46+
public ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings) {
47+
var overriddenModel = IbmWatsonxRerankModel.of(model, taskSettings);
48+
var requestCreator = IbmWatsonxRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
49+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
50+
overriddenModel.getServiceSettings().uri(),
51+
"Ibm Watsonx rerank"
52+
);
53+
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
54+
}
55+
4456
protected IbmWatsonxEmbeddingsRequestManager getEmbeddingsRequestManager(
4557
IbmWatsonxEmbeddingsModel model,
4658
Truncator truncator,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxActionVisitor.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99

1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1111
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
12+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
1213

1314
import java.util.Map;
1415

1516
public interface IbmWatsonxActionVisitor {
1617
ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings);
18+
19+
ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings);
1720
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.ibmwatsonx.IbmWatsonxResponseHandler;
18+
import org.elasticsearch.xpack.inference.external.request.ibmwatsonx.IbmWatsonxRerankRequest;
19+
import org.elasticsearch.xpack.inference.external.response.ibmwatsonx.IbmWatsonxRankedResponseEntity;
20+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
21+
22+
import java.util.List;
23+
import java.util.Objects;
24+
import java.util.function.Supplier;
25+
26+
public class IbmWatsonxRerankRequestManager extends IbmWatsonxRequestManager {
27+
private static final Logger logger = LogManager.getLogger(IbmWatsonxRerankRequestManager.class);
28+
private static final ResponseHandler HANDLER = createIbmWatsonxResponseHandler();
29+
30+
private static ResponseHandler createIbmWatsonxResponseHandler() {
31+
return new IbmWatsonxResponseHandler(
32+
"ibm watsonx rerank",
33+
(request, response) -> IbmWatsonxRankedResponseEntity.fromResponse(response)
34+
);
35+
}
36+
37+
public static IbmWatsonxRerankRequestManager of(IbmWatsonxRerankModel model, ThreadPool threadPool) {
38+
return new IbmWatsonxRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
39+
}
40+
41+
private final IbmWatsonxRerankModel model;
42+
43+
public IbmWatsonxRerankRequestManager(IbmWatsonxRerankModel model, ThreadPool threadPool) {
44+
super(threadPool, model);
45+
this.model = model;
46+
}
47+
48+
@Override
49+
public void execute(
50+
InferenceInputs inferenceInputs,
51+
RequestSender requestSender,
52+
Supplier<Boolean> hasRequestCompletedFunction,
53+
ActionListener<InferenceServiceResults> listener
54+
) {
55+
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
56+
57+
execute(
58+
new ExecutableInferenceRequest(
59+
requestSender,
60+
logger,
61+
getRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model),
62+
HANDLER,
63+
hasRequestCompletedFunction,
64+
listener
65+
)
66+
);
67+
}
68+
69+
protected IbmWatsonxRerankRequest getRerankRequest(String query, List<String> chunks, IbmWatsonxRerankModel model) {
70+
return new IbmWatsonxRerankRequest(query, chunks, model);
71+
}
72+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/ibmwatsonx/IbmWatsonxResponseHandler.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import static org.elasticsearch.core.Strings.format;
1818

1919
public class IbmWatsonxResponseHandler extends BaseResponseHandler {
20-
2120
public IbmWatsonxResponseHandler(String requestType, ResponseParser parseFunction) {
2221
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
2322
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.ibmwatsonx;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.entity.ByteArrayEntity;
13+
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.xcontent.XContentType;
15+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
16+
import org.elasticsearch.xpack.inference.external.request.Request;
17+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
18+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
19+
20+
import java.net.URI;
21+
import java.net.URISyntaxException;
22+
import java.nio.charset.StandardCharsets;
23+
import java.util.List;
24+
import java.util.Objects;
25+
26+
public class IbmWatsonxRerankRequest implements IbmWatsonxRequest {
27+
28+
private final String query;
29+
private final List<String> input;
30+
private final IbmWatsonxRerankTaskSettings taskSettings;
31+
private final IbmWatsonxRerankModel model;
32+
33+
public IbmWatsonxRerankRequest(String query, List<String> input, IbmWatsonxRerankModel model) {
34+
Objects.requireNonNull(model);
35+
36+
this.input = Objects.requireNonNull(input);
37+
this.query = Objects.requireNonNull(query);
38+
taskSettings = model.getTaskSettings();
39+
this.model = model;
40+
}
41+
42+
@Override
43+
public HttpRequest createHttpRequest() {
44+
URI uri;
45+
46+
try {
47+
uri = new URI(model.uri().toString());
48+
} catch (URISyntaxException ex) {
49+
throw new IllegalArgumentException("cannot parse URI patter");
50+
}
51+
52+
HttpPost httpPost = new HttpPost(uri);
53+
54+
ByteArrayEntity byteEntity = new ByteArrayEntity(
55+
Strings.toString(
56+
new IbmWatsonxRerankRequestEntity(
57+
query,
58+
input,
59+
taskSettings,
60+
model.getServiceSettings().modelId(),
61+
model.getServiceSettings().projectId()
62+
)
63+
).getBytes(StandardCharsets.UTF_8)
64+
);
65+
66+
httpPost.setEntity(byteEntity);
67+
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
68+
69+
decorateWithAuth(httpPost);
70+
71+
return new HttpRequest(httpPost, getInferenceEntityId());
72+
}
73+
74+
public void decorateWithAuth(HttpPost httpPost) {
75+
IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId());
76+
}
77+
78+
@Override
79+
public String getInferenceEntityId() {
80+
return model.getInferenceEntityId();
81+
}
82+
83+
@Override
84+
public URI getURI() {
85+
return model.uri();
86+
}
87+
88+
@Override
89+
public Request truncate() {
90+
return this;
91+
}
92+
93+
public String getQuery() {
94+
return query;
95+
}
96+
97+
public List<String> getInput() {
98+
return input;
99+
}
100+
101+
public IbmWatsonxRerankModel getModel() {
102+
return model;
103+
}
104+
105+
@Override
106+
public boolean[] getTruncationInfo() {
107+
return null;
108+
}
109+
110+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.ibmwatsonx;
9+
10+
import org.elasticsearch.xcontent.ToXContentObject;
11+
import org.elasticsearch.xcontent.XContentBuilder;
12+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
13+
14+
import java.io.IOException;
15+
import java.util.List;
16+
import java.util.Objects;
17+
18+
public record IbmWatsonxRerankRequestEntity(
19+
String query,
20+
List<String> inputs,
21+
IbmWatsonxRerankTaskSettings taskSettings,
22+
String modelId,
23+
String projectId
24+
) implements ToXContentObject {
25+
26+
private static final String INPUTS_FIELD = "inputs";
27+
private static final String QUERY_FIELD = "query";
28+
private static final String MODEL_ID_FIELD = "model_id";
29+
private static final String PROJECT_ID_FIELD = "project_id";
30+
31+
public IbmWatsonxRerankRequestEntity {
32+
Objects.requireNonNull(query);
33+
Objects.requireNonNull(inputs);
34+
Objects.requireNonNull(modelId);
35+
Objects.requireNonNull(projectId);
36+
Objects.requireNonNull(taskSettings);
37+
}
38+
39+
@Override
40+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
41+
builder.startObject();
42+
43+
builder.field(MODEL_ID_FIELD, modelId);
44+
builder.field(QUERY_FIELD, query);
45+
builder.startArray(INPUTS_FIELD);
46+
for (String input : inputs) {
47+
builder.startObject();
48+
builder.field("text", input);
49+
builder.endObject();
50+
}
51+
builder.endArray();
52+
builder.field(PROJECT_ID_FIELD, projectId);
53+
54+
builder.startObject("parameters");
55+
{
56+
if (taskSettings.getTruncateInputTokens() != null) {
57+
builder.field("truncate_input_tokens", taskSettings.getTruncateInputTokens());
58+
}
59+
60+
builder.startObject("return_options");
61+
{
62+
if (taskSettings.getDoesReturnDocuments() != null) {
63+
builder.field("inputs", taskSettings.getDoesReturnDocuments());
64+
}
65+
if (taskSettings.getTopNDocumentsOnly() != null) {
66+
builder.field("top_n", taskSettings.getTopNDocumentsOnly());
67+
}
68+
}
69+
builder.endObject();
70+
}
71+
builder.endObject();
72+
73+
builder.endObject();
74+
75+
return builder;
76+
}
77+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ public class IbmWatsonxUtils {
1313
public static final String V1 = "v1";
1414
public static final String TEXT = "text";
1515
public static final String EMBEDDINGS = "embeddings";
16+
public static final String RERANKS = "reranks";
1617

1718
private IbmWatsonxUtils() {}
1819

0 commit comments

Comments
 (0)