Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
33b8110
Add "rerank" task type to "elastic" provider
timgrein Apr 1, 2025
efee615
Resolve merge conflicts
timgrein Apr 8, 2025
b4b1f05
Fix checkstyle violation
timgrein Apr 8, 2025
2d68ed7
Merge branch 'main' into elastic-provider-rerank-integration
timgrein May 23, 2025
8911f0a
Merge branch 'main' into elastic-provider-rerank-integration
timgrein May 28, 2025
595859f
Remove uri field and use getter
timgrein May 28, 2025
603b91b
Merge remote-tracking branch 'origin/elastic-provider-rerank-integrat…
timgrein May 28, 2025
504ed93
[CI] Auto commit changes from spotless
May 28, 2025
2b89009
Remove ElasticInferenceServiceRerankRequestManager
timgrein Jun 2, 2025
8389205
Merge remote-tracking branch 'origin/elastic-provider-rerank-integrat…
timgrein Jun 2, 2025
d5c2744
Spotless apply
timgrein Jun 2, 2025
f43417b
Use Strings.format(...)
timgrein Jun 2, 2025
3a447f1
Remove ElasticInferenceServiceRerankTaskSettings and override validat…
timgrein Jun 2, 2025
249d98a
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 2, 2025
2639a12
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 4, 2025
c0d8a30
Use Strings.format(...) for rerank action, too
timgrein Jun 4, 2025
d954f87
Merge remote-tracking branch 'origin/elastic-provider-rerank-integrat…
timgrein Jun 4, 2025
9e73b10
Add backport transport version
timgrein Jun 4, 2025
9fb2aa8
[CI] Auto commit changes from spotless
Jun 4, 2025
46fcaa9
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 4, 2025
a6a5d0a
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 5, 2025
2ea8cc5
Use Strings.format in ElasticInferenceServiceActionCreator
timgrein Jun 6, 2025
5271bf0
Merge remote-tracking branch 'origin/elastic-provider-rerank-integrat…
timgrein Jun 6, 2025
5d464ae
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 6, 2025
67b1c70
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_REMOVE_AGGREGATE_TYPE = def(9_045_0_00);
public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00);
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0);
public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK = def(9_048_0_00);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the PR is only targeting 9.1 did we also want to support 8.19? If so we'll need to add another transport version and do the backport dance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking with product

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs to go into 9.1 & 8.19


/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,27 @@ protected RankedDocsResults mutateInstanceForVersion(RankedDocsResults instance,
}
}

public record RerankExpectation(int index, float relevanceScore) {}

public static Map<String, Object> buildExpectationRankedDocResults(List<RerankExpectation> rerankExpectations) {
return Map.of(
RankedDocsResults.RERANK,
rerankExpectations.stream()
.map(
rerankExpectation -> Map.of(
RankedDocsResults.RankedDoc.NAME,
Map.of(
RankedDocsResults.RankedDoc.INDEX,
rerankExpectation.index,
RankedDocsResults.RankedDoc.RELEVANCE_SCORE,
rerankExpectation.relevanceScore
)
)
)
.toList()
);
}

private List<RankedDocsResults.RankedDoc> rankedDocsNullStringToEmpty(List<RankedDocsResults.RankedDoc> rankedDocs) {
var result = new ArrayList<RankedDocsResults.RankedDoc>(rankedDocs.size());
for (var doc : rankedDocs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
Expand Down Expand Up @@ -147,7 +149,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addAnthropicNamedWritables(namedWriteables);
addAmazonBedrockNamedWriteables(namedWriteables);
addAwsNamedWriteables(namedWriteables);
addEisNamedWriteables(namedWriteables);
addElasticNamedWriteables(namedWriteables);
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);
addVoyageAINamedWriteables(namedWriteables);
Expand Down Expand Up @@ -646,20 +648,39 @@ private static void addVoyageAINamedWriteables(List<NamedWriteableRegistry.Entry
);
}

private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
private static void addElasticNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
// Sparse Text Embeddings
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticInferenceServiceSparseEmbeddingsServiceSettings.NAME,
ElasticInferenceServiceSparseEmbeddingsServiceSettings::new
)
);

// Completion
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticInferenceServiceCompletionServiceSettings.NAME,
ElasticInferenceServiceCompletionServiceSettings::new
)
);

// Rerank
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticInferenceServiceRerankServiceSettings.NAME,
ElasticInferenceServiceRerankServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
TaskSettings.class,
ElasticInferenceServiceRerankTaskSettings.NAME,
ElasticInferenceServiceRerankTaskSettings::new
)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic.rerank;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.message.BasicHeader;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequest;
import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequestMetadata;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;

public class ElasticInferenceServiceRerankRequest extends ElasticInferenceServiceRequest {

private final URI uri;
private final String query;
private final List<String> documents;
private final TraceContextHandler traceContextHandler;
private final ElasticInferenceServiceRerankModel model;

public ElasticInferenceServiceRerankRequest(
String query,
List<String> documents,
ElasticInferenceServiceRerankModel model,
TraceContext traceContext,
ElasticInferenceServiceRequestMetadata metadata
) {
super(metadata);
this.query = query;
this.documents = documents;
this.model = Objects.requireNonNull(model);
this.uri = model.uri();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We probably don't need a reference to the uri since we have a reference to the model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this.traceContextHandler = new TraceContextHandler(traceContext);
}

@Override
public HttpRequestBase createHttpRequestBase() {
var httpPost = new HttpPost(uri);
var requestEntity = Strings.toString(
new ElasticInferenceServiceRerankRequestEntity(
query,
documents,
model.getServiceSettings().modelId(),
model.getTaskSettings().getTopNDocumentsOnly()
)
);

ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);

traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));

return httpPost;
}

public TraceContext getTraceContext() {
return traceContextHandler.traceContext();
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public URI getURI() {
return uri;
}

@Override
public Request truncate() {
// no truncation
return this;
}

@Override
public boolean[] getTruncationInfo() {
// no truncation
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic.rerank;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

public record ElasticInferenceServiceRerankRequestEntity(
String query,
List<String> documents,
String modelId,
@Nullable Integer topNDocumentsOnly
) implements ToXContentObject {

private static final String QUERY_FIELD = "query";
private static final String MODEL_FIELD = "model";
private static final String TOP_N_DOCUMENTS_ONLY_FIELD = "top_n";
private static final String DOCUMENTS_FIELD = "documents";

public ElasticInferenceServiceRerankRequestEntity {
Objects.requireNonNull(query);
Objects.requireNonNull(documents);
Objects.requireNonNull(modelId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.field(QUERY_FIELD, query);

builder.field(MODEL_FIELD, modelId);

if (Objects.nonNull(topNDocumentsOnly)) {
builder.field(TOP_N_DOCUMENTS_ONLY_FIELD, topNDocumentsOnly);
}

builder.startArray(DOCUMENTS_FIELD);
for (String document : documents) {
builder.value(document);
}

builder.endArray();

builder.endObject();

return builder;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.response.elastic;

import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;

import java.io.IOException;
import java.util.List;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;

public class ElasticInferenceServiceRerankResponseEntity {

record RerankResult(List<RerankResultEntry> entries) {

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<RerankResult, Void> PARSER = new ConstructingObjectParser<>(
RerankResult.class.getSimpleName(),
true,
args -> new RerankResult((List<RerankResultEntry>) args[0])
);

static {
PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("results"));
}

record RerankResultEntry(Integer index, Float relevanceScore) {

public static final ConstructingObjectParser<RerankResultEntry, Void> PARSER = new ConstructingObjectParser<>(
RerankResultEntry.class.getSimpleName(),
args -> new RerankResultEntry((Integer) args[0], (Float) args[1])
);

static {
PARSER.declareInt(constructorArg(), new ParseField("index"));
PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
}

public RankedDocsResults.RankedDoc toRankedDoc() {
return new RankedDocsResults.RankedDoc(index, relevanceScore, null);
}
}
}

public static InferenceServiceResults fromResponse(HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
var rerankResult = RerankResult.PARSER.apply(jsonParser, null);

return new RankedDocsResults(rerankResult.entries.stream().map(RerankResult.RerankResultEntry::toRankedDoc).toList());
}
}

private ElasticInferenceServiceRerankResponseEntity() {}
}
Loading