Skip to content

Commit 0207817

Browse files
committed
Adding rerank common options
1 parent 5c76ffd commit 0207817

File tree

56 files changed

+900
-216
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+900
-216
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,13 @@ public ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> ta
7575
serviceComponents.threadPool(),
7676
overriddenModel,
7777
RERANK_HANDLER,
78-
(rerankInput) -> new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model),
78+
(rerankInput) -> new VoyageAIRerankRequest(
79+
rerankInput.getQuery(),
80+
rerankInput.getChunks(),
81+
rerankInput.getReturnDocuments(),
82+
rerankInput.getTopN(),
83+
model
84+
),
7985
QueryAndDocsInputs.class
8086
);
8187

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ public void execute(
6969
account,
7070
rerankInput.getQuery(),
7171
rerankInput.getChunks(),
72+
rerankInput.getReturnDocuments(),
73+
rerankInput.getTopN(),
7274
model
7375
);
7476

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,13 @@ public void execute(
4949
ActionListener<InferenceServiceResults> listener
5050
) {
5151
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
52-
CohereRerankRequest request = new CohereRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
52+
CohereRerankRequest request = new CohereRerankRequest(
53+
rerankInput.getQuery(),
54+
rerankInput.getChunks(),
55+
rerankInput.getReturnDocuments(),
56+
rerankInput.getTopN(),
57+
model
58+
);
5359

5460
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
5561
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,13 @@ public void execute(
6262
ActionListener<InferenceServiceResults> listener
6363
) {
6464
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
65-
GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
65+
GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest(
66+
rerankInput.getQuery(),
67+
rerankInput.getChunks(),
68+
rerankInput.getReturnDocuments(),
69+
rerankInput.getTopN(),
70+
model
71+
);
6672

6773
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
6874
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRerankRequestManager.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,13 @@ public void execute(
4949
ActionListener<InferenceServiceResults> listener
5050
) {
5151
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
52-
JinaAIRerankRequest request = new JinaAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
52+
JinaAIRerankRequest request = new JinaAIRerankRequest(
53+
rerankInput.getQuery(),
54+
rerankInput.getChunks(),
55+
rerankInput.getReturnDocuments(),
56+
rerankInput.getTopN(),
57+
model
58+
);
5359

5460
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
5561
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
package org.elasticsearch.xpack.inference.external.http.sender;
99

10+
import org.elasticsearch.core.Nullable;
11+
1012
import java.util.List;
1113
import java.util.Objects;
1214

@@ -22,15 +24,25 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) {
2224

2325
private final String query;
2426
private final List<String> chunks;
27+
private final Boolean returnDocuments;
28+
private final Integer topN;
2529

2630
public QueryAndDocsInputs(String query, List<String> chunks) {
27-
this(query, chunks, false);
31+
this(query, chunks, null, null, false);
2832
}
2933

30-
public QueryAndDocsInputs(String query, List<String> chunks, boolean stream) {
34+
public QueryAndDocsInputs(
35+
String query,
36+
List<String> chunks,
37+
@Nullable Boolean returnDocuments,
38+
@Nullable Integer topN,
39+
boolean stream
40+
) {
3141
super(stream);
3242
this.query = Objects.requireNonNull(query);
3343
this.chunks = Objects.requireNonNull(chunks);
44+
this.returnDocuments = returnDocuments;
45+
this.topN = topN;
3446
}
3547

3648
public String getQuery() {
@@ -41,6 +53,14 @@ public List<String> getChunks() {
4153
return chunks;
4254
}
4355

56+
public Boolean getReturnDocuments() {
57+
return returnDocuments;
58+
}
59+
60+
public Integer getTopN() {
61+
return topN;
62+
}
63+
4464
public int inputSize() {
4565
return chunks.size();
4666
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.apache.http.client.utils.URIBuilder;
1313
import org.apache.http.entity.ByteArrayEntity;
1414
import org.elasticsearch.common.Strings;
15+
import org.elasticsearch.core.Nullable;
1516
import org.elasticsearch.xcontent.XContentType;
1617
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
1718
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
@@ -32,6 +33,8 @@ public class AlibabaCloudSearchRerankRequest implements Request {
3233
private final AlibabaCloudSearchAccount account;
3334
private final String query;
3435
private final List<String> input;
36+
private final Boolean returnDocuments;
37+
private final Integer topN;
3538
private final URI uri;
3639
private final AlibabaCloudSearchRerankTaskSettings taskSettings;
3740
private final String model;
@@ -44,13 +47,17 @@ public AlibabaCloudSearchRerankRequest(
4447
AlibabaCloudSearchAccount account,
4548
String query,
4649
List<String> input,
50+
@Nullable Boolean returnDocuments,
51+
@Nullable Integer topN,
4752
AlibabaCloudSearchRerankModel rerankModel
4853
) {
4954
Objects.requireNonNull(rerankModel);
5055

5156
this.account = Objects.requireNonNull(account);
5257
this.query = Objects.requireNonNull(query);
5358
this.input = Objects.requireNonNull(input);
59+
this.returnDocuments = returnDocuments;
60+
this.topN = topN;
5461
taskSettings = rerankModel.getTaskSettings();
5562
model = rerankModel.getServiceSettings().getCommonSettings().modelId();
5663
host = rerankModel.getServiceSettings().getCommonSettings().getHost();
@@ -67,7 +74,8 @@ public HttpRequest createHttpRequest() {
6774
HttpPost httpPost = new HttpPost(uri);
6875

6976
ByteArrayEntity byteEntity = new ByteArrayEntity(
70-
Strings.toString(new AlibabaCloudSearchRerankRequestEntity(query, input, taskSettings)).getBytes(StandardCharsets.UTF_8)
77+
Strings.toString(new AlibabaCloudSearchRerankRequestEntity(query, input, returnDocuments, topN, taskSettings))
78+
.getBytes(StandardCharsets.UTF_8)
7179
);
7280
httpPost.setEntity(byteEntity);
7381

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
99

10+
import org.elasticsearch.core.Nullable;
1011
import org.elasticsearch.xcontent.ToXContentObject;
1112
import org.elasticsearch.xcontent.XContentBuilder;
1213
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankTaskSettings;
@@ -15,9 +16,13 @@
1516
import java.util.List;
1617
import java.util.Objects;
1718

18-
public record AlibabaCloudSearchRerankRequestEntity(String query, List<String> input, AlibabaCloudSearchRerankTaskSettings taskSettings)
19-
implements
20-
ToXContentObject {
19+
public record AlibabaCloudSearchRerankRequestEntity(
20+
String query,
21+
List<String> input,
22+
@Nullable Boolean returnDocuments,
23+
@Nullable Integer topN,
24+
AlibabaCloudSearchRerankTaskSettings taskSettings
25+
) implements ToXContentObject {
2126

2227
private static final String SEARCH_QUERY = "query";
2328
private static final String TEXTS_FIELD = "docs";

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequest.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.apache.http.client.utils.URIBuilder;
1212
import org.apache.http.entity.ByteArrayEntity;
1313
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.core.Nullable;
1415
import org.elasticsearch.xpack.inference.external.cohere.CohereAccount;
1516
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
1617
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -28,16 +29,26 @@ public class CohereRerankRequest extends CohereRequest {
2829
private final CohereAccount account;
2930
private final String query;
3031
private final List<String> input;
32+
private final Boolean returnDocuments;
33+
private final Integer topN;
3134
private final CohereRerankTaskSettings taskSettings;
3235
private final String model;
3336
private final String inferenceEntityId;
3437

35-
public CohereRerankRequest(String query, List<String> input, CohereRerankModel model) {
38+
public CohereRerankRequest(
39+
String query,
40+
List<String> input,
41+
@Nullable Boolean returnDocuments,
42+
@Nullable Integer topN,
43+
CohereRerankModel model
44+
) {
3645
Objects.requireNonNull(model);
3746

3847
this.account = CohereAccount.of(model, CohereRerankRequest::buildDefaultUri);
3948
this.input = Objects.requireNonNull(input);
4049
this.query = Objects.requireNonNull(query);
50+
this.returnDocuments = returnDocuments;
51+
this.topN = topN;
4152
taskSettings = model.getTaskSettings();
4253
this.model = model.getServiceSettings().modelId();
4354
inferenceEntityId = model.getInferenceEntityId();
@@ -48,7 +59,8 @@ public HttpRequest createHttpRequest() {
4859
HttpPost httpPost = new HttpPost(account.uri());
4960

5061
ByteArrayEntity byteEntity = new ByteArrayEntity(
51-
Strings.toString(new CohereRerankRequestEntity(query, input, taskSettings, model)).getBytes(StandardCharsets.UTF_8)
62+
Strings.toString(new CohereRerankRequestEntity(query, input, returnDocuments, topN, taskSettings, model))
63+
.getBytes(StandardCharsets.UTF_8)
5264
);
5365
httpPost.setEntity(byteEntity);
5466

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntity.java

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.external.request.cohere;
99

10+
import org.elasticsearch.core.Nullable;
1011
import org.elasticsearch.xcontent.ToXContentObject;
1112
import org.elasticsearch.xcontent.XContentBuilder;
1213
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
@@ -15,9 +16,14 @@
1516
import java.util.List;
1617
import java.util.Objects;
1718

18-
public record CohereRerankRequestEntity(String model, String query, List<String> documents, CohereRerankTaskSettings taskSettings)
19-
implements
20-
ToXContentObject {
19+
public record CohereRerankRequestEntity(
20+
String model,
21+
String query,
22+
List<String> documents,
23+
@Nullable Boolean returnDocuments,
24+
@Nullable Integer topN,
25+
CohereRerankTaskSettings taskSettings
26+
) implements ToXContentObject {
2127

2228
private static final String DOCUMENTS_FIELD = "documents";
2329
private static final String QUERY_FIELD = "query";
@@ -29,8 +35,15 @@ public record CohereRerankRequestEntity(String model, String query, List<String>
2935
Objects.requireNonNull(taskSettings);
3036
}
3137

32-
public CohereRerankRequestEntity(String query, List<String> input, CohereRerankTaskSettings taskSettings, String model) {
33-
this(model, query, input, taskSettings);
38+
public CohereRerankRequestEntity(
39+
String query,
40+
List<String> input,
41+
@Nullable Boolean returnDocuments,
42+
@Nullable Integer topN,
43+
CohereRerankTaskSettings taskSettings,
44+
String model
45+
) {
46+
this(model, query, input, returnDocuments, topN, taskSettings);
3447
}
3548

3649
@Override
@@ -41,11 +54,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
4154
builder.field(QUERY_FIELD, query);
4255
builder.field(DOCUMENTS_FIELD, documents);
4356

44-
if (taskSettings.getDoesReturnDocuments() != null) {
57+
// prefer the root level return_documents over task settings
58+
if (returnDocuments != null) {
59+
builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments);
60+
} else if (taskSettings.getDoesReturnDocuments() != null) {
4561
builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments());
4662
}
4763

48-
if (taskSettings.getTopNDocumentsOnly() != null) {
64+
// prefer the root level top_n over task settings
65+
if (topN != null) {
66+
builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, topN);
67+
} else if (taskSettings.getTopNDocumentsOnly() != null) {
4968
builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly());
5069
}
5170

0 commit comments

Comments
 (0)