Skip to content

Commit 9981a1d

Browse files
Adding endpoint creation validation for all task types to remaining services (#115020)
* Adding endpoint creation validation for all task types to remaining services * Update Cohere IT tests for rerank validation * Adding missing import * Update docs/changelog/115020.yaml * Fixing GoogleVertex tests after merge from upstream --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent fa9f2bf commit 9981a1d

File tree

17 files changed

+473
-218
lines changed

17 files changed

+473
-218
lines changed

docs/changelog/115020.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 115020
2+
summary: Adding endpoint creation validation for all task types to remaining services
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/CohereServiceMixedIT.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ public void testRerank() throws IOException {
135135

136136
final String inferenceId = "mixed-cluster-rerank";
137137

138+
cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse()));
138139
put(inferenceId, rerankConfig(getUrl(cohereRerankServer)), TaskType.RERANK);
139140
assertRerank(inferenceId);
140141

x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ public void testRerank() throws IOException {
201201
var testTaskType = TaskType.RERANK;
202202

203203
if (isOldCluster()) {
204+
cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse()));
204205
put(oldClusterId, rerankConfig(getUrl(cohereRerankServer)), testTaskType);
205206
var configs = (List<Map<String, Object>>) get(testTaskType, oldClusterId).get(old_cluster_endpoint_identifier);
206207
assertThat(configs, hasSize(1));
@@ -229,6 +230,7 @@ public void testRerank() throws IOException {
229230
assertRerank(oldClusterId);
230231

231232
// New endpoint
233+
cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse()));
232234
put(upgradedClusterId, rerankConfig(getUrl(cohereRerankServer)), testTaskType);
233235
configs = (List<Map<String, Object>>) get(upgradedClusterId).get("endpoints");
234236
assertThat(configs, hasSize(1));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 23 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import org.elasticsearch.inference.ChunkingOptions;
1919
import org.elasticsearch.inference.ChunkingSettings;
2020
import org.elasticsearch.inference.EmptySettingsConfiguration;
21-
import org.elasticsearch.inference.InferenceService;
2221
import org.elasticsearch.inference.InferenceServiceConfiguration;
2322
import org.elasticsearch.inference.InferenceServiceResults;
2423
import org.elasticsearch.inference.InputType;
@@ -51,6 +50,7 @@
5150
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
5251
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
5352
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
53+
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
5454

5555
import java.util.EnumSet;
5656
import java.util.HashMap;
@@ -60,7 +60,6 @@
6060

6161
import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
6262
import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
63-
import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.DEFAULT_TIMEOUT;
6463
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
6564
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
6665
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
@@ -332,68 +331,39 @@ private EmbeddingRequestChunker.EmbeddingType getEmbeddingTypeFromTaskType(TaskT
332331
*/
333332
@Override
334333
public void checkModelConfig(Model model, ActionListener<Model> listener) {
334+
// TODO: Remove this function once all services have been updated to use the new model validators
335+
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
336+
}
337+
338+
@Override
339+
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
335340
if (model instanceof AlibabaCloudSearchEmbeddingsModel embeddingsModel) {
336-
ServiceUtils.getEmbeddingSize(
337-
model,
338-
this,
339-
listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
341+
var serviceSettings = embeddingsModel.getServiceSettings();
342+
343+
var updatedServiceSettings = new AlibabaCloudSearchEmbeddingsServiceSettings(
344+
new AlibabaCloudSearchServiceSettings(
345+
serviceSettings.getCommonSettings().modelId(),
346+
serviceSettings.getCommonSettings().getHost(),
347+
serviceSettings.getCommonSettings().getWorkspaceName(),
348+
serviceSettings.getCommonSettings().getHttpSchema(),
349+
serviceSettings.getCommonSettings().rateLimitSettings()
350+
),
351+
SimilarityMeasure.DOT_PRODUCT,
352+
embeddingSize,
353+
serviceSettings.getMaxInputTokens()
340354
);
355+
356+
return new AlibabaCloudSearchEmbeddingsModel(embeddingsModel, updatedServiceSettings);
341357
} else {
342-
checkAlibabaCloudSearchServiceConfig(model, this, listener);
358+
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
343359
}
344360
}
345361

346-
private AlibabaCloudSearchEmbeddingsModel updateModelWithEmbeddingDetails(AlibabaCloudSearchEmbeddingsModel model, int embeddingSize) {
347-
AlibabaCloudSearchEmbeddingsServiceSettings serviceSettings = new AlibabaCloudSearchEmbeddingsServiceSettings(
348-
new AlibabaCloudSearchServiceSettings(
349-
model.getServiceSettings().getCommonSettings().modelId(),
350-
model.getServiceSettings().getCommonSettings().getHost(),
351-
model.getServiceSettings().getCommonSettings().getWorkspaceName(),
352-
model.getServiceSettings().getCommonSettings().getHttpSchema(),
353-
model.getServiceSettings().getCommonSettings().rateLimitSettings()
354-
),
355-
SimilarityMeasure.DOT_PRODUCT,
356-
embeddingSize,
357-
model.getServiceSettings().getMaxInputTokens()
358-
);
359-
360-
return new AlibabaCloudSearchEmbeddingsModel(model, serviceSettings);
361-
}
362-
363362
@Override
364363
public TransportVersion getMinimalSupportedVersion() {
365364
return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED;
366365
}
367366

368-
/**
369-
* For other models except of text embedding
370-
* check the model's service settings and task settings
371-
*
372-
* @param model The new model
373-
* @param service The inferenceService
374-
* @param listener The listener
375-
*/
376-
private void checkAlibabaCloudSearchServiceConfig(Model model, InferenceService service, ActionListener<Model> listener) {
377-
String input = ALIBABA_CLOUD_SEARCH_SERVICE_CONFIG_INPUT;
378-
String query = model.getTaskType().equals(TaskType.RERANK) ? ALIBABA_CLOUD_SEARCH_SERVICE_CONFIG_QUERY : null;
379-
380-
service.infer(
381-
model,
382-
query,
383-
List.of(input),
384-
false,
385-
Map.of(),
386-
InputType.INGEST,
387-
DEFAULT_TIMEOUT,
388-
listener.delegateFailureAndWrap((delegate, r) -> {
389-
listener.onResponse(model);
390-
})
391-
);
392-
}
393-
394-
private static final String ALIBABA_CLOUD_SEARCH_SERVICE_CONFIG_INPUT = "input";
395-
private static final String ALIBABA_CLOUD_SEARCH_SERVICE_CONFIG_QUERY = "query";
396-
397367
public static class Configuration {
398368
public static InferenceServiceConfiguration get() {
399369
return configuration.getOrCompute();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel;
5050
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
5151
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
52+
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
5253

5354
import java.io.IOException;
5455
import java.util.EnumSet;
@@ -303,49 +304,34 @@ public Set<TaskType> supportedStreamingTasks() {
303304
*/
304305
@Override
305306
public void checkModelConfig(Model model, ActionListener<Model> listener) {
306-
if (model instanceof AmazonBedrockEmbeddingsModel embeddingsModel) {
307-
ServiceUtils.getEmbeddingSize(
308-
model,
309-
this,
310-
listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
311-
);
312-
} else {
313-
listener.onResponse(model);
314-
}
307+
// TODO: Remove this function once all services have been updated to use the new model validators
308+
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
315309
}
316310

317-
private AmazonBedrockEmbeddingsModel updateModelWithEmbeddingDetails(AmazonBedrockEmbeddingsModel model, int embeddingSize) {
318-
AmazonBedrockEmbeddingsServiceSettings serviceSettings = model.getServiceSettings();
319-
if (serviceSettings.dimensionsSetByUser()
320-
&& serviceSettings.dimensions() != null
321-
&& serviceSettings.dimensions() != embeddingSize) {
322-
throw new ElasticsearchStatusException(
323-
Strings.format(
324-
"The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. "
325-
+ "Please recreate the [%s] configuration with the correct dimensions",
326-
embeddingSize,
327-
serviceSettings.dimensions(),
328-
model.getConfigurations().getInferenceEntityId()
329-
),
330-
RestStatus.BAD_REQUEST
311+
@Override
312+
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
313+
if (model instanceof AmazonBedrockEmbeddingsModel embeddingsModel) {
314+
var serviceSettings = embeddingsModel.getServiceSettings();
315+
var similarityFromModel = serviceSettings.similarity();
316+
var similarityToUse = similarityFromModel == null
317+
? getProviderDefaultSimilarityMeasure(embeddingsModel.provider())
318+
: similarityFromModel;
319+
320+
var updatedServiceSettings = new AmazonBedrockEmbeddingsServiceSettings(
321+
serviceSettings.region(),
322+
serviceSettings.modelId(),
323+
serviceSettings.provider(),
324+
embeddingSize,
325+
serviceSettings.dimensionsSetByUser(),
326+
serviceSettings.maxInputTokens(),
327+
similarityToUse,
328+
serviceSettings.rateLimitSettings()
331329
);
332-
}
333-
334-
var similarityFromModel = serviceSettings.similarity();
335-
var similarityToUse = similarityFromModel == null ? getProviderDefaultSimilarityMeasure(model.provider()) : similarityFromModel;
336-
337-
AmazonBedrockEmbeddingsServiceSettings settingsToUse = new AmazonBedrockEmbeddingsServiceSettings(
338-
serviceSettings.region(),
339-
serviceSettings.modelId(),
340-
serviceSettings.provider(),
341-
embeddingSize,
342-
serviceSettings.dimensionsSetByUser(),
343-
serviceSettings.maxInputTokens(),
344-
similarityToUse,
345-
serviceSettings.rateLimitSettings()
346-
);
347330

348-
return new AmazonBedrockEmbeddingsModel(model, settingsToUse);
331+
return new AmazonBedrockEmbeddingsModel(embeddingsModel, updatedServiceSettings);
332+
} else {
333+
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
334+
}
349335
}
350336

351337
private static void checkProviderForTask(TaskType taskType, AmazonBedrockProvider provider) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel;
4040
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
4141
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
42+
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
4243

4344
import java.util.EnumSet;
4445
import java.util.HashMap;
@@ -176,6 +177,12 @@ public AnthropicModel parsePersistedConfig(String inferenceEntityId, TaskType ta
176177
);
177178
}
178179

180+
@Override
181+
public void checkModelConfig(Model model, ActionListener<Model> listener) {
182+
// TODO: Remove this function once all services have been updated to use the new model validators
183+
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
184+
}
185+
179186
@Override
180187
public InferenceServiceConfiguration getConfiguration() {
181188
return Configuration.get();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java

Lines changed: 23 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14-
import org.elasticsearch.common.Strings;
1514
import org.elasticsearch.common.util.LazyInitializable;
1615
import org.elasticsearch.core.Nullable;
1716
import org.elasticsearch.core.TimeValue;
@@ -46,6 +45,7 @@
4645
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
4746
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings;
4847
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
48+
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
4949

5050
import java.util.EnumSet;
5151
import java.util.HashMap;
@@ -294,48 +294,32 @@ protected void doChunkedInfer(
294294
*/
295295
@Override
296296
public void checkModelConfig(Model model, ActionListener<Model> listener) {
297-
if (model instanceof AzureOpenAiEmbeddingsModel embeddingsModel) {
298-
ServiceUtils.getEmbeddingSize(
299-
model,
300-
this,
301-
listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
302-
);
303-
} else {
304-
listener.onResponse(model);
305-
}
297+
// TODO: Remove this function once all services have been updated to use the new model validators
298+
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
306299
}
307300

308-
private AzureOpenAiEmbeddingsModel updateModelWithEmbeddingDetails(AzureOpenAiEmbeddingsModel model, int embeddingSize) {
309-
if (model.getServiceSettings().dimensionsSetByUser()
310-
&& model.getServiceSettings().dimensions() != null
311-
&& model.getServiceSettings().dimensions() != embeddingSize) {
312-
throw new ElasticsearchStatusException(
313-
Strings.format(
314-
"The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. "
315-
+ "Please recreate the [%s] configuration with the correct dimensions",
316-
embeddingSize,
317-
model.getServiceSettings().dimensions(),
318-
model.getConfigurations().getInferenceEntityId()
319-
),
320-
RestStatus.BAD_REQUEST
301+
@Override
302+
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
303+
if (model instanceof AzureOpenAiEmbeddingsModel embeddingsModel) {
304+
var serviceSettings = embeddingsModel.getServiceSettings();
305+
var similarityFromModel = serviceSettings.similarity();
306+
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
307+
308+
var updatedServiceSettings = new AzureOpenAiEmbeddingsServiceSettings(
309+
serviceSettings.resourceName(),
310+
serviceSettings.deploymentId(),
311+
serviceSettings.apiVersion(),
312+
embeddingSize,
313+
serviceSettings.dimensionsSetByUser(),
314+
serviceSettings.maxInputTokens(),
315+
similarityToUse,
316+
serviceSettings.rateLimitSettings()
321317
);
322-
}
323-
324-
var similarityFromModel = model.getServiceSettings().similarity();
325-
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
326-
327-
AzureOpenAiEmbeddingsServiceSettings serviceSettings = new AzureOpenAiEmbeddingsServiceSettings(
328-
model.getServiceSettings().resourceName(),
329-
model.getServiceSettings().deploymentId(),
330-
model.getServiceSettings().apiVersion(),
331-
embeddingSize,
332-
model.getServiceSettings().dimensionsSetByUser(),
333-
model.getServiceSettings().maxInputTokens(),
334-
similarityToUse,
335-
model.getServiceSettings().rateLimitSettings()
336-
);
337318

338-
return new AzureOpenAiEmbeddingsModel(model, serviceSettings);
319+
return new AzureOpenAiEmbeddingsModel(embeddingsModel, updatedServiceSettings);
320+
} else {
321+
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
322+
}
339323
}
340324

341325
@Override

0 commit comments

Comments
 (0)