From 01cdfbee11d19de4560fd8f7c6b74dea4a3271ee Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Wed, 27 Nov 2024 10:27:37 -0500 Subject: [PATCH 1/3] Adding endpoint creation validation to ElasticInferenceService --- .../elastic/ElasticInferenceService.java | 20 +++---------------- .../elastic/ElasticInferenceServiceTests.java | 13 ++++++++++++ 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 1f08c06edaa91..6b8e5d02656ab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -42,6 +42,7 @@ import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.util.EnumSet; @@ -254,11 +255,8 @@ private ElasticInferenceServiceModel createModelFromPersistent( @Override public void checkModelConfig(Model model, ActionListener listener) { - if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel embeddingsModel) { - listener.onResponse(updateModelWithEmbeddingDetails(embeddingsModel)); - } else { - listener.onResponse(model); - } + // TODO: Remove this function once all services have been updated to use the new model validators + ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); } private static List translateToChunkedResults( @@ -275,18 +273,6 @@ private static List translateToChunkedResults( } } - private ElasticInferenceServiceSparseEmbeddingsModel updateModelWithEmbeddingDetails( - ElasticInferenceServiceSparseEmbeddingsModel model - ) { - ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings = new ElasticInferenceServiceSparseEmbeddingsServiceSettings( - model.getServiceSettings().modelId(), - model.getServiceSettings().maxInputTokens(), - model.getServiceSettings().rateLimitSettings() - ); - - return new ElasticInferenceServiceSparseEmbeddingsModel(model, serviceSettings); - } - private TraceContext getCurrentTraceInfo() { var threadPool = getServiceComponents().threadPool(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index d3101099d06c7..834081dc26fd3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -320,6 +320,19 @@ public void testCheckModelConfig_ReturnsNewModelReference() throws IOException { new ElasticInferenceServiceComponents(getUrl(webServer)) ) ) { + String responseJson = """ + { + "data": [ + { + "hello": 2.1259406, + "greet": 1.7073475 + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer)); PlainActionFuture listener = new PlainActionFuture<>(); service.checkModelConfig(model, listener); From b65f06e65e5210fa35a886828704b066298c6430 Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Tue, 18 Feb 2025 11:48:20 -0500 Subject: [PATCH 2/3] Fix unit tests --- .../elastic/ElasticInferenceServiceTests.java | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 2d9958b38a448..5d98a90ec2bf1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -319,6 +319,19 @@ public void testCheckModelConfig_ReturnsNewModelReference() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createService(senderFactory, getUrl(webServer))) { + String responseJson = """ + { + "data": [ + { + "hello": 2.1259406, + "greet": 1.7073475 + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id"); PlainActionFuture listener = new PlainActionFuture<>(); service.checkModelConfig(model, listener); From b53dfb0d613b71a47ffb285e2996d954eb4ea15b Mon Sep 17 00:00:00 2001 From: Dan Rubinstein Date: Tue, 18 Feb 2025 13:55:55 -0500 Subject: [PATCH 3/3] Update docs/changelog/117642.yaml --- docs/changelog/117642.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/117642.yaml diff --git a/docs/changelog/117642.yaml b/docs/changelog/117642.yaml new file mode 100644 index 0000000000000..dbddbbf5e64eb --- /dev/null +++ b/docs/changelog/117642.yaml @@ -0,0 +1,5 @@ +pr: 117642 +summary: Adding endpoint creation validation to `ElasticInferenceService` +area: Machine Learning +type: enhancement +issues: []