Skip to content

Commit 33b8110

Browse files
committed
Add "rerank" task type to "elastic" provider
1 parent ae16016 commit 33b8110

21 files changed

+1542
-19
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ static TransportVersion def(int id) {
193193
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE = def(9_033_0_00);
194194
public static final TransportVersion INDEXING_STATS_INCLUDES_RECENT_WRITE_LOAD = def(9_034_0_00);
195195
public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL = def(9_035_0_00);
196+
public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK = def(9_036_0_00);
196197

197198
/*
198199
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,27 @@ protected RankedDocsResults mutateInstanceForVersion(RankedDocsResults instance,
6666
}
6767
}
6868

69+
public record RerankExpectation(int index, float relevanceScore) {}
70+
71+
public static Map<String, Object> buildExpectationRankedDocResults(List<RerankExpectation> rerankExpectations) {
72+
return Map.of(
73+
RankedDocsResults.RERANK,
74+
rerankExpectations.stream()
75+
.map(
76+
rerankExpectation -> Map.of(
77+
RankedDocsResults.RankedDoc.NAME,
78+
Map.of(
79+
RankedDocsResults.RankedDoc.INDEX,
80+
rerankExpectation.index,
81+
RankedDocsResults.RankedDoc.RELEVANCE_SCORE,
82+
rerankExpectation.relevanceScore
83+
)
84+
)
85+
)
86+
.toList()
87+
);
88+
}
89+
6990
private List<RankedDocsResults.RankedDoc> rankedDocsNullStringToEmpty(List<RankedDocsResults.RankedDoc> rankedDocs) {
7091
var result = new ArrayList<RankedDocsResults.RankedDoc>(rankedDocs.size());
7192
for (var doc : rankedDocs) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
6262
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
6363
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
64+
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
65+
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankTaskSettings;
6466
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
6567
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
6668
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings;
@@ -146,7 +148,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
146148
addCustomElandWriteables(namedWriteables);
147149
addAnthropicNamedWritables(namedWriteables);
148150
addAmazonBedrockNamedWriteables(namedWriteables);
149-
addEisNamedWriteables(namedWriteables);
151+
addElasticNamedWriteables(namedWriteables);
150152
addAlibabaCloudSearchNamedWriteables(namedWriteables);
151153
addJinaAINamedWriteables(namedWriteables);
152154
addVoyageAINamedWriteables(namedWriteables);
@@ -649,20 +651,40 @@ private static void addVoyageAINamedWriteables(List<NamedWriteableRegistry.Entry
649651
);
650652
}
651653

652-
private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
654+
private static void addElasticNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
655+
// Sparse Text Embeddings
653656
namedWriteables.add(
654657
new NamedWriteableRegistry.Entry(
655658
ServiceSettings.class,
656659
ElasticInferenceServiceSparseEmbeddingsServiceSettings.NAME,
657660
ElasticInferenceServiceSparseEmbeddingsServiceSettings::new
658661
)
659662
);
663+
664+
// Completion
660665
namedWriteables.add(
661666
new NamedWriteableRegistry.Entry(
662667
ServiceSettings.class,
663668
ElasticInferenceServiceCompletionServiceSettings.NAME,
664669
ElasticInferenceServiceCompletionServiceSettings::new
665670
)
666671
);
672+
673+
// Rerank
674+
namedWriteables.add(
675+
new NamedWriteableRegistry.Entry(
676+
ServiceSettings.class,
677+
ElasticInferenceServiceRerankServiceSettings.NAME,
678+
ElasticInferenceServiceRerankServiceSettings::new
679+
)
680+
);
681+
namedWriteables.add(
682+
new NamedWriteableRegistry.Entry(
683+
TaskSettings.class,
684+
ElasticInferenceServiceRerankTaskSettings.NAME,
685+
ElasticInferenceServiceRerankTaskSettings::new
686+
)
687+
);
667688
}
689+
668690
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreator.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1111
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
12+
import org.elasticsearch.xpack.inference.external.http.sender.ElasticInferenceServiceRerankRequestManager;
1213
import org.elasticsearch.xpack.inference.external.http.sender.ElasticInferenceServiceSparseEmbeddingsRequestManager;
1314
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1415
import org.elasticsearch.xpack.inference.services.ServiceComponents;
1516
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
17+
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
1618
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
1719

1820
import java.util.Locale;
@@ -43,4 +45,13 @@ public ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel mode
4345
);
4446
return new SenderExecutableAction(sender, requestManager, errorMessage);
4547
}
48+
49+
@Override
50+
public ExecutableAction create(ElasticInferenceServiceRerankModel model) {
51+
var requestManager = new ElasticInferenceServiceRerankRequestManager(model, serviceComponents, traceContext);
52+
var errorMessage = constructFailedToSendRequestMessage(
53+
String.format(Locale.ROOT, "%s rerank", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)
54+
);
55+
return new SenderExecutableAction(sender, requestManager, errorMessage);
56+
}
4657
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionVisitor.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99

1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1111
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
12+
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
1213

1314
public interface ElasticInferenceServiceActionVisitor {
1415

1516
ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model);
1617

18+
ExecutableAction create(ElasticInferenceServiceRerankModel model);
19+
1720
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceResponseHandler;
15+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.request.elastic.rerank.ElasticInferenceServiceRerankRequest;
18+
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceRerankResponseEntity;
19+
import org.elasticsearch.xpack.inference.services.ServiceComponents;
20+
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
21+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
22+
23+
import java.util.List;
24+
import java.util.Locale;
25+
import java.util.function.Supplier;
26+
27+
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
28+
29+
public class ElasticInferenceServiceRerankRequestManager extends ElasticInferenceServiceRequestManager {
30+
31+
private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceRerankRequestManager.class);
32+
33+
private static final ResponseHandler HANDLER = createRerankHandler();
34+
35+
private final ElasticInferenceServiceRerankModel model;
36+
37+
private final TraceContext traceContext;
38+
39+
private static ResponseHandler createRerankHandler() {
40+
return new ElasticInferenceServiceResponseHandler(
41+
String.format(Locale.ROOT, "%s rerank", ELASTIC_INFERENCE_SERVICE_IDENTIFIER),
42+
(request, response) -> ElasticInferenceServiceRerankResponseEntity.fromResponse(response)
43+
);
44+
}
45+
46+
public ElasticInferenceServiceRerankRequestManager(
47+
ElasticInferenceServiceRerankModel model,
48+
ServiceComponents serviceComponents,
49+
TraceContext traceContext
50+
) {
51+
super(serviceComponents.threadPool(), model);
52+
this.model = model;
53+
this.traceContext = traceContext;
54+
}
55+
56+
@Override
57+
public void execute(
58+
InferenceInputs inferenceInputs,
59+
RequestSender requestSender,
60+
Supplier<Boolean> hasRequestCompletedFunction,
61+
ActionListener<InferenceServiceResults> listener
62+
) {
63+
QueryAndDocsInputs input = QueryAndDocsInputs.of(inferenceInputs);
64+
List<String> docs = input.getChunks();
65+
String query = input.getQuery();
66+
67+
ElasticInferenceServiceRerankRequest request = new ElasticInferenceServiceRerankRequest(
68+
query,
69+
docs,
70+
model,
71+
traceContext,
72+
requestMetadata()
73+
);
74+
75+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
76+
}
77+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.elastic.rerank;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.client.methods.HttpRequestBase;
13+
import org.apache.http.entity.ByteArrayEntity;
14+
import org.apache.http.message.BasicHeader;
15+
import org.elasticsearch.common.Strings;
16+
import org.elasticsearch.xcontent.XContentType;
17+
import org.elasticsearch.xpack.inference.external.request.Request;
18+
import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequest;
19+
import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequestMetadata;
20+
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
21+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
22+
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;
23+
24+
import java.net.URI;
25+
import java.nio.charset.StandardCharsets;
26+
import java.util.List;
27+
import java.util.Objects;
28+
29+
public class ElasticInferenceServiceRerankRequest extends ElasticInferenceServiceRequest {
30+
31+
private final URI uri;
32+
private final String query;
33+
private final List<String> documents;
34+
private final TraceContextHandler traceContextHandler;
35+
private final ElasticInferenceServiceRerankModel model;
36+
37+
public ElasticInferenceServiceRerankRequest(
38+
String query,
39+
List<String> documents,
40+
ElasticInferenceServiceRerankModel model,
41+
TraceContext traceContext,
42+
ElasticInferenceServiceRequestMetadata metadata
43+
) {
44+
super(metadata);
45+
this.query = query;
46+
this.documents = documents;
47+
this.model = Objects.requireNonNull(model);
48+
this.uri = model.uri();
49+
this.traceContextHandler = new TraceContextHandler(traceContext);
50+
}
51+
52+
@Override
53+
public HttpRequestBase createHttpRequestBase() {
54+
var httpPost = new HttpPost(uri);
55+
var requestEntity = Strings.toString(
56+
new ElasticInferenceServiceRerankRequestEntity(
57+
query,
58+
documents,
59+
model.getServiceSettings().modelId(),
60+
model.getTaskSettings().getTopNDocumentsOnly()
61+
)
62+
);
63+
64+
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
65+
httpPost.setEntity(byteEntity);
66+
67+
traceContextHandler.propagateTraceContext(httpPost);
68+
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
69+
70+
return httpPost;
71+
}
72+
73+
public TraceContext getTraceContext() {
74+
return traceContextHandler.traceContext();
75+
}
76+
77+
@Override
78+
public String getInferenceEntityId() {
79+
return model.getInferenceEntityId();
80+
}
81+
82+
@Override
83+
public URI getURI() {
84+
return uri;
85+
}
86+
87+
@Override
88+
public Request truncate() {
89+
// no truncation
90+
return this;
91+
}
92+
93+
@Override
94+
public boolean[] getTruncationInfo() {
95+
// no truncation
96+
return null;
97+
}
98+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.elastic.rerank;
9+
10+
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.xcontent.ToXContentObject;
12+
import org.elasticsearch.xcontent.XContentBuilder;
13+
14+
import java.io.IOException;
15+
import java.util.List;
16+
import java.util.Objects;
17+
18+
public record ElasticInferenceServiceRerankRequestEntity(
19+
String query,
20+
List<String> documents,
21+
String modelId,
22+
@Nullable Integer topNDocumentsOnly
23+
) implements ToXContentObject {
24+
25+
private static final String QUERY_FIELD = "query";
26+
private static final String MODEL_FIELD = "model";
27+
private static final String TOP_N_DOCUMENTS_ONLY_FIELD = "top_n";
28+
private static final String DOCUMENTS_FIELD = "documents";
29+
30+
public ElasticInferenceServiceRerankRequestEntity {
31+
Objects.requireNonNull(query);
32+
Objects.requireNonNull(documents);
33+
Objects.requireNonNull(modelId);
34+
}
35+
36+
@Override
37+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
38+
builder.startObject();
39+
40+
builder.field(QUERY_FIELD, query);
41+
42+
builder.field(MODEL_FIELD, modelId);
43+
44+
if (Objects.nonNull(topNDocumentsOnly)) {
45+
builder.field(TOP_N_DOCUMENTS_ONLY_FIELD, topNDocumentsOnly);
46+
}
47+
48+
builder.startArray(DOCUMENTS_FIELD);
49+
for (String document : documents) {
50+
builder.value(document);
51+
}
52+
53+
builder.endArray();
54+
55+
builder.endObject();
56+
57+
return builder;
58+
}
59+
}

0 commit comments

Comments
 (0)