Skip to content

Commit 58a53f8

Browse files
Adding inference endpoint validation for AzureAiStudioService (#113713)
* Adding inference endpoint validation for AzureAiStudioService * Run spotlessApple * Update docs/changelog/113713.yaml * Remove isInClusterService from InferenceService * Run spotless apply --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent e09d406 commit 58a53f8

File tree

9 files changed

+298
-50
lines changed

9 files changed

+298
-50
lines changed

docs/changelog/113713.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 113713
2+
summary: Adding inference endpoint validation for `AzureAiStudioService`
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,15 @@ default Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
178178
return model;
179179
}
180180

181+
/**
182+
* Update a chat completion model's max tokens if required. The default behaviour is to just return the model.
183+
* @param model The original model without updated embedding details
184+
* @return The model with updated chat completion details
185+
*/
186+
default Model updateModelWithChatCompletionDetails(Model model) {
187+
return model;
188+
}
189+
181190
/**
182191
* Defines the version required across all clusters to use this service
183192
* @return {@link TransportVersion} specifying the version

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,15 @@ public static ElasticsearchStatusException invalidModelTypeForUpdateModelWithEmb
209209
);
210210
}
211211

212+
public static ElasticsearchStatusException invalidModelTypeForUpdateModelWithChatCompletionDetails(
213+
Class<? extends Model> invalidModelType
214+
) {
215+
throw new ElasticsearchStatusException(
216+
Strings.format("Can't update chat completion details for model with unexpected type %s", invalidModelType),
217+
RestStatus.BAD_REQUEST
218+
);
219+
}
220+
212221
public static String missingSettingErrorMsg(String settingName, String scope) {
213222
return Strings.format("[%s] does not contain the required setting [%s]", scope, settingName);
214223
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java

Lines changed: 39 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
5050
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
5151
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
52+
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
5253

5354
import java.util.EnumSet;
5455
import java.util.HashMap;
@@ -315,62 +316,52 @@ private AzureAiStudioModel createModelFromPersistent(
315316

316317
@Override
317318
public void checkModelConfig(Model model, ActionListener<Model> listener) {
319+
// TODO: Remove this function once all services have been updated to use the new model validators
320+
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
321+
}
322+
323+
@Override
324+
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
318325
if (model instanceof AzureAiStudioEmbeddingsModel embeddingsModel) {
319-
ServiceUtils.getEmbeddingSize(
320-
model,
321-
this,
322-
listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateEmbeddingModelConfig(embeddingsModel, size)))
326+
var serviceSettings = embeddingsModel.getServiceSettings();
327+
var similarityFromModel = serviceSettings.similarity();
328+
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
329+
330+
var updatedServiceSettings = new AzureAiStudioEmbeddingsServiceSettings(
331+
serviceSettings.target(),
332+
serviceSettings.provider(),
333+
serviceSettings.endpointType(),
334+
embeddingSize,
335+
serviceSettings.dimensionsSetByUser(),
336+
serviceSettings.maxInputTokens(),
337+
similarityToUse,
338+
serviceSettings.rateLimitSettings()
323339
);
324-
} else if (model instanceof AzureAiStudioChatCompletionModel chatCompletionModel) {
325-
listener.onResponse(updateChatCompletionModelConfig(chatCompletionModel));
340+
341+
return new AzureAiStudioEmbeddingsModel(embeddingsModel, updatedServiceSettings);
326342
} else {
327-
listener.onResponse(model);
343+
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
328344
}
329345
}
330346

331-
private AzureAiStudioEmbeddingsModel updateEmbeddingModelConfig(AzureAiStudioEmbeddingsModel embeddingsModel, int embeddingsSize) {
332-
if (embeddingsModel.getServiceSettings().dimensionsSetByUser()
333-
&& embeddingsModel.getServiceSettings().dimensions() != null
334-
&& embeddingsModel.getServiceSettings().dimensions() != embeddingsSize) {
335-
throw new ElasticsearchStatusException(
336-
Strings.format(
337-
"The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. "
338-
+ "Please recreate the [%s] configuration with the correct dimensions",
339-
embeddingsSize,
340-
embeddingsModel.getServiceSettings().dimensions(),
341-
embeddingsModel.getConfigurations().getInferenceEntityId()
342-
),
343-
RestStatus.BAD_REQUEST
347+
@Override
348+
public Model updateModelWithChatCompletionDetails(Model model) {
349+
if (model instanceof AzureAiStudioChatCompletionModel chatCompletionModel) {
350+
var taskSettings = chatCompletionModel.getTaskSettings();
351+
var modelMaxNewTokens = taskSettings.maxNewTokens();
352+
var maxNewTokensToUse = modelMaxNewTokens == null ? DEFAULT_MAX_NEW_TOKENS : modelMaxNewTokens;
353+
354+
var updatedTaskSettings = new AzureAiStudioChatCompletionTaskSettings(
355+
taskSettings.temperature(),
356+
taskSettings.topP(),
357+
taskSettings.doSample(),
358+
maxNewTokensToUse
344359
);
345-
}
346-
347-
var similarityFromModel = embeddingsModel.getServiceSettings().similarity();
348-
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
349-
350-
AzureAiStudioEmbeddingsServiceSettings serviceSettings = new AzureAiStudioEmbeddingsServiceSettings(
351-
embeddingsModel.getServiceSettings().target(),
352-
embeddingsModel.getServiceSettings().provider(),
353-
embeddingsModel.getServiceSettings().endpointType(),
354-
embeddingsSize,
355-
embeddingsModel.getServiceSettings().dimensionsSetByUser(),
356-
embeddingsModel.getServiceSettings().maxInputTokens(),
357-
similarityToUse,
358-
embeddingsModel.getServiceSettings().rateLimitSettings()
359-
);
360-
361-
return new AzureAiStudioEmbeddingsModel(embeddingsModel, serviceSettings);
362-
}
363360

364-
private AzureAiStudioChatCompletionModel updateChatCompletionModelConfig(AzureAiStudioChatCompletionModel chatCompletionModel) {
365-
var modelMaxNewTokens = chatCompletionModel.getTaskSettings().maxNewTokens();
366-
var maxNewTokensToUse = modelMaxNewTokens == null ? DEFAULT_MAX_NEW_TOKENS : modelMaxNewTokens;
367-
var updatedTaskSettings = new AzureAiStudioChatCompletionTaskSettings(
368-
chatCompletionModel.getTaskSettings().temperature(),
369-
chatCompletionModel.getTaskSettings().topP(),
370-
chatCompletionModel.getTaskSettings().doSample(),
371-
maxNewTokensToUse
372-
);
373-
return new AzureAiStudioChatCompletionModel(chatCompletionModel, updatedTaskSettings);
361+
return new AzureAiStudioChatCompletionModel(chatCompletionModel, updatedTaskSettings);
362+
} else {
363+
throw ServiceUtils.invalidModelTypeForUpdateModelWithChatCompletionDetails(model.getClass());
364+
}
374365
}
375366

376367
private static void checkProviderAndEndpointTypeForTask(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.validation;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.inference.InferenceService;
12+
import org.elasticsearch.inference.Model;
13+
14+
public class ChatCompletionModelValidator implements ModelValidator {
15+
16+
private final ServiceIntegrationValidator serviceIntegrationValidator;
17+
18+
public ChatCompletionModelValidator(ServiceIntegrationValidator serviceIntegrationValidator) {
19+
this.serviceIntegrationValidator = serviceIntegrationValidator;
20+
}
21+
22+
@Override
23+
public void validate(InferenceService service, Model model, ActionListener<Model> listener) {
24+
serviceIntegrationValidator.validate(service, model, listener.delegateFailureAndWrap((delegate, r) -> {
25+
delegate.onResponse(postValidate(service, model));
26+
}));
27+
}
28+
29+
private Model postValidate(InferenceService service, Model model) {
30+
return service.updateModelWithChatCompletionDetails(model);
31+
}
32+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ public static ModelValidator buildModelValidator(TaskType taskType) {
2020
case TEXT_EMBEDDING -> {
2121
return new TextEmbeddingModelValidator(new SimpleServiceIntegrationValidator());
2222
}
23-
case SPARSE_EMBEDDING, RERANK, COMPLETION, ANY -> {
23+
case COMPLETION -> {
24+
return new ChatCompletionModelValidator(new SimpleServiceIntegrationValidator());
25+
}
26+
case SPARSE_EMBEDDING, RERANK, ANY -> {
2427
return new SimpleModelValidator(new SimpleServiceIntegrationValidator());
2528
}
2629
default -> throw new IllegalArgumentException(Strings.format("Can't validate inference model of for task type %s ", taskType));

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModelTests;
5454
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettingsTests;
5555
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettingsTests;
56+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
5657
import org.hamcrest.CoreMatchers;
5758
import org.hamcrest.MatcherAssert;
5859
import org.hamcrest.Matchers;
@@ -973,6 +974,112 @@ public void testCheckModelConfig_WorksForChatCompletionsModel() throws IOExcepti
973974
}
974975
}
975976

977+
public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
978+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
979+
try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
980+
var model = AzureAiStudioChatCompletionModelTests.createModel(
981+
randomAlphaOfLength(10),
982+
randomAlphaOfLength(10),
983+
randomFrom(AzureAiStudioProvider.values()),
984+
randomFrom(AzureAiStudioEndpointType.values()),
985+
randomAlphaOfLength(10)
986+
);
987+
assertThrows(
988+
ElasticsearchStatusException.class,
989+
() -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
990+
);
991+
}
992+
}
993+
994+
public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
995+
testUpdateModelWithEmbeddingDetails_Successful(null);
996+
}
997+
998+
public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
999+
testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()));
1000+
}
1001+
1002+
private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException {
1003+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
1004+
try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
1005+
var embeddingSize = randomNonNegativeInt();
1006+
var model = AzureAiStudioEmbeddingsModelTests.createModel(
1007+
randomAlphaOfLength(10),
1008+
randomAlphaOfLength(10),
1009+
randomFrom(AzureAiStudioProvider.values()),
1010+
randomFrom(AzureAiStudioEndpointType.values()),
1011+
randomAlphaOfLength(10),
1012+
randomNonNegativeInt(),
1013+
randomBoolean(),
1014+
randomNonNegativeInt(),
1015+
similarityMeasure,
1016+
randomAlphaOfLength(10),
1017+
RateLimitSettingsTests.createRandom()
1018+
);
1019+
1020+
Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
1021+
1022+
SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? SimilarityMeasure.DOT_PRODUCT : similarityMeasure;
1023+
assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
1024+
assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
1025+
}
1026+
}
1027+
1028+
public void testUpdateModelWithChatCompletionDetails_InvalidModelProvided() throws IOException {
1029+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
1030+
try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
1031+
var model = AzureAiStudioEmbeddingsModelTests.createModel(
1032+
randomAlphaOfLength(10),
1033+
randomAlphaOfLength(10),
1034+
randomFrom(AzureAiStudioProvider.values()),
1035+
randomFrom(AzureAiStudioEndpointType.values()),
1036+
randomAlphaOfLength(10),
1037+
randomNonNegativeInt(),
1038+
randomBoolean(),
1039+
randomNonNegativeInt(),
1040+
randomFrom(SimilarityMeasure.values()),
1041+
randomAlphaOfLength(10),
1042+
RateLimitSettingsTests.createRandom()
1043+
);
1044+
assertThrows(ElasticsearchStatusException.class, () -> { service.updateModelWithChatCompletionDetails(model); });
1045+
}
1046+
}
1047+
1048+
public void testUpdateModelWithChatCompletionDetails_NullSimilarityInOriginalModel() throws IOException {
1049+
testUpdateModelWithChatCompletionDetails_Successful(null);
1050+
}
1051+
1052+
public void testUpdateModelWithChatCompletionDetails_NonNullSimilarityInOriginalModel() throws IOException {
1053+
testUpdateModelWithChatCompletionDetails_Successful(randomNonNegativeInt());
1054+
}
1055+
1056+
private void testUpdateModelWithChatCompletionDetails_Successful(Integer maxNewTokens) throws IOException {
1057+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
1058+
try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
1059+
var model = AzureAiStudioChatCompletionModelTests.createModel(
1060+
randomAlphaOfLength(10),
1061+
randomAlphaOfLength(10),
1062+
randomFrom(AzureAiStudioProvider.values()),
1063+
randomFrom(AzureAiStudioEndpointType.values()),
1064+
randomAlphaOfLength(10),
1065+
randomDouble(),
1066+
randomDouble(),
1067+
randomBoolean(),
1068+
maxNewTokens,
1069+
RateLimitSettingsTests.createRandom()
1070+
);
1071+
1072+
Model updatedModel = service.updateModelWithChatCompletionDetails(model);
1073+
assertThat(updatedModel, instanceOf(AzureAiStudioChatCompletionModel.class));
1074+
AzureAiStudioChatCompletionTaskSettings updatedTaskSettings = (AzureAiStudioChatCompletionTaskSettings) updatedModel
1075+
.getTaskSettings();
1076+
Integer expectedMaxNewTokens = maxNewTokens == null
1077+
? AzureAiStudioChatCompletionTaskSettings.DEFAULT_MAX_NEW_TOKENS
1078+
: maxNewTokens;
1079+
assertEquals(expectedMaxNewTokens, updatedTaskSettings.maxNewTokens());
1080+
}
1081+
}
1082+
9761083
public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOException {
9771084
var sender = mock(Sender.class);
9781085

0 commit comments

Comments
 (0)