Skip to content

Commit 24619f9

Browse files
Adding ElasticsearchInternalServiceModelValidator to stop model deployment on failed validation
1 parent f80f054 commit 24619f9

File tree

7 files changed

+46
-9
lines changed

7 files changed

+46
-9
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,10 @@ void chunkedInfer(
138138
/**
139139
* Stop the model deployment.
140140
* The default action does nothing except acknowledge the request (true).
141-
* @param unparsedModel The unparsed model configuration
141+
* @param model The model configuration
142142
* @param listener The listener
143143
*/
144-
default void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener) {
144+
default void stop(Model model, ActionListener<Boolean> listener) {
145145
listener.onResponse(true);
146146
}
147147

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ private void doExecuteForked(
115115

116116
var service = serviceRegistry.getService(unparsedModel.service());
117117
if (service.isPresent()) {
118-
service.get().stop(unparsedModel, listener);
118+
var model = service.get()
119+
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
120+
service.get().stop(model, listener);
119121
} else {
120122
listener.onFailure(
121123
new ElasticsearchStatusException(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import org.elasticsearch.inference.InputType;
2323
import org.elasticsearch.inference.Model;
2424
import org.elasticsearch.inference.TaskType;
25-
import org.elasticsearch.inference.UnparsedModel;
2625
import org.elasticsearch.xpack.core.ClientHelper;
2726
import org.elasticsearch.xpack.core.ml.MachineLearningField;
2827
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
@@ -119,9 +118,7 @@ public void start(Model model, ActionListener<Boolean> finalListener) {
119118
}
120119

121120
@Override
122-
public void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener) {
123-
124-
var model = parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
121+
public void stop(Model model, ActionListener<Boolean> listener) {
125122
if (model instanceof ElasticsearchInternalModel esModel) {
126123

127124
var serviceSettings = esModel.getServiceSettings();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
500500

501501
@Override
502502
public void checkModelConfig(Model model, ActionListener<Model> listener) {
503-
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
503+
ModelValidatorBuilder.buildModelValidator(model.getTaskType(), true).validate(this, model, listener);
504504
}
505505

506506
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.validation;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.inference.InferenceService;
12+
import org.elasticsearch.inference.Model;
13+
14+
public class ElasticsearchInternalServiceModelValidator implements ModelValidator {
15+
16+
ModelValidator modelValidator;
17+
18+
public ElasticsearchInternalServiceModelValidator(ModelValidator modelValidator) {
19+
this.modelValidator = modelValidator;
20+
}
21+
22+
@Override
23+
public void validate(InferenceService service, Model model, ActionListener<Model> listener) {
24+
modelValidator.validate(service, model, listener.delegateResponse((l, exception) -> {
25+
// TODO: Cleanup the below code
26+
service.stop(model, ActionListener.wrap((v) -> listener.onFailure(exception), (e) -> listener.onFailure(exception)));
27+
}));
28+
}
29+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111
import org.elasticsearch.inference.TaskType;
1212

1313
public class ModelValidatorBuilder {
14+
public static ModelValidator buildModelValidator(TaskType taskType, boolean isElasticsearchInternalService) {
15+
var modelValidator = buildModelValidator(taskType);
16+
if (isElasticsearchInternalService) {
17+
return new ElasticsearchInternalServiceModelValidator(modelValidator);
18+
} else {
19+
return modelValidator;
20+
}
21+
}
22+
1423
public static ModelValidator buildModelValidator(TaskType taskType) {
1524
if (taskType == null) {
1625
throw new IllegalArgumentException("Task type can't be null");

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1463,7 +1463,7 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() {
14631463
);
14641464

14651465
var request = (InferModelAction.Request) invocationOnMock.getArguments()[1];
1466-
assertThat(request.getId(), is("custom-model"));
1466+
assertThat(request.getId(), is(randomInferenceEntityId));
14671467
return Void.TYPE;
14681468
}).when(client).execute(eq(InferModelAction.INSTANCE), any(), any());
14691469
when(client.threadPool()).thenReturn(threadPool);

0 commit comments

Comments
 (0)