diff --git a/docs/changelog/116962.yaml b/docs/changelog/116962.yaml new file mode 100644 index 0000000000000..8f16b00e3f9fc --- /dev/null +++ b/docs/changelog/116962.yaml @@ -0,0 +1,5 @@ +pr: 116962 +summary: "Add special case for elastic reranker in inference API" +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 02bddb6076d69..2320cca8295d1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -63,6 +63,7 @@ import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings; @@ -415,7 +416,13 @@ private static void addInternalNamedWriteables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index fc070965f29c2..f743b94df3810 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -156,6 +156,8 @@ protected void putModel(Model model, ActionListener listener) { putBuiltInModel(e5Model.getServiceSettings().modelId(), listener); } else if (model instanceof ElserInternalModel elserModel) { putBuiltInModel(elserModel.getServiceSettings().modelId(), listener); + } else if (model instanceof ElasticRerankerModel elasticRerankerModel) { + putBuiltInModel(elasticRerankerModel.getServiceSettings().modelId(), listener); } else if (model instanceof CustomElandModel) { logger.info("Custom eland model detected, model must have been already loaded into the cluster with eland."); listener.onResponse(Boolean.TRUE); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java new file mode 100644 index 0000000000000..115cc9f05599a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +public class ElasticRerankerModel extends ElasticsearchInternalModel { + + public ElasticRerankerModel( + String inferenceEntityId, + TaskType taskType, + String service, + ElasticRerankerServiceSettings serviceSettings, + ChunkingSettings chunkingSettings + ) { + super(inferenceEntityId, taskType, service, serviceSettings, chunkingSettings); + } + + @Override + public ElasticRerankerServiceSettings getServiceSettings() { + return (ElasticRerankerServiceSettings) super.getServiceSettings(); + } + + @Override + public ActionListener getCreateTrainedModelAssignmentActionListener( + Model model, + ActionListener listener + ) { + + return new ActionListener<>() { + @Override + public void onResponse(CreateTrainedModelAssignmentAction.Response response) { + listener.onResponse(Boolean.TRUE); + } + + @Override + public void onFailure(Exception e) { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + listener.onFailure( + new ResourceNotFoundException("Could not start the Elastic Reranker Endpoint due to [{}]", e, e.getMessage()) + ); + return; + } + listener.onFailure(e); + } + }; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java new file mode 100644 index 0000000000000..316dc092e03c7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; + +import java.io.IOException; +import java.util.Map; + +public class ElasticRerankerServiceSettings extends ElasticsearchInternalServiceSettings { + + public static final String NAME = "elastic_reranker_service_settings"; + + public ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings other) { + super(other); + } + + public ElasticRerankerServiceSettings( + Integer numAllocations, + int numThreads, + String modelId, + AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) { + super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings); + } + + public ElasticRerankerServiceSettings(StreamInput in) throws IOException { + super(in); + } + + /** + * Parse the ElasticRerankerServiceSettings from map and validate the setting values. + * + * If required setting are missing or the values are invalid an + * {@link ValidationException} is thrown. + * + * @param map Source map containing the config + * @return The builder + */ + public static Builder fromRequestMap(Map map) { + ValidationException validationException = new ValidationException(); + var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return baseSettings; + } + + @Override + public String getWriteableName() { + return ElasticRerankerServiceSettings.NAME; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index fe83acc8574aa..718aeae979fe9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -97,6 +97,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 ); + public static final String RERANKER_ID = ".rerank-v1"; + public static final int EMBEDDING_MAX_BATCH_SIZE = 10; public static final String DEFAULT_ELSER_ID = ".elser-2-elasticsearch"; public static final String DEFAULT_E5_ID = ".multilingual-e5-small-elasticsearch"; @@ -223,6 +225,8 @@ public void parseRequestConfig( ) ) ); + } else if (RERANKER_ID.equals(modelId)) { + rerankerCase(inferenceEntityId, taskType, config, serviceSettingsMap, chunkingSettings, modelListener); } else { customElandCase(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, modelListener); } @@ -323,6 +327,31 @@ private static CustomElandInternalServiceSettings elandServiceSettings( }; } + private void rerankerCase( + String inferenceEntityId, + TaskType taskType, + Map config, + Map serviceSettingsMap, + ChunkingSettings chunkingSettings, + ActionListener modelListener + ) { + + var esServiceSettingsBuilder = ElasticsearchInternalServiceSettings.fromRequestMap(serviceSettingsMap); + + throwIfNotEmptyMap(config, name()); + throwIfNotEmptyMap(serviceSettingsMap, name()); + + modelListener.onResponse( + new ElasticRerankerModel( + inferenceEntityId, + taskType, + NAME, + new ElasticRerankerServiceSettings(esServiceSettingsBuilder.build()), + chunkingSettings + ) + ); + } + private void e5Case( String inferenceEntityId, TaskType taskType,