Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
33b8110
Add "rerank" task type to "elastic" provider
timgrein Apr 1, 2025
efee615
Resolve merge conflicts
timgrein Apr 8, 2025
b4b1f05
Fix checkstyle violation
timgrein Apr 8, 2025
2d68ed7
Merge branch 'main' into elastic-provider-rerank-integration
timgrein May 23, 2025
8911f0a
Merge branch 'main' into elastic-provider-rerank-integration
timgrein May 28, 2025
595859f
Remove uri field and use getter
timgrein May 28, 2025
603b91b
Merge remote-tracking branch 'origin/elastic-provider-rerank-integrat…
timgrein May 28, 2025
504ed93
[CI] Auto commit changes from spotless
May 28, 2025
2b89009
Remove ElasticInferenceServiceRerankRequestManager
timgrein Jun 2, 2025
8389205
Merge remote-tracking branch 'origin/elastic-provider-rerank-integrat…
timgrein Jun 2, 2025
d5c2744
Spotless apply
timgrein Jun 2, 2025
f43417b
Use Strings.format(...)
timgrein Jun 2, 2025
3a447f1
Remove ElasticInferenceServiceRerankTaskSettings and override validat…
timgrein Jun 2, 2025
249d98a
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 2, 2025
2639a12
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 4, 2025
c0d8a30
Use Strings.format(...) for rerank action, too
timgrein Jun 4, 2025
d954f87
Merge remote-tracking branch 'origin/elastic-provider-rerank-integrat…
timgrein Jun 4, 2025
9e73b10
Add backport transport version
timgrein Jun 4, 2025
9fb2aa8
[CI] Auto commit changes from spotless
Jun 4, 2025
46fcaa9
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 4, 2025
a6a5d0a
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 5, 2025
2ea8cc5
Use Strings.format in ElasticInferenceServiceActionCreator
timgrein Jun 6, 2025
5271bf0
Merge remote-tracking branch 'origin/elastic-provider-rerank-integrat…
timgrein Jun 6, 2025
5d464ae
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 6, 2025
67b1c70
Merge branch 'main' into elastic-provider-rerank-integration
timgrein Jun 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INFERENCE_CUSTOM_SERVICE_ADDED = def(9_084_0_00);
public static final TransportVersion ESQL_LIMIT_ROW_SIZE = def(9_085_0_00);
public static final TransportVersion ESQL_REGEX_MATCH_WITH_CASE_INSENSITIVITY = def(9_086_0_00);
public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK = def(9_087_0_00);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're also targeting 8.19, we need another transport version place holder for the 8.19 branch (but in this PR).

For example I recently added one public static final TransportVersion INFERENCE_CUSTOM_SERVICE_ADDED_8_19 = def(8_841_0_39); in the main branch even though it isn't used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
Expand Down Expand Up @@ -165,7 +166,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addAnthropicNamedWritables(namedWriteables);
addAmazonBedrockNamedWriteables(namedWriteables);
addAwsNamedWriteables(namedWriteables);
addEisNamedWriteables(namedWriteables);
addElasticNamedWriteables(namedWriteables);
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);
addVoyageAINamedWriteables(namedWriteables);
Expand Down Expand Up @@ -734,20 +735,32 @@ private static void addVoyageAINamedWriteables(List<NamedWriteableRegistry.Entry
);
}

private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
private static void addElasticNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
// Sparse Text Embeddings
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticInferenceServiceSparseEmbeddingsServiceSettings.NAME,
ElasticInferenceServiceSparseEmbeddingsServiceSettings::new
)
);

// Completion
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticInferenceServiceCompletionServiceSettings.NAME,
ElasticInferenceServiceCompletionServiceSettings::new
)
);

// Rerank
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticInferenceServiceRerankServiceSettings.NAME,
ElasticInferenceServiceRerankServiceSettings::new
)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic.rerank;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.message.BasicHeader;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequest;
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestMetadata;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;

public class ElasticInferenceServiceRerankRequest extends ElasticInferenceServiceRequest {

private final String query;
private final List<String> documents;
private final Integer topN;
private final TraceContextHandler traceContextHandler;
private final ElasticInferenceServiceRerankModel model;

public ElasticInferenceServiceRerankRequest(
String query,
List<String> documents,
Integer topN,
ElasticInferenceServiceRerankModel model,
TraceContext traceContext,
ElasticInferenceServiceRequestMetadata metadata
) {
super(metadata);
this.query = query;
this.documents = documents;
this.topN = topN;
this.model = Objects.requireNonNull(model);
this.traceContextHandler = new TraceContextHandler(traceContext);
}

@Override
public HttpRequestBase createHttpRequestBase() {
var httpPost = new HttpPost(getURI());
var requestEntity = Strings.toString(
new ElasticInferenceServiceRerankRequestEntity(query, documents, model.getServiceSettings().modelId(), topN)
);

ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);

traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));

return httpPost;
}

public TraceContext getTraceContext() {
return traceContextHandler.traceContext();
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public URI getURI() {
return model.uri();
}

@Override
public Request truncate() {
// no truncation
return this;
}

@Override
public boolean[] getTruncationInfo() {
// no truncation
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic.rerank;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

public record ElasticInferenceServiceRerankRequestEntity(
String query,
List<String> documents,
String modelId,
@Nullable Integer topNDocumentsOnly
) implements ToXContentObject {

private static final String QUERY_FIELD = "query";
private static final String MODEL_FIELD = "model";
private static final String TOP_N_DOCUMENTS_ONLY_FIELD = "top_n";
private static final String DOCUMENTS_FIELD = "documents";

public ElasticInferenceServiceRerankRequestEntity {
Objects.requireNonNull(query);
Objects.requireNonNull(documents);
Objects.requireNonNull(modelId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.field(QUERY_FIELD, query);

builder.field(MODEL_FIELD, modelId);

if (Objects.nonNull(topNDocumentsOnly)) {
builder.field(TOP_N_DOCUMENTS_ONLY_FIELD, topNDocumentsOnly);
}

builder.startArray(DOCUMENTS_FIELD);
for (String document : documents) {
builder.value(document);
}

builder.endArray();

builder.endObject();

return builder;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.response.elastic;

import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;

import java.io.IOException;
import java.util.List;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;

public class ElasticInferenceServiceRerankResponseEntity {

record RerankResult(List<RerankResultEntry> entries) {

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<RerankResult, Void> PARSER = new ConstructingObjectParser<>(
RerankResult.class.getSimpleName(),
true,
args -> new RerankResult((List<RerankResultEntry>) args[0])
);

static {
PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("results"));
}

record RerankResultEntry(Integer index, Float relevanceScore) {

public static final ConstructingObjectParser<RerankResultEntry, Void> PARSER = new ConstructingObjectParser<>(
RerankResultEntry.class.getSimpleName(),
args -> new RerankResultEntry((Integer) args[0], (Float) args[1])
);

static {
PARSER.declareInt(constructorArg(), new ParseField("index"));
PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
}

public RankedDocsResults.RankedDoc toRankedDoc() {
return new RankedDocsResults.RankedDoc(index, relevanceScore, null);
}
}
}

public static InferenceServiceResults fromResponse(HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
var rerankResult = RerankResult.PARSER.apply(jsonParser, null);

return new RankedDocsResults(rerankResult.entries.stream().map(RerankResult.RerankResultEntry::toRankedDoc).toList());
}
}

private ElasticInferenceServiceRerankResponseEntity() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
Expand Down Expand Up @@ -79,7 +80,11 @@ public class ElasticInferenceService extends SenderService {
public static final String NAME = "elastic";
public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service";

private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION);
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
TaskType.SPARSE_EMBEDDING,
TaskType.CHAT_COMPLETION,
TaskType.RERANK
);
private static final String SERVICE_NAME = "Elastic";

// rainbow-sprinkles
Expand All @@ -93,7 +98,7 @@ public class ElasticInferenceService extends SenderService {
/**
* The task types that the {@link InferenceAction.Request} can accept.
*/
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING);
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK);

public static String defaultEndpointId(String modelId) {
return Strings.format(".%s-elastic", modelId);
Expand Down Expand Up @@ -163,6 +168,18 @@ public void onNodeStarted() {
authorizationHandler.init();
}

@Override
protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) {
if (returnDocuments != null) {
validationException.addValidationError(
org.elasticsearch.core.Strings.format(
"Invalid return_documents [%s]. The return_documents option is not supported by this service",
returnDocuments
)
);
}
}

/**
* Only use this in tests.
*
Expand Down Expand Up @@ -335,7 +352,7 @@ private static ElasticInferenceServiceModel createModel(
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
@Nullable Map<String, Object> secretSettings,
ElasticInferenceServiceComponents eisServiceComponents,
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
String failureMessage,
ConfigurationParseContext context
) {
Expand All @@ -347,7 +364,7 @@ private static ElasticInferenceServiceModel createModel(
serviceSettings,
taskSettings,
secretSettings,
eisServiceComponents,
elasticInferenceServiceComponents,
context
);
case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel(
Expand All @@ -357,7 +374,17 @@ private static ElasticInferenceServiceModel createModel(
serviceSettings,
taskSettings,
secretSettings,
eisServiceComponents,
elasticInferenceServiceComponents,
context
);
case RERANK -> new ElasticInferenceServiceRerankModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
elasticInferenceServiceComponents,
context
);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
Expand Down Expand Up @@ -462,9 +489,8 @@ private LazyInitializable<InferenceServiceConfiguration, RuntimeException> initC

configurationMap.put(
MODEL_ID,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription(
"The name of the model to use for the inference task."
)
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK))
.setDescription("The name of the model to use for the inference task.")
.setLabel("Model ID")
.setRequired(true)
.setSensitive(false)
Expand All @@ -487,7 +513,9 @@ private LazyInitializable<InferenceServiceConfiguration, RuntimeException> initC
);

configurationMap.putAll(
RateLimitSettings.toSettingsConfiguration(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))
RateLimitSettings.toSettingsConfiguration(
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK)
)
);

return new InferenceServiceConfiguration.Builder().setService(NAME)
Expand Down
Loading