Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
33b8110
Add "rerank" task type to "elastic" provider
timgrein Apr 1, 2025
efee615
Resolve merge conflicts
timgrein Apr 8, 2025
b4b1f05
Fix checkstyle violation
timgrein Apr 8, 2025
2d68ed7
Merge branch 'main' into elastic-provider-rerank-integration
timgrein May 23, 2025
8911f0a
Merge branch 'main' into elastic-provider-rerank-integration
timgrein May 28, 2025
595859f
Remove uri field and use getter
timgrein May 28, 2025
603b91b
Merge remote-tracking branch 'origin/elastic-provider-rerank-integrat…
timgrein May 28, 2025
504ed93
[CI] Auto commit changes from spotless
May 28, 2025
2b89009
Remove ElasticInferenceServiceRerankRequestManager
timgrein Jun 2, 2025
8389205
Merge remote-tracking branch 'origin/elastic-provider-rerank-integrat…
timgrein Jun 2, 2025
d5c2744
Spotless apply
timgrein Jun 2, 2025
f43417b
Use Strings.format(...)
timgrein Jun 2, 2025
3a447f1
Remove ElasticInferenceServiceRerankTaskSettings and override validat…
timgrein Jun 2, 2025
249d98a
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 2, 2025
2639a12
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 4, 2025
c0d8a30
Use Strings.format(...) for rerank action, too
timgrein Jun 4, 2025
d954f87
Merge remote-tracking branch 'origin/elastic-provider-rerank-integrat…
timgrein Jun 4, 2025
9e73b10
Add backport transport version
timgrein Jun 4, 2025
9fb2aa8
[CI] Auto commit changes from spotless
Jun 4, 2025
46fcaa9
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 4, 2025
a6a5d0a
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 5, 2025
2ea8cc5
Use Strings.format in ElasticInferenceServiceActionCreator
timgrein Jun 6, 2025
5271bf0
Merge remote-tracking branch 'origin/elastic-provider-rerank-integrat…
timgrein Jun 6, 2025
5d464ae
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 6, 2025
67b1c70
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ILM_ADD_SKIP_SETTING_8_19 = def(8_841_0_43);
public static final TransportVersion ESQL_REGEX_MATCH_WITH_CASE_INSENSITIVITY_8_19 = def(8_841_0_44);
public static final TransportVersion ESQL_QUERY_PLANNING_DURATION_8_19 = def(8_841_0_45);
public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK_ADDED_8_19 = def(8_841_0_46);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -286,6 +287,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ILM_ADD_SKIP_SETTING = def(9_089_0_00);
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED = def(9_090_0_00);
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST = def(9_091_0_00);
public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK = def(9_092_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
Expand Down Expand Up @@ -166,7 +167,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addAnthropicNamedWritables(namedWriteables);
addAmazonBedrockNamedWriteables(namedWriteables);
addAwsNamedWriteables(namedWriteables);
addEisNamedWriteables(namedWriteables);
addElasticNamedWriteables(namedWriteables);
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);
addVoyageAINamedWriteables(namedWriteables);
Expand Down Expand Up @@ -742,20 +743,32 @@ private static void addVoyageAINamedWriteables(List<NamedWriteableRegistry.Entry
);
}

private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
private static void addElasticNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
// Sparse Text Embeddings
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticInferenceServiceSparseEmbeddingsServiceSettings.NAME,
ElasticInferenceServiceSparseEmbeddingsServiceSettings::new
)
);

// Completion
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticInferenceServiceCompletionServiceSettings.NAME,
ElasticInferenceServiceCompletionServiceSettings::new
)
);

// Rerank
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticInferenceServiceRerankServiceSettings.NAME,
ElasticInferenceServiceRerankServiceSettings::new
)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic.rerank;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.message.BasicHeader;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequest;
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestMetadata;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;

public class ElasticInferenceServiceRerankRequest extends ElasticInferenceServiceRequest {

private final String query;
private final List<String> documents;
private final Integer topN;
private final TraceContextHandler traceContextHandler;
private final ElasticInferenceServiceRerankModel model;

public ElasticInferenceServiceRerankRequest(
String query,
List<String> documents,
Integer topN,
ElasticInferenceServiceRerankModel model,
TraceContext traceContext,
ElasticInferenceServiceRequestMetadata metadata
) {
super(metadata);
this.query = query;
this.documents = documents;
this.topN = topN;
this.model = Objects.requireNonNull(model);
this.traceContextHandler = new TraceContextHandler(traceContext);
}

@Override
public HttpRequestBase createHttpRequestBase() {
var httpPost = new HttpPost(getURI());
var requestEntity = Strings.toString(
new ElasticInferenceServiceRerankRequestEntity(query, documents, model.getServiceSettings().modelId(), topN)
);

ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);

traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));

return httpPost;
}

public TraceContext getTraceContext() {
return traceContextHandler.traceContext();
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public URI getURI() {
return model.uri();
}

@Override
public Request truncate() {
// no truncation
return this;
}

@Override
public boolean[] getTruncationInfo() {
// no truncation
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic.rerank;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

public record ElasticInferenceServiceRerankRequestEntity(
String query,
List<String> documents,
String modelId,
@Nullable Integer topNDocumentsOnly
) implements ToXContentObject {

private static final String QUERY_FIELD = "query";
private static final String MODEL_FIELD = "model";
private static final String TOP_N_DOCUMENTS_ONLY_FIELD = "top_n";
private static final String DOCUMENTS_FIELD = "documents";

public ElasticInferenceServiceRerankRequestEntity {
Objects.requireNonNull(query);
Objects.requireNonNull(documents);
Objects.requireNonNull(modelId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.field(QUERY_FIELD, query);

builder.field(MODEL_FIELD, modelId);

if (Objects.nonNull(topNDocumentsOnly)) {
builder.field(TOP_N_DOCUMENTS_ONLY_FIELD, topNDocumentsOnly);
}

builder.startArray(DOCUMENTS_FIELD);
for (String document : documents) {
builder.value(document);
}

builder.endArray();

builder.endObject();

return builder;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.response.elastic;

import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;

import java.io.IOException;
import java.util.List;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;

public class ElasticInferenceServiceRerankResponseEntity {

record RerankResult(List<RerankResultEntry> entries) {

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<RerankResult, Void> PARSER = new ConstructingObjectParser<>(
RerankResult.class.getSimpleName(),
true,
args -> new RerankResult((List<RerankResultEntry>) args[0])
);

static {
PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("results"));
}

record RerankResultEntry(Integer index, Float relevanceScore) {

public static final ConstructingObjectParser<RerankResultEntry, Void> PARSER = new ConstructingObjectParser<>(
RerankResultEntry.class.getSimpleName(),
args -> new RerankResultEntry((Integer) args[0], (Float) args[1])
);

static {
PARSER.declareInt(constructorArg(), new ParseField("index"));
PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
}

public RankedDocsResults.RankedDoc toRankedDoc() {
return new RankedDocsResults.RankedDoc(index, relevanceScore, null);
}
}
}

public static InferenceServiceResults fromResponse(HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
var rerankResult = RerankResult.PARSER.apply(jsonParser, null);

return new RankedDocsResults(rerankResult.entries.stream().map(RerankResult.RerankResultEntry::toRankedDoc).toList());
}
}

private ElasticInferenceServiceRerankResponseEntity() {}
}
Loading