Skip to content

Commit 0c65909

Browse files
Add endpoint creation validation for ElasticsearchInternalService
1 parent 6911227 commit 0c65909

File tree

4 files changed

+87
-45
lines changed

4 files changed

+87
-45
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ private void parseAndStoreModel(
172172
ActionListener<Model> storeModelListener = listener.delegateFailureAndWrap(
173173
(delegate, verifiedModel) -> modelRegistry.storeModel(
174174
verifiedModel,
175-
ActionListener.wrap(r -> startInferenceEndpoint(service, verifiedModel, delegate), e -> {
175+
ActionListener.wrap(r -> listener.onResponse(createResponse(verifiedModel.getConfigurations())), e -> {
176176
if (e.getCause() instanceof StrictDynamicMappingException && e.getCause().getMessage().contains("chunking_settings")) {
177177
delegate.onFailure(
178178
new ElasticsearchStatusException(
@@ -199,12 +199,8 @@ private void parseAndStoreModel(
199199
service.parseRequestConfig(inferenceEntityId, taskType, config, parsedModelListener);
200200
}
201201

202-
private void startInferenceEndpoint(InferenceService service, Model model, ActionListener<PutInferenceModelAction.Response> listener) {
203-
if (skipValidationAndStart) {
204-
listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()));
205-
} else {
206-
service.start(model, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations())));
207-
}
202+
private PutInferenceModelAction.Response createResponse(ModelConfigurations configurations) {
203+
return new PutInferenceModelAction.Response(configurations);
208204
}
209205

210206
private Map<String, Object> requestToMap(PutInferenceModelAction.Request request) throws IOException {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
6060
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
6161
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
62-
import org.elasticsearch.xpack.inference.services.ServiceUtils;
62+
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
6363

6464
import java.util.ArrayList;
6565
import java.util.Collections;
@@ -499,49 +499,38 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
499499

500500
@Override
501501
public void checkModelConfig(Model model, ActionListener<Model> listener) {
502-
if (model instanceof CustomElandEmbeddingModel elandModel && elandModel.getTaskType() == TaskType.TEXT_EMBEDDING) {
503-
// At this point the inference endpoint configuration has not been persisted yet, if we attempt to do inference using the
504-
// inference id we'll get an error because the trained model code needs to use the persisted inference endpoint to retrieve the
505-
// model id. To get around this we'll have the getEmbeddingSize() method use the model id instead of inference id. So we need
506-
// to create a temporary model that overrides the inference id with the model id.
507-
var temporaryModelWithModelId = new CustomElandEmbeddingModel(
508-
elandModel.getServiceSettings().modelId(),
509-
elandModel.getTaskType(),
510-
elandModel.getConfigurations().getService(),
511-
elandModel.getServiceSettings(),
512-
elandModel.getConfigurations().getChunkingSettings()
502+
ModelValidatorBuilder.buildModelValidator(model.getTaskType(), true).validate(this, model, listener);
503+
}
504+
505+
@Override
506+
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
507+
if (model instanceof CustomElandEmbeddingModel embeddingsModel) {
508+
var serviceSettings = embeddingsModel.getServiceSettings();
509+
510+
var updatedServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
511+
serviceSettings.getNumAllocations(),
512+
serviceSettings.getNumThreads(),
513+
serviceSettings.modelId(),
514+
serviceSettings.getAdaptiveAllocationsSettings(),
515+
embeddingSize,
516+
serviceSettings.similarity(),
517+
serviceSettings.elementType()
513518
);
514519

515-
ServiceUtils.getEmbeddingSize(
516-
temporaryModelWithModelId,
517-
this,
518-
listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(elandModel, size)))
520+
return new CustomElandEmbeddingModel(
521+
model.getInferenceEntityId(),
522+
model.getTaskType(),
523+
model.getConfigurations().getService(),
524+
updatedServiceSettings,
525+
model.getConfigurations().getChunkingSettings()
519526
);
520527
} else {
521-
listener.onResponse(model);
528+
// TODO: This is for the E5 case which is text embedding but we didn't previously update the dimensions. Figure out if we do
529+
// need to update the dimensions?
530+
return model;
522531
}
523532
}
524533

525-
private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomElandEmbeddingModel model, int embeddingSize) {
526-
CustomElandInternalTextEmbeddingServiceSettings serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
527-
model.getServiceSettings().getNumAllocations(),
528-
model.getServiceSettings().getNumThreads(),
529-
model.getServiceSettings().modelId(),
530-
model.getServiceSettings().getAdaptiveAllocationsSettings(),
531-
embeddingSize,
532-
model.getServiceSettings().similarity(),
533-
model.getServiceSettings().elementType()
534-
);
535-
536-
return new CustomElandEmbeddingModel(
537-
model.getInferenceEntityId(),
538-
model.getTaskType(),
539-
model.getConfigurations().getService(),
540-
serviceSettings,
541-
model.getConfigurations().getChunkingSettings()
542-
);
543-
}
544-
545534
@Override
546535
public void infer(
547536
Model model,
@@ -904,7 +893,10 @@ private List<Model> defaultConfigs(boolean useLinuxOptimizedModel) {
904893

905894
@Override
906895
boolean isDefaultId(String inferenceId) {
907-
return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId);
896+
// return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId);
897+
// TODO: This is a temporary override to ensure that we always deploy models on infer to run a validation call.
898+
// Figure out if this is what we actually want to do?
899+
return true;
908900
}
909901

910902
static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSettings(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.validation;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.inference.InferenceService;
12+
import org.elasticsearch.inference.Model;
13+
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandEmbeddingModel;
14+
15+
public class ElasticsearchInternalServiceModelValidator implements ModelValidator {
16+
17+
private final ModelValidator modelValidator;
18+
19+
public ElasticsearchInternalServiceModelValidator(ModelValidator modelValidator) {
20+
this.modelValidator = modelValidator;
21+
}
22+
23+
@Override
24+
public void validate(InferenceService service, Model model, ActionListener<Model> listener) {
25+
var modelToValidate = model;
26+
if (model instanceof CustomElandEmbeddingModel esModel) {
27+
modelToValidate = new CustomElandEmbeddingModel(
28+
esModel.getServiceSettings().modelId(),
29+
esModel.getTaskType(),
30+
esModel.getConfigurations().getService(),
31+
esModel.getServiceSettings(),
32+
esModel.getConfigurations().getChunkingSettings()
33+
);
34+
}
35+
36+
modelValidator.validate(
37+
service,
38+
modelToValidate,
39+
listener.delegateFailureAndWrap((delegate, r) -> { delegate.onResponse(model); })
40+
);
41+
}
42+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,18 @@
1111
import org.elasticsearch.inference.TaskType;
1212

1313
public class ModelValidatorBuilder {
14+
15+
// TODO: Once we merge all the other service validation code we can remove the checkModelConfig function
16+
// from each service, private the buildModelValidator function below this one, and call this directly from
17+
// TransportPutInferenceModelAction.java.
18+
public static ModelValidator buildModelValidator(TaskType taskType, boolean isElasticsearchInternalService) {
19+
if (isElasticsearchInternalService) {
20+
return new ElasticsearchInternalServiceModelValidator(buildModelValidator(taskType));
21+
} else {
22+
return buildModelValidator(taskType);
23+
}
24+
}
25+
1426
public static ModelValidator buildModelValidator(TaskType taskType) {
1527
if (taskType == null) {
1628
throw new IllegalArgumentException("Task type can't be null");

0 commit comments

Comments
 (0)