From 72689bf01e3e13ec776c4a5ed5e2506b8e305631 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 28 Aug 2024 08:53:37 +0100 Subject: [PATCH 1/3] Support sparse model in elasticsearch service --- .../inference/service-elasticsearch.asciidoc | 3 +- .../xpack/inference/CustomElandModelIT.java | 135 +++++++++ .../xpack/inference/RerankingIT.java | 8 +- .../BaseElasticsearchInternalService.java | 6 +- .../ElasticsearchInternalService.java | 151 +++------- .../services/elser/ElserInternalService.java | 28 -- .../ElasticsearchInternalServiceTests.java | 280 +++++++++++------- 7 files changed, 360 insertions(+), 251 deletions(-) create mode 100644 x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java diff --git a/docs/reference/inference/service-elasticsearch.asciidoc b/docs/reference/inference/service-elasticsearch.asciidoc index 99fd41ee2db65..572cad591fba6 100644 --- a/docs/reference/inference/service-elasticsearch.asciidoc +++ b/docs/reference/inference/service-elasticsearch.asciidoc @@ -31,6 +31,7 @@ include::inference-shared.asciidoc[tag=task-type] Available task types: * `rerank`, +* `sparse_embedding`, * `text_embedding`. -- @@ -182,4 +183,4 @@ PUT _inference/text_embedding/my-e5-model } } ------------------------------------------------------------ -// TEST[skip:TBD] \ No newline at end of file +// TEST[skip:TBD] diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java new file mode 100644 index 0000000000000..53fa28e971774 --- /dev/null +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java @@ -0,0 +1,135 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.core.Strings; +import org.elasticsearch.inference.TaskType; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.stream.Collectors; + +public class CustomElandModelIT extends InferenceBaseRestTest { + + // The model definition is taken from org.elasticsearch.xpack.ml.integration.TextExpansionQueryIT + + static final String BASE_64_ENCODED_MODEL = "UEsDBAAACAgAAAAAAAAAAAAAAAAAA" + + "AAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwpUaW55VG" + + "V4dEV4cGFuc2lvbgpxACmBfShYCAAAAHRyYWluaW5ncQGJWBYAAABfaXNfZnVsbF9iYWNrd2FyZF9ob29" + + "rcQJOdWJxAy5QSwcIITmbsFgAAABYAAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAAAAAAAdAB0Ac2ltcGxl" + + "bW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQhkAWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWoWRT4+cMAzF7" + + "/spfASJomF3e0Ga3nrrn8vcELIyxAzRhAQlpjvbT19DWDrdquqBA/bvPT87nVUxwsm41xPd+PNtUi4a77" + + "KvXs+W8voBAHFSQY3EFCIiHKFp1+p57vs/ShyUccZdoIaz93aBTMR+thbPqru+qKBx8P4q/e8TyxRlmwVc" + + "tJp66H1YmCyS7WsZwD50A2L5V7pCBADGTTOj0bGGE7noQyqzv5JDfp0o9fZRCWqP37yjhE4+mqX5X3AdF" + + "ZHGM/2TzOHDpy1IvQWR+OWo3KwsRiKdpcqg4pBFDtm+QJ7nqwIPckrlnGfFJG0uNhOl38Sjut3pCqg26Qu" + + "Zy8BR9In7ScHHrKkKMW0TIucFrGQXCMpdaDO05O6DpOiy8e4kr0Ed/2YKOIhplW8gPr4ntygrd9ixpx3j9" + + "UZZVRagl2c6+imWUzBjuf5m+Ch7afphuvvW+r/0dsfn+2N9MZGb9+/SFtCYdhd83CMYp+mGy0LiKNs8y/e" + + "UuEA8B/d2z4dfUEsHCFSE3IaCAQAAIAMAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJwApAHNpbXBsZ" + + "W1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCJQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlp" + + "aWlpaWlpaWlpaWlpahZHLbtNAFIZtp03rSVIuLRKXjdk5ojitKJsiFq24lem0KKSqpRIZt55gE9/GM+lNL" + + "Fgx4i1Ys2aHhIBXgAVICNggHgNm6rqJN2BZGv36/v/MOWeea/Z5RVHurLfRUsfZXOnccx522itrd53O0vL" + + "qbaKYtsAKUe1pcege7hm9JNtzM8+kOOzNApIX0A3xBXE6YE7g0UWjg2OaZAJXbKvALOnj2GEHKc496ykLkt" + + "gNt3Jz17hprCUxFqExe7YIpQkNpO1/kfHhPUdtUAdH2/gfmeYiIFW7IkM6IBP2wrDNbMe3Mjf2ksiK3Hjg" + + "hg7F2DN9l/omZZl5Mmez2QRk0q4WUUB0+1oh9nDwxGdUXJdXPMRZQs352eGaRPV9s2lcMeZFGWBfKJJiw0Y" + + "gbCMLBaRmXyy4flx6a667Fch55q05QOq2Jg2ANOyZwplhNsjiohVApo7aa21QnNGW5+4GXv8gxK1beBeHSR" + + "rhmLXWVh+0aBhErZ7bx1ejxMOhlR6QU4ycNqGyk8/yNGCWkwY7/RCD7UEQek4QszCgDJAzZtfErA0VqHBy9" + + "ugQP9pUfUmgCjVYgWNwHFbhBJyEOgSwBuuwARWZmoI6J9PwLfzEocpRpPrT8DP8wqHG0b4UX+E3DiscvRgl" + + "XIoi81KKPwioHI5x9EooNKWiy0KOc/T6WF4SssrRuzJ9L2VNRXUhJzj6UKYfS4W/q/5wuh/l4M9R9qsU+y2" + + "dpoo2hJzkaEET8r6KRONicnRdK9EbUi6raFVIwNGjsrlbpk6ZPi7TbS3fv3LyNjPiEKzG0aG0tvNb6xw90/" + + "whe6ONjnJcUxobHDUqQ8bIOW79BVBLBwhfSmPKdAIAAE4EAABQSwMEAAAICAAAAAAAAAAAAAAAAAAAAAAAA" + + "BkABQBzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsRkIBAFqAAikuUEsHCG0vCVcEAAAABAAAAFBLAwQAAAgI" + + "AAAAAAAAAAAAAAAAAAAAAAAAEwA7AHNpbXBsZW1vZGVsL3ZlcnNpb25GQjcAWlpaWlpaWlpaWlpaWlpaWlp" + + "aWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWjMKUEsHCNGeZ1UCAAAAAgAAAFBLAQIAAA" + + "AACAgAAAAAAAAhOZuwWAAAAFgAAAAUAAAAAAAAAAAAAAAAAAAAAABzaW1wbGVtb2RlbC9kYXRhLnBrbFBLA" + + "QIAABQACAgIAAAAAABUhNyGggEAACADAAAdAAAAAAAAAAAAAAAAAKgAAABzaW1wbGVtb2RlbC9jb2RlL19f" + + "dG9yY2hfXy5weVBLAQIAABQACAgIAAAAAABfSmPKdAIAAE4EAAAnAAAAAAAAAAAAAAAAAJICAABzaW1wbGVt" + + "b2RlbC9jb2RlL19fdG9yY2hfXy5weS5kZWJ1Z19wa2xQSwECAAAAAAgIAAAAAAAAbS8JVwQAAAAEAAAAGQAA" + + "AAAAAAAAAAAAAACEBQAAc2ltcGxlbW9kZWwvY29uc3RhbnRzLnBrbFBLAQIAAAAACAgAAAAAAADRnmdVAgAA" + + "AAIAAAATAAAAAAAAAAAAAAAAANQFAABzaW1wbGVtb2RlbC92ZXJzaW9uUEsGBiwAAAAAAAAAHgMtAAAAAAAA" + + "AAAABQAAAAAAAAAFAAAAAAAAAGoBAAAAAAAAUgYAAAAAAABQSwYHAAAAALwHAAAAAAAAAQAAAFBLBQYAAAAABQAFAGoBAABSBgAAAAA="; + + static final long RAW_MODEL_SIZE; // size of the model before base64 encoding + static { + RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length; + } + + // Test a sparse embedding model deployed with the ml trained models APIs + public void testSparse() throws IOException { + String modelId = "custom-text-expansion-model"; + + createTextExpansionModel(modelId); + putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE); + putVocabulary( + List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"), + modelId + ); + + var inferenceConfig = """ + { + "service": "elasticsearch", + "service_settings": { + "model_id": "custom-text-expansion-model", + "num_allocations": 1, + "num_threads": 1 + } + } + """; + + var inferenceId = "sparse-inf"; + putModel(inferenceId, inferenceConfig, TaskType.SPARSE_EMBEDDING); + var results = inferOnMockService(inferenceId, List.of("washing", "machine")); + deleteModel(inferenceId); + assertNotNull(results.get("sparse_embedding")); + } + + protected void createTextExpansionModel(String modelId) throws IOException { + // with_special_tokens: false for this test with limited vocab + Request request = new Request("PUT", "/_ml/trained_models/" + modelId); + request.setJsonEntity(""" + { + "description": "a text expansion model", + "model_type": "pytorch", + "inference_config": { + "text_expansion": { + "tokenization": { + "bert": { + "with_special_tokens": false + } + } + } + } + }"""); + client().performRequest(request); + } + + protected void putVocabulary(List vocabulary, String modelId) throws IOException { + List vocabularyWithPad = new ArrayList<>(); + vocabularyWithPad.add("[PAD]"); + vocabularyWithPad.add("[UNK]"); + vocabularyWithPad.addAll(vocabulary); + String quotedWords = vocabularyWithPad.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(",")); + + Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/vocabulary"); + request.setJsonEntity(Strings.format(""" + { "vocabulary": [%s] } + """, quotedWords)); + client().performRequest(request); + } + + protected void putModelDefinition(String modelId, String base64EncodedModel, long unencodedModelSize) throws IOException { + Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0"); + String body = Strings.format(""" + {"total_definition_length":%s,"definition": "%s","total_parts": 1}""", unencodedModelSize, base64EncodedModel); + request.setJsonEntity(body); + client().performRequest(request); + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/RerankingIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/RerankingIT.java index 77251ada4c488..893d3fb3e9b80 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/RerankingIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/RerankingIT.java @@ -35,7 +35,7 @@ private String putCohereRerankEndpoint() throws IOException { "api_key": "" } } - """);// TODO remove key + """); return endpointID; } @@ -61,7 +61,7 @@ private String putCohereRerankEndpointWithDocuments() throws IOException { "return_documents": true } } - """);// TODO remove key + """); return endpointID; } @@ -81,13 +81,13 @@ private String putCohereRerankEndpointWithTop2() throws IOException { "service": "cohere", "service_settings": { "model_id": "rerank-english-v2.0", - "api_key": "8TNPBvpBO7oN97009HQHzQbBhNrxmREbcJrZCwkK" + "api_key": "" }, "task_settings": { "top_n": 2 } } - """);// TODO remove key + """); return endpointID; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index 574ca77d4587e..457416370e559 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -154,10 +154,10 @@ public void isModelDownloaded(Model model, ActionListener listener) { executeAsyncWithOrigin(client, INFERENCE_ORIGIN, GetTrainedModelsAction.INSTANCE, getRequest, getModelsResponseListener); } else { listener.onFailure( - new IllegalArgumentException( - "Unable to determine supported model for [" + new IllegalStateException( + "Can not check the download status of the model used by [" + model.getConfigurations().getInferenceEntityId() - + "] please verify the request and submit a bug report if necessary." + + "] as the model_id cannot be found." ) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index c3a0111562319..cca8ae63e974c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -7,8 +7,6 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; @@ -27,19 +25,18 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; -import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; -import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; -import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; @@ -53,8 +50,6 @@ import java.util.Set; import java.util.function.Function; -import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; -import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -71,15 +66,13 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 ); - private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class); - public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) { super(context); } @Override protected EnumSet supportedTaskTypes() { - return EnumSet.of(TaskType.RERANK, TaskType.TEXT_EMBEDDING); + return EnumSet.of(TaskType.RERANK, TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING); } @Override @@ -161,6 +154,12 @@ private static CustomElandModel createCustomElandModel( NAME, CustomElandInternalTextEmbeddingServiceSettings.fromMap(serviceSettings, context) ); + case SPARSE_EMBEDDING -> new CustomElandModel( + inferenceEntityId, + taskType, + NAME, + elandServiceSettings(serviceSettings, context) + ); case RERANK -> new CustomElandRerankModel( inferenceEntityId, taskType, @@ -334,6 +333,8 @@ public void infer( inferTextEmbedding(model, input, inputType, timeout, listener); } else if (TaskType.RERANK.equals(taskType)) { inferRerank(model, query, input, inputType, timeout, taskSettings, listener); + } else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { + inferSparseEmbedding(model, input, inputType, timeout, listener); } else { throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); } @@ -364,6 +365,31 @@ public void inferTextEmbedding( ); } + public void inferSparseEmbedding( + Model model, + List inputs, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + var request = buildInferenceRequest( + model.getConfigurations().getInferenceEntityId(), + TextExpansionConfigUpdate.EMPTY_UPDATE, + inputs, + inputType, + timeout, + false + ); + + client.execute( + InferModelAction.INSTANCE, + request, + listener.delegateFailureAndWrap( + (l, inferenceResult) -> l.onResponse(SparseEmbeddingResults.of(inferenceResult.getInferenceResults())) + ) + ); + } + public void inferRerank( Model model, String query, @@ -422,7 +448,7 @@ public void chunkedInfer( TimeValue timeout, ActionListener> listener ) { - if (TaskType.TEXT_EMBEDDING.isAnyOrSame(model.getTaskType()) == false) { + if ((TaskType.TEXT_EMBEDDING.equals(model.getTaskType()) || TaskType.SPARSE_EMBEDDING.equals(model.getTaskType())) == false) { listener.onFailure( new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(model.getTaskType(), NAME), RestStatus.BAD_REQUEST) ); @@ -464,6 +490,8 @@ private static List translateToChunkedResults(Li private static ChunkedInferenceServiceResults translateToChunkedResult(InferenceResults inferenceResult) { if (inferenceResult instanceof MlChunkedTextEmbeddingFloatResults mlChunkedResult) { return InferenceChunkedTextEmbeddingFloatResults.ofMlResults(mlChunkedResult); + } else if (inferenceResult instanceof MlChunkedTextExpansionResults mlChunkedResult) { + return InferenceChunkedSparseEmbeddingResults.ofMlResult(mlChunkedResult); } else if (inferenceResult instanceof ErrorInferenceResults error) { return new ErrorChunkedInferenceResults(error.getException()); } else { @@ -471,103 +499,6 @@ private static ChunkedInferenceServiceResults translateToChunkedResult(Inference } } - @Override - public void start(Model model, ActionListener listener) { - if (model instanceof ElasticsearchInternalModel == false) { - listener.onFailure(notElasticsearchModelException(model)); - return; - } - - if (model.getTaskType() != TaskType.TEXT_EMBEDDING && model.getTaskType() != TaskType.RERANK) { - listener.onFailure( - new IllegalStateException(TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), NAME)) - ); - return; - } - - var startRequest = ((ElasticsearchInternalModel) model).getStartTrainedModelDeploymentActionRequest(); - var responseListener = ((ElasticsearchInternalModel) model).getCreateTrainedModelAssignmentActionListener(model, listener); - - client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener); - } - - @Override - public void stop(String inferenceEntityId, ActionListener listener) { - var request = new StopTrainedModelDeploymentAction.Request(inferenceEntityId); - request.setForce(true); - client.execute( - StopTrainedModelDeploymentAction.INSTANCE, - request, - listener.delegateFailureAndWrap((delegatedResponseListener, response) -> delegatedResponseListener.onResponse(Boolean.TRUE)) - ); - } - - @Override - public void putModel(Model model, ActionListener listener) { - if (model instanceof ElasticsearchInternalModel == false) { - listener.onFailure(notElasticsearchModelException(model)); - return; - } else if (model instanceof MultilingualE5SmallModel e5Model) { - String modelId = e5Model.getServiceSettings().modelId(); - var input = new TrainedModelInput(List.of("text_field")); // by convention text_field is used - var config = TrainedModelConfig.builder().setInput(input).setModelId(modelId).validate(true).build(); - PutTrainedModelAction.Request putRequest = new PutTrainedModelAction.Request(config, false, true); - executeAsyncWithOrigin( - client, - INFERENCE_ORIGIN, - PutTrainedModelAction.INSTANCE, - putRequest, - ActionListener.wrap(response -> listener.onResponse(Boolean.TRUE), e -> { - if (e instanceof ElasticsearchStatusException esException - && esException.getMessage().contains(PutTrainedModelAction.MODEL_ALREADY_EXISTS_ERROR_MESSAGE_FRAGMENT)) { - listener.onResponse(Boolean.TRUE); - } else { - listener.onFailure(e); - } - }) - ); - } else if (model instanceof CustomElandModel) { - logger.info("Custom eland model detected, model must have been already loaded into the cluster with eland."); - listener.onResponse(Boolean.TRUE); - } else { - listener.onFailure( - new IllegalArgumentException( - "Can not download model automatically for [" - + model.getConfigurations().getInferenceEntityId() - + "] you may need to download it through the trained models API or with eland." - ) - ); - return; - } - } - - @Override - public void isModelDownloaded(Model model, ActionListener listener) { - ActionListener getModelsResponseListener = listener.delegateFailure((delegate, response) -> { - if (response.getResources().count() < 1) { - delegate.onResponse(Boolean.FALSE); - } else { - delegate.onResponse(Boolean.TRUE); - } - }); - - if (model.getServiceSettings() instanceof ElasticsearchInternalServiceSettings internalServiceSettings) { - String modelId = internalServiceSettings.modelId(); - GetTrainedModelsAction.Request getRequest = new GetTrainedModelsAction.Request(modelId); - executeAsyncWithOrigin(client, INFERENCE_ORIGIN, GetTrainedModelsAction.INSTANCE, getRequest, getModelsResponseListener); - } else if (model instanceof ElasticsearchInternalModel == false) { - listener.onFailure(notElasticsearchModelException(model)); - } else { - listener.onFailure( - new IllegalArgumentException( - "Unable to determine supported model for [" - + model.getConfigurations().getInferenceEntityId() - + "] please verify the request and submit a bug report if necessary." - ) - ); - } - } - @Override public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_14_0; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java index 775ddca160463..948117954a63f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java @@ -28,7 +28,6 @@ import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; @@ -43,8 +42,6 @@ import java.util.Map; import java.util.Set; -import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; -import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL; @@ -242,31 +239,6 @@ private void checkCompatibleTaskType(TaskType taskType) { } } - @Override - public void isModelDownloaded(Model model, ActionListener listener) { - ActionListener getModelsResponseListener = listener.delegateFailure((delegate, response) -> { - if (response.getResources().count() < 1) { - delegate.onResponse(Boolean.FALSE); - } else { - delegate.onResponse(Boolean.TRUE); - } - }); - - if (model instanceof ElserInternalModel elserModel) { - String modelId = elserModel.getServiceSettings().modelId(); - GetTrainedModelsAction.Request getRequest = new GetTrainedModelsAction.Request(modelId); - executeAsyncWithOrigin(client, INFERENCE_ORIGIN, GetTrainedModelsAction.INSTANCE, getRequest, getModelsResponseListener); - } else { - listener.onFailure( - new IllegalArgumentException( - "Can not download model automatically for [" - + model.getConfigurations().getInferenceEntityId() - + "] you may need to download it through the trained models API or with eland." - ) - ); - } - } - private static ElserMlNodeTaskSettings taskSettingsFromMap(TaskType taskType, Map config) { if (taskType != TaskType.SPARSE_EMBEDDING) { throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index e6fd725a50198..df5cb5b1698e6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InputType; @@ -31,6 +32,7 @@ import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; @@ -39,8 +41,10 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceChunkedTextExpansionResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; @@ -52,12 +56,10 @@ import org.mockito.Mockito; import java.util.ArrayList; -import java.util.Arrays; import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Random; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -76,7 +78,6 @@ public class ElasticsearchInternalServiceTests extends ESTestCase { - TaskType taskType = TaskType.TEXT_EMBEDDING; String randomInferenceEntityId = randomAlphaOfLength(10); private static ThreadPool threadPool; @@ -92,7 +93,25 @@ public void shutdownThreadPool() { } public void testParseRequestConfig() { + var service = createService(mock(Client.class)); + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of(ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, 1, ElasticsearchInternalServiceSettings.NUM_THREADS, 4) + ) + ); + ActionListener modelListener = ActionListener.wrap( + model -> fail("Model parsing should have failed"), + e -> assertThat(e, instanceOf(IllegalArgumentException.class)) + ); + + var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); + service.parseRequestConfig(randomInferenceEntityId, taskType, settings, Set.of(), modelListener); + } + + public void testParseRequestConfig_Misconfigured() { // Null model variant { var service = createService(mock(Client.class)); @@ -109,43 +128,10 @@ public void testParseRequestConfig() { e -> assertThat(e, instanceOf(IllegalArgumentException.class)) ); + var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); service.parseRequestConfig(randomInferenceEntityId, taskType, settings, Set.of(), modelListener); } - // Valid model variant - { - var service = createService(mock(Client.class)); - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElasticsearchInternalServiceSettings.NUM_THREADS, - 4, - ElasticsearchInternalServiceSettings.MODEL_ID, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID - ) - ) - ); - - var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings( - 1, - 4, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID, - null - ); - - service.parseRequestConfig( - randomInferenceEntityId, - taskType, - settings, - Set.of(), - getModelVerificationActionListener(e5ServiceSettings) - ); - } - // Invalid config map { var service = createService(mock(Client.class)); @@ -163,10 +149,12 @@ public void testParseRequestConfig() { e -> assertThat(e, instanceOf(ElasticsearchStatusException.class)) ); + var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); service.parseRequestConfig(randomInferenceEntityId, taskType, settings, Set.of(), modelListener); } + } - // Invalid service settings + public void testParseRequestConfig_E5() { { var service = createService(mock(Client.class)); var settings = new HashMap(); @@ -179,52 +167,28 @@ public void testParseRequestConfig() { ElasticsearchInternalServiceSettings.NUM_THREADS, 4, ElasticsearchInternalServiceSettings.MODEL_ID, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID, // we can't directly test the eland case until we mock - // the threadpool within the client - "not_a_valid_service_setting", - randomAlphaOfLength(10) + ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID ) ) ); - ActionListener modelListener = ActionListener.wrap( - model -> fail("Model parsing should have failed"), - e -> assertThat(e, instanceOf(ElasticsearchStatusException.class)) - ); - - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, Set.of(), modelListener); - } - - // Extra service settings - { - var service = createService(mock(Client.class)); - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElasticsearchInternalServiceSettings.NUM_THREADS, - 4, - ElasticsearchInternalServiceSettings.MODEL_ID, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID, // we can't directly test the eland case until we mock - // the threadpool within the client - "extra_setting_that_should_not_be_here", - randomAlphaOfLength(10) - ) - ) + var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings( + 1, + 4, + ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID, + null ); - ActionListener modelListener = ActionListener.wrap( - model -> fail("Model parsing should have failed"), - e -> assertThat(e, instanceOf(ElasticsearchStatusException.class)) + service.parseRequestConfig( + randomInferenceEntityId, + TaskType.TEXT_EMBEDDING, + settings, + Set.of(), + getModelVerificationActionListener(e5ServiceSettings) ); - - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, Set.of(), modelListener); } - // Extra settings + // Invalid service settings { var service = createService(mock(Client.class)); var settings = new HashMap(); @@ -237,19 +201,19 @@ public void testParseRequestConfig() { ElasticsearchInternalServiceSettings.NUM_THREADS, 4, ElasticsearchInternalServiceSettings.MODEL_ID, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID // we can't directly test the eland case until we mock - // the threadpool within the client + ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID, + "not_a_valid_service_setting", + randomAlphaOfLength(10) ) ) ); - settings.put("extra_setting_that_should_not_be_here", randomAlphaOfLength(10)); ActionListener modelListener = ActionListener.wrap( model -> fail("Model parsing should have failed"), e -> assertThat(e, instanceOf(ElasticsearchStatusException.class)) ); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, Set.of(), modelListener); + service.parseRequestConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings, Set.of(), modelListener); } } @@ -342,10 +306,53 @@ public void testParseRequestConfig_Rerank_DefaultTaskSettings() { } } + @SuppressWarnings("unchecked") + public void testParseRequestConfig_SparseEmbedding() { + var client = mock(Client.class); + doAnswer(invocation -> { + var listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse( + new GetTrainedModelsAction.Response(new QueryPage<>(List.of(mock(TrainedModelConfig.class)), 1, mock(ParseField.class))) + ); + return null; + }).when(client).execute(Mockito.same(GetTrainedModelsAction.INSTANCE), any(), any()); + + when(client.threadPool()).thenReturn(threadPool); + + var service = createService(client); + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, + 1, + ElasticsearchInternalServiceSettings.NUM_THREADS, + 4, + ElasticsearchInternalServiceSettings.MODEL_ID, + "foo" + ) + ) + ); + + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(CustomElandModel.class)); + assertThat(model.getTaskSettings(), instanceOf(EmptyTaskSettings.class)); + assertThat(model.getServiceSettings(), instanceOf(CustomElandInternalServiceSettings.class)); + }, e -> { fail("Model parsing failed " + e.getMessage()); }); + + service.parseRequestConfig(randomInferenceEntityId, TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelListener); + } + private ActionListener getModelVerificationActionListener(MultilingualE5SmallInternalServiceSettings e5ServiceSettings) { return ActionListener.wrap(model -> { assertEquals( - new MultilingualE5SmallModel(randomInferenceEntityId, taskType, ElasticsearchInternalService.NAME, e5ServiceSettings), + new MultilingualE5SmallModel( + randomInferenceEntityId, + TaskType.TEXT_EMBEDDING, + ElasticsearchInternalService.NAME, + e5ServiceSettings + ), model ); }, e -> { fail("Model parsing failed " + e.getMessage()); }); @@ -371,7 +378,10 @@ public void testParsePersistedConfig() { ) ); - expectThrows(IllegalArgumentException.class, () -> service.parsePersistedConfig(randomInferenceEntityId, taskType, settings)); + expectThrows( + IllegalArgumentException.class, + () -> service.parsePersistedConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings) + ); } @@ -397,12 +407,17 @@ public void testParsePersistedConfig() { CustomElandEmbeddingModel parsedModel = (CustomElandEmbeddingModel) service.parsePersistedConfig( randomInferenceEntityId, - taskType, + TaskType.TEXT_EMBEDDING, settings ); var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(1, 4, "invalid", null); assertEquals( - new CustomElandEmbeddingModel(randomInferenceEntityId, taskType, ElasticsearchInternalService.NAME, elandServiceSettings), + new CustomElandEmbeddingModel( + randomInferenceEntityId, + TaskType.TEXT_EMBEDDING, + ElasticsearchInternalService.NAME, + elandServiceSettings + ), parsedModel ); } @@ -436,11 +451,16 @@ public void testParsePersistedConfig() { MultilingualE5SmallModel parsedModel = (MultilingualE5SmallModel) service.parsePersistedConfig( randomInferenceEntityId, - taskType, + TaskType.TEXT_EMBEDDING, settings ); assertEquals( - new MultilingualE5SmallModel(randomInferenceEntityId, taskType, ElasticsearchInternalService.NAME, e5ServiceSettings), + new MultilingualE5SmallModel( + randomInferenceEntityId, + TaskType.TEXT_EMBEDDING, + ElasticsearchInternalService.NAME, + e5ServiceSettings + ), parsedModel ); } @@ -456,6 +476,8 @@ public void testParsePersistedConfig() { ) ); settings.put("not_a_valid_config_setting", randomAlphaOfLength(10)); + + var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); expectThrows(IllegalArgumentException.class, () -> service.parsePersistedConfig(randomInferenceEntityId, taskType, settings)); } @@ -476,12 +498,13 @@ public void testParsePersistedConfig() { ) ) ); + var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); expectThrows(IllegalArgumentException.class, () -> service.parsePersistedConfig(randomInferenceEntityId, taskType, settings)); } } - + @SuppressWarnings("unchecked") - public void testChunkInfer() { + public void testChunkInfer_e5() { var mlTrainedModelResults = new ArrayList(); mlTrainedModelResults.add(MlChunkedTextEmbeddingFloatResultsTests.createRandomResults()); mlTrainedModelResults.add(MlChunkedTextEmbeddingFloatResultsTests.createRandomResults()); @@ -568,6 +591,63 @@ public void testChunkInfer() { assertTrue("Listener not called", gotResults.get()); } + @SuppressWarnings("unchecked") + public void testChunkInfer_Sparse() { + var mlTrainedModelResults = new ArrayList(); + mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); + mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); + mlTrainedModelResults.add(new ErrorInferenceResults(new RuntimeException("boom"))); + var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); + + ThreadPool threadpool = new TestThreadPool("test"); + Client client = mock(Client.class); + when(client.threadPool()).thenReturn(threadpool); + doAnswer(invocationOnMock -> { + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(response); + return null; + }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); + + var model = new CustomElandModel( + "foo", + TaskType.SPARSE_EMBEDDING, + "elasticsearch", + new ElasticsearchInternalServiceSettings(1, 1, "model-id", null) + ); + var service = createService(client); + + var gotResults = new AtomicBoolean(); + var resultsListener = ActionListener.>wrap(chunkedResponse -> { + assertThat(chunkedResponse, hasSize(3)); + assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); + var result1 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(0); + assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(0)).getChunks(), result1.getChunkedResults()); + assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); + var result2 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(1); + assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(1)).getChunks(), result2.getChunkedResults()); + var result3 = (ErrorChunkedInferenceResults) chunkedResponse.get(2); + assertThat(result3.getException(), instanceOf(RuntimeException.class)); + assertThat(result3.getException().getMessage(), containsString("boom")); + gotResults.set(true); + }, ESTestCase::fail); + + service.chunkedInfer( + model, + null, + List.of("foo", "bar"), + Map.of(), + InputType.SEARCH, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.runAfter(resultsListener, () -> terminate(threadpool)) + ); + + if (gotResults.get() == false) { + terminate(threadpool); + } + assertTrue("Listener not called", gotResults.get()); + } + @SuppressWarnings("unchecked") public void testChunkInferSetsTokenization() { var expectedSpan = new AtomicInteger(); @@ -711,7 +791,7 @@ public void testParseRequestConfigEland_PreservesTaskType() { ) ); - var taskType = randomFrom(EnumSet.of(TaskType.RERANK, TaskType.TEXT_EMBEDDING)); + var taskType = randomFrom(EnumSet.of(TaskType.RERANK, TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)); CustomElandModel expectedModel = getCustomElandModel(taskType); PlainActionFuture listener = new PlainActionFuture<>(); @@ -739,6 +819,13 @@ private CustomElandModel getCustomElandModel(TaskType taskType) { ElasticsearchInternalService.NAME, serviceSettings ); + } else if (taskType == TaskType.SPARSE_EMBEDDING) { + expectedModel = new CustomElandModel( + randomInferenceEntityId, + taskType, + ElasticsearchInternalService.NAME, + new CustomElandInternalServiceSettings(1, 4, "custom-model", null) + ); } return expectedModel; } @@ -867,21 +954,4 @@ private ElasticsearchInternalService createService(Client client) { var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client); return new ElasticsearchInternalService(context); } - - public static Model randomModelConfig(String inferenceEntityId) { - List givenList = Arrays.asList("MultilingualE5SmallModel"); - Random rand = org.elasticsearch.common.Randomness.get(); - String model = givenList.get(rand.nextInt(givenList.size())); - - return switch (model) { - case "MultilingualE5SmallModel" -> new MultilingualE5SmallModel( - inferenceEntityId, - TaskType.TEXT_EMBEDDING, - ElasticsearchInternalService.NAME, - MultilingualE5SmallInternalServiceSettingsTests.createRandom() - ); - default -> throw new IllegalArgumentException("model " + model + " is not supported for testing"); - }; - } - } From 22415fde1376ea06801c5cf27b84c232d8882403 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 28 Aug 2024 09:00:23 +0100 Subject: [PATCH 2/3] Update docs/changelog/112270.yaml --- docs/changelog/112270.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/112270.yaml diff --git a/docs/changelog/112270.yaml b/docs/changelog/112270.yaml new file mode 100644 index 0000000000000..1e6b9c7fc9290 --- /dev/null +++ b/docs/changelog/112270.yaml @@ -0,0 +1,5 @@ +pr: 112270 +summary: Support sparse embedding models in the elasticsearch inference service +area: Machine Learning +type: enhancement +issues: [] From 6525f7aaf61e06996e7956affecada57af9eaed7 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 28 Aug 2024 09:08:51 +0100 Subject: [PATCH 3/3] spotless --- .../org/elasticsearch/xpack/inference/CustomElandModelIT.java | 1 - .../elasticsearch/ElasticsearchInternalServiceTests.java | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java index 53fa28e971774..65b7a138e7e1e 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference; import org.elasticsearch.client.Request; -import org.elasticsearch.client.Response; import org.elasticsearch.core.Strings; import org.elasticsearch.inference.TaskType; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index df5cb5b1698e6..257616033f080 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -502,7 +502,7 @@ public void testParsePersistedConfig() { expectThrows(IllegalArgumentException.class, () -> service.parsePersistedConfig(randomInferenceEntityId, taskType, settings)); } } - + @SuppressWarnings("unchecked") public void testChunkInfer_e5() { var mlTrainedModelResults = new ArrayList();