From 14a53832ce15ccc3579507179dd034f394184f2d Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 8 May 2025 15:54:35 -0400 Subject: [PATCH 01/12] Inference changes --- x-pack/plugin/inference/build.gradle | 1 - .../inference/InferenceGetServicesIT.java | 15 +- ...stStreamingCompletionServiceExtension.java | 50 +- .../ShardBulkInferenceActionFilterIT.java | 30 +- .../InferenceNamedWriteablesProvider.java | 42 + .../xpack/inference/InferencePlugin.java | 11 +- .../TransportPutInferenceModelAction.java | 2 +- .../external/http/HttpClientManager.java | 18 +- .../external/http/StreamingHttpResult.java | 5 +- .../http/retry/BaseResponseHandler.java | 2 +- .../http/retry/RetryingHttpSender.java | 7 +- .../mapper/SemanticTextFieldMapper.java | 5 +- .../rest/RestPutInferenceModelAction.java | 10 +- .../inference/services/ServiceUtils.java | 259 ++++++ .../services/custom/CustomModel.java | 98 +++ .../CustomRateLimitServiceSettings.java | 14 + .../services/custom/CustomRequestManager.java | 101 +++ .../custom/CustomResponseHandler.java | 73 ++ .../services/custom/CustomSecretSettings.java | 113 +++ .../services/custom/CustomService.java | 279 +++++++ .../custom/CustomServiceSettings.java | 409 +++++++++- .../services/custom/CustomTaskSettings.java | 134 ++++ .../services/custom/QueryParameters.java | 104 +++ .../custom/request/CustomRequest.java | 165 ++++ .../response/CompletionResponseParser.java | 13 +- .../custom/response/CustomResponseEntity.java | 34 + .../custom/response/ErrorResponseParser.java | 39 +- .../custom/response/RerankResponseParser.java | 15 +- .../SparseEmbeddingResponseParser.java | 12 +- .../response/TextEmbeddingResponseParser.java | 13 +- .../ElasticInferenceServiceSettings.java | 18 - .../OpenAiUnifiedStreamingProcessor.java | 28 +- .../settings/SerializableSecureString.java | 62 ++ .../plugin-metadata/plugin-security.policy | 27 + .../elasticsearch/xpack/inference/Utils.java | 4 + .../action/PutInferenceModelRequestTests.java | 46 +- .../inference/common/JsonUtilsTests.java | 2 + .../common/MapPathExtractorTests.java | 69 ++ .../ErrorMessageResponseEntityTests.java | 6 +- .../xpack/inference/model/TestModel.java | 8 +- .../services/AbstractServiceTests.java | 538 +++++++++++++ .../inference/services/ServiceUtilsTests.java | 333 +++++++- .../services/custom/CustomModelTests.java | 132 ++++ .../custom/CustomRequestManagerTests.java | 88 +++ .../custom/CustomSecretSettingsTests.java | 141 ++++ .../custom/CustomServiceSettingsTests.java | 734 ++++++++++++++++++ .../services/custom/CustomServiceTests.java | 550 +++++++++++++ .../custom/CustomTaskSettingsTests.java | 155 ++++ .../services/custom/QueryParametersTests.java | 114 +++ .../custom/request/CustomRequestTests.java | 310 ++++++++ .../CompletionResponseParserTests.java | 11 +- .../response/CustomResponseEntityTests.java | 211 +++++ .../response/ErrorResponseParserTests.java | 37 +- .../response/RerankResponseParserTests.java | 11 +- .../SparseEmbeddingResponseParserTests.java | 7 +- .../TextEmbeddingResponseParserTests.java | 5 +- 56 files changed, 5547 insertions(+), 173 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRateLimitServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/QueryParameters.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/SerializableSecureString.java create mode 100644 x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractServiceTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/QueryParametersTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 58aa9b29f8565..fba8d9e61f0c4 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -8,7 +8,6 @@ apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' apply plugin: 'elasticsearch.internal-yaml-rest-test' -apply plugin: 'elasticsearch.internal-test-artifact' restResources { restApi { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 682eebd0fa69b..71a4d20d3ad4c 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { public void testGetServicesWithoutTaskType() throws IOException { List services = getAllServices(); - assertThat(services.size(), equalTo(22)); + assertThat(services.size(), equalTo(23)); var providers = providers(services); @@ -39,6 +39,7 @@ public void testGetServicesWithoutTaskType() throws IOException { "azureaistudio", "azureopenai", "cohere", + "custom", "deepseek", "elastic", "elasticsearch", @@ -70,7 +71,7 @@ private Iterable providers(List services) { public void testGetServicesWithTextEmbeddingTaskType() throws IOException { List services = getServices(TaskType.TEXT_EMBEDDING); - assertThat(services.size(), equalTo(16)); + assertThat(services.size(), equalTo(17)); var providers = providers(services); @@ -83,6 +84,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { "azureaistudio", "azureopenai", "cohere", + "custom", "elasticsearch", "googleaistudio", "googlevertexai", @@ -101,7 +103,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { public void testGetServicesWithRerankTaskType() throws IOException { List services = getServices(TaskType.RERANK); - assertThat(services.size(), equalTo(7)); + assertThat(services.size(), equalTo(8)); var providers = providers(services); @@ -111,6 +113,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { List.of( "alibabacloud-ai-search", "cohere", + "custom", "elasticsearch", "googlevertexai", "jinaai", @@ -123,7 +126,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { public void testGetServicesWithCompletionTaskType() throws IOException { List services = getServices(TaskType.COMPLETION); - assertThat(services.size(), equalTo(10)); + assertThat(services.size(), equalTo(11)); var providers = providers(services); @@ -137,6 +140,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException { "azureaistudio", "azureopenai", "cohere", + "custom", "deepseek", "googleaistudio", "openai", @@ -157,7 +161,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { List services = getServices(TaskType.SPARSE_EMBEDDING); - assertThat(services.size(), equalTo(6)); + assertThat(services.size(), equalTo(7)); var providers = providers(services); @@ -166,6 +170,7 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { containsInAnyOrder( List.of( "alibabacloud-ai-search", + "custom", "elastic", "elasticsearch", "hugging_face", diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index b2f8ba5475eb8..e34018c5b8df1 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -34,7 +34,6 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.DequeUtils; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; @@ -257,24 +256,37 @@ public void cancel() {} "object": "chat.completion.chunk" } */ - private StreamingUnifiedChatCompletionResults.Results unifiedCompletionChunk(String delta) { - return new StreamingUnifiedChatCompletionResults.Results( - DequeUtils.of( - new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( - "id", - List.of( - new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( - new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(delta, null, null, null), - null, - 0 - ) - ), - "gpt-4o-2024-08-06", - "chat.completion.chunk", - null - ) - ) - ); + private InferenceServiceResults.Result unifiedCompletionChunk(String delta) { + return new InferenceServiceResults.Result() { + @Override + public String getWriteableName() { + return "test_unifiedCompletionChunk"; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(delta); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return ChunkedToXContentHelper.chunk( + (b, p) -> b.startObject() + .field("id", "id") + .startArray("choices") + .startObject() + .startObject("delta") + .field("content", delta) + .endObject() + .field("index", 0) + .endObject() + .endArray() + .field("model", "gpt-4o-2024-08-06") + .field("object", "chat.completion.chunk") + .endObject() + ); + } + }; } @Override diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java index 8405fba22460f..074678bbea095 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java @@ -9,6 +9,7 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; @@ -16,7 +17,6 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; @@ -242,10 +242,12 @@ public void testRestart() throws Exception { private void assertRandomBulkOperations(String indexName, Function> sourceSupplier) throws Exception { int numHits = numHits(indexName); - int totalBulkReqs = randomIntBetween(2, 10); + int totalBulkReqs = randomIntBetween(2, 100); + long totalDocs = numHits; Set ids = new HashSet<>(); - for (int bulkReqs = 0; bulkReqs < totalBulkReqs; bulkReqs++) { - BulkRequestBuilder bulkReqBuilder = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + for (int bulkReqs = numHits; bulkReqs < totalBulkReqs; bulkReqs++) { + BulkRequestBuilder bulkReqBuilder = client().prepareBulk(); int totalBulkSize = randomIntBetween(1, 100); for (int bulkSize = 0; bulkSize < totalBulkSize; bulkSize++) { if (ids.size() > 0 && rarely(random())) { @@ -255,15 +257,24 @@ private void assertRandomBulkOperations(String indexName, Function source = sourceSupplier.apply(isIndexRequest); if (isIndexRequest) { - String id = randomAlphaOfLength(20); bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(indexName).setId(id).setSource(source)); ids.add(id); } else { - String id = randomFrom(ids); - bulkReqBuilder.add(new UpdateRequestBuilder(client()).setIndex(indexName).setId(id).setDoc(source)); + boolean isUpsert = randomBoolean(); + UpdateRequestBuilder request = new UpdateRequestBuilder(client()).setIndex(indexName).setDoc(source); + if (isUpsert || ids.size() == 0) { + request.setDocAsUpsert(true); + } else { + // Update already existing document + id = randomFrom(ids); + } + request.setId(id); + bulkReqBuilder.add(request); + ids.add(id); } } BulkResponse bulkResponse = bulkReqBuilder.get(); @@ -282,7 +293,8 @@ private void assertRandomBulkOperations(String indexName, Function getNamedWriteables() { addAlibabaCloudSearchNamedWriteables(namedWriteables); addJinaAINamedWriteables(namedWriteables); addVoyageAINamedWriteables(namedWriteables); + addCustomNamedWriteables(namedWriteables); addUnifiedNamedWriteables(namedWriteables); @@ -165,6 +175,38 @@ public static List getNamedWriteables() { return namedWriteables; } + private static void addCustomNamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry(ServiceSettings.class, CustomServiceSettings.NAME, CustomServiceSettings::new) + ); + + namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, CustomTaskSettings.NAME, CustomTaskSettings::new)); + + namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, CustomSecretSettings.NAME, CustomSecretSettings::new)); + + namedWriteables.add( + new NamedWriteableRegistry.Entry(CustomResponseParser.class, TextEmbeddingResponseParser.NAME, TextEmbeddingResponseParser::new) + ); + + namedWriteables.add( + new NamedWriteableRegistry.Entry( + CustomResponseParser.class, + SparseEmbeddingResponseParser.NAME, + SparseEmbeddingResponseParser::new + ) + ); + + namedWriteables.add( + new NamedWriteableRegistry.Entry(CustomResponseParser.class, RerankResponseParser.NAME, RerankResponseParser::new) + ); + + namedWriteables.add(new NamedWriteableRegistry.Entry(CustomResponseParser.class, NoopResponseParser.NAME, NoopResponseParser::new)); + + namedWriteables.add( + new NamedWriteableRegistry.Entry(CustomResponseParser.class, CompletionResponseParser.NAME, CompletionResponseParser::new) + ); + } + private static void addUnifiedNamedWriteables(List namedWriteables) { var writeables = UnifiedCompletionRequest.getNamedWriteables(); namedWriteables.addAll(writeables); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 87256494a60e0..a8d783eacbca0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -119,6 +119,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService; import org.elasticsearch.xpack.inference.services.cohere.CohereService; +import org.elasticsearch.xpack.inference.services.custom.CustomService; import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; @@ -276,17 +277,13 @@ public Collection createComponents(PluginServices services) { var inferenceServices = new ArrayList<>(inferenceServiceExtensions); inferenceServices.add(this::getInferenceServiceFactories); - var inferenceServiceSettings = new ElasticInferenceServiceSettings(settings); - inferenceServiceSettings.init(services.clusterService()); - // Create a separate instance of HTTPClientManager with its own SSL configuration (`xpack.inference.elastic.http.ssl.*`). var elasticInferenceServiceHttpClientManager = HttpClientManager.create( settings, services.threadPool(), services.clusterService(), throttlerManager, - getSslService(), - inferenceServiceSettings.getConnectionTtl() + getSslService() ); var elasticInferenceServiceRequestSenderFactory = new HttpRequestSender.Factory( @@ -296,6 +293,9 @@ public Collection createComponents(PluginServices services) { ); elasicInferenceServiceFactory.set(elasticInferenceServiceRequestSenderFactory); + var inferenceServiceSettings = new ElasticInferenceServiceSettings(settings); + inferenceServiceSettings.init(services.clusterService()); + var authorizationHandler = new ElasticInferenceServiceAuthorizationRequestHandler( inferenceServiceSettings.getElasticInferenceServiceUrl(), services.threadPool() @@ -396,6 +396,7 @@ public List getInferenceServiceFactories() { context -> new JinaAIService(httpFactory.get(), serviceComponents.get()), context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()), context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()), + context -> new CustomService(httpFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index bc9d87f43ada0..eeea8a28df486 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -177,7 +177,7 @@ protected void masterOperation( return; } - parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), listener); + parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener); } private void parseAndStoreModel( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java index ddf19ff0dc96f..6d09c9e67b363 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java @@ -32,7 +32,6 @@ import java.io.Closeable; import java.io.IOException; import java.util.List; -import java.util.concurrent.TimeUnit; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_SSL_CONFIGURATION_PREFIX; @@ -113,15 +112,14 @@ public static HttpClientManager create( ThreadPool threadPool, ClusterService clusterService, ThrottlerManager throttlerManager, - SSLService sslService, - TimeValue connectionTtl + SSLService sslService ) { // Set the sslStrategy to ensure an encrypted connection, as Elastic Inference Service requires it. SSLIOSessionStrategy sslioSessionStrategy = sslService.sslIOSessionStrategy( sslService.getSSLConfiguration(ELASTIC_INFERENCE_SERVICE_SSL_CONFIGURATION_PREFIX) ); - PoolingNHttpClientConnectionManager connectionManager = createConnectionManager(sslioSessionStrategy, connectionTtl); + PoolingNHttpClientConnectionManager connectionManager = createConnectionManager(sslioSessionStrategy); return new HttpClientManager(settings, connectionManager, threadPool, clusterService, throttlerManager); } @@ -148,7 +146,7 @@ public static HttpClientManager create( this.addSettingsUpdateConsumers(clusterService); } - private static PoolingNHttpClientConnectionManager createConnectionManager(SSLIOSessionStrategy sslStrategy, TimeValue connectionTtl) { + private static PoolingNHttpClientConnectionManager createConnectionManager(SSLIOSessionStrategy sslStrategy) { ConnectingIOReactor ioReactor; try { var configBuilder = IOReactorConfig.custom().setSoKeepAlive(true); @@ -164,15 +162,7 @@ private static PoolingNHttpClientConnectionManager createConnectionManager(SSLIO .register("https", sslStrategy) .build(); - return new PoolingNHttpClientConnectionManager( - ioReactor, - null, - registry, - null, - null, - Math.toIntExact(connectionTtl.getMillis()), - TimeUnit.MILLISECONDS - ); + return new PoolingNHttpClientConnectionManager(ioReactor, registry); } private static PoolingNHttpClientConnectionManager createConnectionManager() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java index 1786ee98fcd87..f384d79adae3e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java @@ -11,6 +11,7 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; import java.io.ByteArrayOutputStream; import java.util.concurrent.Flow; @@ -21,7 +22,7 @@ public boolean isSuccessfulResponse() { return RestStatus.isSuccessful(response.getStatusLine().getStatusCode()); } - public Flow.Publisher toHttpResult() { + public Flow.Publisher toHttpResult(HttpRequest httpRequest) { return subscriber -> body().subscribe(new Flow.Subscriber<>() { @Override public void onSubscribe(Flow.Subscription subscription) { @@ -45,7 +46,7 @@ public void onComplete() { }); } - public void readFullResponse(ActionListener fullResponse) { + public void readFullResponse(HttpRequest httpRequest, ActionListener fullResponse) { var stream = new ByteArrayOutputStream(); AtomicReference upstream = new AtomicReference<>(null); body.subscribe(new Flow.Subscriber<>() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index 3dac8d849ba6f..56e994be86eb4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -36,7 +36,7 @@ public abstract class BaseResponseHandler implements ResponseHandler { public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code"; protected final String requestType; - private final ResponseParser parseFunction; + protected final ResponseParser parseFunction; private final Function errorParseFunction; private final boolean canHandleStreamingResponses; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java index d009ec87d5776..e8cb5d3ad16d9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java @@ -115,11 +115,12 @@ public void tryAction(ActionListener listener) { try { if (request.isStreaming() && responseHandler.canHandleStreamingResponses()) { - httpClient.stream(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> { + var httpRequest = request.createHttpRequest(); + httpClient.stream(httpRequest, context, retryableListener.delegateFailure((l, r) -> { if (r.isSuccessfulResponse()) { - l.onResponse(responseHandler.parseResult(request, r.toHttpResult())); + l.onResponse(responseHandler.parseResult(request, r.toHttpResult(httpRequest))); } else { - r.readFullResponse(l.delegateFailureAndWrap((ll, httpResult) -> { + r.readFullResponse(httpRequest, l.delegateFailureAndWrap((ll, httpResult) -> { try { responseHandler.validateResponse(throttlerManager, logger, request, httpResult, true); InferenceServiceResults inferenceResults = responseHandler.parseResult(request, httpResult); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index d15414e34aef1..548f65d4f93fa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -32,7 +32,6 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.BlockLoader; @@ -99,7 +98,6 @@ import java.util.function.Supplier; import static org.elasticsearch.index.IndexVersions.SEMANTIC_TEXT_DEFAULTS_TO_BBQ; -import static org.elasticsearch.index.IndexVersions.SEMANTIC_TEXT_DEFAULTS_TO_BBQ_BACKPORT_8_X; import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; @@ -1079,8 +1077,7 @@ private static Mapper.Builder createEmbeddingsField( denseVectorMapperBuilder.elementType(modelSettings.elementType()); DenseVectorFieldMapper.IndexOptions defaultIndexOptions = null; - if (indexVersionCreated.onOrAfter(SEMANTIC_TEXT_DEFAULTS_TO_BBQ) - || indexVersionCreated.between(SEMANTIC_TEXT_DEFAULTS_TO_BBQ_BACKPORT_8_X, IndexVersions.UPGRADE_TO_LUCENE_10_0_0)) { + if (indexVersionCreated.onOrAfter(SEMANTIC_TEXT_DEFAULTS_TO_BBQ)) { defaultIndexOptions = defaultSemanticDenseIndexOptions(); } if (defaultIndexOptions != null diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java index 838e6512d805f..655e11996d522 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java @@ -20,7 +20,6 @@ import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; -import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout; import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID; import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH; import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH; @@ -50,15 +49,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient taskType = TaskType.ANY; // task type must be defined in the body } - var inferTimeout = parseTimeout(restRequest); var content = restRequest.requiredContent(); - var request = new PutInferenceModelAction.Request( - taskType, - inferenceEntityId, - content, - restRequest.getXContentType(), - inferTimeout - ); + var request = new PutInferenceModelAction.Request(taskType, inferenceEntityId, content, restRequest.getXContentType()); return channel -> client.execute( PutInferenceModelAction.INSTANCE, request, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index bdcadb2277c2b..428c266379f65 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -14,6 +14,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; @@ -21,9 +22,11 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; import java.net.URI; import java.net.URISyntaxException; +import java.util.ArrayList; import java.util.Arrays; import java.util.EnumSet; import java.util.HashMap; @@ -31,6 +34,7 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.function.Function; import java.util.stream.Collectors; import static org.elasticsearch.core.Strings.format; @@ -80,6 +84,11 @@ public static T removeAsType(Map sourceMap, String key, Clas */ @SuppressWarnings("unchecked") public static T removeAsType(Map sourceMap, String key, Class type, ValidationException validationException) { + if (sourceMap == null) { + validationException.addValidationError(Strings.format("Encountered a null input map while parsing field [%s]", key)); + return null; + } + Object o = sourceMap.remove(key); if (o == null) { return null; @@ -188,6 +197,12 @@ public static void throwIfNotEmptyMap(Map settingsMap, String se } } + public static void throwIfNotEmptyMap(Map settingsMap, String field, String scope) { + if (settingsMap != null && settingsMap.isEmpty() == false) { + throw ServiceUtils.unknownSettingsError(settingsMap, field, scope); + } + } + public static ElasticsearchStatusException unknownSettingsError(Map config, String serviceName) { // TODO map as JSON return new ElasticsearchStatusException( @@ -198,6 +213,16 @@ public static ElasticsearchStatusException unknownSettingsError(Map config, String field, String scope) { + return new ElasticsearchStatusException( + "Model configuration contains unknown settings [{}] while parsing field [{}] for settings [{}]", + RestStatus.BAD_REQUEST, + config, + field, + scope + ); + } + public static ElasticsearchStatusException invalidModelTypeForUpdateModelWithEmbeddingDetails(Class invalidModelType) { throw new ElasticsearchStatusException( Strings.format("Can't update embedding details for model with unexpected type %s", invalidModelType), @@ -249,6 +274,10 @@ public static String mustBeNonEmptyString(String settingName, String scope) { return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName); } + public static String mustBeNonEmptyMap(String settingName, String scope) { + return Strings.format("[%s] Invalid value empty map. [%s] must be a non-empty map", scope, settingName); + } + public static String invalidTimeValueMsg(String timeValueStr, String settingName, String scope, String exceptionMsg) { return Strings.format( "[%s] Invalid time value [%s]. [%s] must be a valid time value string: %s", @@ -422,6 +451,236 @@ public static Integer extractRequiredPositiveInteger( return field; } + @SuppressWarnings("unchecked") + public static Map extractRequiredMap( + Map map, + String settingName, + String scope, + ValidationException validationException + ) { + int initialValidationErrorCount = validationException.validationErrors().size(); + Map requiredField = ServiceUtils.removeAsType(map, settingName, Map.class, validationException); + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + return null; + } + + if (requiredField == null) { + validationException.addValidationError(ServiceUtils.missingSettingErrorMsg(settingName, scope)); + } else if (requiredField.isEmpty()) { + validationException.addValidationError(ServiceUtils.mustBeNonEmptyMap(settingName, scope)); + } + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + return null; + } + + return requiredField; + } + + @SuppressWarnings("unchecked") + public static Map extractOptionalMap( + Map map, + String settingName, + String scope, + ValidationException validationException + ) { + int initialValidationErrorCount = validationException.validationErrors().size(); + Map optionalField = ServiceUtils.removeAsType(map, settingName, Map.class, validationException); + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + return null; + } + + return optionalField; + } + + public static List> extractOptionalListOfStringTuples( + Map map, + String settingName, + String scope, + ValidationException validationException + ) { + int initialValidationErrorCount = validationException.validationErrors().size(); + List optionalField = ServiceUtils.removeAsType(map, settingName, List.class, validationException); + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + return null; + } + + if (optionalField == null) { + return null; + } + + var tuples = new ArrayList>(); + for (int tuplesIndex = 0; tuplesIndex < optionalField.size(); tuplesIndex++) { + + var tupleEntry = optionalField.get(tuplesIndex); + if (tupleEntry instanceof List == false) { + validationException.addValidationError( + Strings.format( + "[%s] failed to parse tuple list entry [%d] for setting [%s], expected a list but the entry is [%s]", + scope, + tuplesIndex, + settingName, + tupleEntry.getClass().getSimpleName() + ) + ); + throw validationException; + } + + var listEntry = (List) tupleEntry; + if (listEntry.size() != 2) { + validationException.addValidationError( + Strings.format( + "[%s] failed to parse tuple list entry [%d] for setting [%s], the tuple list size must be two, but was [%d]", + scope, + tuplesIndex, + settingName, + listEntry.size() + ) + ); + throw validationException; + } + + var firstElement = listEntry.get(0); + var secondElement = listEntry.get(1); + validateString(firstElement, settingName, scope, "the first element", tuplesIndex, validationException); + validateString(secondElement, settingName, scope, "the second element", tuplesIndex, validationException); + tuples.add(new Tuple<>((String) firstElement, (String) secondElement)); + } + + return tuples; + } + + private static void validateString( + Object tupleValue, + String settingName, + String scope, + String elementDescription, + int index, + ValidationException validationException + ) { + if (tupleValue instanceof String == false) { + validationException.addValidationError( + Strings.format( + "[%s] failed to parse tuple list entry [%d] for setting [%s], %s must be a string but was [%s]", + scope, + index, + settingName, + elementDescription, + tupleValue.getClass().getSimpleName() + ) + ); + throw validationException; + } + } + + /** + * Validates that each value in the map is a {@link String} and returns a new map of {@code Map}. + */ + public static Map validateMapStringValues( + Map map, + String settingName, + ValidationException validationException, + boolean censorValue + ) { + if (map == null) { + return Map.of(); + } + + validateMapValues(map, List.of(String.class), settingName, validationException, censorValue); + + return map.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> (String) e.getValue())); + } + + /** + * Ensures the values of the map match one of the supplied types. + * @param map Map to validate + * @param allowedTypes List of {@link Class} to accept + * @param settingName the setting name for the field + * @param validationException exception to return if one of the values is invalid + * @param censorValue if true the key and value will be included in the exception message + */ + public static void validateMapValues( + Map map, + List> allowedTypes, + String settingName, + ValidationException validationException, + boolean censorValue + ) { + if (map == null) { + return; + } + + for (var entry : map.entrySet()) { + var isAllowed = false; + + for (Class allowedType : allowedTypes) { + if (allowedType.isInstance(entry.getValue())) { + isAllowed = true; + break; + } + } + + Function errorMessage = (String[] validTypesAsStrings) -> { + if (censorValue) { + return Strings.format( + "Map field [%s] has an entry that is not valid. Value type is not one of [%s].", + settingName, + String.join(", ", validTypesAsStrings) + ); + } else { + return Strings.format( + "Map field [%s] has an entry that is not valid, [%s => %s]. Value type of [%s] is not one of [%s].", + settingName, + entry.getKey(), + entry.getValue(), + entry.getValue(), + String.join(", ", validTypesAsStrings) + ); + } + }; + + if (isAllowed == false) { + var validTypesAsStrings = allowedTypes.stream().map(Class::getSimpleName).toArray(String[]::new); + Arrays.sort(validTypesAsStrings); + + validationException.addValidationError(errorMessage.apply(validTypesAsStrings)); + throw validationException; + } + } + } + + public static Map convertMapStringsToSecureString( + Map map, + String settingName, + ValidationException validationException + ) { + if (map == null) { + return Map.of(); + } + + validateMapStringValues(map, settingName, validationException, true); + + return map.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> new SerializableSecureString((String) e.getValue()))); + } + + /** + * Removes null values. + */ + public static Map removeNullValues(Map map) { + if (map == null) { + return map; + } + + map.values().removeIf(Objects::isNull); + + return map; + } + public static Integer extractRequiredPositiveIntegerLessThanOrEqualToMax( Map map, String settingName, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java new file mode 100644 index 0000000000000..7c00b0a242f94 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; + +import java.util.Map; +import java.util.Objects; + +public class CustomModel extends Model { + private final CustomRateLimitServiceSettings rateLimitServiceSettings; + + public CustomModel(ModelConfigurations configurations, ModelSecrets secrets, CustomRateLimitServiceSettings rateLimitServiceSettings) { + super(configurations, secrets); + this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings); + } + + public static CustomModel of(CustomModel model, Map taskSettings) { + var requestTaskSettings = CustomTaskSettings.fromMap(taskSettings); + return new CustomModel(model, CustomTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + } + + public CustomModel( + String inferenceId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceId, + taskType, + service, + CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId), + CustomTaskSettings.fromMap(taskSettings), + CustomSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + CustomModel( + String inferenceId, + TaskType taskType, + String service, + CustomServiceSettings serviceSettings, + CustomTaskSettings taskSettings, + @Nullable CustomSecretSettings secretSettings + ) { + this( + new ModelConfigurations(inferenceId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + serviceSettings + ); + } + + protected CustomModel(CustomModel model, TaskSettings taskSettings) { + super(model, taskSettings); + rateLimitServiceSettings = model.rateLimitServiceSettings(); + } + + protected CustomModel(CustomModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + rateLimitServiceSettings = model.rateLimitServiceSettings(); + } + + @Override + public CustomServiceSettings getServiceSettings() { + return (CustomServiceSettings) super.getServiceSettings(); + } + + @Override + public CustomTaskSettings getTaskSettings() { + return (CustomTaskSettings) super.getTaskSettings(); + } + + @Override + public CustomSecretSettings getSecretSettings() { + return (CustomSecretSettings) super.getSecretSettings(); + } + + public CustomRateLimitServiceSettings rateLimitServiceSettings() { + return rateLimitServiceSettings; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRateLimitServiceSettings.java new file mode 100644 index 0000000000000..55641bad7ccaa --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRateLimitServiceSettings.java @@ -0,0 +1,14 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +public interface CustomRateLimitServiceSettings { + RateLimitSettings rateLimitSettings(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java new file mode 100644 index 0000000000000..a112e7db26fe3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.BaseRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest; +import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseEntity; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +public class CustomRequestManager extends BaseRequestManager { + private static final Logger logger = LogManager.getLogger(CustomRequestManager.class); + + record RateLimitGrouping(int apiKeyHash) { + public static RateLimitGrouping of(CustomModel model) { + Objects.requireNonNull(model); + + return new RateLimitGrouping(model.rateLimitServiceSettings().hashCode()); + } + } + + private static ResponseHandler createCustomHandler(CustomModel model) { + return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse, model.getServiceSettings().getErrorParser()); + } + + public static CustomRequestManager of(CustomModel model, ThreadPool threadPool) { + return new CustomRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final CustomModel model; + private final ResponseHandler handler; + + private CustomRequestManager(CustomModel model, ThreadPool threadPool) { + super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + this.model = model; + this.handler = createCustomHandler(model); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + String query; + List input; + if (inferenceInputs instanceof QueryAndDocsInputs) { + QueryAndDocsInputs queryAndDocsInputs = QueryAndDocsInputs.of(inferenceInputs); + query = queryAndDocsInputs.getQuery(); + input = queryAndDocsInputs.getChunks(); + } else if (inferenceInputs instanceof ChatCompletionInput chatInputs) { + query = null; + input = chatInputs.getInputs(); + } else if (inferenceInputs instanceof EmbeddingsInput) { + EmbeddingsInput embeddingsInput = EmbeddingsInput.of(inferenceInputs); + query = null; + input = embeddingsInput.getStringInputs(); + } else { + listener.onFailure( + new ElasticsearchStatusException( + Strings.format("Invalid input received from custom service %s", inferenceInputs.getClass().getSimpleName()), + RestStatus.BAD_REQUEST + ) + ); + return; + } + + try { + var request = new CustomRequest(query, input, model); + execute(new ExecutableInferenceRequest(requestSender, logger, request, handler, hasRequestCompletedFunction, listener)); + } catch (Exception e) { + // Intentionally not logging this exception because it could contain sensitive information from the CustomRequest construction + listener.onFailure( + new ElasticsearchStatusException("Failed to construct the custom service request", RestStatus.BAD_REQUEST, e) + ); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java new file mode 100644 index 0000000000000..14a962b112ccd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; + +/** + * Defines how to handle various response types returned from the custom integration. + */ +public class CustomResponseHandler extends BaseResponseHandler { + public CustomResponseHandler(String requestType, ResponseParser parseFunction, ErrorResponseParser errorParser) { + super(requestType, parseFunction, errorParser); + } + + @Override + public InferenceServiceResults parseResult(Request request, HttpResult result) throws RetryException { + try { + return parseFunction.apply(request, result); + } catch (Exception e) { + // if we get a parse failure it's probably an incorrect configuration of the service so report the error back to the user + // immediately without retrying + throw new RetryException( + false, + new ElasticsearchStatusException( + "Failed to parse custom model response, please check that the response parser path matches the response format.", + RestStatus.BAD_REQUEST, + e + ) + ); + } + } + + /** + * Validates the status code throws an RetryException if not in the range [200, 300). + * + * @param request The http request + * @param result The http response and body + * @throws RetryException Throws if status code is {@code >= 300 or < 200 } + */ + @Override + protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { + int statusCode = result.response().getStatusLine().getStatusCode(); + if (statusCode >= 200 && statusCode < 300) { + return; + } + + // handle error codes + if (statusCode >= 500) { + throw new RetryException(false, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 429) { + throw new RetryException(true, buildError(RATE_LIMIT, request, result)); + } else if (statusCode == 401) { + throw new RetryException(false, buildError(AUTHENTICATION, request, result)); + } else if (statusCode >= 300 && statusCode < 400) { + throw new RetryException(false, buildError(REDIRECTION, request, result)); + } else { + throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettings.java new file mode 100644 index 0000000000000..e74d56e1b4fd1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettings.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertMapStringsToSecureString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues; + +public class CustomSecretSettings implements SecretSettings { + public static final String NAME = "custom_secret_settings"; + public static final String SECRET_PARAMETERS = "secret_parameters"; + + public static CustomSecretSettings fromMap(@Nullable Map map) { + if (map == null) { + return null; + } + + ValidationException validationException = new ValidationException(); + + Map requestSecretParamsMap = extractOptionalMap(map, SECRET_PARAMETERS, NAME, validationException); + removeNullValues(requestSecretParamsMap); + var secureStringMap = convertMapStringsToSecureString(requestSecretParamsMap, SECRET_PARAMETERS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new CustomSecretSettings(secureStringMap); + } + + private final Map secretParameters; + + @Override + public SecretSettings newSecretSettings(Map newSecrets) { + return fromMap(new HashMap<>(newSecrets)); + } + + public CustomSecretSettings(@Nullable Map secretParameters) { + this.secretParameters = Objects.requireNonNullElse(secretParameters, Map.of()); + } + + public CustomSecretSettings(StreamInput in) throws IOException { + secretParameters = in.readImmutableMap(SerializableSecureString::new); + } + + public Map getSecretParameters() { + return secretParameters; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (secretParameters.isEmpty() == false) { + builder.startObject(SECRET_PARAMETERS); + { + for (var entry : secretParameters.entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + } + builder.endObject(); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ADD_INFERENCE_CUSTOM_MODEL; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(secretParameters, (streamOutput, v) -> { v.writeTo(streamOutput); }); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CustomSecretSettings that = (CustomSecretSettings) o; + return Objects.equals(secretParameters, that.secretParameters); + } + + @Override + public int hashCode() { + return Objects.hash(secretParameters); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java new file mode 100644 index 0000000000000..e30a0ab564026 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -0,0 +1,279 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.inference.TaskType.unsupportedTaskTypeErrorMsg; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; + +public class CustomService extends SenderService { + public static final String NAME = "custom"; + private static final String SERVICE_NAME = "Custom"; + + private static final EnumSet supportedTaskTypes = EnumSet.of( + TaskType.TEXT_EMBEDDING, + TaskType.SPARSE_EMBEDDING, + TaskType.RERANK, + TaskType.COMPLETION + ); + + public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + CustomModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + serviceSettingsMap, + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return supportedTaskTypes; + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + + private static CustomModel createModelWithoutLoggingDeprecations( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + secretSettings, + ConfigurationParseContext.PERSISTENT + ); + } + + private static CustomModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + ConfigurationParseContext context + ) { + if (supportedTaskTypes.contains(taskType) == false) { + throw new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); + } + return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context); + } + + @Override + public CustomModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap); + } + + @Override + public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + + return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null); + } + + @Override + public void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof CustomModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + CustomModel customModel = (CustomModel) model; + + var overriddenModel = CustomModel.of(customModel, taskSettings); + + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(SERVICE_NAME); + var manager = CustomRequestManager.of(overriddenModel, getServiceComponents().threadPool()); + var action = new SenderExecutableAction(getSender(), manager, failedToSendRequestErrorMessage); + + action.execute(inputs, timeout, listener); + } + + @Override + protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { + ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException); + } + + @Override + protected void doChunkedInfer( + Model model, + EmbeddingsInput inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME)); + } + + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof CustomModel customModel && customModel.getTaskType() == TaskType.TEXT_EMBEDDING) { + var newServiceSettings = getCustomServiceSettings(customModel, embeddingSize); + + return new CustomModel(customModel, newServiceSettings); + } else { + throw new ElasticsearchStatusException( + Strings.format( + "Can't update embedding details for model of type: [%s], task type: [%s]", + model.getClass().getSimpleName(), + model.getTaskType() + ), + RestStatus.BAD_REQUEST + ); + } + } + + private static CustomServiceSettings getCustomServiceSettings(CustomModel customModel, int embeddingSize) { + var serviceSettings = customModel.getServiceSettings(); + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; + + return new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings( + similarityToUse, + embeddingSize, + serviceSettings.getMaxInputTokens(), + serviceSettings.elementType() + ), + serviceSettings.getUrl(), + serviceSettings.getHeaders(), + serviceSettings.getQueryParameters(), + serviceSettings.getRequestContentString(), + serviceSettings.getResponseJsonParser(), + serviceSettings.rateLimitSettings(), + serviceSettings.getErrorParser() + ); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ADD_INFERENCE_CUSTOM_MODEL; + } + + public static class Configuration { + public static InferenceServiceConfiguration get() { + return configuration.getOrCompute(); + } + + private static final LazyInitializable configuration = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + // TODO revisit this + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(supportedTaskTypes) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java index d0f9faf283aef..d40a265b6ef19 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -7,7 +7,52 @@ package org.elasticsearch.xpack.inference.services.custom; -public class CustomServiceSettings { +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues; + +public class CustomServiceSettings extends FilteredXContentObject implements ServiceSettings, CustomRateLimitServiceSettings { public static final String NAME = "custom_service_settings"; public static final String URL = "url"; public static final String HEADERS = "headers"; @@ -16,4 +61,366 @@ public class CustomServiceSettings { public static final String RESPONSE = "response"; public static final String JSON_PARSER = "json_parser"; public static final String ERROR_PARSER = "error_parser"; + + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000); + private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE); + + public static CustomServiceSettings fromMap( + Map map, + ConfigurationParseContext context, + TaskType taskType, + String inferenceId + ) { + ValidationException validationException = new ValidationException(); + + var textEmbeddingSettings = TextEmbeddingSettings.fromMap(map, taskType, validationException); + + String url = extractRequiredString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + + var queryParams = QueryParameters.fromMap(map, validationException); + + Map headers = extractOptionalMap(map, HEADERS, ModelConfigurations.SERVICE_SETTINGS, validationException); + removeNullValues(headers); + var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false); + + Map requestBodyMap = extractRequiredMap(map, REQUEST, ModelConfigurations.SERVICE_SETTINGS, validationException); + + String requestContentString = extractRequiredString( + Objects.requireNonNullElse(requestBodyMap, new HashMap<>()), + REQUEST_CONTENT, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + Map responseParserMap = extractRequiredMap( + map, + RESPONSE, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + Map jsonParserMap = extractRequiredMap( + Objects.requireNonNullElse(responseParserMap, new HashMap<>()), + JSON_PARSER, + RESPONSE_SCOPE, + validationException + ); + + var responseJsonParser = extractResponseParser(taskType, jsonParserMap, validationException); + + Map errorParserMap = extractRequiredMap( + Objects.requireNonNullElse(responseParserMap, new HashMap<>()), + ERROR_PARSER, + RESPONSE_SCOPE, + validationException + ); + + var errorParser = ErrorResponseParser.fromMap(errorParserMap, RESPONSE_SCOPE, inferenceId, validationException); + + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + CustomService.NAME, + context + ); + + if (requestBodyMap == null || responseParserMap == null || jsonParserMap == null || errorParserMap == null) { + throw validationException; + } + + throwIfNotEmptyMap(requestBodyMap, REQUEST, NAME); + throwIfNotEmptyMap(jsonParserMap, JSON_PARSER, NAME); + throwIfNotEmptyMap(responseParserMap, RESPONSE, NAME); + throwIfNotEmptyMap(errorParserMap, ERROR_PARSER, NAME); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new CustomServiceSettings( + textEmbeddingSettings, + url, + stringHeaders, + queryParams, + requestContentString, + responseJsonParser, + rateLimitSettings, + errorParser + ); + } + + public record TextEmbeddingSettings( + @Nullable SimilarityMeasure similarityMeasure, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + @Nullable DenseVectorFieldMapper.ElementType elementType + ) implements ToXContentFragment, Writeable { + + // This specifies float for the element type but null for all other settings + public static final TextEmbeddingSettings DEFAULT_FLOAT = new TextEmbeddingSettings( + null, + null, + null, + DenseVectorFieldMapper.ElementType.FLOAT + ); + + // This refers to settings that are not related to the text embedding task type (all the settings should be null) + public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null, null); + + public static TextEmbeddingSettings fromMap(Map map, TaskType taskType, ValidationException validationException) { + if (taskType != TaskType.TEXT_EMBEDDING) { + return NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS; + } + + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + return new TextEmbeddingSettings(similarity, dims, maxInputTokens, DenseVectorFieldMapper.ElementType.FLOAT); + } + + public TextEmbeddingSettings(StreamInput in) throws IOException { + this( + in.readOptionalEnum(SimilarityMeasure.class), + in.readOptionalVInt(), + in.readOptionalVInt(), + in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalEnum(similarityMeasure); + out.writeOptionalVInt(dimensions); + out.writeOptionalVInt(maxInputTokens); + out.writeOptionalEnum(elementType); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (similarityMeasure != null) { + builder.field(SIMILARITY, similarityMeasure); + } + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + return builder; + } + } + + private final TextEmbeddingSettings textEmbeddingSettings; + private final String url; + private final Map headers; + private final QueryParameters queryParameters; + private final String requestContentString; + private final CustomResponseParser responseJsonParser; + private final RateLimitSettings rateLimitSettings; + private final ErrorResponseParser errorParser; + + public CustomServiceSettings( + TextEmbeddingSettings textEmbeddingSettings, + String url, + @Nullable Map headers, + @Nullable QueryParameters queryParameters, + String requestContentString, + CustomResponseParser responseJsonParser, + @Nullable RateLimitSettings rateLimitSettings, + ErrorResponseParser errorParser + ) { + this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings); + this.url = Objects.requireNonNull(url); + this.headers = Collections.unmodifiableMap(Objects.requireNonNullElse(headers, Map.of())); + this.queryParameters = Objects.requireNonNullElse(queryParameters, QueryParameters.EMPTY); + this.requestContentString = Objects.requireNonNull(requestContentString); + this.responseJsonParser = Objects.requireNonNull(responseJsonParser); + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + this.errorParser = Objects.requireNonNull(errorParser); + } + + public CustomServiceSettings(StreamInput in) throws IOException { + textEmbeddingSettings = new TextEmbeddingSettings(in); + url = in.readString(); + headers = in.readImmutableMap(StreamInput::readString); + queryParameters = new QueryParameters(in); + requestContentString = in.readString(); + responseJsonParser = in.readNamedWriteable(CustomResponseParser.class); + rateLimitSettings = new RateLimitSettings(in); + errorParser = new ErrorResponseParser(in); + } + + @Override + public String modelId() { + // returning null because the model id is embedded in the url or the request body + return null; + } + + @Override + public SimilarityMeasure similarity() { + return textEmbeddingSettings.similarityMeasure; + } + + @Override + public Integer dimensions() { + return textEmbeddingSettings.dimensions; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return textEmbeddingSettings.elementType; + } + + public Integer getMaxInputTokens() { + return textEmbeddingSettings.maxInputTokens; + } + + public String getUrl() { + return url; + } + + public Map getHeaders() { + return headers; + } + + public QueryParameters getQueryParameters() { + return queryParameters; + } + + public String getRequestContentString() { + return requestContentString; + } + + public CustomResponseParser getResponseJsonParser() { + return responseJsonParser; + } + + public ErrorResponseParser getErrorParser() { + return errorParser; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragment(builder, params); + + builder.endObject(); + return builder; + } + + public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException { + return toXContentFragmentOfExposedFields(builder, params); + } + + @Override + public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + textEmbeddingSettings.toXContent(builder, params); + builder.field(URL, url); + + if (headers.isEmpty() == false) { + builder.field(HEADERS, headers); + } + + queryParameters.toXContent(builder, params); + + builder.startObject(REQUEST); + { + builder.field(REQUEST_CONTENT, requestContentString); + } + builder.endObject(); + + builder.startObject(RESPONSE); + { + responseJsonParser.toXContent(builder, params); + errorParser.toXContent(builder, params); + } + builder.endObject(); + + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return this; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ADD_INFERENCE_CUSTOM_MODEL; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + textEmbeddingSettings.writeTo(out); + out.writeString(url); + out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString); + queryParameters.writeTo(out); + out.writeString(requestContentString); + out.writeNamedWriteable(responseJsonParser); + rateLimitSettings.writeTo(out); + errorParser.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CustomServiceSettings that = (CustomServiceSettings) o; + return Objects.equals(textEmbeddingSettings, that.textEmbeddingSettings) + && Objects.equals(url, that.url) + && Objects.equals(headers, that.headers) + && Objects.equals(queryParameters, that.queryParameters) + && Objects.equals(requestContentString, that.requestContentString) + && Objects.equals(responseJsonParser, that.responseJsonParser) + && Objects.equals(rateLimitSettings, that.rateLimitSettings) + && Objects.equals(errorParser, that.errorParser); + } + + @Override + public int hashCode() { + return Objects.hash( + textEmbeddingSettings, + url, + headers, + queryParameters, + requestContentString, + responseJsonParser, + rateLimitSettings, + errorParser + ); + } + + private static CustomResponseParser extractResponseParser( + TaskType taskType, + Map responseParserMap, + ValidationException validationException + ) { + if (responseParserMap == null) { + return NoopResponseParser.INSTANCE; + } + + return switch (taskType) { + case TEXT_EMBEDDING -> TextEmbeddingResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException); + case SPARSE_EMBEDDING -> SparseEmbeddingResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException); + case RERANK -> RerankResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException); + case COMPLETION -> CompletionResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException); + default -> throw new IllegalArgumentException( + Strings.format("Invalid task type received [%s] while constructing response parser", taskType) + ); + }; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettings.java new file mode 100644 index 0000000000000..1ca07ae0caf19 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettings.java @@ -0,0 +1,134 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapValues; + +public class CustomTaskSettings implements TaskSettings { + public static final String NAME = "custom_task_settings"; + + public static final String PARAMETERS = "parameters"; + + static final CustomTaskSettings EMPTY_SETTINGS = new CustomTaskSettings(new HashMap<>()); + + public static CustomTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + Map parameters = extractOptionalMap(map, PARAMETERS, ModelConfigurations.TASK_SETTINGS, validationException); + removeNullValues(parameters); + validateMapValues( + parameters, + List.of(String.class, Integer.class, Double.class, Float.class, Boolean.class), + PARAMETERS, + validationException, + false + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new CustomTaskSettings(Objects.requireNonNullElse(parameters, new HashMap<>())); + } + + /** + * Creates a new {@link CustomTaskSettings} + * by preferring non-null fields from the request settings over the original settings. + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @return a constructed {@link CustomTaskSettings} + */ + public static CustomTaskSettings of(CustomTaskSettings originalSettings, CustomTaskSettings requestTaskSettings) { + var copy = new HashMap<>(originalSettings.parameters); + requestTaskSettings.parameters.forEach((key, value) -> copy.merge(key, value, (originalValue, requestValue) -> requestValue)); + return new CustomTaskSettings(copy); + } + + private final Map parameters; + + public CustomTaskSettings(StreamInput in) throws IOException { + parameters = in.readGenericMap(); + } + + public CustomTaskSettings(Map parameters) { + this.parameters = Objects.requireNonNull(parameters); + } + + public Map getParameters() { + return parameters; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (parameters.isEmpty() == false) { + builder.field(PARAMETERS, parameters); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ADD_INFERENCE_CUSTOM_MODEL; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeGenericMap(parameters); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CustomTaskSettings that = (CustomTaskSettings) o; + return Objects.equals(parameters, that.parameters); + } + + @Override + public int hashCode() { + return Objects.hash(parameters); + } + + @Override + public boolean isEmpty() { + return parameters.isEmpty(); + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + CustomTaskSettings updatedSettings = CustomTaskSettings.fromMap(new HashMap<>(newSettings)); + return of(this, updatedSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/QueryParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/QueryParameters.java new file mode 100644 index 0000000000000..2b5bc2fe964b3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/QueryParameters.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfStringTuples; + +public record QueryParameters(List parameters) implements ToXContentFragment, Writeable { + + public static final QueryParameters EMPTY = new QueryParameters(List.of()); + public static final String QUERY_PARAMETERS = "query_parameters"; + + public static QueryParameters fromMap(Map map, ValidationException validationException) { + List> queryParams = extractOptionalListOfStringTuples( + map, + QUERY_PARAMETERS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return QueryParameters.fromTuples(queryParams); + } + + private static QueryParameters fromTuples(List> queryParams) { + if (queryParams == null) { + return QueryParameters.EMPTY; + } + + return new QueryParameters(queryParams.stream().map((tuple) -> new Parameter(tuple.v1(), tuple.v2())).toList()); + } + + public record Parameter(String key, String value) implements ToXContentFragment, Writeable { + public Parameter { + Objects.requireNonNull(key); + Objects.requireNonNull(value); + } + + public Parameter(StreamInput in) throws IOException { + this(in.readString(), in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(key); + out.writeString(value); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startArray(); + builder.value(key); + builder.value(value); + builder.endArray(); + return builder; + } + } + + public QueryParameters { + Objects.requireNonNull(parameters); + } + + public QueryParameters(StreamInput in) throws IOException { + this(in.readCollectionAsImmutableList(Parameter::new)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(parameters); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (parameters.isEmpty() == false) { + builder.startArray(QUERY_PARAMETERS); + for (var parameter : parameters) { + parameter.toXContent(builder, params); + } + builder.endArray(); + } + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java new file mode 100644 index 0000000000000..0a50b08163260 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java @@ -0,0 +1,165 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.StringEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.ValidatingSubstitutor; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.custom.CustomModel; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.common.JsonUtils.toJson; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.REQUEST_CONTENT; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.URL; + +public class CustomRequest implements Request { + private static final String QUERY = "query"; + private static final String INPUT = "input"; + + private final URI uri; + private final ValidatingSubstitutor jsonPlaceholderReplacer; + private final ValidatingSubstitutor stringPlaceholderReplacer; + private final CustomModel model; + + public CustomRequest(String query, List input, CustomModel model) { + this.model = Objects.requireNonNull(model); + + var stringOnlyParams = new HashMap(); + addStringParams(stringOnlyParams, model.getSecretSettings().getSecretParameters()); + addStringParams(stringOnlyParams, model.getTaskSettings().getParameters()); + + var jsonParams = new HashMap(); + addJsonStringParams(jsonParams, model.getSecretSettings().getSecretParameters()); + addJsonStringParams(jsonParams, model.getTaskSettings().getParameters()); + + if (query != null) { + jsonParams.put(QUERY, toJson(query, QUERY)); + } + + addInputJsonParam(jsonParams, input, model.getTaskType()); + + jsonPlaceholderReplacer = new ValidatingSubstitutor(jsonParams, "${", "}"); + stringPlaceholderReplacer = new ValidatingSubstitutor(stringOnlyParams, "${", "}"); + uri = buildUri(); + } + + private static void addStringParams(Map stringParams, Map paramsToAdd) { + for (var entry : paramsToAdd.entrySet()) { + if (entry.getValue() instanceof String str) { + stringParams.put(entry.getKey(), str); + } else if (entry.getValue() instanceof SerializableSecureString serializableSecureString) { + stringParams.put(entry.getKey(), serializableSecureString.getSecureString().toString()); + } else if (entry.getValue() instanceof SecureString secureString) { + stringParams.put(entry.getKey(), secureString.toString()); + } + } + } + + private static void addJsonStringParams(Map jsonStringParams, Map params) { + for (var entry : params.entrySet()) { + jsonStringParams.put(entry.getKey(), toJson(entry.getValue(), entry.getKey())); + } + } + + private static void addInputJsonParam(Map jsonParams, List input, TaskType taskType) { + if (taskType == TaskType.COMPLETION && input.isEmpty() == false) { + jsonParams.put(INPUT, toJson(input.get(0), INPUT)); + } else { + jsonParams.put(INPUT, toJson(input, INPUT)); + } + } + + private URI buildUri() { + var replacedUrl = stringPlaceholderReplacer.replace(model.getServiceSettings().getUrl(), URL); + + try { + var builder = new URIBuilder(replacedUrl); + for (var queryParam : model.getServiceSettings().getQueryParameters().parameters()) { + builder.addParameter( + queryParam.key(), + stringPlaceholderReplacer.replace(queryParam.value(), Strings.format("query parameters: [%s]", queryParam.key())) + ); + } + return builder.build(); + } catch (URISyntaxException e) { + throw new IllegalStateException(Strings.format("Failed to build URI, error: %s", e.getMessage()), e); + } + + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpRequest = new HttpPost(uri); + + setHeaders(httpRequest); + setRequestContent(httpRequest); + + return new HttpRequest(httpRequest, getInferenceEntityId()); + } + + private void setHeaders(HttpRequestBase httpRequest) { + // Header content_type's default value, if user defines the Content-Type, it will be replaced by user's value; + httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + + for (var entry : model.getServiceSettings().getHeaders().entrySet()) { + String replacedHeadersValue = stringPlaceholderReplacer.replace(entry.getValue(), Strings.format("header.%s", entry.getKey())); + httpRequest.setHeader(entry.getKey(), replacedHeadersValue); + } + } + + private void setRequestContent(HttpPost httpRequest) { + String replacedRequestContentString = jsonPlaceholderReplacer.replace( + model.getServiceSettings().getRequestContentString(), + REQUEST_CONTENT + ); + StringEntity stringEntity = new StringEntity(replacedRequestContentString, StandardCharsets.UTF_8); + httpRequest.setEntity(stringEntity); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public URI getURI() { + return uri; + } + + public CustomServiceSettings getServiceSettings() { + return model.getServiceSettings(); + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java index 762556fb381ed..ecd3125e228c9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java @@ -30,8 +30,17 @@ public class CompletionResponseParser extends BaseCustomResponseParser responseParserMap, ValidationException validationException) { - var path = extractRequiredString(responseParserMap, COMPLETION_PARSER_RESULT, JSON_PARSER, validationException); + public static CompletionResponseParser fromMap( + Map responseParserMap, + String scope, + ValidationException validationException + ) { + var path = extractRequiredString( + responseParserMap, + COMPLETION_PARSER_RESULT, + String.join(".", scope, JSON_PARSER), + validationException + ); if (validationException.validationErrors().isEmpty() == false) { throw validationException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntity.java new file mode 100644 index 0000000000000..b52670b99e6a9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntity.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest; + +import java.io.IOException; + +public class CustomResponseEntity { + public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException { + if (request instanceof CustomRequest customRequest) { + var responseJsonParser = customRequest.getServiceSettings().getResponseJsonParser(); + + return responseJsonParser.parse(response); + } else { + throw new IllegalArgumentException( + Strings.format( + "Original request is an invalid type [%s], expected [%s]", + request.getClass().getSimpleName(), + CustomRequest.class.getSimpleName() + ) + ); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java index d05fa68595b3a..51fb8b1486a82 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java @@ -7,6 +7,9 @@ package org.elasticsearch.xpack.inference.services.custom.response; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -21,6 +24,7 @@ import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.Objects; import java.util.function.Function; @@ -31,30 +35,40 @@ public class ErrorResponseParser implements ToXContentFragment, Function { + private static final Logger logger = LogManager.getLogger(ErrorResponseParser.class); public static final String MESSAGE_PATH = "path"; private final String messagePath; + private final String inferenceId; - public static ErrorResponseParser fromMap(Map responseParserMap, ValidationException validationException) { - var path = extractRequiredString(responseParserMap, MESSAGE_PATH, ERROR_PARSER, validationException); + public static ErrorResponseParser fromMap( + Map responseParserMap, + String scope, + String inferenceId, + ValidationException validationException + ) { + var path = extractRequiredString(responseParserMap, MESSAGE_PATH, String.join(".", scope, ERROR_PARSER), validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new ErrorResponseParser(path); + return new ErrorResponseParser(path, inferenceId); } - public ErrorResponseParser(String messagePath) { + public ErrorResponseParser(String messagePath, String inferenceId) { this.messagePath = Objects.requireNonNull(messagePath); + this.inferenceId = Objects.requireNonNull(inferenceId); } public ErrorResponseParser(StreamInput in) throws IOException { this.messagePath = in.readString(); + this.inferenceId = in.readString(); } public void writeTo(StreamOutput out) throws IOException { out.writeString(messagePath); + out.writeString(inferenceId); } public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { @@ -86,7 +100,6 @@ public ErrorResponse apply(HttpResult httpResult) { .createParser(XContentParserConfiguration.EMPTY, httpResult.body()) ) { var map = jsonParser.map(); - // NOTE: This deviates from what we've done in the past. In the ErrorMessageResponseEntity logic // if we find the top level error field we'll return a response with an empty message but indicate // that we found the structure of the error object. Here if we're missing the final field we will return @@ -97,9 +110,19 @@ public ErrorResponse apply(HttpResult httpResult) { var errorText = toType(MapPathExtractor.extract(map, messagePath).extractedObject(), String.class, messagePath); return new ErrorResponse(errorText); } catch (Exception e) { - // swallow the error + var resultAsString = new String(httpResult.body(), StandardCharsets.UTF_8); + + logger.info( + Strings.format( + "Failed to parse error object for custom service inference id [%s], message path: [%s], result as string: [%s]", + inferenceId, + messagePath, + resultAsString + ), + e + ); + + return new ErrorResponse(Strings.format("Unable to parse the error, response body: [%s]", resultAsString)); } - - return ErrorResponse.UNDEFINED_ERROR; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java index 18d3cbbad051b..0a4c2c42b8c79 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java @@ -37,11 +37,16 @@ public class RerankResponseParser extends BaseCustomResponseParser responseParserMap, ValidationException validationException) { - - var relevanceScore = extractRequiredString(responseParserMap, RERANK_PARSER_SCORE, JSON_PARSER, validationException); - var rerankIndex = extractOptionalString(responseParserMap, RERANK_PARSER_INDEX, JSON_PARSER, validationException); - var documentText = extractOptionalString(responseParserMap, RERANK_PARSER_DOCUMENT_TEXT, JSON_PARSER, validationException); + public static RerankResponseParser fromMap( + Map responseParserMap, + String scope, + ValidationException validationException + ) { + var fullScope = String.join(".", scope, JSON_PARSER); + + var relevanceScore = extractRequiredString(responseParserMap, RERANK_PARSER_SCORE, fullScope, validationException); + var rerankIndex = extractOptionalString(responseParserMap, RERANK_PARSER_INDEX, fullScope, validationException); + var documentText = extractOptionalString(responseParserMap, RERANK_PARSER_DOCUMENT_TEXT, fullScope, validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java index b6c83fd7fbfc6..7d54e90865122 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java @@ -35,10 +35,15 @@ public class SparseEmbeddingResponseParser extends BaseCustomResponseParser responseParserMap, ValidationException validationException) { - var tokenPath = extractRequiredString(responseParserMap, SPARSE_EMBEDDING_TOKEN_PATH, JSON_PARSER, validationException); + public static SparseEmbeddingResponseParser fromMap( + Map responseParserMap, + String scope, + ValidationException validationException + ) { + var fullScope = String.join(".", scope, JSON_PARSER); + var tokenPath = extractRequiredString(responseParserMap, SPARSE_EMBEDDING_TOKEN_PATH, fullScope, validationException); - var weightPath = extractRequiredString(responseParserMap, SPARSE_EMBEDDING_WEIGHT_PATH, JSON_PARSER, validationException); + var weightPath = extractRequiredString(responseParserMap, SPARSE_EMBEDDING_WEIGHT_PATH, fullScope, validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -149,6 +154,7 @@ private static SparseEmbeddingResults.Embedding createEmbedding( // Alibaba can return a token id which is an integer and needs to be converted to a string var tokenIdAsString = token.toString(); + try { var weightAsFloat = toFloat(weight, weightFieldName); weightedTokens.add(new WeightedToken(tokenIdAsString, weightAsFloat)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java index fe5b4ec236282..b5b0a191f3c4e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java @@ -30,8 +30,17 @@ public class TextEmbeddingResponseParser extends BaseCustomResponseParser responseParserMap, ValidationException validationException) { - var path = extractRequiredString(responseParserMap, TEXT_EMBEDDING_PARSER_EMBEDDINGS, JSON_PARSER, validationException); + public static TextEmbeddingResponseParser fromMap( + Map responseParserMap, + String scope, + ValidationException validationException + ) { + var path = extractRequiredString( + responseParserMap, + TEXT_EMBEDDING_PARSER_EMBEDDINGS, + String.join(".", scope, JSON_PARSER), + validationException + ); if (validationException.validationErrors().isEmpty() == false) { throw validationException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java index 0d8bef246b35d..fe6ebb6cfb625 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java @@ -70,17 +70,6 @@ public class ElasticInferenceServiceSettings { Setting.Property.NodeScope ); - /** - * Total time to live (TTL) defines maximum life span of persistent connections regardless of their - * expiration setting. No persistent connection will be re-used past its TTL value. - * Using a TTL of -1 will disable the expiration of persistent connections (the idle connection evictor will still apply). - */ - public static final Setting CONNECTION_TTL_SETTING = Setting.timeSetting( - "xpack.inference.elastic.http.connection_ttl", - TimeValue.timeValueSeconds(60), - Setting.Property.NodeScope - ); - @Deprecated private final String eisGatewayUrl; @@ -88,7 +77,6 @@ public class ElasticInferenceServiceSettings { private final boolean periodicAuthorizationEnabled; private volatile TimeValue authRequestInterval; private volatile TimeValue maxAuthorizationRequestJitter; - private final TimeValue connectionTtl; public ElasticInferenceServiceSettings(Settings settings) { eisGatewayUrl = EIS_GATEWAY_URL.get(settings); @@ -96,7 +84,6 @@ public ElasticInferenceServiceSettings(Settings settings) { periodicAuthorizationEnabled = PERIODIC_AUTHORIZATION_ENABLED.get(settings); authRequestInterval = AUTHORIZATION_REQUEST_INTERVAL.get(settings); maxAuthorizationRequestJitter = MAX_AUTHORIZATION_REQUEST_JITTER.get(settings); - connectionTtl = CONNECTION_TTL_SETTING.get(settings); } /** @@ -128,10 +115,6 @@ public TimeValue getMaxAuthorizationRequestJitter() { return maxAuthorizationRequestJitter; } - public TimeValue getConnectionTtl() { - return connectionTtl; - } - public static List> getSettingsDefinitions() { ArrayList> settings = new ArrayList<>(); settings.add(EIS_GATEWAY_URL); @@ -141,7 +124,6 @@ public static List> getSettingsDefinitions() { settings.add(PERIODIC_AUTHORIZATION_ENABLED); settings.add(AUTHORIZATION_REQUEST_INTERVAL); settings.add(MAX_AUTHORIZATION_REQUEST_JITTER); - settings.add(CONNECTION_TTL_SETTING); return settings; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java index 5ab743c3d4cc0..10c8d8928ea65 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java @@ -26,6 +26,7 @@ import java.util.Deque; import java.util.Iterator; import java.util.List; +import java.util.concurrent.LinkedBlockingDeque; import java.util.function.BiFunction; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; @@ -59,11 +60,21 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor< public static final String TOTAL_TOKENS_FIELD = "total_tokens"; private final BiFunction errorParser; + private final Deque buffer = new LinkedBlockingDeque<>(); public OpenAiUnifiedStreamingProcessor(BiFunction errorParser) { this.errorParser = errorParser; } + @Override + protected void upstreamRequest(long n) { + if (buffer.isEmpty()) { + super.upstreamRequest(n); + } else { + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll()))); + } + } + @Override protected void next(Deque item) throws Exception { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); @@ -85,8 +96,15 @@ protected void next(Deque item) throws Exception { if (results.isEmpty()) { upstream().request(1); - } else { + } else if (results.size() == 1) { downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results)); + } else { + // results > 1, but openai spec only wants 1 chunk per SSE event + var firstItem = singleItem(results.poll()); + while (results.isEmpty() == false) { + buffer.offer(results.poll()); + } + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem)); } } @@ -279,4 +297,12 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage pa } } } + + private Deque singleItem( + StreamingUnifiedChatCompletionResults.ChatCompletionChunk result + ) { + var deque = new ArrayDeque(1); + deque.offer(result); + return deque; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/SerializableSecureString.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/SerializableSecureString.java new file mode 100644 index 0000000000000..0ebd4cc0cad81 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/SerializableSecureString.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.settings; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class SerializableSecureString implements ToXContentFragment, Writeable { + + private final SecureString secureString; + + public SerializableSecureString(StreamInput in) throws IOException { + secureString = in.readSecureString(); + } + + public SerializableSecureString(SecureString secureString) { + this.secureString = Objects.requireNonNull(secureString); + } + + public SerializableSecureString(String value) { + secureString = new SecureString(value.toCharArray()); + } + + public SecureString getSecureString() { + return secureString; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.value(secureString.toString()); + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeSecureString(secureString); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + SerializableSecureString that = (SerializableSecureString) o; + return Objects.equals(secureString, that.secureString); + } + + @Override + public int hashCode() { + return Objects.hashCode(secureString); + } +} diff --git a/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..e36b553d2def2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +grant { + // required by: com.google.api.client.json.JsonParser#parseValue + // also required by AWS SDK for client configuration + permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.lang.RuntimePermission "getClassLoader"; + + // required by: com.google.api.client.json.GenericJson# + // also by AWS SDK for Jackson's ObjectMapper + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; + + // required to add google certs to the gcs client trustore + permission java.lang.RuntimePermission "setFactory"; + + // gcs client opens socket connections for to access repository + // also, AWS Bedrock client opens socket connections and needs resolve for to access to resources + permission java.net.SocketPermission "*", "connect,resolve"; + + // AWS Clients always try to check the http.proxyHost system property + permission java.util.PropertyPermission "http.proxyHost", "read"; +}; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index bd18058277d9c..b8648936956ae 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -235,4 +235,8 @@ public static void assertJsonEquals(String actual, String expected) throws IOExc assertThat(actualParser.mapOrdered(), equalTo(expectedParser.mapOrdered())); } } + + public static Map modifiableMap(Map aMap) { + return new HashMap<>(aMap); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java index e514867780669..f61398fcacacf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java @@ -7,16 +7,13 @@ package org.elasticsearch.xpack.inference.action; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; -import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; -public class PutInferenceModelRequestTests extends AbstractBWCWireSerializationTestCase { +public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase { @Override protected Writeable.Reader instanceReader() { return PutInferenceModelAction.Request::new; @@ -28,29 +25,38 @@ protected PutInferenceModelAction.Request createTestInstance() { randomFrom(TaskType.values()), randomAlphaOfLength(6), randomBytesReference(50), - randomFrom(XContentType.values()), - randomTimeValue() + randomFrom(XContentType.values()) ); } @Override protected PutInferenceModelAction.Request mutateInstance(PutInferenceModelAction.Request instance) { - return randomValueOtherThan(instance, this::createTestInstance); - } - - @Override - protected PutInferenceModelAction.Request mutateInstanceForVersion(PutInferenceModelAction.Request instance, TransportVersion version) { - if (version.onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT) - || version.isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) { - return instance; - } else { - return new PutInferenceModelAction.Request( + return switch (randomIntBetween(0, 3)) { + case 0 -> new PutInferenceModelAction.Request( + TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length], + instance.getInferenceEntityId(), + instance.getContent(), + instance.getContentType() + ); + case 1 -> new PutInferenceModelAction.Request( + instance.getTaskType(), + instance.getInferenceEntityId() + "foo", + instance.getContent(), + instance.getContentType() + ); + case 2 -> new PutInferenceModelAction.Request( + instance.getTaskType(), + instance.getInferenceEntityId(), + randomBytesReference(instance.getContent().length() + 1), + instance.getContentType() + ); + case 3 -> new PutInferenceModelAction.Request( instance.getTaskType(), instance.getInferenceEntityId(), instance.getContent(), - instance.getContentType(), - InferenceAction.Request.DEFAULT_TIMEOUT + XContentType.values()[(instance.getContentType().ordinal() + 1) % XContentType.values().length] ); - } + default -> throw new IllegalStateException(); + }; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/JsonUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/JsonUtilsTests.java index b49a819a3a698..7f1003c6723a0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/JsonUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/JsonUtilsTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; import java.io.IOException; import java.util.List; @@ -52,6 +53,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws assertThat(toJson(1.1f, "field"), is("1.1")); assertThat(toJson(true, "field"), is("true")); assertThat(toJson(false, "field"), is("false")); + assertThat(toJson(new SerializableSecureString("api_key"), "field"), is("\"api_key\"")); } public void testToJson_ThrowsException_WhenUnableToSerialize() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java index 047c0c8d647fb..a22bf12d29eb6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java @@ -86,6 +86,75 @@ public void testExtract_IteratesListOfMapsToListOfMapsOfStringToDoubles() { ); } + public void testExtract_IteratesSparseEmbeddingStyleMap_ExtractsMaps() { + Map input = Map.of( + "result", + Map.of( + "sparse_embeddings", + List.of( + Map.of( + "index", + 0, + "embedding", + List.of(Map.of("tokenId", 6, "weight", 0.123d), Map.of("tokenId", 100, "weight", -123d)) + ), + Map.of( + "index", + 1, + "embedding", + List.of(Map.of("tokenId", 7, "weight", 0.456d), Map.of("tokenId", 200, "weight", -456d)) + ) + ) + ) + ); + + assertThat( + MapPathExtractor.extract(input, "$.result.sparse_embeddings[*].embedding[*]"), + is( + new MapPathExtractor.Result( + List.of( + List.of(Map.of("tokenId", 6, "weight", 0.123d), Map.of("tokenId", 100, "weight", -123d)), + List.of(Map.of("tokenId", 7, "weight", 0.456d), Map.of("tokenId", 200, "weight", -456d)) + ), + List.of("result.sparse_embeddings", "result.sparse_embeddings.embedding") + ) + ) + ); + } + + public void testExtract_IteratesSparseEmbeddingStyleMap_ExtractsFieldFromMap() { + Map input = Map.of( + "result", + Map.of( + "sparse_embeddings", + List.of( + Map.of( + "index", + 0, + "embedding", + List.of(Map.of("tokenId", 6, "weight", 0.123d), Map.of("tokenId", 100, "weight", -123d)) + ), + Map.of( + "index", + 1, + "embedding", + List.of(Map.of("tokenId", 7, "weight", 0.456d), Map.of("tokenId", 200, "weight", -456d)) + ) + ) + ) + ); + + assertThat( + MapPathExtractor.extract(input, "$.result.sparse_embeddings[*].embedding[*].tokenId"), + is( + new MapPathExtractor.Result( + List.of(List.of(6, 100), List.of(7, 200)), + List.of("result.sparse_embeddings", "result.sparse_embeddings.embedding", "result.sparse_embeddings.embedding.tokenId") + ) + ) + ); + } + public void testExtract_ReturnsNullForEmptyList() { Map input = Map.of(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntityTests.java index 9ad9b9f3ca0a5..5024cf53dffa9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntityTests.java @@ -32,7 +32,7 @@ public void testErrorResponse_ExtractsError() { var error = ErrorMessageResponseEntity.fromResponse(result); assertNotNull(error); - assertThat(error.getErrorMessage(), is("test_error_message")); + assertThat(error, is(new ErrorMessageResponseEntity("test_error_message"))); } public void testFromResponse_WithOtherFieldsPresent() { @@ -50,7 +50,7 @@ public void testFromResponse_WithOtherFieldsPresent() { ErrorResponse errorMessage = ErrorMessageResponseEntity.fromResponse( new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertEquals("You didn't provide an API key", errorMessage.getErrorMessage()); + assertThat(errorMessage, is(new ErrorMessageResponseEntity("You didn't provide an API key"))); } public void testFromResponse_noMessage() { @@ -65,7 +65,7 @@ public void testFromResponse_noMessage() { var errorMessage = ErrorMessageResponseEntity.fromResponse( new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(errorMessage.getErrorMessage(), is("")); + assertThat(errorMessage, is(new ErrorMessageResponseEntity(""))); assertTrue(errorMessage.errorStructureFound()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index c3b50cdb4a670..a00f8e55a4e27 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -31,7 +31,6 @@ import java.util.List; import java.util.Map; -import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength; import static org.elasticsearch.test.ESTestCase.randomFrom; import static org.elasticsearch.test.ESTestCase.randomInt; @@ -47,14 +46,9 @@ public static TestModel createRandomInstance(TaskType taskType) { } public static TestModel createRandomInstance(TaskType taskType, List excludedSimilarities) { - // Use a max dimension count that has a reasonable probability of being compatible with BBQ - return createRandomInstance(taskType, excludedSimilarities, BBQ_MIN_DIMS * 2); - } - - public static TestModel createRandomInstance(TaskType taskType, List excludedSimilarities, int maxDimensions) { var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(DenseVectorFieldMapper.ElementType.values()) : null; var dimensions = taskType == TaskType.TEXT_EMBEDDING - ? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions) + ? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 64) : null; SimilarityMeasure similarity = null; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractServiceTests.java new file mode 100644 index 0000000000000..071c4caa90a9f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractServiceTests.java @@ -0,0 +1,538 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.custom.CustomModel; +import org.junit.After; +import org.junit.Assume; +import org.junit.Before; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; +import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; +import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; +import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +/** + * Base class for testing inference services. + *

+ * This class provides common unit tests for inference services, such as testing the model creation, and calling the infer method. + * + * To use this class, extend it and pass the constructor a configuration. + *

+ */ +public abstract class AbstractServiceTests extends ESTestCase { + + protected final MockWebServer webServer = new MockWebServer(); + protected ThreadPool threadPool; + protected HttpClientManager clientManager; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + private final TestConfiguration testConfiguration; + + public AbstractServiceTests(TestConfiguration testConfiguration) { + this.testConfiguration = Objects.requireNonNull(testConfiguration); + } + + /** + * Main configurations for the tests + */ + public record TestConfiguration(CommonConfig commonConfig, UpdateModelConfiguration updateModelConfiguration) { + public static class Builder { + private final CommonConfig commonConfig; + private UpdateModelConfiguration updateModelConfiguration = DISABLED_UPDATE_MODEL_TESTS; + + public Builder(CommonConfig commonConfig) { + this.commonConfig = commonConfig; + } + + public Builder enableUpdateModelTests(UpdateModelConfiguration updateModelConfiguration) { + this.updateModelConfiguration = updateModelConfiguration; + return this; + } + + public TestConfiguration build() { + return new TestConfiguration(commonConfig, updateModelConfiguration); + } + } + } + + /** + * Configurations that useful for most tests + */ + public abstract static class CommonConfig { + + private final TaskType taskType; + private final TaskType unsupportedTaskType; + + public CommonConfig(TaskType taskType, @Nullable TaskType unsupportedTaskType) { + this.taskType = Objects.requireNonNull(taskType); + this.unsupportedTaskType = unsupportedTaskType; + } + + protected abstract SenderService createService(ThreadPool threadPool, HttpClientManager clientManager); + + protected abstract Map createServiceSettingsMap(TaskType taskType); + + protected abstract Map createTaskSettingsMap(); + + protected abstract Map createSecretSettingsMap(); + + protected abstract void assertModel(Model model, TaskType taskType); + + protected abstract EnumSet supportedStreamingTasks(); + } + + /** + * Configurations specific to the {@link SenderService#updateModelWithEmbeddingDetails(Model, int)} tests + */ + public abstract static class UpdateModelConfiguration { + + public boolean isEnabled() { + return true; + } + + protected abstract CustomModel createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure); + } + + private static final UpdateModelConfiguration DISABLED_UPDATE_MODEL_TESTS = new UpdateModelConfiguration() { + @Override + public boolean isEnabled() { + return false; + } + + @Override + protected CustomModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { + throw new UnsupportedOperationException("Update model tests are disabled"); + } + }; + + public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws Exception { + var parseRequestConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var config = getRequestConfigMap( + parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING), + parseRequestConfigTestConfig.createTaskSettingsMap(), + parseRequestConfigTestConfig.createSecretSettingsMap() + ); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener); + + parseRequestConfigTestConfig.assertModel(listener.actionGet(TIMEOUT), TaskType.TEXT_EMBEDDING); + } + } + + public void testParseRequestConfig_CreatesACompletionModel() throws Exception { + var parseRequestConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var config = getRequestConfigMap( + parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION), + parseRequestConfigTestConfig.createTaskSettingsMap(), + parseRequestConfigTestConfig.createSecretSettingsMap() + ); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", TaskType.COMPLETION, config, listener); + + parseRequestConfigTestConfig.assertModel(listener.actionGet(TIMEOUT), TaskType.COMPLETION); + } + } + + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws Exception { + var parseRequestConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var config = getRequestConfigMap( + parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType), + parseRequestConfigTestConfig.createTaskSettingsMap(), + parseRequestConfigTestConfig.createSecretSettingsMap() + ); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", parseRequestConfigTestConfig.unsupportedTaskType, config, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + exception.getMessage(), + containsString(Strings.format("service does not support task type [%s]", parseRequestConfigTestConfig.unsupportedTaskType)) + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + var parseRequestConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var config = getRequestConfigMap( + parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType), + parseRequestConfigTestConfig.createTaskSettingsMap(), + parseRequestConfigTestConfig.createSecretSettingsMap() + ); + config.put("extra_key", "value"); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings [{extra_key=value}]")); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + var parseRequestConfigTestConfig = testConfiguration.commonConfig; + try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var serviceSettings = parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType); + serviceSettings.put("extra_key", "value"); + var config = getRequestConfigMap( + serviceSettings, + parseRequestConfigTestConfig.createTaskSettingsMap(), + parseRequestConfigTestConfig.createSecretSettingsMap() + ); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings [{extra_key=value}]")); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + var parseRequestConfigTestConfig = testConfiguration.commonConfig; + try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var taskSettings = parseRequestConfigTestConfig.createTaskSettingsMap(); + taskSettings.put("extra_key", "value"); + var config = getRequestConfigMap( + parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType), + taskSettings, + parseRequestConfigTestConfig.createSecretSettingsMap() + ); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings [{extra_key=value}]")); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + var parseRequestConfigTestConfig = testConfiguration.commonConfig; + try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var secretSettingsMap = parseRequestConfigTestConfig.createSecretSettingsMap(); + secretSettingsMap.put("extra_key", "value"); + var config = getRequestConfigMap( + parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType), + parseRequestConfigTestConfig.createTaskSettingsMap(), + secretSettingsMap + ); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings [{extra_key=value}]")); + } + } + + // parsePersistedConfigWithSecrets + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throws Exception { + var parseConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var persistedConfigMap = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING), + parseConfigTestConfig.createTaskSettingsMap(), + parseConfigTestConfig.createSecretSettingsMap() + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfigMap.config(), + persistedConfigMap.secrets() + ); + parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws Exception { + var parseConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var persistedConfigMap = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION), + parseConfigTestConfig.createTaskSettingsMap(), + parseConfigTestConfig.createSecretSettingsMap() + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfigMap.config(), + persistedConfigMap.secrets() + ); + parseConfigTestConfig.assertModel(model, TaskType.COMPLETION); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsUnsupportedModelType() throws Exception { + var parseConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var persistedConfigMap = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType), + parseConfigTestConfig.createTaskSettingsMap(), + parseConfigTestConfig.createSecretSettingsMap() + ); + + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfigWithSecrets( + "id", + parseConfigTestConfig.unsupportedTaskType, + persistedConfigMap.config(), + persistedConfigMap.secrets() + ) + ); + + assertThat( + exception.getMessage(), + containsString(Strings.format("service does not support task type [%s]", parseConfigTestConfig.unsupportedTaskType)) + ); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + var parseConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var persistedConfigMap = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType), + parseConfigTestConfig.createTaskSettingsMap(), + parseConfigTestConfig.createSecretSettingsMap() + ); + persistedConfigMap.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + parseConfigTestConfig.taskType, + persistedConfigMap.config(), + persistedConfigMap.secrets() + ); + + parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + var parseConfigTestConfig = testConfiguration.commonConfig; + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var serviceSettings = parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType); + serviceSettings.put("extra_key", "value"); + var persistedConfigMap = getPersistedConfigMap( + serviceSettings, + parseConfigTestConfig.createTaskSettingsMap(), + parseConfigTestConfig.createSecretSettingsMap() + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + parseConfigTestConfig.taskType, + persistedConfigMap.config(), + persistedConfigMap.secrets() + ); + + parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + var parseConfigTestConfig = testConfiguration.commonConfig; + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var taskSettings = parseConfigTestConfig.createTaskSettingsMap(); + taskSettings.put("extra_key", "value"); + var config = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType), + taskSettings, + parseConfigTestConfig.createSecretSettingsMap() + ); + + var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets()); + + parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + var parseConfigTestConfig = testConfiguration.commonConfig; + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var secretSettingsMap = parseConfigTestConfig.createSecretSettingsMap(); + secretSettingsMap.put("extra_key", "value"); + var config = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType), + parseConfigTestConfig.createTaskSettingsMap(), + secretSettingsMap + ); + + var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets()); + + parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException { + try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + var listener = new PlainActionFuture(); + + service.infer( + getInvalidModel("id", "service"), + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INTERNAL_SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + exception.getMessage(), + is("The internal model was invalid, please delete the service [service] with id [id] and add it again.") + ); + } + } + + public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { + try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + var listener = new PlainActionFuture(); + + var exception = expectThrows( + ValidationException.class, + () -> service.infer( + getInvalidModel("id", "service"), + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ) + ); + + assertThat( + exception.getMessage(), + is("Validation Failed: 1: Invalid input_type [ingest]. The input_type option is not supported by this service;") + ); + } + } + + public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { + Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled()); + + try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> service.updateModelWithEmbeddingDetails(getInvalidModel("id", "service"), randomNonNegativeInt()) + ); + + assertThat(exception.getMessage(), containsString("Can't update embedding details for model of type:")); + } + } + + public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException { + Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled()); + + try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + var embeddingSize = randomNonNegativeInt(); + var model = testConfiguration.updateModelConfiguration.createEmbeddingModel(null); + + Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); + + assertEquals(SimilarityMeasure.DOT_PRODUCT, updatedModel.getServiceSettings().similarity()); + assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue()); + } + } + + public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException { + Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled()); + + try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + var embeddingSize = randomNonNegativeInt(); + var model = testConfiguration.updateModelConfiguration.createEmbeddingModel(SimilarityMeasure.COSINE); + + Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); + + assertEquals(SimilarityMeasure.COSINE, updatedModel.getServiceSettings().similarity()); + assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue()); + } + } + + // streaming tests + public void testSupportedStreamingTasks() throws Exception { + try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + assertThat(service.supportedStreamingTasks(), is(testConfiguration.commonConfig.supportedStreamingTasks())); + assertFalse(service.canStream(TaskType.ANY)); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index 6b2731bb313b5..770f85c866ba7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -11,26 +11,36 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.Booleans; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; +import static org.elasticsearch.xpack.inference.Utils.modifiableMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertMapStringsToSecureString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfStringTuples; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveLong; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalTimeValue; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveIntegerLessThanOrEqualToMax; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapValues; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -920,7 +930,326 @@ public void testValidateInputType_ValidationErrorsWhenInputTypeIsSpecified() { assertThat(validationException.validationErrors().size(), is(4)); } - private static Map modifiableMap(Map aMap) { - return new HashMap<>(aMap); + public void testExtractRequiredMap() { + var validation = new ValidationException(); + var extractedMap = extractRequiredMap(modifiableMap(Map.of("setting", Map.of("key", "value"))), "setting", "scope", validation); + + assertTrue(validation.validationErrors().isEmpty()); + assertThat(extractedMap, is(Map.of("key", "value"))); + } + + public void testExtractRequiredMap_ReturnsNull_WhenTypeIsInvalid() { + var validation = new ValidationException(); + var extractedMap = extractRequiredMap(modifiableMap(Map.of("setting", 123)), "setting", "scope", validation); + + assertNull(extractedMap); + assertThat( + validation.getMessage(), + is("Validation Failed: 1: field [setting] is not of the expected type. The value [123] cannot be converted to a [Map];") + ); + } + + public void testExtractRequiredMap_ReturnsNull_WhenMissingSetting() { + var validation = new ValidationException(); + var extractedMap = extractRequiredMap(modifiableMap(Map.of("not_setting", Map.of("key", "value"))), "setting", "scope", validation); + + assertNull(extractedMap); + assertThat(validation.getMessage(), is("Validation Failed: 1: [scope] does not contain the required setting [setting];")); + } + + public void testExtractRequiredMap_ReturnsNull_WhenMapIsEmpty() { + var validation = new ValidationException(); + var extractedMap = extractRequiredMap(modifiableMap(Map.of("setting", Map.of())), "setting", "scope", validation); + + assertNull(extractedMap); + assertThat( + validation.getMessage(), + is("Validation Failed: 1: [scope] Invalid value empty map. [setting] must be a non-empty map;") + ); + } + + public void testExtractOptionalMap() { + var validation = new ValidationException(); + var extractedMap = extractOptionalMap(modifiableMap(Map.of("setting", Map.of("key", "value"))), "setting", "scope", validation); + + assertTrue(validation.validationErrors().isEmpty()); + assertThat(extractedMap, is(Map.of("key", "value"))); + } + + public void testExtractOptionalMap_ReturnsNull_WhenTypeIsInvalid() { + var validation = new ValidationException(); + var extractedMap = extractOptionalMap(modifiableMap(Map.of("setting", 123)), "setting", "scope", validation); + + assertNull(extractedMap); + assertThat( + validation.getMessage(), + is("Validation Failed: 1: field [setting] is not of the expected type. The value [123] cannot be converted to a [Map];") + ); + } + + public void testExtractOptionalMap_ReturnsNull_WhenMissingSetting() { + var validation = new ValidationException(); + var extractedMap = extractOptionalMap(modifiableMap(Map.of("not_setting", Map.of("key", "value"))), "setting", "scope", validation); + + assertNull(extractedMap); + assertTrue(validation.validationErrors().isEmpty()); + } + + public void testExtractOptionalMap_ReturnsEmptyMap_WhenEmpty() { + var validation = new ValidationException(); + var extractedMap = extractOptionalMap(modifiableMap(Map.of("setting", Map.of())), "setting", "scope", validation); + + assertThat(extractedMap, is(Map.of())); + } + + public void testValidateMapValues() { + var validation = new ValidationException(); + validateMapValues( + Map.of("string_key", "abc", "num_key", Integer.valueOf(1)), + List.of(String.class, Integer.class), + "setting", + validation, + false + ); + } + + public void testValidateMapValues_IgnoresNullMap() { + var validation = new ValidationException(); + validateMapValues(null, List.of(String.class, Integer.class), "setting", validation, false); + } + + public void testValidateMapValues_ThrowsException_WhenMapContainsInvalidTypes() { + // Includes the invalid key and value in the exception message + { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> validateMapValues( + Map.of("string_key", "abc", "num_key", Integer.valueOf(1)), + List.of(String.class), + "setting", + validation, + false + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Map field [setting] has an entry that is not valid, " + + "[num_key => 1]. Value type of [1] is not one of [String].;" + ) + ); + } + + // Does not include the invalid key and value in the exception message + { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> validateMapValues( + Map.of("string_key", "abc", "num_key", Integer.valueOf(1)), + List.of(String.class, List.class), + "setting", + validation, + true + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Map field [setting] has an entry that is not valid. " + + "Value type is not one of [List, String].;" + ) + ); + } + } + + public void testValidateMapStringValues() { + var validation = new ValidationException(); + assertThat( + validateMapStringValues(Map.of("string_key", "abc", "string_key2", new String("awesome")), "setting", validation, false), + is(Map.of("string_key", "abc", "string_key2", "awesome")) + ); + } + + public void testValidateMapStringValues_ReturnsEmptyMap_WhenMapIsNull() { + var validation = new ValidationException(); + assertThat(validateMapStringValues(null, "setting", validation, false), is(Map.of())); + } + + public void testValidateMapStringValues_ThrowsException_WhenMapContainsInvalidTypes() { + // Includes the invalid key and value in the exception message + { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> validateMapStringValues(Map.of("string_key", "abc", "num_key", Integer.valueOf(1)), "setting", validation, false) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Map field [setting] has an entry that is not valid, " + + "[num_key => 1]. Value type of [1] is not one of [String].;" + ) + ); + } + + // Does not include the invalid key and value in the exception message + { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> validateMapStringValues(Map.of("string_key", "abc", "num_key", Integer.valueOf(1)), "setting", validation, true) + ); + + assertThat( + exception.getMessage(), + is("Validation Failed: 1: Map field [setting] has an entry that is not valid. Value type is not one of [String].;") + ); + } + } + + public void testConvertMapStringsToSecureString() { + var validation = new ValidationException(); + assertThat( + convertMapStringsToSecureString(Map.of("key", "value", "key2", "abc"), "setting", validation), + is(Map.of("key", new SerializableSecureString("value"), "key2", new SerializableSecureString("abc"))) + ); + } + + public void testConvertMapStringsToSecureString_ReturnsAnEmptyMap_WhenMapIsNull() { + var validation = new ValidationException(); + assertThat(convertMapStringsToSecureString(null, "setting", validation), is(Map.of())); + } + + public void testConvertMapStringsToSecureString_ThrowsException_WhenMapContainsInvalidTypes() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> convertMapStringsToSecureString(Map.of("key", "value", "key2", 123), "setting", validation) + ); + + assertThat( + exception.getMessage(), + is("Validation Failed: 1: Map field [setting] has an entry that is not valid. Value type is not one of [String].;") + ); + } + + public void testRemoveNullValues() { + var map = new HashMap(); + map.put("key1", null); + map.put("key2", "awesome"); + map.put("key3", null); + + assertThat(removeNullValues(map), is(Map.of("key2", "awesome"))); + } + + public void testRemoveNullValues_ReturnsNull_WhenMapIsNull() { + assertNull(removeNullValues(null)); + } + + public void testExtractOptionalListOfStringTuples() { + var validation = new ValidationException(); + assertThat( + extractOptionalListOfStringTuples( + modifiableMap(Map.of("params", List.of(List.of("key", "value"), List.of("key2", "value2")))), + "params", + "scope", + validation + ), + is(List.of(new Tuple<>("key", "value"), new Tuple<>("key2", "value2"))) + ); + } + + public void testExtractOptionalListOfStringTuples_ReturnsNull_WhenFieldIsNotAList() { + var validation = new ValidationException(); + assertNull(extractOptionalListOfStringTuples(modifiableMap(Map.of("params", Map.of())), "params", "scope", validation)); + + assertThat( + validation.getMessage(), + is("Validation Failed: 1: field [params] is not of the expected type. The value [{}] cannot be converted to a [List];") + ); + } + + public void testExtractOptionalListOfStringTuples_Exception_WhenTupleIsNotAList() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> extractOptionalListOfStringTuples(modifiableMap(Map.of("params", List.of("string"))), "params", "scope", validation) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: [scope] failed to parse tuple list entry [0] for setting " + + "[params], expected a list but the entry is [String];" + ) + ); + } + + public void testExtractOptionalListOfStringTuples_Exception_WhenTupleIsListSize2() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> extractOptionalListOfStringTuples( + modifiableMap(Map.of("params", List.of(List.of("string")))), + "params", + "scope", + validation + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: [scope] failed to parse tuple list entry " + + "[0] for setting [params], the tuple list size must be two, but was [1];" + ) + ); + } + + public void testExtractOptionalListOfStringTuples_Exception_WhenTupleFirstElement_IsNotAString() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> extractOptionalListOfStringTuples( + modifiableMap(Map.of("params", List.of(List.of(1, "value")))), + "params", + "scope", + validation + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: [scope] failed to parse tuple list entry [0] for setting [params], " + + "the first element must be a string but was [Integer];" + ) + ); + } + + public void testExtractOptionalListOfStringTuples_Exception_WhenTupleSecondElement_IsNotAString() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> extractOptionalListOfStringTuples( + modifiableMap(Map.of("params", List.of(List.of("key", 2)))), + "params", + "scope", + validation + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: [scope] failed to parse tuple list entry [0] for setting [params], " + + "the second element must be a string but was [Integer];" + ) + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java new file mode 100644 index 0000000000000..c3c4a44bcab07 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; +import org.hamcrest.MatcherAssert; + +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class CustomModelTests extends ESTestCase { + private static final String taskSettingsKey = "test_taskSettings_key"; + private static final String taskSettingsValue = "test_taskSettings_value"; + + private static final String secretSettingsKey = "test_secret_key"; + private static final SerializableSecureString secretSettingsValue = new SerializableSecureString("test_secret_value"); + private static final String url = "http://www.abc.com"; + + public void testOverride_DoesNotModifiedFields_TaskSettingsIsEmpty() { + var model = createModel( + "service", + TaskType.TEXT_EMBEDDING, + CustomServiceSettingsTests.createRandom(), + CustomTaskSettingsTests.createRandom(), + CustomSecretSettingsTests.createRandom() + ); + + var overriddenModel = CustomModel.of(model, Map.of()); + MatcherAssert.assertThat(overriddenModel, is(model)); + } + + public void testOverride() { + var model = createModel( + "service", + TaskType.TEXT_EMBEDDING, + CustomServiceSettingsTests.createRandom(), + new CustomTaskSettings(Map.of("key", "value")), + CustomSecretSettingsTests.createRandom() + ); + + var overriddenModel = CustomModel.of( + model, + new HashMap<>(Map.of(CustomTaskSettings.PARAMETERS, new HashMap<>(Map.of("key", "different_value")))) + ); + MatcherAssert.assertThat( + overriddenModel, + is( + createModel( + "service", + TaskType.TEXT_EMBEDDING, + model.getServiceSettings(), + new CustomTaskSettings(Map.of("key", "different_value")), + model.getSecretSettings() + ) + ) + ); + } + + public static CustomModel createModel( + String inferenceId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets + ) { + return new CustomModel(inferenceId, taskType, CustomService.NAME, serviceSettings, taskSettings, secrets, null); + } + + public static CustomModel createModel( + String inferenceId, + TaskType taskType, + CustomServiceSettings serviceSettings, + CustomTaskSettings taskSettings, + @Nullable CustomSecretSettings secretSettings + ) { + return new CustomModel(inferenceId, taskType, CustomService.NAME, serviceSettings, taskSettings, secretSettings); + } + + public static CustomModel getTestModel() { + return getTestModel(TaskType.TEXT_EMBEDDING, new TextEmbeddingResponseParser("$.result.embeddings[*].embedding")); + } + + public static CustomModel getTestModel(TaskType taskType, CustomResponseParser responseParser) { + return getTestModel(taskType, responseParser, url); + } + + public static CustomModel getTestModel(TaskType taskType, CustomResponseParser responseParser, String url) { + var inferenceId = "inference_id"; + Integer dims = 1536; + Integer maxInputTokens = 512; + Map headers = Map.of(HttpHeaders.AUTHORIZATION, "${" + secretSettingsKey + "}"); + String requestContentString = "\"input\":\"${input}\""; + + CustomServiceSettings serviceSettings = new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings( + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + DenseVectorFieldMapper.ElementType.FLOAT + ), + url, + headers, + QueryParameters.EMPTY, + requestContentString, + responseParser, + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + CustomTaskSettings taskSettings = new CustomTaskSettings(Map.of(taskSettingsKey, taskSettingsValue)); + CustomSecretSettings secretSettings = new CustomSecretSettings(Map.of(secretSettingsKey, secretSettingsValue)); + + return CustomModelTests.createModel(inferenceId, taskType, serviceSettings, taskSettings, secretSettings); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java new file mode 100644 index 0000000000000..16c058b3e0115 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; +import org.junit.After; +import org.junit.Before; + +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class CustomRequestManagerTests extends ESTestCase { + + private ThreadPool threadPool; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + threadPool = createThreadPool(inferenceUtilityPool()); + } + + @After + @Override + public void tearDown() throws Exception { + super.tearDown(); + terminate(threadPool); + } + + public void testCreateRequest_ThrowsException_ForInvalidUrl() { + var inferenceId = "inference_id"; + + var requestContentString = """ + { + "input": ${input} + } + """; + + var serviceSettings = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "${url}", + null, + null, + requestContentString, + new RerankResponseParser("$.result.score"), + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + var model = CustomModelTests.createModel( + inferenceId, + TaskType.RERANK, + serviceSettings, + new CustomTaskSettings(Map.of("url", "^")), + new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) + ); + + var listener = new PlainActionFuture(); + var manager = CustomRequestManager.of(model, threadPool); + manager.execute(new EmbeddingsInput(List.of("abc", "123"), null, null), mock(RequestSender.class), () -> false, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TimeValue.timeValueSeconds(30))); + + assertThat(exception.getMessage(), is("Failed to construct the custom service request")); + assertThat(exception.getCause().getMessage(), is("Failed to build URI, error: Illegal character in path at index 0: ^")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettingsTests.java new file mode 100644 index 0000000000000..a29992cd7f9fd --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettingsTests.java @@ -0,0 +1,141 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.core.Tuple.tuple; +import static org.elasticsearch.xpack.inference.Utils.modifiableMap; +import static org.hamcrest.Matchers.is; + +public class CustomSecretSettingsTests extends AbstractBWCWireSerializationTestCase { + public static CustomSecretSettings createRandom() { + Map secretParameters = randomMap( + 0, + 5, + () -> tuple(randomAlphaOfLength(5), new SerializableSecureString(randomAlphaOfLength(5))) + ); + + return new CustomSecretSettings(secretParameters); + } + + public void testFromMap() { + Map secretParameters = new HashMap<>( + Map.of(CustomSecretSettings.SECRET_PARAMETERS, new HashMap<>(Map.of("test_key", "test_value"))) + ); + + assertThat( + CustomSecretSettings.fromMap(secretParameters), + is(new CustomSecretSettings(Map.of("test_key", new SerializableSecureString("test_value")))) + ); + } + + public void testFromMap_PassedNull_ReturnsNull() { + assertNull(CustomSecretSettings.fromMap(null)); + } + + public void testFromMap_RemovesNullValues() { + var mapWithNulls = new HashMap(); + mapWithNulls.put("value", "abc"); + mapWithNulls.put("null", null); + + assertThat( + CustomSecretSettings.fromMap(modifiableMap(Map.of(CustomSecretSettings.SECRET_PARAMETERS, mapWithNulls))), + is(new CustomSecretSettings(Map.of("value", new SerializableSecureString("abc")))) + ); + } + + public void testFromMap_Throws_IfValueIsInvalid() { + var exception = expectThrows( + ValidationException.class, + () -> CustomSecretSettings.fromMap( + modifiableMap(Map.of(CustomSecretSettings.SECRET_PARAMETERS, modifiableMap(Map.of("key", Map.of("another_key", "value"))))) + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Map field [secret_parameters] has an entry that is not valid. " + + "Value type is not one of [String].;" + ) + ); + } + + public void testFromMap_DefaultsToEmptyMap_WhenSecretParametersField_DoesNotExist() { + var map = new HashMap(Map.of("key", new HashMap<>(Map.of("test_key", "test_value")))); + + assertThat(CustomSecretSettings.fromMap(map), is(new CustomSecretSettings(Map.of()))); + } + + public void testXContent() throws IOException { + var entity = new CustomSecretSettings(Map.of("test_key", new SerializableSecureString("test_value"))); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "secret_parameters": { + "test_key": "test_value" + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testXContent_EmptyParameters() throws IOException { + var entity = new CustomSecretSettings(Map.of()); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + } + """); + + assertThat(xContentResult, is(expected)); + } + + @Override + protected Writeable.Reader instanceReader() { + return CustomSecretSettings::new; + } + + @Override + protected CustomSecretSettings createTestInstance() { + return createRandom(); + } + + @Override + protected CustomSecretSettings mutateInstance(CustomSecretSettings instance) { + return randomValueOtherThan(instance, CustomSecretSettingsTests::createRandom); + } + + @Override + protected CustomSecretSettings mutateInstanceForVersion(CustomSecretSettings instance, TransportVersion version) { + return instance; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java new file mode 100644 index 0000000000000..71eec73df5375 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -0,0 +1,734 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class CustomServiceSettingsTests extends AbstractBWCWireSerializationTestCase { + public static CustomServiceSettings createRandom(String inputUrl) { + var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION); + + SimilarityMeasure similarityMeasure = null; + Integer dims = null; + var isTextEmbeddingModel = taskType.equals(TaskType.TEXT_EMBEDDING); + if (isTextEmbeddingModel) { + similarityMeasure = SimilarityMeasure.DOT_PRODUCT; + dims = 1536; + } + var maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256); + var url = inputUrl != null ? inputUrl : randomAlphaOfLength(15); + Map headers = randomBoolean() ? Map.of() : Map.of("key", "value"); + var queryParameters = randomBoolean() + ? QueryParameters.EMPTY + : new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"))); + var requestContentString = randomAlphaOfLength(10); + + var responseJsonParser = switch (taskType) { + case TEXT_EMBEDDING -> new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"); + case SPARSE_EMBEDDING -> new SparseEmbeddingResponseParser( + "$.result.sparse_embeddings[*].embedding[*].token_id", + "$.result.sparse_embeddings[*].embedding[*].weights" + ); + case RERANK -> new RerankResponseParser( + "$.result.reranked_results[*].index", + "$.result.reranked_results[*].relevance_score", + "$.result.reranked_results[*].document_text" + ); + case COMPLETION -> new CompletionResponseParser("$.result.text"); + default -> new NoopResponseParser(); + }; + + var errorParser = new ErrorResponseParser("$.error.message", randomAlphaOfLength(5)); + + RateLimitSettings rateLimitSettings = new RateLimitSettings(randomLongBetween(1, 1000000)); + + return new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings( + similarityMeasure, + dims, + maxInputTokens, + DenseVectorFieldMapper.ElementType.FLOAT + ), + url, + headers, + queryParameters, + requestContentString, + responseJsonParser, + rateLimitSettings, + errorParser + ); + } + + public static CustomServiceSettings createRandom() { + return createRandom(randomAlphaOfLength(5)); + } + + public void testFromMap() { + String similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + Integer dims = 1536; + Integer maxInputTokens = 512; + String url = "http://www.abc.com"; + Map headers = Map.of("key", "value"); + var queryParameters = List.of(List.of("key", "value")); + String requestContentString = "request body"; + + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"); + + var settings = CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + headers, + QueryParameters.QUERY_PARAMETERS, + queryParameters, + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ), + ConfigurationParseContext.REQUEST, + TaskType.TEXT_EMBEDDING, + "inference_id" + ); + + assertThat( + settings, + is( + new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings( + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + DenseVectorFieldMapper.ElementType.FLOAT + ), + url, + headers, + new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"))), + requestContentString, + responseParser, + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", "inference_id") + ) + ) + ); + } + + public void testFromMap_WithOptionalsNotSpecified() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"); + + var settings = CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ), + ConfigurationParseContext.REQUEST, + TaskType.TEXT_EMBEDDING, + "inference_id" + ); + + MatcherAssert.assertThat( + settings, + is( + new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.DEFAULT_FLOAT, + url, + Map.of(), + null, + requestContentString, + responseParser, + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", "inference_id") + ) + ) + ); + } + + public void testFromMap_RemovesNullValues_FromMaps() { + String similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + Integer dims = 1536; + Integer maxInputTokens = 512; + String url = "http://www.abc.com"; + + var headersWithNulls = new HashMap(); + headersWithNulls.put("value", "abc"); + headersWithNulls.put("null", null); + + String requestContentString = "request body"; + + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"); + + var settings = CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + headersWithNulls, + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ), + ConfigurationParseContext.REQUEST, + TaskType.TEXT_EMBEDDING, + "inference_id" + ); + + MatcherAssert.assertThat( + settings, + is( + new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings( + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + DenseVectorFieldMapper.ElementType.FLOAT + ), + url, + Map.of("value", "abc"), + null, + requestContentString, + responseParser, + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", "inference_id") + ) + ) + ); + } + + public void testFromMap_ReturnsError_IfHeadersContainsNonStringValues() { + String similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + Integer dims = 1536; + Integer maxInputTokens = 512; + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap( + Map.of( + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + new HashMap<>(Map.of("key", 1)), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ); + + var exception = expectThrows( + ValidationException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Map field [headers] has an entry that is not valid, [key => 1]. " + + "Value type of [1] is not one of [String].;" + ) + ); + } + + public void testFromMap_ReturnsError_IfQueryParamsContainsNonStringValues() { + String similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + Integer dims = 1536; + Integer maxInputTokens = 512; + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + CustomServiceSettings.URL, + url, + QueryParameters.QUERY_PARAMETERS, + List.of(List.of("key", 1)), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ); + + var exception = expectThrows( + ValidationException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: [service_settings] failed to parse tuple list entry [0] " + + "for setting [query_parameters], the second element must be a string but was [Integer];" + ) + ); + } + + public void testFromMap_ReturnsError_IfRequestMapIsMissing() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + new HashMap<>(Map.of("key", "value")), + "invalid_request", + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ); + + var exception = expectThrows( + ValidationException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: [service_settings] does not contain the required setting [request];" + + "2: [service_settings] does not contain the required setting [content];" + ) + ); + } + + public void testFromMap_ReturnsError_IfResponseMapIsMissing() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + new HashMap<>(Map.of("key", "value")), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + "invalid_response", + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ); + + var exception = expectThrows( + ValidationException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: [service_settings] does not contain the required setting [response];" + + "2: [service_settings.response] does not contain the required setting [json_parser];" + + "3: [service_settings.response] does not contain the required setting [error_parser];" + + "4: Encountered a null input map while parsing field [path];" + ) + ); + } + + public void testFromMap_ReturnsError_IfRequestMapIsNotEmptyAfterParsing() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + new HashMap<>(Map.of("key", "value")), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString, "key", "value")), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ); + + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + ); + + assertThat( + exception.getMessage(), + is( + "Model configuration contains unknown settings [{key=value}] while parsing field [request]" + + " for settings [custom_service_settings]" + ) + ); + } + + public void testFromMap_ReturnsError_IfJsonParserMapIsNotEmptyAfterParsing() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + new HashMap<>(Map.of("key", "value")), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of( + TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, + "$.result.embeddings[*].embedding", + "key", + "value" + ) + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ); + + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + ); + + assertThat( + exception.getMessage(), + is( + "Model configuration contains unknown settings [{key=value}] while parsing field [json_parser]" + + " for settings [custom_service_settings]" + ) + ); + } + + public void testFromMap_ReturnsError_IfResponseMapIsNotEmptyAfterParsing() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + new HashMap<>(Map.of("key", "value")), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")), + "key", + "value" + ) + ) + ) + ); + + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + ); + + assertThat( + exception.getMessage(), + is( + "Model configuration contains unknown settings [{key=value}] while parsing field [response]" + + " for settings [custom_service_settings]" + ) + ); + } + + public void testFromMap_ReturnsError_IfErrorParserMapIsNotEmptyAfterParsing() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + new HashMap<>(Map.of("key", "value")), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message", "key", "value")) + ) + ) + ) + ); + + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + ); + + assertThat( + exception.getMessage(), + is( + "Model configuration contains unknown settings [{key=value}] while parsing field [error_parser]" + + " for settings [custom_service_settings]" + ) + ); + } + + public void testFromMap_ReturnsError_IfTaskTypeIsInvalid() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + new HashMap<>(Map.of("key", "value")), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message", "key", "value")) + ) + ) + ) + ); + + var exception = expectThrows( + IllegalArgumentException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.CHAT_COMPLETION, "inference_id") + ); + + assertThat(exception.getMessage(), is("Invalid task type received [chat_completion] while constructing response parser")); + } + + public void testXContent() throws IOException { + var entity = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "http://www.abc.com", + Map.of("key", "value"), + null, + "string", + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), + null, + new ErrorResponseParser("$.error.message", "inference_id") + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "url": "http://www.abc.com", + "headers": { + "key": "value" + }, + "request": { + "content": "string" + }, + "response": { + "json_parser": { + "text_embeddings": "$.result.embeddings[*].embedding" + }, + "error_parser": { + "path": "$.error.message" + } + }, + "rate_limit": { + "requests_per_minute": 10000 + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(InferenceNamedWriteablesProvider.getNamedWriteables()); + } + + @Override + protected Writeable.Reader instanceReader() { + return CustomServiceSettings::new; + } + + @Override + protected CustomServiceSettings createTestInstance() { + return createRandom(randomAlphaOfLength(5)); + } + + @Override + protected CustomServiceSettings mutateInstance(CustomServiceSettings instance) { + return randomValueOtherThan(instance, CustomServiceSettingsTests::createRandom); + } + + @Override + protected CustomServiceSettings mutateInstanceForVersion(CustomServiceSettings instance, TransportVersion version) { + return instance; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java new file mode 100644 index 0000000000000..6ce181b0487ad --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -0,0 +1,550 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.services.AbstractServiceTests; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_DOCUMENT_TEXT; +import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_INDEX; +import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_SCORE; +import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_TOKEN_PATH; +import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_WEIGHT_PATH; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class CustomServiceTests extends AbstractServiceTests { + + public CustomServiceTests() { + super(createTestConfiguration()); + } + + private static TestConfiguration createTestConfiguration() { + return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION) { + @Override + protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + return CustomServiceTests.createService(threadPool, clientManager); + } + + @Override + protected Map createServiceSettingsMap(TaskType taskType) { + return CustomServiceTests.createServiceSettingsMap(taskType); + } + + @Override + protected Map createTaskSettingsMap() { + return CustomServiceTests.createTaskSettingsMap(); + } + + @Override + protected Map createSecretSettingsMap() { + return CustomServiceTests.createSecretSettingsMap(); + } + + @Override + protected void assertModel(Model model, TaskType taskType) { + CustomServiceTests.assertModel(model, taskType); + } + + @Override + protected EnumSet supportedStreamingTasks() { + return EnumSet.noneOf(TaskType.class); + } + }).enableUpdateModelTests(new UpdateModelConfiguration() { + @Override + protected CustomModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { + return createInternalEmbeddingModel(similarityMeasure); + } + }).build(); + } + + private static void assertModel(Model model, TaskType taskType) { + switch (taskType) { + case TEXT_EMBEDDING -> assertTextEmbeddingModel(model); + case COMPLETION -> assertCompletionModel(model); + default -> fail("unexpected task type [" + taskType + "]"); + } + } + + private static void assertTextEmbeddingModel(Model model) { + var customModel = assertCommonModelFields(model); + + assertThat(customModel.getTaskType(), is(TaskType.TEXT_EMBEDDING)); + assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(TextEmbeddingResponseParser.class)); + } + + private static CustomModel assertCommonModelFields(Model model) { + assertThat(model, instanceOf(CustomModel.class)); + + var customModel = (CustomModel) model; + + assertThat(customModel.getServiceSettings().getUrl(), is("http://www.abc.com")); + assertThat(customModel.getTaskSettings().getParameters(), is(Map.of("test_key", "test_value"))); + assertThat( + customModel.getSecretSettings().getSecretParameters(), + is(Map.of("test_key", new SerializableSecureString("test_value"))) + ); + + return customModel; + } + + private static void assertCompletionModel(Model model) { + var customModel = assertCommonModelFields(model); + assertThat(customModel.getTaskType(), is(TaskType.COMPLETION)); + assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(CompletionResponseParser.class)); + } + + private static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + return new CustomService(senderFactory, createWithEmptySettings(threadPool)); + } + + private static Map createServiceSettingsMap(TaskType taskType) { + var settingsMap = new HashMap<>( + Map.of( + CustomServiceSettings.URL, + "http://www.abc.com", + CustomServiceSettings.HEADERS, + Map.of("key", "value"), + QueryParameters.QUERY_PARAMETERS, + List.of(List.of("key", "value")), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, "request body")), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + createResponseParserMap(taskType), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ); + + if (taskType == TaskType.TEXT_EMBEDDING) { + settingsMap.putAll(Map.of(ServiceFields.SIMILARITY, + SimilarityMeasure.DOT_PRODUCT.toString(), + ServiceFields.DIMENSIONS, + 1536, + ServiceFields.MAX_INPUT_TOKENS, + 512)); + } + + return settingsMap; + } + + private static Map createResponseParserMap(TaskType taskType) { + return switch (taskType) { + case TEXT_EMBEDDING -> new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ); + case COMPLETION -> new HashMap<>(Map.of(CompletionResponseParser.COMPLETION_PARSER_RESULT, "$.result.text")); + case SPARSE_EMBEDDING -> new HashMap<>( + Map.of( + SPARSE_EMBEDDING_TOKEN_PATH, + "$.result[*].embeddings[*].token", + SPARSE_EMBEDDING_WEIGHT_PATH, + "$.result[*].embeddings[*].weight" + ) + ); + case RERANK -> new HashMap<>( + Map.of( + RERANK_PARSER_SCORE, + "$.result.scores[*].score", + RERANK_PARSER_INDEX, + "$.result.scores[*].index", + RERANK_PARSER_DOCUMENT_TEXT, + "$.result.scores[*].document_text" + ) + ); + default -> throw new IllegalArgumentException("unexpected task type [" + taskType + "]"); + }; + } + + private static Map createTaskSettingsMap() { + return new HashMap<>(Map.of(CustomTaskSettings.PARAMETERS, new HashMap<>(Map.of("test_key", "test_value")))); + } + + private static Map createSecretSettingsMap() { + return new HashMap<>(Map.of(CustomSecretSettings.SECRET_PARAMETERS, new HashMap<>(Map.of("test_key", "test_value")))); + } + + private static CustomModel createInternalEmbeddingModel(SimilarityMeasure similarityMeasure) { + return createInternalEmbeddingModel( + similarityMeasure, + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), + "http://www.abc.com" + ); + } + + private static CustomModel createInternalEmbeddingModel(TextEmbeddingResponseParser parser, String url) { + return createInternalEmbeddingModel(SimilarityMeasure.DOT_PRODUCT, parser, url); + } + + private static CustomModel createInternalEmbeddingModel( + @Nullable SimilarityMeasure similarityMeasure, + TextEmbeddingResponseParser parser, + String url + ) { + var inferenceId = "inference_id"; + + return new CustomModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + CustomService.NAME, + new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings( + similarityMeasure, + 123, + 456, + DenseVectorFieldMapper.ElementType.FLOAT + ), + url, + Map.of("key", "value"), + QueryParameters.EMPTY, + "\"input\":\"${input}\"", + parser, + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ), + new CustomTaskSettings(Map.of("key", "test_value")), + new CustomSecretSettings(Map.of("test_key", new SerializableSecureString("test_value"))) + ); + } + + private static CustomModel createCustomModel(TaskType taskType, CustomResponseParser customResponseParser, String url) { + var inferenceId = "inference_id"; + + return new CustomModel( + "model_id", + taskType, + CustomService.NAME, + new CustomServiceSettings( + getDefaultTextEmbeddingSettings(taskType), + url, + Map.of("key", "value"), + QueryParameters.EMPTY, + "\"input\":\"${input}\"", + customResponseParser, + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ), + new CustomTaskSettings(Map.of("key", "test_value")), + new CustomSecretSettings(Map.of("test_key", new SerializableSecureString("test_value"))) + ); + } + + private static CustomServiceSettings.TextEmbeddingSettings getDefaultTextEmbeddingSettings(TaskType taskType) { + return taskType == TaskType.TEXT_EMBEDDING + ? CustomServiceSettings.TextEmbeddingSettings.DEFAULT_FLOAT + : CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS; + } + + public void testInfer_HandlesTextEmbeddingRequest_OpenAI_Format() throws IOException { + try (var service = createService(threadPool, clientManager)) { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createInternalEmbeddingModel(new TextEmbeddingResponseParser("$.data[*].embedding"), getUrl(webServer)); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("test input"), + false, + new HashMap<>(), + InputType.INTERNAL_SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + InferenceServiceResults results = listener.actionGet(TIMEOUT); + assertThat(results, instanceOf(TextEmbeddingFloatResults.class)); + + var embeddingResults = (TextEmbeddingFloatResults) results; + assertThat( + embeddingResults.embeddings(), + is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }))) + ); + } + } + + public void testInfer_HandlesRerankRequest_Cohere_Format() throws IOException { + try (var service = createService(threadPool, clientManager)) { + String responseJson = """ + { + "index": "44873262-1315-4c06-8433-fdc90c9790d0", + "results": [ + { + "document": { + "text": "Washington, D.C.." + }, + "index": 2, + "relevance_score": 0.98005307 + }, + { + "document": { + "text": "Capital punishment has existed in the United States since beforethe United States was a country." + }, + "index": 3, + "relevance_score": 0.27904198 + }, + { + "document": { + "text": "Carson City is the capital city of the American state of Nevada." + }, + "index": 0, + "relevance_score": 0.10194652 + } + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "search_units": 1 + } + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createCustomModel( + TaskType.RERANK, + new RerankResponseParser("$.results[*].relevance_score", "$.results[*].index", "$.results[*].document.text"), + getUrl(webServer) + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + null, + null, + List.of("test input"), + false, + new HashMap<>(), + InputType.INTERNAL_SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + InferenceServiceResults results = listener.actionGet(TIMEOUT); + assertThat(results, instanceOf(RankedDocsResults.class)); + + var rerankResults = (RankedDocsResults) results; + assertThat( + rerankResults.getRankedDocs(), + is( + List.of( + new RankedDocsResults.RankedDoc(2, 0.98005307f, "Washington, D.C.."), + new RankedDocsResults.RankedDoc( + 3, + 0.27904198f, + "Capital punishment has existed in the United States since beforethe United States was a country." + ), + new RankedDocsResults.RankedDoc(0, 0.10194652f, "Carson City is the capital city of the American state of Nevada.") + ) + ) + ); + } + } + + public void testInfer_HandlesCompletionRequest_OpenAI_Format() throws IOException { + try (var service = createService(threadPool, clientManager)) { + String responseJson = """ + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo-0613", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello there, how may I assist you today?" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createCustomModel( + TaskType.COMPLETION, + new CompletionResponseParser("$.choices[*].message.content"), + getUrl(webServer) + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("test input"), + false, + new HashMap<>(), + InputType.INTERNAL_SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + InferenceServiceResults results = listener.actionGet(TIMEOUT); + assertThat(results, instanceOf(ChatCompletionResults.class)); + + var completionResults = (ChatCompletionResults) results; + assertThat( + completionResults.getResults(), + is(List.of(new ChatCompletionResults.Result("Hello there, how may I assist you today?"))) + ); + } + } + + public void testInfer_HandlesSparseEmbeddingRequest_Alibaba_Format() throws IOException { + try (var service = createService(threadPool, clientManager)) { + String responseJson = """ + { + "request_id": "75C50B5B-E79E-4930-****-F48DBB392231", + "latency": 22, + "usage": { + "token_count": 11 + }, + "result": { + "sparse_embeddings": [ + { + "index": 0, + "embedding": [ + { + "tokenId": 6, + "weight": 0.101 + }, + { + "tokenId": 163040, + "weight": 0.28417 + } + ] + } + ] + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createCustomModel( + TaskType.SPARSE_EMBEDDING, + new SparseEmbeddingResponseParser( + "$.result.sparse_embeddings[*].embedding[*].tokenId", + "$.result.sparse_embeddings[*].embedding[*].weight" + ), + getUrl(webServer) + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("test input"), + false, + new HashMap<>(), + InputType.INTERNAL_SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + InferenceServiceResults results = listener.actionGet(TIMEOUT); + assertThat(results, instanceOf(SparseEmbeddingResults.class)); + + var sparseEmbeddingResults = (SparseEmbeddingResults) results; + assertThat( + sparseEmbeddingResults.embeddings(), + is( + List.of( + new SparseEmbeddingResults.Embedding( + List.of(new WeightedToken("6", 0.101f), new WeightedToken("163040", 0.28417f)), + false + ) + ) + ) + ); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java new file mode 100644 index 0000000000000..01d09af0b7a27 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java @@ -0,0 +1,155 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.core.Tuple.tuple; +import static org.elasticsearch.xpack.inference.Utils.modifiableMap; +import static org.hamcrest.Matchers.is; + +public class CustomTaskSettingsTests extends AbstractBWCWireSerializationTestCase { + public static CustomTaskSettings createRandom() { + Map parameters = randomBoolean() + ? randomMap(0, 5, () -> tuple(randomAlphaOfLength(5), (Object) randomAlphaOfLength(5))) + : Map.of(); + return new CustomTaskSettings(parameters); + } + + public void testFromMap() { + var taskSettingsMap = new HashMap( + Map.of(CustomTaskSettings.PARAMETERS, new HashMap<>(Map.of("test_key", "test_value"))) + ); + + assertThat(CustomTaskSettings.fromMap(taskSettingsMap), is(new CustomTaskSettings(Map.of("test_key", "test_value")))); + } + + public void testFromMap_Null_EmptyMap_Returns_EmptySettings() { + assertThat(CustomTaskSettings.fromMap(Map.of()), is(CustomTaskSettings.EMPTY_SETTINGS)); + assertThat(CustomTaskSettings.fromMap(null), is(CustomTaskSettings.EMPTY_SETTINGS)); + } + + public void testFromMap_RemovesNullValues() { + var mapWithNulls = new HashMap(); + mapWithNulls.put("value", "abc"); + mapWithNulls.put("null", null); + + assertThat( + CustomTaskSettings.fromMap(modifiableMap(Map.of(CustomTaskSettings.PARAMETERS, mapWithNulls))), + is(new CustomTaskSettings(Map.of("value", "abc"))) + ); + } + + public void testFromMap_Throws_IfValueIsInvalid() { + var exception = expectThrows( + ValidationException.class, + () -> CustomTaskSettings.fromMap( + modifiableMap(Map.of(CustomTaskSettings.PARAMETERS, modifiableMap(Map.of("key", Map.of("another_key", "value"))))) + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Map field [parameters] has an entry that is not valid, [key => {another_key=value}]. " + + "Value type of [{another_key=value}] is not one of [Boolean, Double, Float, Integer, String].;" + ) + ); + } + + public void testFromMap_DefaultsToEmptyMap_WhenParametersField_DoesNotExist() { + var taskSettingsMap = new HashMap(Map.of("key", new HashMap<>(Map.of("test_key", "test_value")))); + + assertThat(CustomTaskSettings.fromMap(taskSettingsMap), is(new CustomTaskSettings(Map.of()))); + } + + public void testOf_PrefersSettingsFromRequest() { + assertThat( + CustomTaskSettings.of( + new CustomTaskSettings(Map.of("a", "a_value", "b", "b_value")), + new CustomTaskSettings(Map.of("b", "b_value_overwritten")) + ), + is(new CustomTaskSettings(Map.of("a", "a_value", "b", "b_value_overwritten"))) + ); + } + + public void testXContent() throws IOException { + var entity = new CustomTaskSettings(Map.of("test_key", "test_value")); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "parameters": { + "test_key": "test_value" + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testXContent_EmptyParameters() throws IOException { + var entity = new CustomTaskSettings(Map.of()); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + } + """); + + assertThat(xContentResult, is(expected)); + } + + @Override + protected Writeable.Reader instanceReader() { + return CustomTaskSettings::new; + } + + @Override + protected CustomTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected CustomTaskSettings mutateInstance(CustomTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, CustomTaskSettingsTests::createRandom); + } + + public static Map getTaskSettingsMap(@Nullable Map parameters) { + var map = new HashMap(); + if (parameters != null) { + map.put(CustomTaskSettings.PARAMETERS, parameters); + } + + return map; + } + + @Override + protected CustomTaskSettings mutateInstanceForVersion(CustomTaskSettings instance, TransportVersion version) { + return instance; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/QueryParametersTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/QueryParametersTests.java new file mode 100644 index 0000000000000..d6fac6709cc32 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/QueryParametersTests.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.Utils.modifiableMap; +import static org.hamcrest.Matchers.is; + +public class QueryParametersTests extends AbstractBWCWireSerializationTestCase { + public static QueryParameters createRandom() { + var parameters = randomList(5, () -> new QueryParameters.Parameter(randomAlphaOfLength(5), randomAlphaOfLength(5))); + return new QueryParameters(parameters); + } + + public void testFromMap() { + Map params = new HashMap<>(Map.of(QueryParameters.QUERY_PARAMETERS, List.of(List.of("test_key", "test_value")))); + + assertThat( + QueryParameters.fromMap(params, new ValidationException()), + is(new QueryParameters(List.of(new QueryParameters.Parameter("test_key", "test_value")))) + ); + } + + public void testFromMap_ReturnsEmpty_IfFieldDoesNotExist() { + assertThat(QueryParameters.fromMap(modifiableMap(Map.of()), new ValidationException()), is(QueryParameters.EMPTY)); + } + + public void testFromMap_Throws_IfFieldIsInvalid() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> QueryParameters.fromMap(modifiableMap(Map.of(QueryParameters.QUERY_PARAMETERS, "string")), validation) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: field [query_parameters] is not of the expected type. " + + "The value [string] cannot be converted to a [List];" + ) + ); + } + + public void testXContent() throws IOException { + var entity = new QueryParameters(List.of(new QueryParameters.Parameter("test_key", "test_value"))); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + { + builder.startObject(); + entity.toXContent(builder, null); + builder.endObject(); + } + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "query_parameters": [ + ["test_key", "test_value"] + ] + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testXContent_EmptyParameters() throws IOException { + var entity = QueryParameters.EMPTY; + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is("")); + } + + @Override + protected Writeable.Reader instanceReader() { + return QueryParameters::new; + } + + @Override + protected QueryParameters createTestInstance() { + return createRandom(); + } + + @Override + protected QueryParameters mutateInstance(QueryParameters instance) { + return randomValueOtherThan(instance, QueryParametersTests::createRandom); + } + + @Override + protected QueryParameters mutateInstanceForVersion(QueryParameters instance, TransportVersion version) { + return instance; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java new file mode 100644 index 0000000000000..06bfc0b1f6956 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java @@ -0,0 +1,310 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.Streams; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.custom.CustomModelTests; +import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings; +import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings; +import org.elasticsearch.xpack.inference.services.custom.QueryParameters; +import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class CustomRequestTests extends ESTestCase { + + public void testCreateRequest() throws IOException { + var inferenceId = "inference_id"; + var dims = 1536; + var maxInputTokens = 512; + Map headers = Map.of(HttpHeaders.AUTHORIZATION, Strings.format("${api_key}")); + var requestContentString = """ + { + "input": ${input} + } + """; + + var serviceSettings = new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings( + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + DenseVectorFieldMapper.ElementType.FLOAT + ), + "${url}", + headers, + new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))), + requestContentString, + new TextEmbeddingResponseParser("$.result.embeddings"), + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + var model = CustomModelTests.createModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + serviceSettings, + new CustomTaskSettings(Map.of("url", "https://www.elastic.com")), + new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) + ); + + var request = new CustomRequest(null, List.of("abc", "123"), model); + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + assertThat(httpPost.getURI().toString(), is("https://www.elastic.com?key=value&key=value2")); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("my-secret-key")); + + var expectedBody = XContentHelper.stripWhitespace(""" + { + "input": ["abc", "123"] + } + """); + + assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody)); + } + + public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() { + var inferenceId = "inferenceId"; + var requestContentString = """ + { + "input": ${input} + } + """; + + var serviceSettings = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "http://www.elastic.co", + null, + // escaped characters retrieved from here: https://docs.microfocus.com/OMi/10.62/Content/OMi/ExtGuide/ExtApps/URL_encoding.htm + new QueryParameters( + List.of( + new QueryParameters.Parameter("key", " <>#%+{}|\\^~[]`;/?:@=&$"), + // unicode is a 😀 + // Note: In the current version of the apache library (4.x) being used to do the encoding, spaces are converted to + + // There's a bug fix here explaining that: https://issues.apache.org/jira/browse/HTTPCORE-628 + new QueryParameters.Parameter("key", "Σ \uD83D\uDE00") + ) + ), + requestContentString, + new TextEmbeddingResponseParser("$.result.embeddings"), + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + var model = CustomModelTests.createModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + serviceSettings, + new CustomTaskSettings(Map.of("url", "https://www.elastic.com")), + new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) + ); + + var request = new CustomRequest(null, List.of("abc", "123"), model); + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + assertThat( + httpPost.getURI().toString(), + // To visually verify that this is correct, input the query parameters into here: https://www.urldecoder.org/ + is("http://www.elastic.co?key=+%3C%3E%23%25%2B%7B%7D%7C%5C%5E%7E%5B%5D%60%3B%2F%3F%3A%40%3D%26%24&key=%CE%A3+%F0%9F%98%80") + ); + } + + public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws IOException { + var inferenceId = "inference_id"; + var dims = 1536; + var maxInputTokens = 512; + Map headers = Map.of(HttpHeaders.AUTHORIZATION, Strings.format("${api_key}")); + var requestContentString = """ + { + "input": ${input}, + "secret": ${api_key} + } + """; + + var serviceSettings = new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings( + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + DenseVectorFieldMapper.ElementType.FLOAT + ), + "${url}", + headers, + new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))), + requestContentString, + new TextEmbeddingResponseParser("$.result.embeddings"), + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + var model = CustomModelTests.createModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + serviceSettings, + new CustomTaskSettings(Map.of("url", "https://www.elastic.com")), + new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) + ); + + var request = new CustomRequest(null, List.of("abc", "123"), model); + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + assertThat(httpPost.getURI().toString(), is("https://www.elastic.com?key=value&key=value2")); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("my-secret-key")); + + // secret is encoded in json format (with quotes) + var expectedBody = XContentHelper.stripWhitespace(""" + { + "input": ["abc", "123"], + "secret": "my-secret-key" + } + """); + + assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody)); + } + + public void testCreateRequest_HandlesQuery() throws IOException { + var inferenceId = "inference_id"; + var requestContentString = """ + { + "input": ${input}, + "query": ${query} + } + """; + + var serviceSettings = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "http://www.elastic.co", + null, + null, + requestContentString, + new RerankResponseParser("$.result.score"), + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + var model = CustomModelTests.createModel( + inferenceId, + TaskType.RERANK, + serviceSettings, + new CustomTaskSettings(Map.of()), + new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) + ); + + var request = new CustomRequest("query string", List.of("abc", "123"), model); + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var expectedBody = XContentHelper.stripWhitespace(""" + { + "input": ["abc", "123"], + "query": "query string" + } + """); + + assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody)); + } + + public void testCreateRequest_IgnoresNonStringFields_ForStringParams() throws IOException { + var inferenceId = "inference_id"; + var requestContentString = """ + { + "input": ${input} + } + """; + + var serviceSettings = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "http://www.elastic.co", + Map.of(HttpHeaders.ACCEPT, Strings.format("${task.key}")), + null, + requestContentString, + new RerankResponseParser("$.result.score"), + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + var model = CustomModelTests.createModel( + inferenceId, + TaskType.RERANK, + serviceSettings, + new CustomTaskSettings(Map.of("task.key", 100)), + new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) + ); + + var request = new CustomRequest(null, List.of("abc", "123"), model); + var exception = expectThrows(IllegalStateException.class, request::createHttpRequest); + assertThat(exception.getMessage(), is("Found placeholder [${task.key}] in field [header.Accept] after replacement call")); + } + + public void testCreateRequest_ThrowsException_ForInvalidUrl() { + var inferenceId = "inference_id"; + var requestContentString = """ + { + "input": ${input} + } + """; + + var serviceSettings = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "${url}", + Map.of(HttpHeaders.ACCEPT, Strings.format("${task.key}")), + null, + requestContentString, + new RerankResponseParser("$.result.score"), + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + var model = CustomModelTests.createModel( + inferenceId, + TaskType.RERANK, + serviceSettings, + new CustomTaskSettings(Map.of("url", "^")), + new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) + ); + + var exception = expectThrows(IllegalStateException.class, () -> new CustomRequest(null, List.of("abc", "123"), model)); + assertThat(exception.getMessage(), is("Failed to build URI, error: Illegal character in path at index 0: ^")); + } + + private static String convertToString(InputStream inputStream) throws IOException { + return XContentHelper.stripWhitespace(Streams.copyToString(new InputStreamReader(inputStream, StandardCharsets.UTF_8))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParserTests.java index 46cb23a4ceaa5..1e8dbb41e4d9d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParserTests.java @@ -38,7 +38,12 @@ public static CompletionResponseParser createRandom() { public void testFromMap() { var validation = new ValidationException(); - var parser = CompletionResponseParser.fromMap(new HashMap<>(Map.of(COMPLETION_PARSER_RESULT, "$.result[*].text")), validation); + + var parser = CompletionResponseParser.fromMap( + new HashMap<>(Map.of(COMPLETION_PARSER_RESULT, "$.result[*].text")), + "scope", + validation + ); assertThat(parser, is(new CompletionResponseParser("$.result[*].text"))); } @@ -47,12 +52,12 @@ public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { var validation = new ValidationException(); var exception = expectThrows( ValidationException.class, - () -> CompletionResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].text")), validation) + () -> CompletionResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].text")), "scope", validation) ); assertThat( exception.getMessage(), - is("Validation Failed: 1: [json_parser] does not contain the required setting [completion_result];") + is("Validation Failed: 1: [scope.json_parser] does not contain the required setting [completion_result];") ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java new file mode 100644 index 0000000000000..e7f6a47e7c9c7 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java @@ -0,0 +1,211 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.services.custom.CustomModelTests; +import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class CustomResponseEntityTests extends ESTestCase { + + public void testFromTextEmbeddingResponse() throws IOException { + String responseJson = """ + { + "request_id": "B4AB89C8-B135-xxxx-A6F8-2BAB801A2CE4", + "latency": 38, + "usage": { + "token_count": 3072 + }, + "result": { + "embeddings": [ + { + "index": 0, + "embedding": [ + -0.02868066355586052, + 0.022033605724573135 + ] + } + ] + } + } + """; + + var request = new CustomRequest( + null, + List.of("abc"), + CustomModelTests.getTestModel(TaskType.TEXT_EMBEDDING, new TextEmbeddingResponseParser("$.result.embeddings[*].embedding")) + ); + InferenceServiceResults results = CustomResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(results, instanceOf(TextEmbeddingFloatResults.class)); + assertThat( + ((TextEmbeddingFloatResults) results).embeddings(), + is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f }))) + ); + } + + public void testFromSparseEmbeddingResponse() throws IOException { + String responseJson = """ + { + "request_id": "75C50B5B-E79E-4930-****-F48DBB392231", + "latency": 22, + "usage": { + "token_count": 11 + }, + "result": { + "sparse_embeddings": [ + { + "index": 0, + "embedding": [ + { + "tokenId": 6, + "weight": 0.10137939453125 + }, + { + "tokenId": 163040, + "weight": 0.2841796875 + } + ] + } + ] + } + } + """; + + var request = new CustomRequest( + null, + List.of("abc"), + CustomModelTests.getTestModel( + TaskType.SPARSE_EMBEDDING, + new SparseEmbeddingResponseParser( + "$.result.sparse_embeddings[*].embedding[*].tokenId", + "$.result.sparse_embeddings[*].embedding[*].weight" + ) + ) + ); + + InferenceServiceResults results = CustomResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + assertThat(results, instanceOf(SparseEmbeddingResults.class)); + + SparseEmbeddingResults sparseEmbeddingResults = (SparseEmbeddingResults) results; + + List embeddingList = new ArrayList<>(); + List weightedTokens = new ArrayList<>(); + weightedTokens.add(new WeightedToken("6", 0.10137939453125f)); + weightedTokens.add(new WeightedToken("163040", 0.2841796875f)); + embeddingList.add(new SparseEmbeddingResults.Embedding(weightedTokens, false)); + + for (int i = 0; i < embeddingList.size(); i++) { + assertThat(sparseEmbeddingResults.embeddings().get(i), is(embeddingList.get(i))); + } + } + + public void testFromRerankResponse() throws IOException { + String responseJson = """ + { + "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f", + "latency": 564.903929, + "usage": { + "doc_count": 2 + }, + "result": { + "scores":[ + { + "index":1, + "score": 1.37 + }, + { + "index":0, + "score": -0.3 + } + ] + } + } + """; + + var request = new CustomRequest( + null, + List.of("abc"), + CustomModelTests.getTestModel( + TaskType.RERANK, + new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", null) + ) + ); + + InferenceServiceResults results = CustomResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(results, instanceOf(RankedDocsResults.class)); + var expected = new ArrayList(); + expected.add(new RankedDocsResults.RankedDoc(1, 1.37F, null)); + expected.add(new RankedDocsResults.RankedDoc(0, -0.3F, null)); + + for (int i = 0; i < ((RankedDocsResults) results).getRankedDocs().size(); i++) { + assertThat(((RankedDocsResults) results).getRankedDocs().get(i).index(), is(expected.get(i).index())); + } + } + + public void testFromCompletionResponse() throws IOException { + String responseJson = """ + { + "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f", + "latency": 564.903929, + "result": { + "text":"completion results" + }, + "usage": { + "output_tokens": 6320, + "input_tokens": 35, + "total_tokens": 6355 + } + } + """; + + var request = new CustomRequest( + null, + List.of("abc"), + CustomModelTests.getTestModel(TaskType.COMPLETION, new CompletionResponseParser("$.result.text")) + ); + + InferenceServiceResults results = CustomResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(results, instanceOf(ChatCompletionResults.class)); + ChatCompletionResults chatCompletionResults = (ChatCompletionResults) results; + assertThat(chatCompletionResults.getResults().size(), is(1)); + assertThat(chatCompletionResults.getResults().get(0).content(), is("completion results")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java index 56987407e02ac..e52d7d9d0ff69 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java @@ -24,34 +24,38 @@ import static org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser.MESSAGE_PATH; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.sameInstance; import static org.mockito.Mockito.mock; public class ErrorResponseParserTests extends ESTestCase { public static ErrorResponseParser createRandom() { - return new ErrorResponseParser("$." + randomAlphaOfLength(5)); + return new ErrorResponseParser("$." + randomAlphaOfLength(5), randomAlphaOfLength(5)); } public void testFromMap() { var validation = new ValidationException(); - var parser = ErrorResponseParser.fromMap(new HashMap<>(Map.of(MESSAGE_PATH, "$.error.message")), validation); + var parser = ErrorResponseParser.fromMap( + new HashMap<>(Map.of(MESSAGE_PATH, "$.error.message")), + "scope", + "inference_id", + validation + ); - assertThat(parser, is(new ErrorResponseParser("$.error.message"))); + assertThat(parser, is(new ErrorResponseParser("$.error.message", "inference_id"))); } public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { var validation = new ValidationException(); var exception = expectThrows( ValidationException.class, - () -> ErrorResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.error.message")), validation) + () -> ErrorResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.error.message")), "scope", "inference_id", validation) ); - assertThat(exception.getMessage(), is("Validation Failed: 1: [error_parser] does not contain the required setting [path];")); + assertThat(exception.getMessage(), is("Validation Failed: 1: [scope.error_parser] does not contain the required setting [path];")); } public void testToXContent() throws IOException { - var entity = new ErrorResponseParser("$.error.message"); + var entity = new ErrorResponseParser("$.error.message", "inference_id"); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); { @@ -80,7 +84,7 @@ public void testErrorResponse_ExtractsError() throws IOException { } }"""); - var parser = new ErrorResponseParser("$.error.message"); + var parser = new ErrorResponseParser("$.error.message", "inference_id"); var error = parser.apply(result); assertThat(error, is(new ErrorResponse("test_error_message"))); } @@ -97,7 +101,7 @@ public void testFromResponse_WithOtherFieldsPresent() throws IOException { } """; - var parser = new ErrorResponseParser("$.error.message"); + var parser = new ErrorResponseParser("$.error.message", "inference_id"); var error = parser.apply(getMockResult(responseJson)); assertThat(error, is(new ErrorResponse("You didn't provide an API key"))); @@ -112,30 +116,29 @@ public void testFromResponse_noMessage() throws IOException { } """; - var parser = new ErrorResponseParser("$.error.message"); + var parser = new ErrorResponseParser("$.error.message", "inference_id"); var error = parser.apply(getMockResult(responseJson)); - assertThat(error, sameInstance(ErrorResponse.UNDEFINED_ERROR)); - assertThat(error.getErrorMessage(), is("")); - assertFalse(error.errorStructureFound()); + assertThat(error.getErrorMessage(), is("Unable to parse the error, response body: [{\"error\":{\"type\":\"not_found_error\"}}]")); + assertTrue(error.errorStructureFound()); } public void testErrorResponse_ReturnsUndefinedObjectIfNoError() throws IOException { var mockResult = getMockResult(""" {"noerror":true}"""); - var parser = new ErrorResponseParser("$.error.message"); + var parser = new ErrorResponseParser("$.error.message", "inference_id"); var error = parser.apply(mockResult); - assertThat(error, sameInstance(ErrorResponse.UNDEFINED_ERROR)); + assertThat(error.getErrorMessage(), is("Unable to parse the error, response body: [{\"noerror\":true}]")); } public void testErrorResponse_ReturnsUndefinedObjectIfNotJson() { var result = new HttpResult(mock(HttpResponse.class), Strings.toUTF8Bytes("not a json string")); - var parser = new ErrorResponseParser("$.error.message"); + var parser = new ErrorResponseParser("$.error.message", "inference_id"); var error = parser.apply(result); - assertThat(error, sameInstance(ErrorResponse.UNDEFINED_ERROR)); + assertThat(error.getErrorMessage(), is("Unable to parse the error, response body: [not a json string]")); } private static HttpResult getMockResult(String jsonString) throws IOException { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParserTests.java index 523d15ec2a805..0c88d1f93bc73 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParserTests.java @@ -53,6 +53,7 @@ public void testFromMap() { "$.result.scores[*].document_text" ) ), + "scope", validation ); @@ -64,7 +65,11 @@ public void testFromMap() { public void testFromMap_WithoutOptionalFields() { var validation = new ValidationException(); - var parser = RerankResponseParser.fromMap(new HashMap<>(Map.of(RERANK_PARSER_SCORE, "$.result.scores[*].score")), validation); + var parser = RerankResponseParser.fromMap( + new HashMap<>(Map.of(RERANK_PARSER_SCORE, "$.result.scores[*].score")), + "scope", + validation + ); assertThat(parser, is(new RerankResponseParser("$.result.scores[*].score", null, null))); } @@ -73,12 +78,12 @@ public void testFromMap_ThrowsException_WhenRequiredFieldsAreNotPresent() { var validation = new ValidationException(); var exception = expectThrows( ValidationException.class, - () -> RerankResponseParser.fromMap(new HashMap<>(Map.of("not_path", "$.result[*].embeddings")), validation) + () -> RerankResponseParser.fromMap(new HashMap<>(Map.of("not_path", "$.result[*].embeddings")), "scope", validation) ); assertThat( exception.getMessage(), - is("Validation Failed: 1: [json_parser] does not contain the required setting [relevance_score];") + is("Validation Failed: 1: [scope.json_parser] does not contain the required setting [relevance_score];") ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParserTests.java index c4b69ae8c8b19..7e54f95ef0fc1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParserTests.java @@ -49,6 +49,7 @@ public void testFromMap() { "$.result[*].embeddings[*].weight" ) ), + "scope", validation ); @@ -59,14 +60,14 @@ public void testFromMap_ThrowsException_WhenRequiredFieldsAreNotPresent() { var validation = new ValidationException(); var exception = expectThrows( ValidationException.class, - () -> SparseEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("not_path", "$.result[*].embeddings")), validation) + () -> SparseEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("not_path", "$.result[*].embeddings")), "scope", validation) ); assertThat( exception.getMessage(), is( - "Validation Failed: 1: [json_parser] does not contain the required setting [token_path];" - + "2: [json_parser] does not contain the required setting [weight_path];" + "Validation Failed: 1: [scope.json_parser] does not contain the required setting [token_path];" + + "2: [scope.json_parser] does not contain the required setting [weight_path];" ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java index b240e07a66336..82ddfa618d3b7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java @@ -40,6 +40,7 @@ public void testFromMap() { var validation = new ValidationException(); var parser = TextEmbeddingResponseParser.fromMap( new HashMap<>(Map.of(TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result[*].embeddings")), + "scope", validation ); @@ -50,12 +51,12 @@ public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { var validation = new ValidationException(); var exception = expectThrows( ValidationException.class, - () -> TextEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].embeddings")), validation) + () -> TextEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].embeddings")), "scope", validation) ); assertThat( exception.getMessage(), - is("Validation Failed: 1: [json_parser] does not contain " + "the required setting [text_embeddings];") + is("Validation Failed: 1: [scope.json_parser] does not contain the required setting [text_embeddings];") ); } From eba5fce3cd7b43a43f8f7ecbb547b8c81beb45f2 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 8 May 2025 16:06:04 -0400 Subject: [PATCH 02/12] Custom service fixes --- .../org/elasticsearch/TransportVersions.java | 3 +- ...stStreamingCompletionServiceExtension.java | 50 +++++++------------ .../rest/RestPutInferenceModelAction.java | 10 +++- .../action/PutInferenceModelRequestTests.java | 46 ++++++++--------- 4 files changed, 50 insertions(+), 59 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 82e70c2fd69f5..6f59e193006db 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -174,6 +174,7 @@ static TransportVersion def(int id) { public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19 = def(8_841_0_27); public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28); public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29); + public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL_8_19 = def(8_841_0_30); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); @@ -252,7 +253,7 @@ static TransportVersion def(int id) { public static final TransportVersion FIELD_CAPS_ADD_CLUSTER_ALIAS = def(9_073_0_00); public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_074_0_00); public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00); - + public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL = def(9_076_0_00); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index e34018c5b8df1..b2f8ba5475eb8 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -34,6 +34,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.DequeUtils; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; @@ -256,37 +257,24 @@ public void cancel() {} "object": "chat.completion.chunk" } */ - private InferenceServiceResults.Result unifiedCompletionChunk(String delta) { - return new InferenceServiceResults.Result() { - @Override - public String getWriteableName() { - return "test_unifiedCompletionChunk"; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(delta); - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - return ChunkedToXContentHelper.chunk( - (b, p) -> b.startObject() - .field("id", "id") - .startArray("choices") - .startObject() - .startObject("delta") - .field("content", delta) - .endObject() - .field("index", 0) - .endObject() - .endArray() - .field("model", "gpt-4o-2024-08-06") - .field("object", "chat.completion.chunk") - .endObject() - ); - } - }; + private StreamingUnifiedChatCompletionResults.Results unifiedCompletionChunk(String delta) { + return new StreamingUnifiedChatCompletionResults.Results( + DequeUtils.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + "id", + List.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(delta, null, null, null), + null, + 0 + ) + ), + "gpt-4o-2024-08-06", + "chat.completion.chunk", + null + ) + ) + ); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java index 655e11996d522..838e6512d805f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java @@ -20,6 +20,7 @@ import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; +import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout; import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID; import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH; import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH; @@ -49,8 +50,15 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient taskType = TaskType.ANY; // task type must be defined in the body } + var inferTimeout = parseTimeout(restRequest); var content = restRequest.requiredContent(); - var request = new PutInferenceModelAction.Request(taskType, inferenceEntityId, content, restRequest.getXContentType()); + var request = new PutInferenceModelAction.Request( + taskType, + inferenceEntityId, + content, + restRequest.getXContentType(), + inferTimeout + ); return channel -> client.execute( PutInferenceModelAction.INSTANCE, request, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java index f61398fcacacf..e514867780669 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java @@ -7,13 +7,16 @@ package org.elasticsearch.xpack.inference.action; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; -public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase { +public class PutInferenceModelRequestTests extends AbstractBWCWireSerializationTestCase { @Override protected Writeable.Reader instanceReader() { return PutInferenceModelAction.Request::new; @@ -25,38 +28,29 @@ protected PutInferenceModelAction.Request createTestInstance() { randomFrom(TaskType.values()), randomAlphaOfLength(6), randomBytesReference(50), - randomFrom(XContentType.values()) + randomFrom(XContentType.values()), + randomTimeValue() ); } @Override protected PutInferenceModelAction.Request mutateInstance(PutInferenceModelAction.Request instance) { - return switch (randomIntBetween(0, 3)) { - case 0 -> new PutInferenceModelAction.Request( - TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length], - instance.getInferenceEntityId(), - instance.getContent(), - instance.getContentType() - ); - case 1 -> new PutInferenceModelAction.Request( - instance.getTaskType(), - instance.getInferenceEntityId() + "foo", - instance.getContent(), - instance.getContentType() - ); - case 2 -> new PutInferenceModelAction.Request( - instance.getTaskType(), - instance.getInferenceEntityId(), - randomBytesReference(instance.getContent().length() + 1), - instance.getContentType() - ); - case 3 -> new PutInferenceModelAction.Request( + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected PutInferenceModelAction.Request mutateInstanceForVersion(PutInferenceModelAction.Request instance, TransportVersion version) { + if (version.onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT) + || version.isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) { + return instance; + } else { + return new PutInferenceModelAction.Request( instance.getTaskType(), instance.getInferenceEntityId(), instance.getContent(), - XContentType.values()[(instance.getContentType().ordinal() + 1) % XContentType.values().length] + instance.getContentType(), + InferenceAction.Request.DEFAULT_TIMEOUT ); - default -> throw new IllegalStateException(); - }; + } } } From 9af98bee9e556c85c4523082bf26dc96c91d3530 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Thu, 8 May 2025 16:08:47 -0400 Subject: [PATCH 03/12] Update docs/changelog/127939.yaml --- docs/changelog/127939.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/127939.yaml diff --git a/docs/changelog/127939.yaml b/docs/changelog/127939.yaml new file mode 100644 index 0000000000000..b369363052679 --- /dev/null +++ b/docs/changelog/127939.yaml @@ -0,0 +1,5 @@ +pr: 127939 +summary: Custom inference service jon +area: Machine Learning +type: enhancement +issues: [] From cb09e3082aed65304a10f73603959ff8ede37917 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 8 May 2025 16:16:48 -0400 Subject: [PATCH 04/12] Cleaning up from failed merge --- x-pack/plugin/inference/build.gradle | 1 + .../ShardBulkInferenceActionFilterIT.java | 30 ++++++------------- .../xpack/inference/InferencePlugin.java | 9 +++--- .../TransportPutInferenceModelAction.java | 2 +- .../external/http/HttpClientManager.java | 18 ++++++++--- .../external/http/StreamingHttpResult.java | 5 ++-- .../http/retry/RetryingHttpSender.java | 7 ++--- .../mapper/SemanticTextFieldMapper.java | 5 +++- .../ElasticInferenceServiceSettings.java | 18 +++++++++++ .../OpenAiUnifiedStreamingProcessor.java | 28 +---------------- .../plugin-metadata/plugin-security.policy | 27 ----------------- .../xpack/inference/model/TestModel.java | 8 ++++- 12 files changed, 65 insertions(+), 93 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index fba8d9e61f0c4..58aa9b29f8565 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -8,6 +8,7 @@ apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' apply plugin: 'elasticsearch.internal-yaml-rest-test' +apply plugin: 'elasticsearch.internal-test-artifact' restResources { restApi { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java index 074678bbea095..8405fba22460f 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java @@ -9,7 +9,6 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; -import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; @@ -17,6 +16,7 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; @@ -242,12 +242,10 @@ public void testRestart() throws Exception { private void assertRandomBulkOperations(String indexName, Function> sourceSupplier) throws Exception { int numHits = numHits(indexName); - int totalBulkReqs = randomIntBetween(2, 100); - long totalDocs = numHits; + int totalBulkReqs = randomIntBetween(2, 10); Set ids = new HashSet<>(); - - for (int bulkReqs = numHits; bulkReqs < totalBulkReqs; bulkReqs++) { - BulkRequestBuilder bulkReqBuilder = client().prepareBulk(); + for (int bulkReqs = 0; bulkReqs < totalBulkReqs; bulkReqs++) { + BulkRequestBuilder bulkReqBuilder = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); int totalBulkSize = randomIntBetween(1, 100); for (int bulkSize = 0; bulkSize < totalBulkSize; bulkSize++) { if (ids.size() > 0 && rarely(random())) { @@ -257,24 +255,15 @@ private void assertRandomBulkOperations(String indexName, Function source = sourceSupplier.apply(isIndexRequest); if (isIndexRequest) { + String id = randomAlphaOfLength(20); bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(indexName).setId(id).setSource(source)); ids.add(id); } else { - boolean isUpsert = randomBoolean(); - UpdateRequestBuilder request = new UpdateRequestBuilder(client()).setIndex(indexName).setDoc(source); - if (isUpsert || ids.size() == 0) { - request.setDocAsUpsert(true); - } else { - // Update already existing document - id = randomFrom(ids); - } - request.setId(id); - bulkReqBuilder.add(request); - ids.add(id); + String id = randomFrom(ids); + bulkReqBuilder.add(new UpdateRequestBuilder(client()).setIndex(indexName).setId(id).setDoc(source)); } } BulkResponse bulkResponse = bulkReqBuilder.get(); @@ -293,8 +282,7 @@ private void assertRandomBulkOperations(String indexName, Function createComponents(PluginServices services) { var inferenceServices = new ArrayList<>(inferenceServiceExtensions); inferenceServices.add(this::getInferenceServiceFactories); + var inferenceServiceSettings = new ElasticInferenceServiceSettings(settings); + inferenceServiceSettings.init(services.clusterService()); + // Create a separate instance of HTTPClientManager with its own SSL configuration (`xpack.inference.elastic.http.ssl.*`). var elasticInferenceServiceHttpClientManager = HttpClientManager.create( settings, services.threadPool(), services.clusterService(), throttlerManager, - getSslService() + getSslService(), + inferenceServiceSettings.getConnectionTtl() ); var elasticInferenceServiceRequestSenderFactory = new HttpRequestSender.Factory( @@ -293,9 +297,6 @@ public Collection createComponents(PluginServices services) { ); elasicInferenceServiceFactory.set(elasticInferenceServiceRequestSenderFactory); - var inferenceServiceSettings = new ElasticInferenceServiceSettings(settings); - inferenceServiceSettings.init(services.clusterService()); - var authorizationHandler = new ElasticInferenceServiceAuthorizationRequestHandler( inferenceServiceSettings.getElasticInferenceServiceUrl(), services.threadPool() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index eeea8a28df486..bc9d87f43ada0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -177,7 +177,7 @@ protected void masterOperation( return; } - parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener); + parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), listener); } private void parseAndStoreModel( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java index 6d09c9e67b363..ddf19ff0dc96f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java @@ -32,6 +32,7 @@ import java.io.Closeable; import java.io.IOException; import java.util.List; +import java.util.concurrent.TimeUnit; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_SSL_CONFIGURATION_PREFIX; @@ -112,14 +113,15 @@ public static HttpClientManager create( ThreadPool threadPool, ClusterService clusterService, ThrottlerManager throttlerManager, - SSLService sslService + SSLService sslService, + TimeValue connectionTtl ) { // Set the sslStrategy to ensure an encrypted connection, as Elastic Inference Service requires it. SSLIOSessionStrategy sslioSessionStrategy = sslService.sslIOSessionStrategy( sslService.getSSLConfiguration(ELASTIC_INFERENCE_SERVICE_SSL_CONFIGURATION_PREFIX) ); - PoolingNHttpClientConnectionManager connectionManager = createConnectionManager(sslioSessionStrategy); + PoolingNHttpClientConnectionManager connectionManager = createConnectionManager(sslioSessionStrategy, connectionTtl); return new HttpClientManager(settings, connectionManager, threadPool, clusterService, throttlerManager); } @@ -146,7 +148,7 @@ public static HttpClientManager create( this.addSettingsUpdateConsumers(clusterService); } - private static PoolingNHttpClientConnectionManager createConnectionManager(SSLIOSessionStrategy sslStrategy) { + private static PoolingNHttpClientConnectionManager createConnectionManager(SSLIOSessionStrategy sslStrategy, TimeValue connectionTtl) { ConnectingIOReactor ioReactor; try { var configBuilder = IOReactorConfig.custom().setSoKeepAlive(true); @@ -162,7 +164,15 @@ private static PoolingNHttpClientConnectionManager createConnectionManager(SSLIO .register("https", sslStrategy) .build(); - return new PoolingNHttpClientConnectionManager(ioReactor, registry); + return new PoolingNHttpClientConnectionManager( + ioReactor, + null, + registry, + null, + null, + Math.toIntExact(connectionTtl.getMillis()), + TimeUnit.MILLISECONDS + ); } private static PoolingNHttpClientConnectionManager createConnectionManager() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java index f384d79adae3e..1786ee98fcd87 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java @@ -11,7 +11,6 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; import java.io.ByteArrayOutputStream; import java.util.concurrent.Flow; @@ -22,7 +21,7 @@ public boolean isSuccessfulResponse() { return RestStatus.isSuccessful(response.getStatusLine().getStatusCode()); } - public Flow.Publisher toHttpResult(HttpRequest httpRequest) { + public Flow.Publisher toHttpResult() { return subscriber -> body().subscribe(new Flow.Subscriber<>() { @Override public void onSubscribe(Flow.Subscription subscription) { @@ -46,7 +45,7 @@ public void onComplete() { }); } - public void readFullResponse(HttpRequest httpRequest, ActionListener fullResponse) { + public void readFullResponse(ActionListener fullResponse) { var stream = new ByteArrayOutputStream(); AtomicReference upstream = new AtomicReference<>(null); body.subscribe(new Flow.Subscriber<>() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java index e8cb5d3ad16d9..d009ec87d5776 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java @@ -115,12 +115,11 @@ public void tryAction(ActionListener listener) { try { if (request.isStreaming() && responseHandler.canHandleStreamingResponses()) { - var httpRequest = request.createHttpRequest(); - httpClient.stream(httpRequest, context, retryableListener.delegateFailure((l, r) -> { + httpClient.stream(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> { if (r.isSuccessfulResponse()) { - l.onResponse(responseHandler.parseResult(request, r.toHttpResult(httpRequest))); + l.onResponse(responseHandler.parseResult(request, r.toHttpResult())); } else { - r.readFullResponse(httpRequest, l.delegateFailureAndWrap((ll, httpResult) -> { + r.readFullResponse(l.delegateFailureAndWrap((ll, httpResult) -> { try { responseHandler.validateResponse(throttlerManager, logger, request, httpResult, true); InferenceServiceResults inferenceResults = responseHandler.parseResult(request, httpResult); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 548f65d4f93fa..d15414e34aef1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -32,6 +32,7 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.BlockLoader; @@ -98,6 +99,7 @@ import java.util.function.Supplier; import static org.elasticsearch.index.IndexVersions.SEMANTIC_TEXT_DEFAULTS_TO_BBQ; +import static org.elasticsearch.index.IndexVersions.SEMANTIC_TEXT_DEFAULTS_TO_BBQ_BACKPORT_8_X; import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; @@ -1077,7 +1079,8 @@ private static Mapper.Builder createEmbeddingsField( denseVectorMapperBuilder.elementType(modelSettings.elementType()); DenseVectorFieldMapper.IndexOptions defaultIndexOptions = null; - if (indexVersionCreated.onOrAfter(SEMANTIC_TEXT_DEFAULTS_TO_BBQ)) { + if (indexVersionCreated.onOrAfter(SEMANTIC_TEXT_DEFAULTS_TO_BBQ) + || indexVersionCreated.between(SEMANTIC_TEXT_DEFAULTS_TO_BBQ_BACKPORT_8_X, IndexVersions.UPGRADE_TO_LUCENE_10_0_0)) { defaultIndexOptions = defaultSemanticDenseIndexOptions(); } if (defaultIndexOptions != null diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java index fe6ebb6cfb625..0d8bef246b35d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java @@ -70,6 +70,17 @@ public class ElasticInferenceServiceSettings { Setting.Property.NodeScope ); + /** + * Total time to live (TTL) defines maximum life span of persistent connections regardless of their + * expiration setting. No persistent connection will be re-used past its TTL value. + * Using a TTL of -1 will disable the expiration of persistent connections (the idle connection evictor will still apply). + */ + public static final Setting CONNECTION_TTL_SETTING = Setting.timeSetting( + "xpack.inference.elastic.http.connection_ttl", + TimeValue.timeValueSeconds(60), + Setting.Property.NodeScope + ); + @Deprecated private final String eisGatewayUrl; @@ -77,6 +88,7 @@ public class ElasticInferenceServiceSettings { private final boolean periodicAuthorizationEnabled; private volatile TimeValue authRequestInterval; private volatile TimeValue maxAuthorizationRequestJitter; + private final TimeValue connectionTtl; public ElasticInferenceServiceSettings(Settings settings) { eisGatewayUrl = EIS_GATEWAY_URL.get(settings); @@ -84,6 +96,7 @@ public ElasticInferenceServiceSettings(Settings settings) { periodicAuthorizationEnabled = PERIODIC_AUTHORIZATION_ENABLED.get(settings); authRequestInterval = AUTHORIZATION_REQUEST_INTERVAL.get(settings); maxAuthorizationRequestJitter = MAX_AUTHORIZATION_REQUEST_JITTER.get(settings); + connectionTtl = CONNECTION_TTL_SETTING.get(settings); } /** @@ -115,6 +128,10 @@ public TimeValue getMaxAuthorizationRequestJitter() { return maxAuthorizationRequestJitter; } + public TimeValue getConnectionTtl() { + return connectionTtl; + } + public static List> getSettingsDefinitions() { ArrayList> settings = new ArrayList<>(); settings.add(EIS_GATEWAY_URL); @@ -124,6 +141,7 @@ public static List> getSettingsDefinitions() { settings.add(PERIODIC_AUTHORIZATION_ENABLED); settings.add(AUTHORIZATION_REQUEST_INTERVAL); settings.add(MAX_AUTHORIZATION_REQUEST_JITTER); + settings.add(CONNECTION_TTL_SETTING); return settings; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java index 10c8d8928ea65..5ab743c3d4cc0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java @@ -26,7 +26,6 @@ import java.util.Deque; import java.util.Iterator; import java.util.List; -import java.util.concurrent.LinkedBlockingDeque; import java.util.function.BiFunction; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; @@ -60,21 +59,11 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor< public static final String TOTAL_TOKENS_FIELD = "total_tokens"; private final BiFunction errorParser; - private final Deque buffer = new LinkedBlockingDeque<>(); public OpenAiUnifiedStreamingProcessor(BiFunction errorParser) { this.errorParser = errorParser; } - @Override - protected void upstreamRequest(long n) { - if (buffer.isEmpty()) { - super.upstreamRequest(n); - } else { - downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll()))); - } - } - @Override protected void next(Deque item) throws Exception { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); @@ -96,15 +85,8 @@ protected void next(Deque item) throws Exception { if (results.isEmpty()) { upstream().request(1); - } else if (results.size() == 1) { - downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results)); } else { - // results > 1, but openai spec only wants 1 chunk per SSE event - var firstItem = singleItem(results.poll()); - while (results.isEmpty() == false) { - buffer.offer(results.poll()); - } - downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem)); + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results)); } } @@ -297,12 +279,4 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage pa } } } - - private Deque singleItem( - StreamingUnifiedChatCompletionResults.ChatCompletionChunk result - ) { - var deque = new ArrayDeque(1); - deque.offer(result); - return deque; - } } diff --git a/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy deleted file mode 100644 index e36b553d2def2..0000000000000 --- a/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -grant { - // required by: com.google.api.client.json.JsonParser#parseValue - // also required by AWS SDK for client configuration - permission java.lang.RuntimePermission "accessDeclaredMembers"; - permission java.lang.RuntimePermission "getClassLoader"; - - // required by: com.google.api.client.json.GenericJson# - // also by AWS SDK for Jackson's ObjectMapper - permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; - - // required to add google certs to the gcs client trustore - permission java.lang.RuntimePermission "setFactory"; - - // gcs client opens socket connections for to access repository - // also, AWS Bedrock client opens socket connections and needs resolve for to access to resources - permission java.net.SocketPermission "*", "connect,resolve"; - - // AWS Clients always try to check the http.proxyHost system property - permission java.util.PropertyPermission "http.proxyHost", "read"; -}; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index a00f8e55a4e27..c3b50cdb4a670 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -31,6 +31,7 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength; import static org.elasticsearch.test.ESTestCase.randomFrom; import static org.elasticsearch.test.ESTestCase.randomInt; @@ -46,9 +47,14 @@ public static TestModel createRandomInstance(TaskType taskType) { } public static TestModel createRandomInstance(TaskType taskType, List excludedSimilarities) { + // Use a max dimension count that has a reasonable probability of being compatible with BBQ + return createRandomInstance(taskType, excludedSimilarities, BBQ_MIN_DIMS * 2); + } + + public static TestModel createRandomInstance(TaskType taskType, List excludedSimilarities, int maxDimensions) { var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(DenseVectorFieldMapper.ElementType.values()) : null; var dimensions = taskType == TaskType.TEXT_EMBEDDING - ? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 64) + ? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions) : null; SimilarityMeasure similarity = null; From e7c62d8ec7a369d15e6cc5ed62b04f9068f83938 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 8 May 2025 16:18:45 -0400 Subject: [PATCH 05/12] Fixing changelog --- docs/changelog/127939.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog/127939.yaml b/docs/changelog/127939.yaml index b369363052679..7cc67e6207a85 100644 --- a/docs/changelog/127939.yaml +++ b/docs/changelog/127939.yaml @@ -1,5 +1,5 @@ pr: 127939 -summary: Custom inference service jon +summary: Add Custom inference service area: Machine Learning type: enhancement issues: [] From 6bb2a95c9f0c6daecedac90dd7ecb6e9456b6bcd Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 8 May 2025 20:27:16 +0000 Subject: [PATCH 06/12] [CI] Auto commit changes from spotless --- .../custom/CustomServiceSettingsTests.java | 2 +- .../services/custom/CustomServiceTests.java | 23 +++++++++---------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index 71eec73df5375..04830add1bb17 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -259,7 +259,7 @@ public void testFromMap_RemovesNullValues_FromMaps() { ), ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, - "inference_id" + "inference_id" ); MatcherAssert.assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index 6ce181b0487ad..fb6c50f1bd9c4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -164,12 +164,16 @@ private static Map createServiceSettingsMap(TaskType taskType) { ); if (taskType == TaskType.TEXT_EMBEDDING) { - settingsMap.putAll(Map.of(ServiceFields.SIMILARITY, - SimilarityMeasure.DOT_PRODUCT.toString(), - ServiceFields.DIMENSIONS, - 1536, - ServiceFields.MAX_INPUT_TOKENS, - 512)); + settingsMap.putAll( + Map.of( + ServiceFields.SIMILARITY, + SimilarityMeasure.DOT_PRODUCT.toString(), + ServiceFields.DIMENSIONS, + 1536, + ServiceFields.MAX_INPUT_TOKENS, + 512 + ) + ); } return settingsMap; @@ -235,12 +239,7 @@ private static CustomModel createInternalEmbeddingModel( TaskType.TEXT_EMBEDDING, CustomService.NAME, new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings( - similarityMeasure, - 123, - 456, - DenseVectorFieldMapper.ElementType.FLOAT - ), + new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, 123, 456, DenseVectorFieldMapper.ElementType.FLOAT), url, Map.of("key", "value"), QueryParameters.EMPTY, From 6be22b58aeeddb6b06487a40f9cf8d29729d78da Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 20 May 2025 14:10:26 -0400 Subject: [PATCH 07/12] Fixing test --- .../elasticsearch/xpack/inference/InferenceGetServicesIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 85a9aafae44bc..d4ba5dbfdab59 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -126,7 +126,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { public void testGetServicesWithCompletionTaskType() throws IOException { List services = getServices(TaskType.COMPLETION); - assertThat(services.size(), equalTo(11)); + assertThat(services.size(), equalTo(12)); var providers = providers(services); From 84c16cee5aea729058b07003fccb2a9e1f2adfaa Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 29 May 2025 10:32:58 -0400 Subject: [PATCH 08/12] Adding feature flag --- .../test/cluster/FeatureFlag.java | 3 ++- .../inference/BaseMockEISAuthServerTest.java | 2 ++ .../inference/InferenceBaseRestTest.java | 2 ++ .../inference/CustomServiceFeatureFlag.java | 21 +++++++++++++++++++ .../InferenceNamedWriteablesProvider.java | 6 ++++++ .../xpack/inference/InferencePlugin.java | 12 +++++++++-- .../inference/services/ServiceUtils.java | 4 ++-- 7 files changed, 45 insertions(+), 5 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/CustomServiceFeatureFlag.java diff --git a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java index eea25b85b8548..53e4a971add7d 100644 --- a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java +++ b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java @@ -19,7 +19,8 @@ public enum FeatureFlag { TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null), SUB_OBJECTS_AUTO_ENABLED("es.sub_objects_auto_feature_flag_enabled=true", Version.fromString("8.16.0"), null), DOC_VALUES_SKIPPER("es.doc_values_skipper_feature_flag_enabled=true", Version.fromString("8.18.1"), null), - USE_LUCENE101_POSTINGS_FORMAT("es.use_lucene101_postings_format_feature_flag_enabled=true", Version.fromString("9.1.0"), null); + USE_LUCENE101_POSTINGS_FORMAT("es.use_lucene101_postings_format_feature_flag_enabled=true", Version.fromString("9.1.0"), null), + INFERENCE_CUSTOM_SERVICE_ENABLED("es.inference_custom_service_feature_flag_enabled=true", Version.fromString("8.19.0"), null); public final String systemProperty; public final Version from; diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java index 81c1a8dc7a5ba..1fb2d5a7463f2 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.FeatureFlag; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.ESRestTestCase; import org.junit.ClassRule; @@ -46,6 +47,7 @@ public class BaseMockEISAuthServerTest extends ESRestTestCase { // This plugin is located in the inference/qa/test-service-plugin package, look for TestInferenceServicePlugin .plugin("inference-service-test") .user("x_pack_rest_user", "x-pack-test-password") + .feature(FeatureFlag.INFERENCE_CUSTOM_SERVICE_ENABLED) .build(); // The reason we're doing this is to make sure the mock server is initialized first so we can get the address before communicating diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 54d0aca772061..c3d573c1d6af5 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -20,6 +20,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.FeatureFlag; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.ESRestTestCase; import org.elasticsearch.xcontent.XContentBuilder; @@ -50,6 +51,7 @@ public class InferenceBaseRestTest extends ESRestTestCase { .setting("xpack.security.enabled", "true") .plugin("inference-service-test") .user("x_pack_rest_user", "x-pack-test-password") + .feature(FeatureFlag.INFERENCE_CUSTOM_SERVICE_ENABLED) .build(); @ClassRule diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/CustomServiceFeatureFlag.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/CustomServiceFeatureFlag.java new file mode 100644 index 0000000000000..203947e782250 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/CustomServiceFeatureFlag.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.common.util.FeatureFlag; + +public class CustomServiceFeatureFlag { + /** + * {@link org.elasticsearch.xpack.inference.services.custom.CustomService} feature flag. When the feature is complete, + * this flag will be removed. + * Enable feature via JVM option: `-Des.inference_custom_service_feature_flag_enabled=true`. + */ + public static final FeatureFlag CUSTOM_SERVICE_FEATURE_FLAG = new FeatureFlag("inference_custom_service"); + + private CustomServiceFeatureFlag() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index e85b618431eb6..ba7de99532a3b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -114,6 +114,8 @@ import java.util.ArrayList; import java.util.List; +import static org.elasticsearch.xpack.inference.CustomServiceFeatureFlag.CUSTOM_SERVICE_FEATURE_FLAG; + public class InferenceNamedWriteablesProvider { private InferenceNamedWriteablesProvider() {} @@ -177,6 +179,10 @@ public static List getNamedWriteables() { } private static void addCustomNamedWriteables(List namedWriteables) { + if (CUSTOM_SERVICE_FEATURE_FLAG.isEnabled() == false) { + return; + } + namedWriteables.add( new NamedWriteableRegistry.Entry(ServiceSettings.class, CustomServiceSettings.NAME, CustomServiceSettings::new) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index b2db235005f9a..f74d7630b275a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -149,8 +149,10 @@ import java.util.Set; import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.stream.Stream; import static java.util.Collections.singletonList; +import static org.elasticsearch.xpack.inference.CustomServiceFeatureFlag.CUSTOM_SERVICE_FEATURE_FLAG; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE; import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG; @@ -380,7 +382,11 @@ public void loadExtensions(ExtensionLoader loader) { } public List getInferenceServiceFactories() { - return List.of( + List conditionalServices = CUSTOM_SERVICE_FEATURE_FLAG.isEnabled() + ? List.of(context -> new CustomService(httpFactory.get(), serviceComponents.get())) + : List.of(); + + List availableServices = List.of( context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()), context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()), context -> new OpenAiService(httpFactory.get(), serviceComponents.get()), @@ -397,9 +403,11 @@ public List getInferenceServiceFactories() { context -> new JinaAIService(httpFactory.get(), serviceComponents.get()), context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()), context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()), - context -> new CustomService(httpFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); + + return Stream.concat(availableServices.stream(), conditionalServices.stream()) + .toList(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 428c266379f65..b12f5989e55ce 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -206,7 +206,7 @@ public static void throwIfNotEmptyMap(Map settingsMap, String fi public static ElasticsearchStatusException unknownSettingsError(Map config, String serviceName) { // TODO map as JSON return new ElasticsearchStatusException( - "Model configuration contains settings [{}] unknown to the [{}] service", + "Configuration contains settings [{}] unknown to the [{}] service", RestStatus.BAD_REQUEST, config, serviceName @@ -215,7 +215,7 @@ public static ElasticsearchStatusException unknownSettingsError(Map config, String field, String scope) { return new ElasticsearchStatusException( - "Model configuration contains unknown settings [{}] while parsing field [{}] for settings [{}]", + "Configuration contains unknown settings [{}] while parsing field [{}] for settings [{}]", RestStatus.BAD_REQUEST, config, field, From 280d4dd942bb9643acb5abca7ad1006cea8458c8 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 29 May 2025 14:45:01 +0000 Subject: [PATCH 09/12] [CI] Auto commit changes from spotless --- .../org/elasticsearch/xpack/inference/InferencePlugin.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index f74d7630b275a..915a4d3f7af9b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -386,7 +386,7 @@ public List getInferenceServiceFactories() { ? List.of(context -> new CustomService(httpFactory.get(), serviceComponents.get())) : List.of(); - List availableServices = List.of( + List availableServices = List.of( context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()), context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()), context -> new OpenAiService(httpFactory.get(), serviceComponents.get()), @@ -406,8 +406,7 @@ public List getInferenceServiceFactories() { ElasticsearchInternalService::new ); - return Stream.concat(availableServices.stream(), conditionalServices.stream()) - .toList(); + return Stream.concat(availableServices.stream(), conditionalServices.stream()).toList(); } @Override From d1137b69f8421ffa9efe460a72ea10150f00c05b Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 29 May 2025 11:47:05 -0400 Subject: [PATCH 10/12] Fixing test issue --- .../inference/services/AbstractServiceTests.java | 8 ++++---- .../amazonbedrock/AmazonBedrockServiceTests.java | 8 ++++---- .../services/anthropic/AnthropicServiceTests.java | 8 ++++---- .../azureaistudio/AzureAiStudioServiceTests.java | 14 +++++++------- .../azureopenai/AzureOpenAiServiceTests.java | 8 ++++---- .../services/cohere/CohereServiceTests.java | 8 ++++---- .../custom/CustomServiceSettingsTests.java | 8 ++++---- .../services/deepseek/DeepSeekServiceTests.java | 2 +- .../elastic/ElasticInferenceServiceTests.java | 8 ++++---- .../googleaistudio/GoogleAiStudioServiceTests.java | 8 ++++---- .../googlevertexai/GoogleVertexAiServiceTests.java | 8 ++++---- .../huggingface/HuggingFaceServiceTests.java | 6 +++--- .../ibmwatsonx/IbmWatsonxServiceTests.java | 2 +- .../services/jinaai/JinaAIServiceTests.java | 8 ++++---- .../services/mistral/MistralServiceTests.java | 6 +++--- .../services/openai/OpenAiServiceTests.java | 8 ++++---- .../model/SageMakerModelBuilderTests.java | 4 ++-- .../services/settings/RateLimitSettingsTests.java | 2 +- .../services/voyageai/VoyageAIServiceTests.java | 8 ++++---- 19 files changed, 66 insertions(+), 66 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractServiceTests.java index 071c4caa90a9f..24e0c7cadb73f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractServiceTests.java @@ -227,7 +227,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), containsString("Model configuration contains settings [{extra_key=value}]")); + assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]")); } } @@ -246,7 +246,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), containsString("Model configuration contains settings [{extra_key=value}]")); + assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]")); } } @@ -265,7 +265,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), containsString("Model configuration contains settings [{extra_key=value}]")); + assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]")); } } @@ -284,7 +284,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), containsString("Model configuration contains settings [{extra_key=value}]")); + assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 07ee7397504a0..a014f27e7f0cc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -349,7 +349,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I assertThat(exception, instanceOf(ElasticsearchStatusException.class)); assertThat( exception.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") ); }); @@ -368,7 +368,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat( e.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") ); }); @@ -390,7 +390,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat( e.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") ); }); @@ -412,7 +412,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat( e.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") ); }); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index 856ec3ed419ea..75ce59b16a763 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -153,7 +153,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [anthropic] service" + "Configuration contains settings [{extra_key=value}] unknown to the [anthropic] service" ); service.parseRequestConfig("id", TaskType.COMPLETION, config, failureListener); } @@ -172,7 +172,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [anthropic] service" + "Configuration contains settings [{extra_key=value}] unknown to the [anthropic] service" ); service.parseRequestConfig("id", TaskType.COMPLETION, config, failureListener); } @@ -191,7 +191,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [anthropic] service" + "Configuration contains settings [{extra_key=value}] unknown to the [anthropic] service" ); service.parseRequestConfig("id", TaskType.COMPLETION, config, failureListener); } @@ -210,7 +210,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [anthropic] service" + "Configuration contains settings [{extra_key=value}] unknown to the [anthropic] service" ); service.parseRequestConfig("id", TaskType.COMPLETION, config, failureListener); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 20d28c3068ed5..3d7ba7f7436fb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -257,7 +257,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I assertThat(exception, instanceOf(ElasticsearchStatusException.class)); assertThat( exception.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") ); } ); @@ -279,7 +279,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingServiceS assertThat(exception, instanceOf(ElasticsearchStatusException.class)); assertThat( exception.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") ); } ); @@ -328,7 +328,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingTaskSett assertThat(exception, instanceOf(ElasticsearchStatusException.class)); assertThat( exception.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") ); } ); @@ -354,7 +354,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSe assertThat(exception, instanceOf(ElasticsearchStatusException.class)); assertThat( exception.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") ); } ); @@ -380,7 +380,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionSer assertThat(exception, instanceOf(ElasticsearchStatusException.class)); assertThat( exception.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") ); } ); @@ -406,7 +406,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionTas assertThat(exception, instanceOf(ElasticsearchStatusException.class)); assertThat( exception.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") ); } ); @@ -432,7 +432,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionSec assertThat(exception, instanceOf(ElasticsearchStatusException.class)); assertThat( exception.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") ); } ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index f4ab75ebee48f..61be006c28223 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -228,7 +228,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I assertThat(exception, instanceOf(ElasticsearchStatusException.class)); assertThat( exception.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") ); } ); @@ -254,7 +254,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat( e.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") ); }); @@ -279,7 +279,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat( e.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") ); }); @@ -304,7 +304,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat( e.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") ); }); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 469e9e55c695f..fabf87151644b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -272,7 +272,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [cohere] service" + "Configuration contains settings [{extra_key=value}] unknown to the [cohere] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } @@ -287,7 +287,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [cohere] service" + "Configuration contains settings [{extra_key=value}] unknown to the [cohere] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } @@ -306,7 +306,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [cohere] service" + "Configuration contains settings [{extra_key=value}] unknown to the [cohere] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); @@ -326,7 +326,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [cohere] service" + "Configuration contains settings [{extra_key=value}] unknown to the [cohere] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index 04830add1bb17..1bb3d44b897c8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -498,7 +498,7 @@ public void testFromMap_ReturnsError_IfRequestMapIsNotEmptyAfterParsing() { assertThat( exception.getMessage(), is( - "Model configuration contains unknown settings [{key=value}] while parsing field [request]" + "Configuration contains unknown settings [{key=value}] while parsing field [request]" + " for settings [custom_service_settings]" ) ); @@ -543,7 +543,7 @@ public void testFromMap_ReturnsError_IfJsonParserMapIsNotEmptyAfterParsing() { assertThat( exception.getMessage(), is( - "Model configuration contains unknown settings [{key=value}] while parsing field [json_parser]" + "Configuration contains unknown settings [{key=value}] while parsing field [json_parser]" + " for settings [custom_service_settings]" ) ); @@ -585,7 +585,7 @@ public void testFromMap_ReturnsError_IfResponseMapIsNotEmptyAfterParsing() { assertThat( exception.getMessage(), is( - "Model configuration contains unknown settings [{key=value}] while parsing field [response]" + "Configuration contains unknown settings [{key=value}] while parsing field [response]" + " for settings [custom_service_settings]" ) ); @@ -625,7 +625,7 @@ public void testFromMap_ReturnsError_IfErrorParserMapIsNotEmptyAfterParsing() { assertThat( exception.getMessage(), is( - "Model configuration contains unknown settings [{key=value}] while parsing field [error_parser]" + "Configuration contains unknown settings [{key=value}] while parsing field [error_parser]" + " for settings [custom_service_settings]" ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index 88a2fc76aadcf..6204c2588d3f5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -148,7 +148,7 @@ public void testParseRequestConfigWithExtraSettings() throws IOException { assertNoSuccessListener( e -> assertThat( e.getMessage(), - equalTo("Model configuration contains settings [{so=extra}] unknown to the [deepseek] service") + equalTo("Configuration contains settings [{so=extra}] unknown to the [deepseek] service") ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index ba92bf399f99c..7e19c7407af6d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -157,7 +157,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [elastic] service" + "Configuration contains settings [{extra_key=value}] unknown to the [elastic] service" ); service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, failureListener); } @@ -172,7 +172,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [elastic] service" + "Configuration contains settings [{extra_key=value}] unknown to the [elastic] service" ); service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, failureListener); } @@ -186,7 +186,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [elastic] service" + "Configuration contains settings [{extra_key=value}] unknown to the [elastic] service" ); service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, failureListener); } @@ -200,7 +200,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [elastic] service" + "Configuration contains settings [{extra_key=value}] unknown to the [elastic] service" ); service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, failureListener); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 4581c23563e0b..41175581df1cf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -251,7 +251,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" + "Configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" ); service.parseRequestConfig("id", TaskType.COMPLETION, config, failureListener); } @@ -266,7 +266,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" + "Configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" ); service.parseRequestConfig("id", TaskType.COMPLETION, config, failureListener); } @@ -285,7 +285,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" + "Configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" ); service.parseRequestConfig("id", TaskType.COMPLETION, config, failureListener); } @@ -304,7 +304,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" + "Configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" ); service.parseRequestConfig("id", TaskType.COMPLETION, config, failureListener); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 8ea1c12ea9e4a..3f09f610a8e96 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -293,7 +293,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [googlevertexai] service" + "Configuration contains settings [{extra_key=value}] unknown to the [googlevertexai] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } @@ -317,7 +317,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [googlevertexai] service" + "Configuration contains settings [{extra_key=value}] unknown to the [googlevertexai] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } @@ -345,7 +345,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [googlevertexai] service" + "Configuration contains settings [{extra_key=value}] unknown to the [googlevertexai] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } @@ -373,7 +373,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [googlevertexai] service" + "Configuration contains settings [{extra_key=value}] unknown to the [googlevertexai] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index fcc3c5bfb98fb..0d76311d81fa5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -638,7 +638,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat( e.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [hugging_face] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [hugging_face] service") ); } ); @@ -660,7 +660,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat( e.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [hugging_face] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [hugging_face] service") ); } ); @@ -682,7 +682,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat( e.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [hugging_face] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [hugging_face] service") ); } ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 4013009a086a7..35dbcdd6aa99f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -276,7 +276,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [watsonxai] service" + "Configuration contains settings [{extra_key=value}] unknown to the [watsonxai] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index b4a39be58b245..eca76bc1a702a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -256,7 +256,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" + "Configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } @@ -275,7 +275,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" + "Configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } @@ -294,7 +294,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" + "Configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); @@ -314,7 +314,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" + "Configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 1b9bb447b2e60..04a7c38229292 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -217,7 +217,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I assertThat(exception, instanceOf(ElasticsearchStatusException.class)); assertThat( exception.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [mistral] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [mistral] service") ); } ); @@ -243,7 +243,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingTaskSett assertThat(exception, instanceOf(ElasticsearchStatusException.class)); assertThat( exception.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [mistral] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [mistral] service") ); } ); @@ -269,7 +269,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSe assertThat(exception, instanceOf(ElasticsearchStatusException.class)); assertThat( exception.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [mistral] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [mistral] service") ); } ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 36833667db96d..26a0a5b6ef770 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -218,7 +218,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I assertThat(exception, instanceOf(ElasticsearchStatusException.class)); assertThat( exception.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [openai] service") + is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service") ); } ); @@ -238,7 +238,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa fail("Expected exception, but got model: " + model); }, e -> { assertThat(e, instanceOf(ElasticsearchStatusException.class)); - assertThat(e.getMessage(), is("Model configuration contains settings [{extra_key=value}] unknown to the [openai] service")); + assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service")); }); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); @@ -256,7 +256,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() fail("Expected exception, but got model: " + model); }, e -> { assertThat(e, instanceOf(ElasticsearchStatusException.class)); - assertThat(e.getMessage(), is("Model configuration contains settings [{extra_key=value}] unknown to the [openai] service")); + assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service")); }); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); @@ -274,7 +274,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap fail("Expected exception, but got model: " + model); }, e -> { assertThat(e, instanceOf(ElasticsearchStatusException.class)); - assertThat(e.getMessage(), is("Model configuration contains settings [{extra_key=value}] unknown to the [openai] service")); + assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service")); }); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModelBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModelBuilderTests.java index 4228cb781a49a..523deedac8667 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModelBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModelBuilderTests.java @@ -201,7 +201,7 @@ public void testFromRequestWithExtraServiceKeys() { } """, ElasticsearchStatusException.class, - "Model configuration contains settings [{hello=there}] unknown to the [service] service" + "Configuration contains settings [{hello=there}] unknown to the [service] service" ); } @@ -222,7 +222,7 @@ public void testFromRequestWithExtraTaskKeys() { } """, ElasticsearchStatusException.class, - "Model configuration contains settings [{hello=there}] unknown to the [service] service" + "Configuration contains settings [{hello=there}] unknown to the [service] service" ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettingsTests.java index 7e3bdd6b8e5dc..4a808087b6363 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettingsTests.java @@ -86,7 +86,7 @@ public void testOf_ThrowsException_WithUnknownField_InRequestContext() { () -> RateLimitSettings.of(settings, new RateLimitSettings(1), validation, "test", ConfigurationParseContext.REQUEST) ); - assertThat(exception.getMessage(), is("Model configuration contains settings [{abc=100}] unknown to the [test] service")); + assertThat(exception.getMessage(), is("Configuration contains settings [{abc=100}] unknown to the [test] service")); } public void testToXContent() throws IOException { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 7b53dc959d0ea..8602621e9eb78 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -247,7 +247,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [voyageai] service" + "Configuration contains settings [{extra_key=value}] unknown to the [voyageai] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } @@ -266,7 +266,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [voyageai] service" + "Configuration contains settings [{extra_key=value}] unknown to the [voyageai] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } @@ -285,7 +285,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [voyageai] service" + "Configuration contains settings [{extra_key=value}] unknown to the [voyageai] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); @@ -305,7 +305,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var failureListener = getModelListenerForException( ElasticsearchStatusException.class, - "Model configuration contains settings [{extra_key=value}] unknown to the [voyageai] service" + "Configuration contains settings [{extra_key=value}] unknown to the [voyageai] service" ); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } From 7d2c11227aa89285b4cf280eb731d64ac2be7da2 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 29 May 2025 15:55:03 +0000 Subject: [PATCH 11/12] [CI] Auto commit changes from spotless --- .../azureopenai/AzureOpenAiServiceTests.java | 15 +---- .../deepseek/DeepSeekServiceTests.java | 5 +- .../model/SageMakerModelBuilderTests.java | 56 ++++++++----------- 3 files changed, 28 insertions(+), 48 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 61be006c28223..de2e9ae9a21b8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -252,10 +252,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa fail("Expected exception, but got model: " + model); }, e -> { assertThat(e, instanceOf(ElasticsearchStatusException.class)); - assertThat( - e.getMessage(), - is("Configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") - ); + assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service")); }); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); @@ -277,10 +274,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() fail("Expected exception, but got model: " + model); }, e -> { assertThat(e, instanceOf(ElasticsearchStatusException.class)); - assertThat( - e.getMessage(), - is("Configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") - ); + assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service")); }); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); @@ -302,10 +296,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap fail("Expected exception, but got model: " + model); }, e -> { assertThat(e, instanceOf(ElasticsearchStatusException.class)); - assertThat( - e.getMessage(), - is("Configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") - ); + assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service")); }); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index 6204c2588d3f5..2f4eed8df7812 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -146,10 +146,7 @@ public void testParseRequestConfigWithExtraSettings() throws IOException { } """, assertNoSuccessListener( - e -> assertThat( - e.getMessage(), - equalTo("Configuration contains settings [{so=extra}] unknown to the [deepseek] service") - ) + e -> assertThat(e.getMessage(), equalTo("Configuration contains settings [{so=extra}] unknown to the [deepseek] service")) ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModelBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModelBuilderTests.java index 523deedac8667..9dd278535850a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModelBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModelBuilderTests.java @@ -187,43 +187,35 @@ public void testFromRequestWithoutEndpointName() { } public void testFromRequestWithExtraServiceKeys() { - testExceptionFromRequest( - """ - { - "service_settings": { - "access_key": "test-access-key", - "secret_key": "test-secret-key", - "region": "us-east-1", - "api": "test-api", - "endpoint_name": "test-endpoint", - "hello": "there" - } + testExceptionFromRequest(""" + { + "service_settings": { + "access_key": "test-access-key", + "secret_key": "test-secret-key", + "region": "us-east-1", + "api": "test-api", + "endpoint_name": "test-endpoint", + "hello": "there" } - """, - ElasticsearchStatusException.class, - "Configuration contains settings [{hello=there}] unknown to the [service] service" - ); + } + """, ElasticsearchStatusException.class, "Configuration contains settings [{hello=there}] unknown to the [service] service"); } public void testFromRequestWithExtraTaskKeys() { - testExceptionFromRequest( - """ - { - "service_settings": { - "access_key": "test-access-key", - "secret_key": "test-secret-key", - "region": "us-east-1", - "api": "test-api", - "endpoint_name": "test-endpoint" - }, - "task_settings": { - "hello": "there" - } + testExceptionFromRequest(""" + { + "service_settings": { + "access_key": "test-access-key", + "secret_key": "test-secret-key", + "region": "us-east-1", + "api": "test-api", + "endpoint_name": "test-endpoint" + }, + "task_settings": { + "hello": "there" } - """, - ElasticsearchStatusException.class, - "Configuration contains settings [{hello=there}] unknown to the [service] service" - ); + } + """, ElasticsearchStatusException.class, "Configuration contains settings [{hello=there}] unknown to the [service] service"); } public void testRoundTrip() throws IOException { From 63fdaed95033c75ae3aa81f4bf11c514b205469c Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 29 May 2025 16:00:28 -0400 Subject: [PATCH 12/12] Fixing the expected values --- .../elasticsearch/xpack/inference/InferenceGetServicesIT.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 529b35cfaeb50..0ac7e3336fb42 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -103,7 +103,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { public void testGetServicesWithRerankTaskType() throws IOException { List services = getServices(TaskType.RERANK); - assertThat(services.size(), equalTo(8)); + assertThat(services.size(), equalTo(9)); var providers = providers(services); @@ -127,7 +127,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { public void testGetServicesWithCompletionTaskType() throws IOException { List services = getServices(TaskType.COMPLETION); - assertThat(services.size(), equalTo(12)); + assertThat(services.size(), equalTo(13)); var providers = providers(services);