Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/116352.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 116352
summary: Add endpoint creation validation for `ElasticsearchInternalService`
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,10 @@ void chunkedInfer(
/**
* Stop the model deployment.
* The default action does nothing except acknowledge the request (true).
* @param unparsedModel The unparsed model configuration
* @param model The model configuration
* @param listener The listener
*/
default void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener) {
default void stop(Model model, ActionListener<Boolean> listener) {
listener.onResponse(true);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ private void doExecuteForked(

var service = serviceRegistry.getService(unparsedModel.service());
if (service.isPresent()) {
service.get().stop(unparsedModel, listener);
var model = service.get()
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
service.get().stop(model, listener);
} else {
listener.onFailure(
new ElasticsearchStatusException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
Expand Down Expand Up @@ -118,9 +117,7 @@ public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalL
}

@Override
public void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener) {

var model = parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
public void stop(Model model, ActionListener<Boolean> listener) {
if (model instanceof ElasticsearchInternalModel esModel) {

var serviceSettings = esModel.getServiceSettings();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;

import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -498,47 +499,40 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M

@Override
public void checkModelConfig(Model model, ActionListener<Model> listener) {
if (model instanceof CustomElandEmbeddingModel elandModel && elandModel.getTaskType() == TaskType.TEXT_EMBEDDING) {
// At this point the inference endpoint configuration has not been persisted yet, if we attempt to do inference using the
// inference id we'll get an error because the trained model code needs to use the persisted inference endpoint to retrieve the
// model id. To get around this we'll have the getEmbeddingSize() method use the model id instead of inference id. So we need
// to create a temporary model that overrides the inference id with the model id.
var temporaryModelWithModelId = new CustomElandEmbeddingModel(
elandModel.getServiceSettings().modelId(),
elandModel.getTaskType(),
elandModel.getConfigurations().getService(),
elandModel.getServiceSettings(),
elandModel.getConfigurations().getChunkingSettings()
);

ServiceUtils.getEmbeddingSize(
temporaryModelWithModelId,
this,
listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(elandModel, size)))
);
} else {
listener.onResponse(model);
}
ModelValidatorBuilder.buildModelValidator(model.getTaskType(), true).validate(this, model, listener);
}

private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomElandEmbeddingModel model, int embeddingSize) {
CustomElandInternalTextEmbeddingServiceSettings serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
model.getServiceSettings().getNumAllocations(),
model.getServiceSettings().getNumThreads(),
model.getServiceSettings().modelId(),
model.getServiceSettings().getAdaptiveAllocationsSettings(),
embeddingSize,
model.getServiceSettings().similarity(),
model.getServiceSettings().elementType()
);
@Override
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
if (model instanceof ElasticsearchInternalModel) {
if (model instanceof CustomElandEmbeddingModel embeddingsModel) {
var serviceSettings = embeddingsModel.getServiceSettings();

var updatedServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
serviceSettings.getNumAllocations(),
serviceSettings.getNumThreads(),
serviceSettings.modelId(),
serviceSettings.getAdaptiveAllocationsSettings(),
embeddingSize,
serviceSettings.similarity(),
serviceSettings.elementType()
);

return new CustomElandEmbeddingModel(
model.getInferenceEntityId(),
model.getTaskType(),
model.getConfigurations().getService(),
serviceSettings,
model.getConfigurations().getChunkingSettings()
);
return new CustomElandEmbeddingModel(
model.getInferenceEntityId(),
model.getTaskType(),
model.getConfigurations().getService(),
updatedServiceSettings,
model.getConfigurations().getChunkingSettings()
);
} else {
// TODO: This is for the E5 case which is text embedding but we didn't previously update the dimensions. Figure out if we do
// need to update the dimensions?
return model;
}
} else {
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
}
}

@Override
Expand Down Expand Up @@ -882,7 +876,10 @@ private List<Model> defaultConfigs(boolean useLinuxOptimizedModel) {

@Override
boolean isDefaultId(String inferenceId) {
return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId);
// return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId);
// TODO: This is a temporary override to ensure that we always deploy models on infer to run a validation call.
// Figure out if this is what we actually want to do?
return true;
}

static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSettings(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.validation;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.Model;

public class ElasticsearchInternalServiceModelValidator implements ModelValidator {

ModelValidator modelValidator;

public ElasticsearchInternalServiceModelValidator(ModelValidator modelValidator) {
this.modelValidator = modelValidator;
}

@Override
public void validate(InferenceService service, Model model, ActionListener<Model> listener) {
modelValidator.validate(service, model, listener.delegateResponse((l, exception) -> {
// TODO: Cleanup the below code
service.stop(model, ActionListener.wrap((v) -> listener.onFailure(exception), (e) -> listener.onFailure(exception)));
}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@
import org.elasticsearch.inference.TaskType;

public class ModelValidatorBuilder {
public static ModelValidator buildModelValidator(TaskType taskType, boolean isElasticsearchInternalService) {
var modelValidator = buildModelValidator(taskType);
if (isElasticsearchInternalService) {
return new ElasticsearchInternalServiceModelValidator(modelValidator);
} else {
return modelValidator;
}
}

public static ModelValidator buildModelValidator(TaskType taskType) {
if (taskType == null) {
throw new IllegalArgumentException("Task type can't be null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,13 @@
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
Expand Down Expand Up @@ -1440,7 +1442,7 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() {
);

var request = (InferModelAction.Request) invocationOnMock.getArguments()[1];
assertThat(request.getId(), is("custom-model"));
assertThat(request.getId(), is(randomInferenceEntityId));
return Void.TYPE;
}).when(client).execute(eq(InferModelAction.INSTANCE), any(), any());
when(client.threadPool()).thenReturn(threadPool);
Expand Down Expand Up @@ -1488,6 +1490,84 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() {
assertThat(model, is(expectedModel));
}

public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
var client = mock(Client.class);
try (var service = createService(client)) {
var model = OpenAiChatCompletionModelTests.createChatCompletionModel(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
randomAlphaOfLength(10),
randomAlphaOfLength(10),
randomAlphaOfLength(10)
);
assertThrows(
ElasticsearchStatusException.class,
() -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
);
}
}

public void testUpdateModelWithEmbeddingDetails_NonElandModelProvided() throws IOException {
var client = mock(Client.class);
try (var service = createService(client)) {
var originalModel = new MultilingualE5SmallModel(
randomAlphaOfLength(10),
TaskType.TEXT_EMBEDDING,
randomAlphaOfLength(10),
new MultilingualE5SmallInternalServiceSettings(
randomNonNegativeInt(),
randomNonNegativeInt(),
randomAlphaOfLength(10),
null
),
null
);

var updatedModel = service.updateModelWithEmbeddingDetails(originalModel, randomNonNegativeInt());
assertEquals(originalModel, updatedModel);
}
}

public void testUpdateModelWithEmbeddingDetails_ElandModelProvided() throws IOException {
var client = mock(Client.class);
try (var service = createService(client)) {
var originalServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
randomNonNegativeInt(),
randomNonNegativeInt(),
randomAlphaOfLength(10),
null
);
var originalModel = new CustomElandEmbeddingModel(
randomAlphaOfLength(10),
TaskType.TEXT_EMBEDDING,
randomAlphaOfLength(10),
originalServiceSettings,
ChunkingSettingsTests.createRandomChunkingSettings()
);

var embeddingSize = randomNonNegativeInt();
var expectedUpdatedServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
originalServiceSettings.getNumAllocations(),
originalServiceSettings.getNumThreads(),
originalServiceSettings.modelId(),
originalServiceSettings.getAdaptiveAllocationsSettings(),
embeddingSize,
originalServiceSettings.similarity(),
originalServiceSettings.elementType()
);
var expectedUpdatedModel = new CustomElandEmbeddingModel(
originalModel.getInferenceEntityId(),
originalModel.getTaskType(),
originalModel.getConfigurations().getService(),
expectedUpdatedServiceSettings,
originalModel.getConfigurations().getChunkingSettings()
);

var actualUpdatedModel = service.updateModelWithEmbeddingDetails(originalModel, embeddingSize);
assertEquals(expectedUpdatedModel, actualUpdatedModel);
}
}

public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic() {
{
assertFalse(
Expand Down