Skip to content

Commit 20f6a2a

Browse files
dan-rubinsteinelasticsearchmachineelasticmachine
authored
Adding endpoint creation validation to ElasticsearchInternalService (#123044)
* Adding validation to ElasticsearchInternalService * Update docs/changelog/123044.yaml * [CI] Auto commit changes from spotless * Removing checkModelConfig * Fixing IT * [CI] Auto commit changes from spotless * Remove DeepSeek checkModelConfig and fix tests * Cleaning up comments, updating validation input type, and moving model deployment starting to model validator --------- Co-authored-by: elasticsearchmachine <[email protected]> Co-authored-by: Elastic Machine <[email protected]>
1 parent 060a9b7 commit 20f6a2a

File tree

57 files changed

+448
-2698
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+448
-2698
lines changed

docs/changelog/123044.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 123044
2+
summary: Adding validation to `ElasticsearchInternalService`
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -162,24 +162,13 @@ void chunkedInfer(
162162
/**
163163
* Stop the model deployment.
164164
* The default action does nothing except acknowledge the request (true).
165-
* @param unparsedModel The unparsed model configuration
165+
* @param model The model configuration
166166
* @param listener The listener
167167
*/
168-
default void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener) {
168+
default void stop(Model model, ActionListener<Boolean> listener) {
169169
listener.onResponse(true);
170170
}
171171

172-
/**
173-
* Optionally test the new model configuration in the inference service.
174-
* This function should be called when the model is first created, the
175-
* default action is to do nothing.
176-
* @param model The new model
177-
* @param listener The listener
178-
*/
179-
default void checkModelConfig(Model model, ActionListener<Model> listener) {
180-
listener.onResponse(model);
181-
};
182-
183172
/**
184173
* Update a text embedding model's dimensions based on a provided embedding
185174
* size and set the default similarity if required. The default behaviour is to just return the model.

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
6464
@SuppressWarnings("unchecked")
6565
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
6666
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
67-
assertThat(services.size(), equalTo(15));
67+
assertThat(services.size(), equalTo(16));
6868

6969
String[] providers = new String[services.size()];
7070
for (int i = 0; i < services.size(); i++) {
@@ -86,6 +86,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
8686
"jinaai",
8787
"mistral",
8888
"openai",
89+
"test_service",
8990
"text_embedding_test_service",
9091
"voyageai",
9192
"watsonxai"
@@ -157,7 +158,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
157158
@SuppressWarnings("unchecked")
158159
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
159160
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
160-
assertThat(services.size(), equalTo(5));
161+
assertThat(services.size(), equalTo(6));
161162

162163
String[] providers = new String[services.size()];
163164
for (int i = 0; i < services.size(); i++) {
@@ -166,7 +167,14 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
166167
}
167168

168169
assertArrayEquals(
169-
List.of("alibabacloud-ai-search", "elastic", "elasticsearch", "hugging_face", "test_service").toArray(),
170+
List.of(
171+
"alibabacloud-ai-search",
172+
"elastic",
173+
"elasticsearch",
174+
"hugging_face",
175+
"streaming_completion_test_service",
176+
"test_service"
177+
).toArray(),
170178
providers
171179
);
172180
}

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.elasticsearch.xcontent.XContentBuilder;
3636
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
3737
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
38+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
3839
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
3940

4041
import java.io.IOException;
@@ -63,7 +64,7 @@ public TestSparseModel(String inferenceEntityId, TestServiceSettings serviceSett
6364
public static class TestInferenceService extends AbstractTestInferenceService {
6465
public static final String NAME = "test_service";
6566

66-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING);
67+
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING);
6768

6869
public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}
6970

@@ -114,7 +115,8 @@ public void infer(
114115
ActionListener<InferenceServiceResults> listener
115116
) {
116117
switch (model.getConfigurations().getTaskType()) {
117-
case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeResults(input));
118+
case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeSparseEmbeddingResults(input));
119+
case TEXT_EMBEDDING -> listener.onResponse(makeTextEmbeddingResults(input));
118120
default -> listener.onFailure(
119121
new ElasticsearchStatusException(
120122
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
@@ -155,7 +157,7 @@ public void chunkedInfer(
155157
}
156158
}
157159

158-
private SparseEmbeddingResults makeResults(List<String> input) {
160+
private SparseEmbeddingResults makeSparseEmbeddingResults(List<String> input) {
159161
var embeddings = new ArrayList<SparseEmbeddingResults.Embedding>();
160162
for (int i = 0; i < input.size(); i++) {
161163
var tokens = new ArrayList<WeightedToken>();
@@ -167,6 +169,18 @@ private SparseEmbeddingResults makeResults(List<String> input) {
167169
return new SparseEmbeddingResults(embeddings);
168170
}
169171

172+
private TextEmbeddingFloatResults makeTextEmbeddingResults(List<String> input) {
173+
var embeddings = new ArrayList<TextEmbeddingFloatResults.Embedding>();
174+
for (int i = 0; i < input.size(); i++) {
175+
var values = new float[5];
176+
for (int j = 0; j < 5; j++) {
177+
values[j] = random.nextFloat();
178+
}
179+
embeddings.add(new TextEmbeddingFloatResults.Embedding(values));
180+
}
181+
return new TextEmbeddingFloatResults(embeddings);
182+
}
183+
170184
private List<ChunkedInference> makeChunkedResults(List<ChunkInferenceInput> inputs) {
171185
List<ChunkedInference> results = new ArrayList<>();
172186
for (ChunkInferenceInput chunkInferenceInput : inputs) {

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@
3636
import org.elasticsearch.xcontent.XContentBuilder;
3737
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
3838
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
39+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
3940

4041
import java.io.IOException;
42+
import java.util.ArrayList;
4143
import java.util.EnumSet;
4244
import java.util.HashMap;
4345
import java.util.Iterator;
@@ -59,7 +61,11 @@ public static class TestInferenceService extends AbstractTestInferenceService {
5961
private static final String NAME = "streaming_completion_test_service";
6062
private static final Set<TaskType> supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
6163

62-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
64+
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
65+
TaskType.COMPLETION,
66+
TaskType.CHAT_COMPLETION,
67+
TaskType.SPARSE_EMBEDDING
68+
);
6369

6470
public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}
6571

@@ -115,7 +121,21 @@ public void infer(
115121
ActionListener<InferenceServiceResults> listener
116122
) {
117123
switch (model.getConfigurations().getTaskType()) {
118-
case COMPLETION -> listener.onResponse(makeResults(input));
124+
case COMPLETION -> listener.onResponse(makeChatCompletionResults(input));
125+
case SPARSE_EMBEDDING -> {
126+
if (stream) {
127+
listener.onFailure(
128+
new ElasticsearchStatusException(
129+
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
130+
RestStatus.BAD_REQUEST
131+
)
132+
);
133+
} else {
134+
// Return text embedding results when creating a sparse_embedding inference endpoint to allow creation validation to
135+
// pass. This is required to test that streaming fails for a sparse_embedding endpoint.
136+
listener.onResponse(makeTextEmbeddingResults(input));
137+
}
138+
}
119139
default -> listener.onFailure(
120140
new ElasticsearchStatusException(
121141
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
@@ -143,7 +163,7 @@ public void unifiedCompletionInfer(
143163
}
144164
}
145165

146-
private StreamingChatCompletionResults makeResults(List<String> input) {
166+
private StreamingChatCompletionResults makeChatCompletionResults(List<String> input) {
147167
var responseIter = input.stream().map(s -> s.toUpperCase(Locale.ROOT)).iterator();
148168
return new StreamingChatCompletionResults(subscriber -> {
149169
subscriber.onSubscribe(new Flow.Subscription() {
@@ -162,6 +182,18 @@ public void cancel() {}
162182
});
163183
}
164184

185+
private TextEmbeddingFloatResults makeTextEmbeddingResults(List<String> input) {
186+
var embeddings = new ArrayList<TextEmbeddingFloatResults.Embedding>();
187+
for (int i = 0; i < input.size(); i++) {
188+
var values = new float[5];
189+
for (int j = 0; j < 5; j++) {
190+
values[j] = random.nextFloat();
191+
}
192+
embeddings.add(new TextEmbeddingFloatResults.Embedding(values));
193+
}
194+
return new TextEmbeddingFloatResults(embeddings);
195+
}
196+
165197
private InferenceServiceResults.Result completionChunk(String delta) {
166198
return new InferenceServiceResults.Result() {
167199
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ private void doExecuteForked(
125125

126126
var service = serviceRegistry.getService(unparsedModel.service());
127127
if (service.isPresent()) {
128-
service.get().stop(unparsedModel, listener);
128+
var model = service.get()
129+
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
130+
service.get().stop(model, listener);
129131
} else {
130132
listener.onFailure(
131133
new ElasticsearchStatusException(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
4646
import org.elasticsearch.xpack.inference.services.ServiceUtils;
4747
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
48+
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
4849

4950
import java.io.IOException;
5051
import java.util.List;
@@ -190,19 +191,23 @@ private void parseAndStoreModel(
190191
ActionListener<Model> storeModelListener = listener.delegateFailureAndWrap(
191192
(delegate, verifiedModel) -> modelRegistry.storeModel(
192193
verifiedModel,
193-
ActionListener.wrap(r -> startInferenceEndpoint(service, timeout, verifiedModel, delegate), e -> {
194-
if (e.getCause() instanceof StrictDynamicMappingException && e.getCause().getMessage().contains("chunking_settings")) {
195-
delegate.onFailure(
196-
new ElasticsearchStatusException(
197-
"One or more nodes in your cluster does not support chunking_settings. "
198-
+ "Please update all nodes in your cluster to the latest version to use chunking_settings.",
199-
RestStatus.BAD_REQUEST
200-
)
201-
);
202-
} else {
203-
delegate.onFailure(e);
194+
ActionListener.wrap(
195+
r -> listener.onResponse(new PutInferenceModelAction.Response(verifiedModel.getConfigurations())),
196+
e -> {
197+
if (e.getCause() instanceof StrictDynamicMappingException
198+
&& e.getCause().getMessage().contains("chunking_settings")) {
199+
delegate.onFailure(
200+
new ElasticsearchStatusException(
201+
"One or more nodes in your cluster does not support chunking_settings. "
202+
+ "Please update all nodes in your cluster to the latest version to use chunking_settings.",
203+
RestStatus.BAD_REQUEST
204+
)
205+
);
206+
} else {
207+
delegate.onFailure(e);
208+
}
204209
}
205-
}),
210+
),
206211
timeout
207212
)
208213
);
@@ -211,26 +216,14 @@ private void parseAndStoreModel(
211216
if (skipValidationAndStart) {
212217
storeModelListener.onResponse(model);
213218
} else {
214-
service.checkModelConfig(model, storeModelListener);
219+
ModelValidatorBuilder.buildModelValidator(model.getTaskType(), service instanceof ElasticsearchInternalService)
220+
.validate(service, model, timeout, storeModelListener);
215221
}
216222
});
217223

218224
service.parseRequestConfig(inferenceEntityId, taskType, config, parsedModelListener);
219225
}
220226

221-
private void startInferenceEndpoint(
222-
InferenceService service,
223-
TimeValue timeout,
224-
Model model,
225-
ActionListener<PutInferenceModelAction.Response> listener
226-
) {
227-
if (skipValidationAndStart) {
228-
listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()));
229-
} else {
230-
service.start(model, timeout, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations())));
231-
}
232-
}
233-
234227
private Map<String, Object> requestToMap(PutInferenceModelAction.Request request) throws IOException {
235228
try (
236229
XContentParser parser = XContentHelper.createParser(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,17 @@
88
package org.elasticsearch.xpack.inference.services;
99

1010
import org.elasticsearch.ElasticsearchStatusException;
11-
import org.elasticsearch.action.ActionListener;
1211
import org.elasticsearch.action.ActionRequestValidationException;
1312
import org.elasticsearch.common.ValidationException;
1413
import org.elasticsearch.common.settings.SecureString;
1514
import org.elasticsearch.core.Nullable;
1615
import org.elasticsearch.core.Strings;
1716
import org.elasticsearch.core.TimeValue;
18-
import org.elasticsearch.inference.InferenceService;
1917
import org.elasticsearch.inference.InputType;
2018
import org.elasticsearch.inference.Model;
2119
import org.elasticsearch.inference.SimilarityMeasure;
2220
import org.elasticsearch.inference.TaskType;
2321
import org.elasticsearch.rest.RestStatus;
24-
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
25-
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
26-
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
2722
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
2823
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
2924

@@ -723,53 +718,6 @@ public static ElasticsearchStatusException createInvalidModelException(Model mod
723718
);
724719
}
725720

726-
/**
727-
* Evaluate the model and return the text embedding size
728-
* @param model Should be a text embedding model
729-
* @param service The inference service
730-
* @param listener Size listener
731-
*/
732-
public static void getEmbeddingSize(Model model, InferenceService service, ActionListener<Integer> listener) {
733-
assert model.getTaskType() == TaskType.TEXT_EMBEDDING;
734-
735-
service.infer(
736-
model,
737-
null,
738-
null,
739-
null,
740-
List.of(TEST_EMBEDDING_INPUT),
741-
false,
742-
Map.of(),
743-
InputType.INTERNAL_INGEST,
744-
InferenceAction.Request.DEFAULT_TIMEOUT,
745-
listener.delegateFailureAndWrap((delegate, r) -> {
746-
if (r instanceof TextEmbeddingResults<?> embeddingResults) {
747-
try {
748-
delegate.onResponse(embeddingResults.getFirstEmbeddingSize());
749-
} catch (Exception e) {
750-
delegate.onFailure(
751-
new ElasticsearchStatusException("Could not determine embedding size", RestStatus.BAD_REQUEST, e)
752-
);
753-
}
754-
} else {
755-
delegate.onFailure(
756-
new ElasticsearchStatusException(
757-
"Could not determine embedding size. "
758-
+ "Expected a result of type ["
759-
+ TextEmbeddingFloatResults.NAME
760-
+ "] got ["
761-
+ r.getWriteableName()
762-
+ "]",
763-
RestStatus.BAD_REQUEST
764-
)
765-
);
766-
}
767-
})
768-
);
769-
}
770-
771-
private static final String TEST_EMBEDDING_INPUT = "how big";
772-
773721
public static SecureString apiKey(@Nullable ApiKeySecrets secrets) {
774722
// To avoid a possible null pointer throughout the code we'll create a noop api key of an empty array
775723
return secrets == null ? new SecureString(new char[0]) : secrets.apiKey();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
4949
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
5050
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
51-
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
5251

5352
import java.util.EnumSet;
5453
import java.util.HashMap;
@@ -348,19 +347,6 @@ protected void doChunkedInfer(
348347
}
349348
}
350349

351-
/**
352-
* For text embedding models get the embedding size and
353-
* update the service settings.
354-
*
355-
* @param model The new model
356-
* @param listener The listener
357-
*/
358-
@Override
359-
public void checkModelConfig(Model model, ActionListener<Model> listener) {
360-
// TODO: Remove this function once all services have been updated to use the new model validators
361-
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
362-
}
363-
364350
@Override
365351
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
366352
if (model instanceof AlibabaCloudSearchEmbeddingsModel embeddingsModel) {

0 commit comments

Comments
 (0)