Skip to content

Commit e6b0bfd

Browse files
Integrate watsonx reranking to inference api
1 parent a620e7c commit e6b0bfd

15 files changed

+1172
-5
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@
8080
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
8181
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings;
8282
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
83+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
84+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
8385
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
8486
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
8587
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
@@ -364,6 +366,15 @@ private static void addIbmWatsonxNamedWritables(List<NamedWriteableRegistry.Entr
364366
IbmWatsonxEmbeddingsServiceSettings::new
365367
)
366368
);
369+
370+
namedWriteables.add(
371+
new NamedWriteableRegistry.Entry(ServiceSettings.class,
372+
IbmWatsonxRerankServiceSettings.NAME,
373+
IbmWatsonxRerankServiceSettings::new)
374+
);
375+
namedWriteables.add(
376+
new NamedWriteableRegistry.Entry(TaskSettings.class, IbmWatsonxRerankTaskSettings.NAME, IbmWatsonxRerankTaskSettings::new)
377+
);
367378
}
368379

369380
private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxActionCreator.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,27 @@
66
*/
77

88
package org.elasticsearch.xpack.inference.external.action.ibmwatsonx;
9-
109
import org.elasticsearch.threadpool.ThreadPool;
1110
import org.elasticsearch.xpack.inference.common.Truncator;
1211
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1312
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1413
import org.elasticsearch.xpack.inference.external.http.sender.IbmWatsonxEmbeddingsRequestManager;
14+
import org.elasticsearch.xpack.inference.external.http.sender.IbmWatsonxRerankRequestManager;
1515
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1616
import org.elasticsearch.xpack.inference.services.ServiceComponents;
1717
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
18-
18+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
1919
import java.util.Map;
2020
import java.util.Objects;
2121

2222
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
2323

2424
public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {
25-
2625
private final Sender sender;
2726
private final ServiceComponents serviceComponents;
2827

2928
public IbmWatsonxActionCreator(Sender sender, ServiceComponents serviceComponents) {
29+
// TODO Batching - accept a class that can handle batching
3030
this.sender = Objects.requireNonNull(sender);
3131
this.serviceComponents = Objects.requireNonNull(serviceComponents);
3232
}
@@ -41,6 +41,17 @@ public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Obje
4141
);
4242
}
4343

44+
@Override
45+
public ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings) {
46+
var overriddenModel = IbmWatsonxRerankModel.of(model, taskSettings);
47+
var requestCreator = IbmWatsonxRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
48+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
49+
overriddenModel.getServiceSettings().uri(),
50+
"Ibm Watsonx rerank"
51+
);
52+
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
53+
}
54+
4455
protected IbmWatsonxEmbeddingsRequestManager getEmbeddingsRequestManager(
4556
IbmWatsonxEmbeddingsModel model,
4657
Truncator truncator,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxActionVisitor.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99

1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1111
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
12+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
1213

1314
import java.util.Map;
1415

1516
public interface IbmWatsonxActionVisitor {
1617
ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings);
18+
19+
ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings);
1720
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxEmbeddingsRequestManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public class IbmWatsonxEmbeddingsRequestManager extends IbmWatsonxRequestManager
3333
private static final ResponseHandler HANDLER = createEmbeddingsHandler();
3434

3535
private static ResponseHandler createEmbeddingsHandler() {
36-
return new IbmWatsonxResponseHandler("ibm watsonx embeddings", IbmWatsonxEmbeddingsResponseEntity::fromResponse);
36+
return new IbmWatsonxResponseHandler("ibm watsonx embeddings", IbmWatsonxEmbeddingsResponseEntity::fromResponse, false);
3737
}
3838

3939
private final IbmWatsonxEmbeddingsModel model;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.ibmwatsonx.IbmWatsonxResponseHandler;
18+
import org.elasticsearch.xpack.inference.external.request.ibmwatsonx.IbmWatsonxRerankRequest;
19+
import org.elasticsearch.xpack.inference.external.response.ibmwatsonx.IbmWatsonxRankedResponseEntity;
20+
21+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
22+
23+
import java.util.Objects;
24+
import java.util.function.Supplier;
25+
26+
27+
public class IbmWatsonxRerankRequestManager extends IbmWatsonxRequestManager {
28+
private static final Logger logger = LogManager.getLogger(IbmWatsonxRerankRequestManager.class);
29+
private static final ResponseHandler HANDLER = createIbmWatsonxResponseHandler();
30+
31+
private static ResponseHandler createIbmWatsonxResponseHandler() {
32+
return new IbmWatsonxResponseHandler("ibm watsonx rerank",
33+
(request, response) -> IbmWatsonxRankedResponseEntity.fromResponse(response), false);
34+
}
35+
36+
public static IbmWatsonxRerankRequestManager of(IbmWatsonxRerankModel model, ThreadPool threadPool) {
37+
return new IbmWatsonxRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
38+
}
39+
40+
private final IbmWatsonxRerankModel model;
41+
42+
private IbmWatsonxRerankRequestManager(IbmWatsonxRerankModel model, ThreadPool threadPool) {
43+
super(threadPool, model);
44+
this.model = model;
45+
}
46+
47+
@Override
48+
public void execute(
49+
InferenceInputs inferenceInputs,
50+
RequestSender requestSender,
51+
Supplier<Boolean> hasRequestCompletedFunction,
52+
ActionListener<InferenceServiceResults> listener
53+
) {
54+
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
55+
56+
IbmWatsonxRerankRequest request = new IbmWatsonxRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
57+
58+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
59+
}
60+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/ibmwatsonx/IbmWatsonxResponseHandler.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,15 @@
1818

1919
public class IbmWatsonxResponseHandler extends BaseResponseHandler {
2020

21-
public IbmWatsonxResponseHandler(String requestType, ResponseParser parseFunction) {
21+
private final boolean canHandleStreamingResponse;
22+
23+
public IbmWatsonxResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponse) {
2224
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
25+
this.canHandleStreamingResponse = canHandleStreamingResponse;
26+
}
27+
28+
public boolean canHandleStreamingResponses() {
29+
return canHandleStreamingResponse;
2330
}
2431

2532
/**
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.ibmwatsonx;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.client.utils.URIBuilder;
13+
import org.apache.http.entity.ByteArrayEntity;
14+
import org.elasticsearch.common.Strings;
15+
import org.elasticsearch.xcontent.XContentType;
16+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
17+
import org.elasticsearch.xpack.inference.external.request.Request;
18+
import org.elasticsearch.xpack.inference.external.request.cohere.CohereUtils;
19+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
20+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
21+
22+
import java.net.URI;
23+
import java.net.URISyntaxException;
24+
import java.nio.charset.StandardCharsets;
25+
import java.util.List;
26+
import java.util.Objects;
27+
28+
public class IbmWatsonxRerankRequest implements IbmWatsonxRequest {
29+
30+
private final String query;
31+
private final List<String> input;
32+
private final IbmWatsonxRerankTaskSettings taskSettings;
33+
private final String modelId;
34+
private final IbmWatsonxRerankModel model;
35+
private final String inferenceEntityId;
36+
37+
public IbmWatsonxRerankRequest(String query, List<String> input, IbmWatsonxRerankModel model) {
38+
Objects.requireNonNull(model);
39+
40+
this.input = Objects.requireNonNull(input);
41+
this.query = Objects.requireNonNull(query);
42+
taskSettings = model.getTaskSettings();
43+
this.model = model;
44+
this.modelId = model.getServiceSettings().modelId();
45+
inferenceEntityId = model.getInferenceEntityId();
46+
}
47+
48+
@Override
49+
public HttpRequest createHttpRequest() {
50+
// HttpPost httpPost = new HttpPost(model.uri());
51+
52+
53+
URI uri;
54+
55+
try {
56+
uri = new URI("https://us-south.ml.cloud.ibm.com/ml/v1/text/reranks?version=2024-05-02");
57+
} catch (URISyntaxException ex) {
58+
throw new IllegalArgumentException("cannot parse URI patter");
59+
}
60+
61+
HttpPost httpPost = new HttpPost(uri);
62+
63+
ByteArrayEntity byteEntity = new ByteArrayEntity(
64+
Strings.toString(
65+
new IbmWatsonxRerankRequestEntity(query, input, taskSettings, modelId)).getBytes(StandardCharsets.UTF_8)
66+
);
67+
68+
httpPost.setEntity(byteEntity);
69+
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
70+
71+
decorateWithAuth(httpPost);
72+
73+
return new HttpRequest(httpPost, getInferenceEntityId());
74+
}
75+
76+
public void decorateWithAuth(HttpPost httpPost) {
77+
IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId());
78+
}
79+
80+
@Override
81+
public String getInferenceEntityId() {
82+
return inferenceEntityId;
83+
}
84+
85+
@Override
86+
public URI getURI() {
87+
return model.uri();
88+
}
89+
90+
@Override
91+
public Request truncate() {
92+
return this; // TODO?
93+
}
94+
95+
@Override
96+
public boolean[] getTruncationInfo() {
97+
return null;
98+
}
99+
100+
public static URI buildDefaultUri() throws URISyntaxException {
101+
return new URIBuilder().setScheme("https")
102+
.setHost(CohereUtils.HOST)
103+
.setPathSegments(CohereUtils.VERSION_1, CohereUtils.RERANK_PATH)
104+
.build();
105+
}
106+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.ibmwatsonx;
9+
10+
import org.elasticsearch.xcontent.ToXContentObject;
11+
import org.elasticsearch.xcontent.XContentBuilder;
12+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
13+
14+
import java.io.IOException;
15+
import java.util.List;
16+
import java.util.Objects;
17+
18+
public record IbmWatsonxRerankRequestEntity(String model, String query, List<String> documents, IbmWatsonxRerankTaskSettings taskSettings)
19+
implements ToXContentObject {
20+
21+
private static final String INPUTS_FIELD = "inputs";
22+
private static final String QUERY_FIELD = "query";
23+
private static final String MODEL_ID_FIELD = "model_id";
24+
private static final String PROJECT_ID_FIELD = "project_id";
25+
26+
public IbmWatsonxRerankRequestEntity {
27+
Objects.requireNonNull(query);
28+
Objects.requireNonNull(documents);
29+
Objects.requireNonNull(taskSettings);
30+
}
31+
32+
public IbmWatsonxRerankRequestEntity(String query, List<String> input, IbmWatsonxRerankTaskSettings taskSettings, String model) {
33+
this(model, query, input, taskSettings);
34+
}
35+
36+
@Override
37+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
38+
builder.startObject();
39+
40+
builder.field(MODEL_ID_FIELD, model);
41+
builder.field(QUERY_FIELD, query);
42+
builder.startArray(INPUTS_FIELD);
43+
for (String document : documents) {
44+
builder.startObject();
45+
builder.field("text", document);
46+
builder.endObject();
47+
}
48+
builder.endArray();
49+
builder.field(PROJECT_ID_FIELD, "e2706421-ecbb-41b1-906e-c4e32c58a3a8");
50+
51+
builder.endObject();
52+
53+
return builder;
54+
}
55+
}

0 commit comments

Comments
 (0)