Skip to content

Commit 9ec646e

Browse files
Adding endpoint creation validation to ElasticInferenceService (elastic#117642) (elastic#122956)
* Adding endpoint creation validation to ElasticInferenceService * Fix unit tests * Update docs/changelog/117642.yaml --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent 8172440 commit 9ec646e

File tree

3 files changed

+22
-17
lines changed

3 files changed

+22
-17
lines changed

docs/changelog/117642.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 117642
2+
summary: Adding endpoint creation validation to `ElasticInferenceService`
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
5555
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
5656
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
57+
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
5758
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
5859

5960
import java.util.ArrayList;
@@ -574,11 +575,8 @@ private ElasticInferenceServiceModel createModelFromPersistent(
574575

575576
@Override
576577
public void checkModelConfig(Model model, ActionListener<Model> listener) {
577-
if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel embeddingsModel) {
578-
listener.onResponse(updateModelWithEmbeddingDetails(embeddingsModel));
579-
} else {
580-
listener.onResponse(model);
581-
}
578+
// TODO: Remove this function once all services have been updated to use the new model validators
579+
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
582580
}
583581

584582
private static List<ChunkedInference> translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) {
@@ -593,18 +591,6 @@ private static List<ChunkedInference> translateToChunkedResults(InferenceInputs
593591
}
594592
}
595593

596-
private ElasticInferenceServiceSparseEmbeddingsModel updateModelWithEmbeddingDetails(
597-
ElasticInferenceServiceSparseEmbeddingsModel model
598-
) {
599-
ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings = new ElasticInferenceServiceSparseEmbeddingsServiceSettings(
600-
model.getServiceSettings().modelId(),
601-
model.getServiceSettings().maxInputTokens(),
602-
model.getServiceSettings().rateLimitSettings()
603-
);
604-
605-
return new ElasticInferenceServiceSparseEmbeddingsModel(model, serviceSettings);
606-
}
607-
608594
public static class Configuration {
609595

610596
private final EnumSet<TaskType> enabledTaskTypes;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,21 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists
317317

318318
public void testCheckModelConfig_ReturnsNewModelReference() throws IOException {
319319
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
320+
320321
try (var service = createService(senderFactory, getUrl(webServer))) {
322+
String responseJson = """
323+
{
324+
"data": [
325+
{
326+
"hello": 2.1259406,
327+
"greet": 1.7073475
328+
}
329+
]
330+
}
331+
""";
332+
333+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
334+
321335
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id");
322336
PlainActionFuture<Model> listener = new PlainActionFuture<>();
323337
service.checkModelConfig(model, listener);

0 commit comments

Comments
 (0)