-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ML] Refactor inference API service tests base classes #135461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
028d445
c3243fd
399a50b
060eed9
ac44cc6
0af841d
bbd0707
9392ec0
cc99292
0896e5c
49c1301
3110967
6ec99f9
d52d4dc
bbd1f3e
dd9dde0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1067,6 +1067,10 @@ public interface EnumConstructor<E extends Enum<E>> { | |
| E apply(String name) throws IllegalArgumentException; | ||
| } | ||
|
|
||
| /** | ||
| * @deprecated use {@link #parsePersistedConfigErrorMsg(String, String, TaskType)} instead | ||
| */ | ||
| @Deprecated | ||
| public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName) { | ||
| return format( | ||
| "Failed to parse stored model [%s] for [%s] service, please delete and add the service again", | ||
|
|
@@ -1075,6 +1079,15 @@ public static String parsePersistedConfigErrorMsg(String inferenceEntityId, Stri | |
| ); | ||
| } | ||
|
|
||
| public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName, TaskType taskType) { | ||
| return format( | ||
| "Failed to parse stored model [%s] for [%s] service, error: [%s]. Please delete and add the service again", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would deleting and adding the service again actually help if the task type was unsupported? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I think deleting is probably the only solution here. I think this would only occur if the the inference endpoint got corrupted somehow. Basically this is saying that the persisted inference endpoint is stating it is leverage a particular task type that is not supported. The request context parsing should prevent getting into that scenario. But if we had a regression or the endpoint was corrupted somehow we could. |
||
| inferenceEntityId, | ||
| serviceName, | ||
| TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName) | ||
| ); | ||
| } | ||
|
|
||
| public static ElasticsearchStatusException createInvalidModelException(Model model) { | ||
| return new ElasticsearchStatusException( | ||
| format( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -213,7 +213,7 @@ public Ai21Model parsePersistedConfigWithSecrets( | |
| taskType, | ||
| serviceSettingsMap, | ||
| secretSettingsMap, | ||
| parsePersistedConfigErrorMsg(modelId, NAME) | ||
| parsePersistedConfigErrorMsg(modelId, NAME, taskType) | ||
|
||
| ); | ||
| } | ||
|
|
||
|
|
@@ -222,7 +222,13 @@ public Ai21Model parsePersistedConfig(String modelId, TaskType taskType, Map<Str | |
| Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); | ||
| removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); | ||
|
|
||
| return createModelFromPersistent(modelId, taskType, serviceSettingsMap, null, parsePersistedConfigErrorMsg(modelId, NAME)); | ||
| return createModelFromPersistent( | ||
| modelId, | ||
| taskType, | ||
| serviceSettingsMap, | ||
| null, | ||
| parsePersistedConfigErrorMsg(modelId, NAME, taskType) | ||
| ); | ||
| } | ||
|
|
||
| @Override | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -105,7 +105,10 @@ public void parseRequestConfig( | |
| Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); | ||
| Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); | ||
|
|
||
| var chunkingSettings = extractChunkingSettings(config, taskType); | ||
| ChunkingSettings chunkingSettings = null; | ||
| if (TaskType.TEXT_EMBEDDING.equals(taskType)) { | ||
| chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); | ||
|
||
| } | ||
|
|
||
| CustomModel model = createModel( | ||
| inferenceEntityId, | ||
|
|
@@ -156,14 +159,6 @@ private static RequestParameters createParameters(CustomModel model) { | |
| }; | ||
| } | ||
|
|
||
| private static ChunkingSettings extractChunkingSettings(Map<String, Object> config, TaskType taskType) { | ||
| if (TaskType.TEXT_EMBEDDING.equals(taskType)) { | ||
| return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); | ||
| } | ||
|
|
||
| return null; | ||
| } | ||
|
Comment on lines
-160
to
-166
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be better to keep using this method in the two places below, where the behaviour is unchanged, rather than duplicating the logic? |
||
|
|
||
| @Override | ||
| public InferenceServiceConfiguration getConfiguration() { | ||
| return Configuration.get(); | ||
|
|
@@ -229,7 +224,10 @@ public CustomModel parsePersistedConfigWithSecrets( | |
| Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); | ||
| Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); | ||
|
|
||
| var chunkingSettings = extractChunkingSettings(config, taskType); | ||
| ChunkingSettings chunkingSettings = null; | ||
| if (TaskType.TEXT_EMBEDDING.equals(taskType)) { | ||
| chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); | ||
| } | ||
|
|
||
| return createModelWithoutLoggingDeprecations( | ||
| inferenceEntityId, | ||
|
|
@@ -246,7 +244,10 @@ public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskT | |
| Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); | ||
| Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); | ||
|
|
||
| var chunkingSettings = extractChunkingSettings(config, taskType); | ||
| ChunkingSettings chunkingSettings = null; | ||
| if (TaskType.TEXT_EMBEDDING.equals(taskType)) { | ||
| chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); | ||
| } | ||
|
|
||
| return createModelWithoutLoggingDeprecations( | ||
| inferenceEntityId, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,7 +47,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject | |
| // The rate limit for usage tier 1 is 500 request per minute for most of the completion models | ||
| // To find this information you need to access your account's limits https://platform.openai.com/account/limits | ||
| // 500 requests per minute | ||
| private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500); | ||
| public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Making various things public so they can be accessible in tests that are outside of the package. |
||
|
|
||
| public static OpenAiChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) { | ||
| ValidationException validationException = new ValidationException(); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services; | ||
|
|
||
| import org.elasticsearch.common.settings.Settings; | ||
| import org.elasticsearch.core.Nullable; | ||
| import org.elasticsearch.inference.InferenceService; | ||
| import org.elasticsearch.inference.Model; | ||
| import org.elasticsearch.inference.RerankingInferenceService; | ||
| import org.elasticsearch.inference.SimilarityMeasure; | ||
| import org.elasticsearch.inference.TaskType; | ||
| import org.elasticsearch.test.http.MockWebServer; | ||
| import org.elasticsearch.threadpool.ThreadPool; | ||
| import org.elasticsearch.xpack.inference.external.http.HttpClientManager; | ||
| import org.elasticsearch.xpack.inference.logging.ThrottlerManager; | ||
| import org.junit.After; | ||
| import org.junit.Before; | ||
|
|
||
| import java.util.EnumSet; | ||
| import java.util.Map; | ||
| import java.util.Objects; | ||
|
|
||
| import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; | ||
| import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; | ||
| import static org.mockito.Mockito.mock; | ||
|
|
||
| public abstract class AbstractInferenceServiceBaseTests extends InferenceServiceTestCase{ | ||
| protected final TestConfiguration testConfiguration; | ||
|
|
||
| protected final MockWebServer webServer = new MockWebServer(); | ||
| protected ThreadPool threadPool; | ||
| protected HttpClientManager clientManager; | ||
| protected AbstractInferenceServiceParameterizedTests.TestCase testCase; | ||
|
|
||
| @Override | ||
| @Before | ||
| public void setUp() throws Exception { | ||
| super.setUp(); | ||
| webServer.start(); | ||
| threadPool = createThreadPool(inferenceUtilityExecutors()); | ||
| clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); | ||
| } | ||
|
|
||
| @Override | ||
| @After | ||
| public void tearDown() throws Exception { | ||
| super.tearDown(); | ||
| clientManager.close(); | ||
| terminate(threadPool); | ||
| webServer.close(); | ||
| } | ||
|
|
||
| public AbstractInferenceServiceBaseTests(TestConfiguration testConfiguration) { | ||
| this.testConfiguration = Objects.requireNonNull(testConfiguration); | ||
| } | ||
|
|
||
| /** | ||
| * Main configurations for the tests | ||
| */ | ||
| public record TestConfiguration(CommonConfig commonConfig, UpdateModelConfiguration updateModelConfiguration) { | ||
| public static class Builder { | ||
| private final CommonConfig commonConfig; | ||
| private UpdateModelConfiguration updateModelConfiguration = DISABLED_UPDATE_MODEL_TESTS; | ||
|
|
||
| public Builder(CommonConfig commonConfig) { | ||
| this.commonConfig = commonConfig; | ||
| } | ||
|
|
||
| public TestConfiguration.Builder enableUpdateModelTests(UpdateModelConfiguration updateModelConfiguration) { | ||
| this.updateModelConfiguration = updateModelConfiguration; | ||
| return this; | ||
| } | ||
|
|
||
| public TestConfiguration build() { | ||
| return new TestConfiguration(commonConfig, updateModelConfiguration); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Configurations that are useful for most tests | ||
| */ | ||
| public abstract static class CommonConfig { | ||
|
|
||
| private final TaskType targetTaskType; | ||
| private final TaskType unsupportedTaskType; | ||
| private final EnumSet<TaskType> supportedTaskTypes; | ||
|
|
||
| public CommonConfig(TaskType targetTaskType, @Nullable TaskType unsupportedTaskType, EnumSet<TaskType> supportedTaskTypes) { | ||
| this.targetTaskType = Objects.requireNonNull(targetTaskType); | ||
| this.unsupportedTaskType = unsupportedTaskType; | ||
| this.supportedTaskTypes = Objects.requireNonNull(supportedTaskTypes); | ||
| } | ||
|
|
||
| public TaskType targetTaskType() { | ||
| return targetTaskType; | ||
| } | ||
|
|
||
| public TaskType unsupportedTaskType() { | ||
| return unsupportedTaskType; | ||
| } | ||
|
|
||
| public EnumSet<TaskType> supportedTaskTypes() { | ||
| return supportedTaskTypes; | ||
| } | ||
|
|
||
| protected abstract SenderService createService(ThreadPool threadPool, HttpClientManager clientManager); | ||
|
|
||
| protected abstract Map<String, Object> createServiceSettingsMap(TaskType taskType); | ||
|
|
||
| protected Map<String, Object> createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) { | ||
| return createServiceSettingsMap(taskType); | ||
| } | ||
|
|
||
| protected abstract Map<String, Object> createTaskSettingsMap(); | ||
|
|
||
| protected abstract Map<String, Object> createSecretSettingsMap(); | ||
|
|
||
| protected abstract void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets); | ||
|
|
||
| protected void assertModel(Model model, TaskType taskType) { | ||
| assertModel(model, taskType, true); | ||
| } | ||
|
|
||
| protected abstract EnumSet<TaskType> supportedStreamingTasks(); | ||
|
|
||
| /** | ||
| * Override this method if the service support reranking. This method won't be called if the service doesn't support reranking. | ||
| */ | ||
| protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { | ||
| fail("Reranking services should override this test method to verify window size"); | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Configurations specific to the {@link SenderService#updateModelWithEmbeddingDetails(Model, int)} tests | ||
| */ | ||
| public abstract static class UpdateModelConfiguration { | ||
|
|
||
| public boolean isEnabled() { | ||
| return true; | ||
| } | ||
|
|
||
| protected abstract Model createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure); | ||
| } | ||
|
|
||
| private static final UpdateModelConfiguration DISABLED_UPDATE_MODEL_TESTS = new UpdateModelConfiguration() { | ||
| @Override | ||
| public boolean isEnabled() { | ||
| return false; | ||
| } | ||
|
|
||
| @Override | ||
| protected Model createEmbeddingModel(SimilarityMeasure similarityMeasure) { | ||
| throw new UnsupportedOperationException("Update model tests are disabled"); | ||
| } | ||
| }; | ||
|
|
||
| @Override | ||
| public InferenceService createInferenceService() { | ||
| return testConfiguration.commonConfig.createService(threadPool, clientManager); | ||
| } | ||
|
|
||
| @Override | ||
| protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { | ||
| testConfiguration.commonConfig.assertRerankerWindowSize(rerankingInferenceService); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a new version of the error message to make the error more clear. Ideally all of the services will switch to use this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to switch the services to use the new method in this PR? There are a dozen tests that would need to be updated to reflect the new message, but it would be nice to have consistency across all services.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep I can do that 👍