Skip to content

Commit 1384959

Browse files
dan-rubinsteinelasticsearchmachineelasticmachine
authored
Adding endpoint creation validation to ElasticsearchInternalService (#123044) (#126472)
* Adding validation to ElasticsearchInternalService * Update docs/changelog/123044.yaml * [CI] Auto commit changes from spotless * Removing checkModelConfig * Fixing IT * [CI] Auto commit changes from spotless * Remove DeepSeek checkModelConfig and fix tests * Cleaning up comments, updating validation input type, and moving model deployment starting to model validator --------- Co-authored-by: elasticsearchmachine <[email protected]> Co-authored-by: Elastic Machine <[email protected]>
1 parent 69fa6b3 commit 1384959

File tree

57 files changed

+448
-2700
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+448
-2700
lines changed

docs/changelog/123044.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 123044
2+
summary: Adding validation to `ElasticsearchInternalService`
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -162,24 +162,13 @@ void chunkedInfer(
162162
/**
163163
* Stop the model deployment.
164164
* The default action does nothing except acknowledge the request (true).
165-
* @param unparsedModel The unparsed model configuration
165+
* @param model The model configuration
166166
* @param listener The listener
167167
*/
168-
default void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener) {
168+
default void stop(Model model, ActionListener<Boolean> listener) {
169169
listener.onResponse(true);
170170
}
171171

172-
/**
173-
* Optionally test the new model configuration in the inference service.
174-
* This function should be called when the model is first created, the
175-
* default action is to do nothing.
176-
* @param model The new model
177-
* @param listener The listener
178-
*/
179-
default void checkModelConfig(Model model, ActionListener<Model> listener) {
180-
listener.onResponse(model);
181-
};
182-
183172
/**
184173
* Update a text embedding model's dimensions based on a provided embedding
185174
* size and set the default similarity if required. The default behaviour is to just return the model.

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
6565
@SuppressWarnings("unchecked")
6666
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
6767
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
68-
assertThat(services.size(), equalTo(15));
68+
assertThat(services.size(), equalTo(16));
6969

7070
String[] providers = new String[services.size()];
7171
for (int i = 0; i < services.size(); i++) {
@@ -87,6 +87,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
8787
"jinaai",
8888
"mistral",
8989
"openai",
90+
"test_service",
9091
"text_embedding_test_service",
9192
"voyageai",
9293
"watsonxai"
@@ -159,8 +160,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
159160
@SuppressWarnings("unchecked")
160161
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
161162
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
162-
163-
assertThat(services.size(), equalTo(5));
163+
assertThat(services.size(), equalTo(6));
164164

165165
String[] providers = new String[services.size()];
166166
for (int i = 0; i < services.size(); i++) {
@@ -169,7 +169,14 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
169169
}
170170

171171
assertArrayEquals(
172-
List.of("alibabacloud-ai-search", "elastic", "elasticsearch", "hugging_face", "test_service").toArray(),
172+
List.of(
173+
"alibabacloud-ai-search",
174+
"elastic",
175+
"elasticsearch",
176+
"hugging_face",
177+
"streaming_completion_test_service",
178+
"test_service"
179+
).toArray(),
173180
providers
174181
);
175182
}

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
3636
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
3737
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
38+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
3839
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
3940

4041
import java.io.IOException;
@@ -62,7 +63,7 @@ public TestSparseModel(String inferenceEntityId, TestServiceSettings serviceSett
6263
public static class TestInferenceService extends AbstractTestInferenceService {
6364
public static final String NAME = "test_service";
6465

65-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING);
66+
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING);
6667

6768
public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}
6869

@@ -113,7 +114,8 @@ public void infer(
113114
ActionListener<InferenceServiceResults> listener
114115
) {
115116
switch (model.getConfigurations().getTaskType()) {
116-
case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeResults(input));
117+
case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeSparseEmbeddingResults(input));
118+
case TEXT_EMBEDDING -> listener.onResponse(makeTextEmbeddingResults(input));
117119
default -> listener.onFailure(
118120
new ElasticsearchStatusException(
119121
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
@@ -154,7 +156,7 @@ public void chunkedInfer(
154156
}
155157
}
156158

157-
private SparseEmbeddingResults makeResults(List<String> input) {
159+
private SparseEmbeddingResults makeSparseEmbeddingResults(List<String> input) {
158160
var embeddings = new ArrayList<SparseEmbeddingResults.Embedding>();
159161
for (int i = 0; i < input.size(); i++) {
160162
var tokens = new ArrayList<WeightedToken>();
@@ -166,6 +168,18 @@ private SparseEmbeddingResults makeResults(List<String> input) {
166168
return new SparseEmbeddingResults(embeddings);
167169
}
168170

171+
private TextEmbeddingFloatResults makeTextEmbeddingResults(List<String> input) {
172+
var embeddings = new ArrayList<TextEmbeddingFloatResults.Embedding>();
173+
for (int i = 0; i < input.size(); i++) {
174+
var values = new float[5];
175+
for (int j = 0; j < 5; j++) {
176+
values[j] = random.nextFloat();
177+
}
178+
embeddings.add(new TextEmbeddingFloatResults.Embedding(values));
179+
}
180+
return new TextEmbeddingFloatResults(embeddings);
181+
}
182+
169183
private List<ChunkedInference> makeChunkedResults(List<String> input) {
170184
List<ChunkedInference> results = new ArrayList<>();
171185
for (int i = 0; i < input.size(); i++) {

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@
3535
import org.elasticsearch.xcontent.XContentBuilder;
3636
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
3737
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
38+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
3839

3940
import java.io.IOException;
41+
import java.util.ArrayList;
4042
import java.util.EnumSet;
4143
import java.util.HashMap;
4244
import java.util.Iterator;
@@ -58,7 +60,11 @@ public static class TestInferenceService extends AbstractTestInferenceService {
5860
private static final String NAME = "streaming_completion_test_service";
5961
private static final Set<TaskType> supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
6062

61-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
63+
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
64+
TaskType.COMPLETION,
65+
TaskType.CHAT_COMPLETION,
66+
TaskType.SPARSE_EMBEDDING
67+
);
6268

6369
public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}
6470

@@ -114,7 +120,21 @@ public void infer(
114120
ActionListener<InferenceServiceResults> listener
115121
) {
116122
switch (model.getConfigurations().getTaskType()) {
117-
case COMPLETION -> listener.onResponse(makeResults(input));
123+
case COMPLETION -> listener.onResponse(makeChatCompletionResults(input));
124+
case SPARSE_EMBEDDING -> {
125+
if (stream) {
126+
listener.onFailure(
127+
new ElasticsearchStatusException(
128+
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
129+
RestStatus.BAD_REQUEST
130+
)
131+
);
132+
} else {
133+
// Return text embedding results when creating a sparse_embedding inference endpoint to allow creation validation to
134+
// pass. This is required to test that streaming fails for a sparse_embedding endpoint.
135+
listener.onResponse(makeTextEmbeddingResults(input));
136+
}
137+
}
118138
default -> listener.onFailure(
119139
new ElasticsearchStatusException(
120140
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
@@ -142,7 +162,7 @@ public void unifiedCompletionInfer(
142162
}
143163
}
144164

145-
private StreamingChatCompletionResults makeResults(List<String> input) {
165+
private StreamingChatCompletionResults makeChatCompletionResults(List<String> input) {
146166
var responseIter = input.stream().map(s -> s.toUpperCase(Locale.ROOT)).iterator();
147167
return new StreamingChatCompletionResults(subscriber -> {
148168
subscriber.onSubscribe(new Flow.Subscription() {
@@ -161,6 +181,18 @@ public void cancel() {}
161181
});
162182
}
163183

184+
private TextEmbeddingFloatResults makeTextEmbeddingResults(List<String> input) {
185+
var embeddings = new ArrayList<TextEmbeddingFloatResults.Embedding>();
186+
for (int i = 0; i < input.size(); i++) {
187+
var values = new float[5];
188+
for (int j = 0; j < 5; j++) {
189+
values[j] = random.nextFloat();
190+
}
191+
embeddings.add(new TextEmbeddingFloatResults.Embedding(values));
192+
}
193+
return new TextEmbeddingFloatResults(embeddings);
194+
}
195+
164196
private InferenceServiceResults.Result completionChunk(String delta) {
165197
return new InferenceServiceResults.Result() {
166198
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ private void doExecuteForked(
125125

126126
var service = serviceRegistry.getService(unparsedModel.service());
127127
if (service.isPresent()) {
128-
service.get().stop(unparsedModel, listener);
128+
var model = service.get()
129+
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
130+
service.get().stop(model, listener);
129131
} else {
130132
listener.onFailure(
131133
new ElasticsearchStatusException(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
4747
import org.elasticsearch.xpack.inference.services.ServiceUtils;
4848
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
49+
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
4950

5051
import java.io.IOException;
5152
import java.util.List;
@@ -194,19 +195,23 @@ private void parseAndStoreModel(
194195
ActionListener<Model> storeModelListener = listener.delegateFailureAndWrap(
195196
(delegate, verifiedModel) -> modelRegistry.storeModel(
196197
verifiedModel,
197-
ActionListener.wrap(r -> startInferenceEndpoint(service, timeout, verifiedModel, delegate), e -> {
198-
if (e.getCause() instanceof StrictDynamicMappingException && e.getCause().getMessage().contains("chunking_settings")) {
199-
delegate.onFailure(
200-
new ElasticsearchStatusException(
201-
"One or more nodes in your cluster does not support chunking_settings. "
202-
+ "Please update all nodes in your cluster to the latest version to use chunking_settings.",
203-
RestStatus.BAD_REQUEST
204-
)
205-
);
206-
} else {
207-
delegate.onFailure(e);
198+
ActionListener.wrap(
199+
r -> listener.onResponse(new PutInferenceModelAction.Response(verifiedModel.getConfigurations())),
200+
e -> {
201+
if (e.getCause() instanceof StrictDynamicMappingException
202+
&& e.getCause().getMessage().contains("chunking_settings")) {
203+
delegate.onFailure(
204+
new ElasticsearchStatusException(
205+
"One or more nodes in your cluster does not support chunking_settings. "
206+
+ "Please update all nodes in your cluster to the latest version to use chunking_settings.",
207+
RestStatus.BAD_REQUEST
208+
)
209+
);
210+
} else {
211+
delegate.onFailure(e);
212+
}
208213
}
209-
}),
214+
),
210215
timeout
211216
)
212217
);
@@ -215,26 +220,14 @@ private void parseAndStoreModel(
215220
if (skipValidationAndStart) {
216221
storeModelListener.onResponse(model);
217222
} else {
218-
service.checkModelConfig(model, storeModelListener);
223+
ModelValidatorBuilder.buildModelValidator(model.getTaskType(), service instanceof ElasticsearchInternalService)
224+
.validate(service, model, timeout, storeModelListener);
219225
}
220226
});
221227

222228
service.parseRequestConfig(inferenceEntityId, taskType, config, parsedModelListener);
223229
}
224230

225-
private void startInferenceEndpoint(
226-
InferenceService service,
227-
TimeValue timeout,
228-
Model model,
229-
ActionListener<PutInferenceModelAction.Response> listener
230-
) {
231-
if (skipValidationAndStart) {
232-
listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()));
233-
} else {
234-
service.start(model, timeout, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations())));
235-
}
236-
}
237-
238231
private Map<String, Object> requestToMap(PutInferenceModelAction.Request request) throws IOException {
239232
try (
240233
XContentParser parser = XContentHelper.createParser(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,17 @@
88
package org.elasticsearch.xpack.inference.services;
99

1010
import org.elasticsearch.ElasticsearchStatusException;
11-
import org.elasticsearch.action.ActionListener;
1211
import org.elasticsearch.action.ActionRequestValidationException;
1312
import org.elasticsearch.common.ValidationException;
1413
import org.elasticsearch.common.settings.SecureString;
1514
import org.elasticsearch.core.Nullable;
1615
import org.elasticsearch.core.Strings;
1716
import org.elasticsearch.core.TimeValue;
18-
import org.elasticsearch.inference.InferenceService;
1917
import org.elasticsearch.inference.InputType;
2018
import org.elasticsearch.inference.Model;
2119
import org.elasticsearch.inference.SimilarityMeasure;
2220
import org.elasticsearch.inference.TaskType;
2321
import org.elasticsearch.rest.RestStatus;
24-
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
25-
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
26-
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
2722
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
2823
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
2924

@@ -723,53 +718,6 @@ public static ElasticsearchStatusException createInvalidModelException(Model mod
723718
);
724719
}
725720

726-
/**
727-
* Evaluate the model and return the text embedding size
728-
* @param model Should be a text embedding model
729-
* @param service The inference service
730-
* @param listener Size listener
731-
*/
732-
public static void getEmbeddingSize(Model model, InferenceService service, ActionListener<Integer> listener) {
733-
assert model.getTaskType() == TaskType.TEXT_EMBEDDING;
734-
735-
service.infer(
736-
model,
737-
null,
738-
null,
739-
null,
740-
List.of(TEST_EMBEDDING_INPUT),
741-
false,
742-
Map.of(),
743-
InputType.INTERNAL_INGEST,
744-
InferenceAction.Request.DEFAULT_TIMEOUT,
745-
listener.delegateFailureAndWrap((delegate, r) -> {
746-
if (r instanceof TextEmbeddingResults<?> embeddingResults) {
747-
try {
748-
delegate.onResponse(embeddingResults.getFirstEmbeddingSize());
749-
} catch (Exception e) {
750-
delegate.onFailure(
751-
new ElasticsearchStatusException("Could not determine embedding size", RestStatus.BAD_REQUEST, e)
752-
);
753-
}
754-
} else {
755-
delegate.onFailure(
756-
new ElasticsearchStatusException(
757-
"Could not determine embedding size. "
758-
+ "Expected a result of type ["
759-
+ TextEmbeddingFloatResults.NAME
760-
+ "] got ["
761-
+ r.getWriteableName()
762-
+ "]",
763-
RestStatus.BAD_REQUEST
764-
)
765-
);
766-
}
767-
})
768-
);
769-
}
770-
771-
private static final String TEST_EMBEDDING_INPUT = "how big";
772-
773721
public static SecureString apiKey(@Nullable ApiKeySecrets secrets) {
774722
// To avoid a possible null pointer throughout the code we'll create a noop api key of an empty array
775723
return secrets == null ? new SecureString(new char[0]) : secrets.apiKey();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
4949
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
5050
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
51-
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
5251

5352
import java.util.EnumSet;
5453
import java.util.HashMap;
@@ -348,19 +347,6 @@ protected void doChunkedInfer(
348347
}
349348
}
350349

351-
/**
352-
* For text embedding models get the embedding size and
353-
* update the service settings.
354-
*
355-
* @param model The new model
356-
* @param listener The listener
357-
*/
358-
@Override
359-
public void checkModelConfig(Model model, ActionListener<Model> listener) {
360-
// TODO: Remove this function once all services have been updated to use the new model validators
361-
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
362-
}
363-
364350
@Override
365351
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
366352
if (model instanceof AlibabaCloudSearchEmbeddingsModel embeddingsModel) {

0 commit comments

Comments
 (0)