Skip to content

Commit e4a174d

Browse files
Inference endpoint validation for OpenAIService (elastic#113137) (elastic#113546)
* Adding service integration and model validators for inference services. * Adding ModelValidators to OpenAiService * Cleaning up tests * Cleaning up tests * Adding mock response for completion model OpenAIService integration tests --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent fd77531 commit e4a174d

File tree

20 files changed

+733
-41
lines changed

20 files changed

+733
-41
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,17 @@ default void checkModelConfig(Model model, ActionListener<Model> listener) {
175175
listener.onResponse(model);
176176
};
177177

178+
/**
179+
* Update a text embedding model's dimensions based on a provided embedding
180+
* size and set the default similarity if required. The default behaviour is to just return the model.
181+
* @param model The original model without updated embedding details
182+
* @param embeddingSize The embedding size to update the model with
183+
* @return The model with updated embedding details
184+
*/
185+
default Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
186+
return model;
187+
}
188+
178189
/**
179190
* Return true if this model is hosted in the local Elasticsearch cluster
180191
* @return True if in cluster

server/src/main/java/org/elasticsearch/inference/ServiceSettings.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ default Integer dimensions() {
3434
return null;
3535
}
3636

37+
/**
38+
* Boolean signifying whether the dimensions were set by the user
39+
*
40+
* @return boolean signifying whether the dimensions were set by the user
41+
*/
42+
default Boolean dimensionsSetByUser() {
43+
return null;
44+
}
45+
3746
/**
3847
* The data type for the embeddings this service works with. Defaults to null,
3948
* Text Embedding models should return a non-null value

x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/OpenAIServiceMixedIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ public void testOpenAiCompletions() throws IOException {
9595
final String inferenceId = "mixed-cluster-completions";
9696
final String upgradedClusterId = "upgraded-cluster-completions";
9797

98+
// queue a response as PUT will call the service
99+
openAiChatCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionsResponse()));
98100
put(inferenceId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), TaskType.COMPLETION);
99101

100102
var configsMap = get(TaskType.COMPLETION, inferenceId);

x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/OpenAiServiceUpgradeIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ public void testOpenAiCompletions() throws IOException {
128128
var testTaskType = TaskType.COMPLETION;
129129

130130
if (isOldCluster()) {
131+
openAiChatCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionsResponse()));
131132
put(oldClusterId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), testTaskType);
132133

133134
var configs = (List<Map<String, Object>>) get(testTaskType, oldClusterId).get(old_cluster_endpoint_identifier);
@@ -157,6 +158,7 @@ public void testOpenAiCompletions() throws IOException {
157158

158159
assertCompletionInference(oldClusterId);
159160

161+
openAiChatCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionsResponse()));
160162
put(upgradedClusterId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), testTaskType);
161163
configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterId).get("endpoints");
162164
assertThat(configs, hasSize(1));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettings.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ public Integer dimensions() {
183183
return dimensions;
184184
}
185185

186-
public boolean dimensionsSetByUser() {
186+
@Override
187+
public Boolean dimensionsSetByUser() {
187188
return this.dimensionsSetByUser;
188189
}
189190

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettings.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ public SimilarityMeasure similarity() {
146146
return similarity;
147147
}
148148

149-
public boolean dimensionsSetByUser() {
149+
@Override
150+
public Boolean dimensionsSetByUser() {
150151
return this.dimensionsSetByUser;
151152
}
152153

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14-
import org.elasticsearch.common.Strings;
1514
import org.elasticsearch.core.Nullable;
15+
import org.elasticsearch.core.Strings;
1616
import org.elasticsearch.core.TimeValue;
1717
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
1818
import org.elasticsearch.inference.ChunkingOptions;
@@ -31,10 +31,10 @@
3131
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
3232
import org.elasticsearch.xpack.inference.services.SenderService;
3333
import org.elasticsearch.xpack.inference.services.ServiceComponents;
34-
import org.elasticsearch.xpack.inference.services.ServiceUtils;
3534
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
3635
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;
3736
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
37+
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
3838

3939
import java.util.List;
4040
import java.util.Map;
@@ -255,48 +255,35 @@ protected void doChunkedInfer(
255255
*/
256256
@Override
257257
public void checkModelConfig(Model model, ActionListener<Model> listener) {
258+
// TODO: Remove this function once all services have been updated to use the new model validators
259+
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
260+
}
261+
262+
@Override
263+
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
258264
if (model instanceof OpenAiEmbeddingsModel embeddingsModel) {
259-
ServiceUtils.getEmbeddingSize(
260-
model,
261-
this,
262-
listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
265+
var serviceSettings = embeddingsModel.getServiceSettings();
266+
var similarityFromModel = serviceSettings.similarity();
267+
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
268+
269+
var updatedServiceSettings = new OpenAiEmbeddingsServiceSettings(
270+
serviceSettings.modelId(),
271+
serviceSettings.uri(),
272+
serviceSettings.organizationId(),
273+
similarityToUse,
274+
embeddingSize,
275+
serviceSettings.maxInputTokens(),
276+
serviceSettings.dimensionsSetByUser(),
277+
serviceSettings.rateLimitSettings()
263278
);
264-
} else {
265-
listener.onResponse(model);
266-
}
267-
}
268279

269-
private OpenAiEmbeddingsModel updateModelWithEmbeddingDetails(OpenAiEmbeddingsModel model, int embeddingSize) {
270-
if (model.getServiceSettings().dimensionsSetByUser()
271-
&& model.getServiceSettings().dimensions() != null
272-
&& model.getServiceSettings().dimensions() != embeddingSize) {
280+
return new OpenAiEmbeddingsModel(embeddingsModel, updatedServiceSettings);
281+
} else {
273282
throw new ElasticsearchStatusException(
274-
Strings.format(
275-
"The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. "
276-
+ "Please recreate the [%s] configuration with the correct dimensions",
277-
embeddingSize,
278-
model.getServiceSettings().dimensions(),
279-
model.getConfigurations().getInferenceEntityId()
280-
),
283+
Strings.format("Can't update embedding details for model with unexpected type %s", model.getClass()),
281284
RestStatus.BAD_REQUEST
282285
);
283286
}
284-
285-
var similarityFromModel = model.getServiceSettings().similarity();
286-
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
287-
288-
OpenAiEmbeddingsServiceSettings serviceSettings = new OpenAiEmbeddingsServiceSettings(
289-
model.getServiceSettings().modelId(),
290-
model.getServiceSettings().uri(),
291-
model.getServiceSettings().organizationId(),
292-
similarityToUse,
293-
embeddingSize,
294-
model.getServiceSettings().maxInputTokens(),
295-
model.getServiceSettings().dimensionsSetByUser(),
296-
model.getServiceSettings().rateLimitSettings()
297-
);
298-
299-
return new OpenAiEmbeddingsModel(model, serviceSettings);
300287
}
301288

302289
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ public Integer dimensions() {
247247
return dimensions;
248248
}
249249

250+
@Override
250251
public Boolean dimensionsSetByUser() {
251252
return dimensionsSetByUser;
252253
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.validation;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.inference.InferenceService;
12+
import org.elasticsearch.inference.Model;
13+
14+
public interface ModelValidator {
15+
void validate(InferenceService service, Model model, ActionListener<Model> listener);
16+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.validation;
9+
10+
import org.elasticsearch.core.Strings;
11+
import org.elasticsearch.inference.TaskType;
12+
13+
public class ModelValidatorBuilder {
14+
public static ModelValidator buildModelValidator(TaskType taskType) {
15+
if (taskType == null) {
16+
throw new IllegalArgumentException("Task type can't be null");
17+
}
18+
19+
switch (taskType) {
20+
case TEXT_EMBEDDING -> {
21+
return new TextEmbeddingModelValidator(new SimpleServiceIntegrationValidator());
22+
}
23+
case SPARSE_EMBEDDING, RERANK, COMPLETION, ANY -> {
24+
return new SimpleModelValidator(new SimpleServiceIntegrationValidator());
25+
}
26+
default -> throw new IllegalArgumentException(Strings.format("Can't validate inference model of for task type %s ", taskType));
27+
}
28+
}
29+
}

0 commit comments

Comments
 (0)