Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/123044.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 123044
summary: Adding validation to `ElasticsearchInternalService`
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -162,24 +162,13 @@ void chunkedInfer(
/**
* Stop the model deployment.
* The default action does nothing except acknowledge the request (true).
* @param unparsedModel The unparsed model configuration
* @param model The model configuration
* @param listener The listener
*/
default void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener) {
default void stop(Model model, ActionListener<Boolean> listener) {
listener.onResponse(true);
}

/**
* Optionally test the new model configuration in the inference service.
* This function should be called when the model is first created, the
* default action is to do nothing.
* @param model The new model
* @param listener The listener
*/
default void checkModelConfig(Model model, ActionListener<Model> listener) {
listener.onResponse(model);
};

/**
* Update a text embedding model's dimensions based on a provided embedding
* size and set the default similarity if required. The default behaviour is to just return the model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(15));
assertThat(services.size(), equalTo(16));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -87,6 +87,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
"jinaai",
"mistral",
"openai",
"test_service",
"text_embedding_test_service",
"voyageai",
"watsonxai"
Expand Down Expand Up @@ -159,8 +160,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);

assertThat(services.size(), equalTo(5));
assertThat(services.size(), equalTo(6));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -169,7 +169,14 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
}

assertArrayEquals(
List.of("alibabacloud-ai-search", "elastic", "elasticsearch", "hugging_face", "test_service").toArray(),
List.of(
"alibabacloud-ai-search",
"elastic",
"elasticsearch",
"hugging_face",
"streaming_completion_test_service",
"test_service"
).toArray(),
providers
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;

import java.io.IOException;
Expand Down Expand Up @@ -62,7 +63,7 @@ public TestSparseModel(String inferenceEntityId, TestServiceSettings serviceSett
public static class TestInferenceService extends AbstractTestInferenceService {
public static final String NAME = "test_service";

private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING);
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING);

public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}

Expand Down Expand Up @@ -113,7 +114,8 @@ public void infer(
ActionListener<InferenceServiceResults> listener
) {
switch (model.getConfigurations().getTaskType()) {
case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeResults(input));
case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeSparseEmbeddingResults(input));
case TEXT_EMBEDDING -> listener.onResponse(makeTextEmbeddingResults(input));
default -> listener.onFailure(
new ElasticsearchStatusException(
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
Expand Down Expand Up @@ -154,7 +156,7 @@ public void chunkedInfer(
}
}

private SparseEmbeddingResults makeResults(List<String> input) {
private SparseEmbeddingResults makeSparseEmbeddingResults(List<String> input) {
var embeddings = new ArrayList<SparseEmbeddingResults.Embedding>();
for (int i = 0; i < input.size(); i++) {
var tokens = new ArrayList<WeightedToken>();
Expand All @@ -166,6 +168,18 @@ private SparseEmbeddingResults makeResults(List<String> input) {
return new SparseEmbeddingResults(embeddings);
}

private TextEmbeddingFloatResults makeTextEmbeddingResults(List<String> input) {
var embeddings = new ArrayList<TextEmbeddingFloatResults.Embedding>();
for (int i = 0; i < input.size(); i++) {
var values = new float[5];
for (int j = 0; j < 5; j++) {
values[j] = random.nextFloat();
}
embeddings.add(new TextEmbeddingFloatResults.Embedding(values));
}
return new TextEmbeddingFloatResults(embeddings);
}

private List<ChunkedInference> makeChunkedResults(List<String> input) {
List<ChunkedInference> results = new ArrayList<>();
for (int i = 0; i < input.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;

import java.io.IOException;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Iterator;
Expand All @@ -58,7 +60,11 @@ public static class TestInferenceService extends AbstractTestInferenceService {
private static final String NAME = "streaming_completion_test_service";
private static final Set<TaskType> supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);

private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
TaskType.COMPLETION,
TaskType.CHAT_COMPLETION,
TaskType.SPARSE_EMBEDDING
);

public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}

Expand Down Expand Up @@ -114,7 +120,21 @@ public void infer(
ActionListener<InferenceServiceResults> listener
) {
switch (model.getConfigurations().getTaskType()) {
case COMPLETION -> listener.onResponse(makeResults(input));
case COMPLETION -> listener.onResponse(makeChatCompletionResults(input));
case SPARSE_EMBEDDING -> {
if (stream) {
listener.onFailure(
new ElasticsearchStatusException(
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
RestStatus.BAD_REQUEST
)
);
} else {
// Return text embedding results when creating a sparse_embedding inference endpoint to allow creation validation to
// pass. This is required to test that streaming fails for a sparse_embedding endpoint.
listener.onResponse(makeTextEmbeddingResults(input));
}
}
default -> listener.onFailure(
new ElasticsearchStatusException(
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
Expand Down Expand Up @@ -142,7 +162,7 @@ public void unifiedCompletionInfer(
}
}

private StreamingChatCompletionResults makeResults(List<String> input) {
private StreamingChatCompletionResults makeChatCompletionResults(List<String> input) {
var responseIter = input.stream().map(s -> s.toUpperCase(Locale.ROOT)).iterator();
return new StreamingChatCompletionResults(subscriber -> {
subscriber.onSubscribe(new Flow.Subscription() {
Expand All @@ -161,6 +181,18 @@ public void cancel() {}
});
}

private TextEmbeddingFloatResults makeTextEmbeddingResults(List<String> input) {
var embeddings = new ArrayList<TextEmbeddingFloatResults.Embedding>();
for (int i = 0; i < input.size(); i++) {
var values = new float[5];
for (int j = 0; j < 5; j++) {
values[j] = random.nextFloat();
}
embeddings.add(new TextEmbeddingFloatResults.Embedding(values));
}
return new TextEmbeddingFloatResults(embeddings);
}

private InferenceServiceResults.Result completionChunk(String delta) {
return new InferenceServiceResults.Result() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ private void doExecuteForked(

var service = serviceRegistry.getService(unparsedModel.service());
if (service.isPresent()) {
service.get().stop(unparsedModel, listener);
var model = service.get()
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
service.get().stop(model, listener);
} else {
listener.onFailure(
new ElasticsearchStatusException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;

import java.io.IOException;
import java.util.List;
Expand Down Expand Up @@ -194,19 +195,23 @@ private void parseAndStoreModel(
ActionListener<Model> storeModelListener = listener.delegateFailureAndWrap(
(delegate, verifiedModel) -> modelRegistry.storeModel(
verifiedModel,
ActionListener.wrap(r -> startInferenceEndpoint(service, timeout, verifiedModel, delegate), e -> {
if (e.getCause() instanceof StrictDynamicMappingException && e.getCause().getMessage().contains("chunking_settings")) {
delegate.onFailure(
new ElasticsearchStatusException(
"One or more nodes in your cluster does not support chunking_settings. "
+ "Please update all nodes in your cluster to the latest version to use chunking_settings.",
RestStatus.BAD_REQUEST
)
);
} else {
delegate.onFailure(e);
ActionListener.wrap(
r -> listener.onResponse(new PutInferenceModelAction.Response(verifiedModel.getConfigurations())),
e -> {
if (e.getCause() instanceof StrictDynamicMappingException
&& e.getCause().getMessage().contains("chunking_settings")) {
delegate.onFailure(
new ElasticsearchStatusException(
"One or more nodes in your cluster does not support chunking_settings. "
+ "Please update all nodes in your cluster to the latest version to use chunking_settings.",
RestStatus.BAD_REQUEST
)
);
} else {
delegate.onFailure(e);
}
}
}),
),
timeout
)
);
Expand All @@ -215,26 +220,14 @@ private void parseAndStoreModel(
if (skipValidationAndStart) {
storeModelListener.onResponse(model);
} else {
service.checkModelConfig(model, storeModelListener);
ModelValidatorBuilder.buildModelValidator(model.getTaskType(), service instanceof ElasticsearchInternalService)
.validate(service, model, timeout, storeModelListener);
}
});

service.parseRequestConfig(inferenceEntityId, taskType, config, parsedModelListener);
}

private void startInferenceEndpoint(
InferenceService service,
TimeValue timeout,
Model model,
ActionListener<PutInferenceModelAction.Response> listener
) {
if (skipValidationAndStart) {
listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()));
} else {
service.start(model, timeout, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations())));
}
}

private Map<String, Object> requestToMap(PutInferenceModelAction.Request request) throws IOException {
try (
XContentParser parser = XContentHelper.createParser(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,17 @@
package org.elasticsearch.xpack.inference.services;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;

Expand Down Expand Up @@ -723,53 +718,6 @@ public static ElasticsearchStatusException createInvalidModelException(Model mod
);
}

/**
* Evaluate the model and return the text embedding size
* @param model Should be a text embedding model
* @param service The inference service
* @param listener Size listener
*/
public static void getEmbeddingSize(Model model, InferenceService service, ActionListener<Integer> listener) {
assert model.getTaskType() == TaskType.TEXT_EMBEDDING;

service.infer(
model,
null,
null,
null,
List.of(TEST_EMBEDDING_INPUT),
false,
Map.of(),
InputType.INTERNAL_INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener.delegateFailureAndWrap((delegate, r) -> {
if (r instanceof TextEmbeddingResults<?> embeddingResults) {
try {
delegate.onResponse(embeddingResults.getFirstEmbeddingSize());
} catch (Exception e) {
delegate.onFailure(
new ElasticsearchStatusException("Could not determine embedding size", RestStatus.BAD_REQUEST, e)
);
}
} else {
delegate.onFailure(
new ElasticsearchStatusException(
"Could not determine embedding size. "
+ "Expected a result of type ["
+ TextEmbeddingFloatResults.NAME
+ "] got ["
+ r.getWriteableName()
+ "]",
RestStatus.BAD_REQUEST
)
);
}
})
);
}

private static final String TEST_EMBEDDING_INPUT = "how big";

public static SecureString apiKey(@Nullable ApiKeySecrets secrets) {
// To avoid a possible null pointer throughout the code we'll create a noop api key of an empty array
return secrets == null ? new SecureString(new char[0]) : secrets.apiKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;

import java.util.EnumSet;
import java.util.HashMap;
Expand Down Expand Up @@ -348,19 +347,6 @@ protected void doChunkedInfer(
}
}

/**
* For text embedding models get the embedding size and
* update the service settings.
*
* @param model The new model
* @param listener The listener
*/
@Override
public void checkModelConfig(Model model, ActionListener<Model> listener) {
// TODO: Remove this function once all services have been updated to use the new model validators
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
}

@Override
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
if (model instanceof AlibabaCloudSearchEmbeddingsModel embeddingsModel) {
Expand Down
Loading