Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,10 @@ public interface EnumConstructor<E extends Enum<E>> {
E apply(String name) throws IllegalArgumentException;
}

/**
* @deprecated use {@link #parsePersistedConfigErrorMsg(String, String, TaskType)} instead
*/
@Deprecated
public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName) {
return format(
"Failed to parse stored model [%s] for [%s] service, please delete and add the service again",
Expand All @@ -1075,6 +1079,15 @@ public static String parsePersistedConfigErrorMsg(String inferenceEntityId, Stri
);
}

public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName, TaskType taskType) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a new version of the error message to make the error more clear. Ideally all of the services will switch to use this one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to switch the services to use the new method in this PR? There are a dozen tests that would need to be updated to reflect the new message, but it would be nice to have consistency across all services.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep I can do that 👍

return format(
"Failed to parse stored model [%s] for [%s] service, error: [%s]. Please delete and add the service again",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would deleting and adding the service again actually help if the task type was unsupported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think deleting is probably the only solution here. I think this would only occur if the the inference endpoint got corrupted somehow. Basically this is saying that the persisted inference endpoint is stating it is leverage a particular task type that is not supported. The request context parsing should prevent getting into that scenario. But if we had a regression or the endpoint was corrupted somehow we could.

inferenceEntityId,
serviceName,
TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName)
);
}

public static ElasticsearchStatusException createInvalidModelException(Model model) {
return new ElasticsearchStatusException(
format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ public Ai21Model parsePersistedConfigWithSecrets(
taskType,
serviceSettingsMap,
secretSettingsMap,
parsePersistedConfigErrorMsg(modelId, NAME)
parsePersistedConfigErrorMsg(modelId, NAME, taskType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reduce some code duplication and prevent us from constructing a String that we might never use every time we call parseRequestConfig(), I think it should be possible to move the parsePersistedConfigErrorMsg() call down into createModel() along with the unsupportedTaskTypeErrorMsg() call used for the REQUEST context and only call them (which one we call would depend on the context passed in to createModel()) when throwing an exception that would need the message.

I haven't checked every Service to see if this would work for all of them, but I assume that the structure is pretty similar between implementations.

);
}

Expand All @@ -222,7 +222,13 @@ public Ai21Model parsePersistedConfig(String modelId, TaskType taskType, Map<Str
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

return createModelFromPersistent(modelId, taskType, serviceSettingsMap, null, parsePersistedConfigErrorMsg(modelId, NAME));
return createModelFromPersistent(
modelId,
taskType,
serviceSettingsMap,
null,
parsePersistedConfigErrorMsg(modelId, NAME, taskType)
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ public void parseRequestConfig(
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

var chunkingSettings = extractChunkingSettings(config, taskType);
ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the bug fix. extractChunkingSettings was using removeFromMap which if the settings don't exist it would provide null to the ChunkingSettingsBuilder.fromMap(). We intentionally do this when parsing from persistent state to handle backwards compatibility I think.

I don't think the change here in parseRequestConfig will cause any backwards compatibility issues. For new endpoints being created we'll use the newer default chunking settings instead though.

}

CustomModel model = createModel(
inferenceEntityId,
Expand Down Expand Up @@ -156,14 +159,6 @@ private static RequestParameters createParameters(CustomModel model) {
};
}

private static ChunkingSettings extractChunkingSettings(Map<String, Object> config, TaskType taskType) {
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return null;
}
Comment on lines -160 to -166
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to keep using this method in the two places below, where the behaviour is unchanged, rather than duplicating the logic?


@Override
public InferenceServiceConfiguration getConfiguration() {
return Configuration.get();
Expand Down Expand Up @@ -229,7 +224,10 @@ public CustomModel parsePersistedConfigWithSecrets(
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);

var chunkingSettings = extractChunkingSettings(config, taskType);
ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModelWithoutLoggingDeprecations(
inferenceEntityId,
Expand All @@ -246,7 +244,10 @@ public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskT
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);

var chunkingSettings = extractChunkingSettings(config, taskType);
ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModelWithoutLoggingDeprecations(
inferenceEntityId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ public Model parsePersistedConfigWithSecrets(
serviceSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(modelId, NAME)
parsePersistedConfigErrorMsg(modelId, NAME, taskType)
);
}

Expand Down Expand Up @@ -357,7 +357,7 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String,
serviceSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(modelId, NAME)
parsePersistedConfigErrorMsg(modelId, NAME, taskType)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public OpenAiModel parsePersistedConfigWithSecrets(
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
);
}

Expand All @@ -255,7 +255,7 @@ public OpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskT
taskSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
// The rate limit for usage tier 1 is 500 request per minute for most of the completion models
// To find this information you need to access your account's limits https://platform.openai.com/account/limits
// 500 requests per minute
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500);
public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making various things public so they can be accessible in tests that are outside of the package.


public static OpenAiChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl

public static final String NAME = "openai_service_settings";

static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
public static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
// The rate limit for usage tier 1 is 3000 request per minute for the text embedding models
// To find this information you need to access your account's limits https://platform.openai.com/account/limits
// 3000 requests per minute
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);

public static OpenAiEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
return switch (context) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services;

import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.RerankingInferenceService;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.junit.After;
import org.junit.Before;

import java.util.EnumSet;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.mockito.Mockito.mock;

public abstract class AbstractInferenceServiceBaseTests extends InferenceServiceTestCase{
protected final TestConfiguration testConfiguration;

protected final MockWebServer webServer = new MockWebServer();
protected ThreadPool threadPool;
protected HttpClientManager clientManager;
protected AbstractInferenceServiceParameterizedTests.TestCase testCase;

@Override
@Before
public void setUp() throws Exception {
super.setUp();
webServer.start();
threadPool = createThreadPool(inferenceUtilityExecutors());
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
}

@Override
@After
public void tearDown() throws Exception {
super.tearDown();
clientManager.close();
terminate(threadPool);
webServer.close();
}

public AbstractInferenceServiceBaseTests(TestConfiguration testConfiguration) {
this.testConfiguration = Objects.requireNonNull(testConfiguration);
}

/**
* Main configurations for the tests
*/
public record TestConfiguration(CommonConfig commonConfig, UpdateModelConfiguration updateModelConfiguration) {
public static class Builder {
private final CommonConfig commonConfig;
private UpdateModelConfiguration updateModelConfiguration = DISABLED_UPDATE_MODEL_TESTS;

public Builder(CommonConfig commonConfig) {
this.commonConfig = commonConfig;
}

public TestConfiguration.Builder enableUpdateModelTests(UpdateModelConfiguration updateModelConfiguration) {
this.updateModelConfiguration = updateModelConfiguration;
return this;
}

public TestConfiguration build() {
return new TestConfiguration(commonConfig, updateModelConfiguration);
}
}
}

/**
* Configurations that are useful for most tests
*/
public abstract static class CommonConfig {

private final TaskType targetTaskType;
private final TaskType unsupportedTaskType;
private final EnumSet<TaskType> supportedTaskTypes;

public CommonConfig(TaskType targetTaskType, @Nullable TaskType unsupportedTaskType, EnumSet<TaskType> supportedTaskTypes) {
this.targetTaskType = Objects.requireNonNull(targetTaskType);
this.unsupportedTaskType = unsupportedTaskType;
this.supportedTaskTypes = Objects.requireNonNull(supportedTaskTypes);
}

public TaskType targetTaskType() {
return targetTaskType;
}

public TaskType unsupportedTaskType() {
return unsupportedTaskType;
}

public EnumSet<TaskType> supportedTaskTypes() {
return supportedTaskTypes;
}

protected abstract SenderService createService(ThreadPool threadPool, HttpClientManager clientManager);

protected abstract Map<String, Object> createServiceSettingsMap(TaskType taskType);

protected Map<String, Object> createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) {
return createServiceSettingsMap(taskType);
}

protected abstract Map<String, Object> createTaskSettingsMap();

protected abstract Map<String, Object> createSecretSettingsMap();

protected abstract void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets);

protected void assertModel(Model model, TaskType taskType) {
assertModel(model, taskType, true);
}

protected abstract EnumSet<TaskType> supportedStreamingTasks();

/**
* Override this method if the service support reranking. This method won't be called if the service doesn't support reranking.
*/
protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
fail("Reranking services should override this test method to verify window size");
}
}

/**
* Configurations specific to the {@link SenderService#updateModelWithEmbeddingDetails(Model, int)} tests
*/
public abstract static class UpdateModelConfiguration {

public boolean isEnabled() {
return true;
}

protected abstract Model createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure);
}

private static final UpdateModelConfiguration DISABLED_UPDATE_MODEL_TESTS = new UpdateModelConfiguration() {
@Override
public boolean isEnabled() {
return false;
}

@Override
protected Model createEmbeddingModel(SimilarityMeasure similarityMeasure) {
throw new UnsupportedOperationException("Update model tests are disabled");
}
};

@Override
public InferenceService createInferenceService() {
return testConfiguration.commonConfig.createService(threadPool, clientManager);
}

@Override
protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
testConfiguration.commonConfig.assertRerankerWindowSize(rerankingInferenceService);
}
}
Loading