Skip to content

Commit 701ed61

Browse files
Adding inference endpoint creation validation for MistralService, GoogleAiStudioService, and HuggingFaceService (#113492)
* Adding inference endpoint creation validation for MistralService, GoogleAiStudioService, and HuggingFaceService * Moving invalid model type exception to shared ServiceUtils function * Fixing naming inconsistency * Updating HuggingFaceIT ELSER tests for inference endpoint validation
1 parent 1b67dab commit 701ed61

File tree

13 files changed

+213
-77
lines changed

13 files changed

+213
-77
lines changed

x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/HuggingFaceServiceMixedIT.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ public void testElser() throws IOException {
8484
final String inferenceId = "mixed-cluster-elser";
8585
final String upgradedClusterId = "upgraded-cluster-elser";
8686

87+
elserServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponse()));
8788
put(inferenceId, elserConfig(getUrl(elserServer)), TaskType.SPARSE_EMBEDDING);
8889

8990
var configs = (List<Map<String, Object>>) get(TaskType.SPARSE_EMBEDDING, inferenceId).get("endpoints");

x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/HuggingFaceServiceUpgradeIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ public void testElser() throws IOException {
117117
var testTaskType = TaskType.SPARSE_EMBEDDING;
118118

119119
if (isOldCluster()) {
120+
elserServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponse()));
120121
put(oldClusterId, elserConfig(getUrl(elserServer)), testTaskType);
121122
var configs = (List<Map<String, Object>>) get(testTaskType, oldClusterId).get(old_cluster_endpoint_identifier);
122123
assertThat(configs, hasSize(1));
@@ -136,6 +137,7 @@ public void testElser() throws IOException {
136137
assertElser(oldClusterId);
137138

138139
// New endpoint
140+
elserServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponse()));
139141
put(upgradedClusterId, elserConfig(getUrl(elserServer)), testTaskType);
140142
configs = (List<Map<String, Object>>) get(upgradedClusterId).get("endpoints");
141143
assertThat(configs, hasSize(1));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,13 @@ public static ElasticsearchStatusException unknownSettingsError(Map<String, Obje
202202
);
203203
}
204204

205+
public static ElasticsearchStatusException invalidModelTypeForUpdateModelWithEmbeddingDetails(Class<? extends Model> invalidModelType) {
206+
throw new ElasticsearchStatusException(
207+
Strings.format("Can't update embedding details for model with unexpected type %s", invalidModelType),
208+
RestStatus.BAD_REQUEST
209+
);
210+
}
211+
205212
public static String missingSettingErrorMsg(String settingName, String scope) {
206213
return Strings.format("[%s] does not contain the required setting [%s]", scope, settingName);
207214
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel;
3636
import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsModel;
3737
import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings;
38+
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
3839

3940
import java.util.List;
4041
import java.util.Map;
@@ -187,30 +188,29 @@ public TransportVersion getMinimalSupportedVersion() {
187188

188189
@Override
189190
public void checkModelConfig(Model model, ActionListener<Model> listener) {
190-
if (model instanceof GoogleAiStudioEmbeddingsModel embeddingsModel) {
191-
ServiceUtils.getEmbeddingSize(
192-
model,
193-
this,
194-
listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
195-
);
196-
} else {
197-
listener.onResponse(model);
198-
}
191+
// TODO: Remove this function once all services have been updated to use the new model validators
192+
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
199193
}
200194

201-
private GoogleAiStudioEmbeddingsModel updateModelWithEmbeddingDetails(GoogleAiStudioEmbeddingsModel model, int embeddingSize) {
202-
var similarityFromModel = model.getServiceSettings().similarity();
203-
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
195+
@Override
196+
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
197+
if (model instanceof GoogleAiStudioEmbeddingsModel embeddingsModel) {
198+
var serviceSettings = embeddingsModel.getServiceSettings();
199+
var similarityFromModel = serviceSettings.similarity();
200+
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
204201

205-
GoogleAiStudioEmbeddingsServiceSettings serviceSettings = new GoogleAiStudioEmbeddingsServiceSettings(
206-
model.getServiceSettings().modelId(),
207-
model.getServiceSettings().maxInputTokens(),
208-
embeddingSize,
209-
similarityToUse,
210-
model.getServiceSettings().rateLimitSettings()
211-
);
202+
var updatedServiceSettings = new GoogleAiStudioEmbeddingsServiceSettings(
203+
serviceSettings.modelId(),
204+
serviceSettings.maxInputTokens(),
205+
embeddingSize,
206+
similarityToUse,
207+
serviceSettings.rateLimitSettings()
208+
);
212209

213-
return new GoogleAiStudioEmbeddingsModel(model, serviceSettings);
210+
return new GoogleAiStudioEmbeddingsModel(embeddingsModel, updatedServiceSettings);
211+
} else {
212+
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
213+
}
214214
}
215215

216216
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.elasticsearch.xpack.inference.services.ServiceUtils;
3030
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
3131
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
32+
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
3233

3334
import java.util.List;
3435
import java.util.Map;
@@ -67,34 +68,31 @@ protected HuggingFaceModel createModel(
6768

6869
@Override
6970
public void checkModelConfig(Model model, ActionListener<Model> listener) {
71+
// TODO: Remove this function once all services have been updated to use the new model validators
72+
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
73+
}
74+
75+
@Override
76+
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
7077
if (model instanceof HuggingFaceEmbeddingsModel embeddingsModel) {
71-
ServiceUtils.getEmbeddingSize(
72-
model,
73-
this,
74-
listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
78+
var serviceSettings = embeddingsModel.getServiceSettings();
79+
var similarityFromModel = serviceSettings.similarity();
80+
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.COSINE : similarityFromModel;
81+
82+
var updatedServiceSettings = new HuggingFaceServiceSettings(
83+
serviceSettings.uri(),
84+
similarityToUse,
85+
embeddingSize,
86+
embeddingsModel.getTokenLimit(),
87+
serviceSettings.rateLimitSettings()
7588
);
89+
90+
return new HuggingFaceEmbeddingsModel(embeddingsModel, updatedServiceSettings);
7691
} else {
77-
listener.onResponse(model);
92+
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
7893
}
7994
}
8095

81-
private static HuggingFaceEmbeddingsModel updateModelWithEmbeddingDetails(HuggingFaceEmbeddingsModel model, int embeddingSize) {
82-
// default to cosine similarity
83-
var similarity = model.getServiceSettings().similarity() == null
84-
? SimilarityMeasure.COSINE
85-
: model.getServiceSettings().similarity();
86-
87-
var serviceSettings = new HuggingFaceServiceSettings(
88-
model.getServiceSettings().uri(),
89-
similarity,
90-
embeddingSize,
91-
model.getTokenLimit(),
92-
model.getServiceSettings().rateLimitSettings()
93-
);
94-
95-
return new HuggingFaceEmbeddingsModel(model, serviceSettings);
96-
}
97-
9896
@Override
9997
protected void doChunkedInfer(
10098
Model model,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.elasticsearch.xpack.inference.services.ServiceUtils;
3434
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
3535
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
36+
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
3637

3738
import java.util.List;
3839
import java.util.Map;
@@ -214,32 +215,28 @@ private MistralEmbeddingsModel createModelFromPersistent(
214215

215216
@Override
216217
public void checkModelConfig(Model model, ActionListener<Model> listener) {
217-
if (model instanceof MistralEmbeddingsModel embeddingsModel) {
218-
ServiceUtils.getEmbeddingSize(
219-
model,
220-
this,
221-
listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateEmbeddingModelConfig(embeddingsModel, size)))
222-
);
223-
} else {
224-
listener.onResponse(model);
225-
}
218+
// TODO: Remove this function once all services have been updated to use the new model validators
219+
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
226220
}
227221

228-
private MistralEmbeddingsModel updateEmbeddingModelConfig(MistralEmbeddingsModel embeddingsModel, int embeddingsSize) {
229-
var embeddingServiceSettings = embeddingsModel.getServiceSettings();
230-
231-
var similarityFromModel = embeddingsModel.getServiceSettings().similarity();
232-
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
222+
@Override
223+
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
224+
if (model instanceof MistralEmbeddingsModel embeddingsModel) {
225+
var serviceSettings = embeddingsModel.getServiceSettings();
233226

234-
MistralEmbeddingsServiceSettings serviceSettings = new MistralEmbeddingsServiceSettings(
235-
embeddingServiceSettings.modelId(),
236-
embeddingsSize,
237-
embeddingServiceSettings.maxInputTokens(),
238-
similarityToUse,
239-
embeddingServiceSettings.rateLimitSettings()
240-
);
227+
var similarityFromModel = embeddingsModel.getServiceSettings().similarity();
228+
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
241229

242-
return new MistralEmbeddingsModel(embeddingsModel, serviceSettings);
230+
MistralEmbeddingsServiceSettings updatedServiceSettings = new MistralEmbeddingsServiceSettings(
231+
serviceSettings.modelId(),
232+
embeddingSize,
233+
serviceSettings.maxInputTokens(),
234+
similarityToUse,
235+
serviceSettings.rateLimitSettings()
236+
);
237+
return new MistralEmbeddingsModel(embeddingsModel, updatedServiceSettings);
238+
} else {
239+
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
240+
}
243241
}
244-
245242
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
1414
import org.elasticsearch.core.Nullable;
15-
import org.elasticsearch.core.Strings;
1615
import org.elasticsearch.core.TimeValue;
1716
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
1817
import org.elasticsearch.inference.ChunkingOptions;
@@ -35,6 +34,7 @@
3534
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
3635
import org.elasticsearch.xpack.inference.services.SenderService;
3736
import org.elasticsearch.xpack.inference.services.ServiceComponents;
37+
import org.elasticsearch.xpack.inference.services.ServiceUtils;
3838
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
3939
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;
4040
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
@@ -307,10 +307,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
307307

308308
return new OpenAiEmbeddingsModel(embeddingsModel, updatedServiceSettings);
309309
} else {
310-
throw new ElasticsearchStatusException(
311-
Strings.format("Can't update embedding details for model with unexpected type %s", model.getClass()),
312-
RestStatus.BAD_REQUEST
313-
);
310+
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
314311
}
315312
}
316313

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
/*
23
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
34
* or more contributor license agreements. Licensed under the Elastic License
@@ -33,14 +34,25 @@ public void validate(InferenceService service, Model model, ActionListener<Infer
3334
Map.of(),
3435
InputType.INGEST,
3536
InferenceAction.Request.DEFAULT_TIMEOUT,
36-
listener.delegateFailureAndWrap((delegate, r) -> {
37+
ActionListener.wrap(r -> {
3738
if (r != null) {
38-
delegate.onResponse(r);
39+
listener.onResponse(r);
3940
} else {
40-
delegate.onFailure(
41-
new ElasticsearchStatusException("Could not make a validation call to the selected service", RestStatus.BAD_REQUEST)
41+
listener.onFailure(
42+
new ElasticsearchStatusException(
43+
"Could not complete inference endpoint creation as validation call to service returned null response.",
44+
RestStatus.BAD_REQUEST
45+
)
4246
);
4347
}
48+
}, e -> {
49+
listener.onFailure(
50+
new ElasticsearchStatusException(
51+
"Could not complete inference endpoint creation as validation call to service threw an exception.",
52+
RestStatus.BAD_REQUEST,
53+
e
54+
)
55+
);
4456
})
4557
);
4658
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,45 @@ public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosi
914914
}
915915
}
916916

917+
public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
918+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
919+
try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
920+
var model = GoogleAiStudioCompletionModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10));
921+
assertThrows(
922+
ElasticsearchStatusException.class,
923+
() -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
924+
);
925+
}
926+
}
927+
928+
public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
929+
testUpdateModelWithEmbeddingDetails_Successful(null);
930+
}
931+
932+
public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
933+
testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()));
934+
}
935+
936+
private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException {
937+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
938+
try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
939+
var embeddingSize = randomNonNegativeInt();
940+
var model = GoogleAiStudioEmbeddingsModelTests.createModel(
941+
randomAlphaOfLength(10),
942+
randomAlphaOfLength(10),
943+
randomAlphaOfLength(10),
944+
randomNonNegativeInt(),
945+
similarityMeasure
946+
);
947+
948+
Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
949+
950+
SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? SimilarityMeasure.DOT_PRODUCT : similarityMeasure;
951+
assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
952+
assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
953+
}
954+
}
955+
917956
public static Map<String, Object> buildExpectationCompletions(List<String> completions) {
918957
return Map.of(
919958
ChatCompletionResults.COMPLETION,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,45 @@ public void testCheckModelConfig_DefaultsSimilarityToCosine() throws IOException
595595
}
596596
}
597597

598+
public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
599+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
600+
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
601+
var model = HuggingFaceElserModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10));
602+
assertThrows(
603+
ElasticsearchStatusException.class,
604+
() -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
605+
);
606+
}
607+
}
608+
609+
public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
610+
testUpdateModelWithEmbeddingDetails_Successful(null);
611+
}
612+
613+
public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
614+
testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()));
615+
}
616+
617+
private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException {
618+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
619+
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
620+
var embeddingSize = randomNonNegativeInt();
621+
var model = HuggingFaceEmbeddingsModelTests.createModel(
622+
randomAlphaOfLength(10),
623+
randomAlphaOfLength(10),
624+
randomNonNegativeInt(),
625+
randomNonNegativeInt(),
626+
similarityMeasure
627+
);
628+
629+
Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
630+
631+
SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? SimilarityMeasure.COSINE : similarityMeasure;
632+
assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
633+
assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
634+
}
635+
}
636+
598637
public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() throws IOException {
599638
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
600639

0 commit comments

Comments
 (0)