diff --git a/docs/changelog/131551.yaml b/docs/changelog/131551.yaml new file mode 100644 index 0000000000000..c9f9f070f4e41 --- /dev/null +++ b/docs/changelog/131551.yaml @@ -0,0 +1,5 @@ +pr: 131551 +summary: Added support to configure query timeout for inference +area: Inference +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 03d2806005788..930988bf18c38 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -108,7 +108,9 @@ default boolean hideFromConfigurationApi() { * @param stream Stream inference results * @param taskSettings Settings in the request to override the model's defaults * @param inputType For search, ingest etc - * @param timeout The timeout for the request + * @param timeout The timeout for the request. Callers should normally pass in a timeout. + * Passing in null is specifically for query-time inference, when the timeout is managed by the + * xpack.inference.query_timeout cluster setting. * @param listener Inference result listener */ void infer( @@ -120,7 +122,7 @@ void infer( boolean stream, Map taskSettings, InputType inputType, - TimeValue timeout, + @Nullable TimeValue timeout, ActionListener listener ); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java index cc67bc28d675e..a36f87e939673 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java @@ -31,7 +31,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; -import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; @@ -279,7 +278,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { List.of(query), TextExpansionConfigUpdate.EMPTY_UPDATE, false, - InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API + null ); inferRequest.setHighPriority(true); inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java index b859f59fc3e0e..b702eb7fc8ff9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java @@ -19,7 +19,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; -import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; @@ -116,7 +115,7 @@ public void buildVector(Client client, ActionListener listener) { List.of(modelText), TextEmbeddingConfigUpdate.EMPTY_INSTANCE, false, - InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API + null ); inferRequest.setHighPriority(true); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java index 8b321a6978359..b855cc4ca0a03 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java @@ -124,7 +124,7 @@ protected boolean canSimulateMethod(Method method, Object[] args) throws NoSuchM @Override protected Object simulateMethod(Method method, Object[] args) { CoordinatedInferenceAction.Request request = (CoordinatedInferenceAction.Request) args[1]; - assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, request.getInferenceTimeout()); + assertNull(request.getInferenceTimeout()); assertEquals(TrainedModelPrefixStrings.PrefixType.SEARCH, request.getPrefixType()); assertEquals(CoordinatedInferenceAction.Request.RequestModelType.NLP_MODEL, request.getRequestModelType()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index bbb1bd1a2fec2..374770ad25eb1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -145,6 +145,8 @@ import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -180,6 +182,13 @@ public class InferencePlugin extends Plugin Setting.Property.NodeScope, Setting.Property.Dynamic ); + public static final Setting INFERENCE_QUERY_TIMEOUT = Setting.timeSetting( + "xpack.inference.query_timeout", + TimeValue.timeValueSeconds(10), + TimeValue.timeValueMillis(1), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); public static final LicensedFeature.Momentary INFERENCE_API_FEATURE = LicensedFeature.momentary( "inference", @@ -490,7 +499,11 @@ public static ExecutorBuilder inferenceUtilityExecutor(Settings settings) { @Override public List> getSettings() { - ArrayList> settings = new ArrayList<>(); + return List.copyOf(getInferenceSettings()); + } + + public static Set> getInferenceSettings() { + Set> settings = new HashSet<>(); settings.addAll(HttpSettings.getSettingsDefinitions()); settings.addAll(HttpClientManager.getSettingsDefinitions()); settings.addAll(ThrottlerManager.getSettingsDefinitions()); @@ -499,9 +512,9 @@ public List> getSettings() { settings.addAll(RequestExecutorServiceSettings.getSettingsDefinitions()); settings.add(SKIP_VALIDATE_AND_START); settings.add(INDICES_INFERENCE_BATCH_SIZE); + settings.add(INFERENCE_QUERY_TIMEOUT); settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions()); - - return settings; + return Collections.unmodifiableSet(settings); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 182c083ef1c26..f12e50674e5f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -31,7 +31,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; @@ -237,7 +236,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu List.of(query), Map.of(), InputType.INTERNAL_SEARCH, - InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, + null, false ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 5074749c1cd9f..f483eaac6b496 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -70,9 +70,10 @@ public void infer( boolean stream, Map taskSettings, InputType inputType, - TimeValue timeout, + @Nullable TimeValue timeout, ActionListener listener ) { + timeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, clusterService); init(); var chunkInferenceInput = input.stream().map(i -> new ChunkInferenceInput(i, null)).toList(); var inferenceInput = createInput(this, model, chunkInferenceInput, inputType, query, returnDocuments, topN, stream); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index f5f1074bfbb86..ac80e1a7dd145 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; @@ -21,7 +22,9 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; import java.net.URI; @@ -1098,5 +1101,24 @@ public static void checkByteBounds(short value) { } } + /** + * Resolves the inference timeout based on input type and cluster settings. + * + * @param timeout The provided timeout value, may be null + * @param inputType The input type for the inference request + * @param clusterService The cluster service to get timeout settings from + * @return The resolved timeout value + */ + public static TimeValue resolveInferenceTimeout(@Nullable TimeValue timeout, InputType inputType, ClusterService clusterService) { + if (timeout == null) { + if (inputType == InputType.SEARCH || inputType == InputType.INTERNAL_SEARCH) { + return clusterService.getClusterSettings().get(InferencePlugin.INFERENCE_QUERY_TIMEOUT); + } else { + return InferenceAction.Request.DEFAULT_TIMEOUT; + } + } + return timeout; + } + private ServiceUtils() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index 4aaf3c2db2e61..d62be1ec171bd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -299,6 +299,10 @@ private void preferredVariantFromPlatformArchitecture(ActionListener 0 means scaling should be available for ml nodes diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index b17392311629f..8ad41e3567c04 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -610,9 +610,10 @@ public void infer( boolean stream, Map taskSettings, InputType inputType, - TimeValue timeout, + @Nullable TimeValue timeout, ActionListener listener ) { + timeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, getClusterService()); if (model instanceof ElasticsearchInternalModel esModel) { var taskType = model.getConfigurations().getTaskType(); if (TaskType.TEXT_EMBEDDING.equals(taskType)) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java index 653c4288263f9..393e3503b5cfc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java @@ -31,6 +31,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas; @@ -160,7 +161,7 @@ public void infer( listener.onFailure(createInvalidModelException(model)); return; } - + timeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, clusterService); var inferenceRequest = new SageMakerInferenceRequest(query, returnDocuments, topN, input, stream, inputType); try { @@ -173,7 +174,7 @@ public void infer( client.invokeStream( regionAndSecrets, request, - timeout != null ? timeout : DEFAULT_TIMEOUT, + timeout, ActionListener.wrap( response -> listener.onResponse(schema.streamResponse(sageMakerModel, response)), e -> listener.onFailure(schema.error(sageMakerModel, e)) @@ -185,7 +186,7 @@ public void infer( client.invoke( regionAndSecrets, request, - timeout != null ? timeout : DEFAULT_TIMEOUT, + timeout, ActionListener.wrap( response -> listener.onResponse(schema.response(sageMakerModel, response, threadPool.getThreadContext())), e -> listener.onFailure(schema.error(sageMakerModel, e)) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index 8f865312c3b23..53ada98b69cfe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -26,25 +26,16 @@ import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.inference.common.Truncator; -import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.HttpSettings; -import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; -import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; -import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.hamcrest.Matchers; import java.io.IOException; -import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; -import java.util.stream.Stream; import static org.elasticsearch.test.ESTestCase.randomFrom; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; @@ -71,15 +62,7 @@ public static ClusterService mockClusterServiceEmpty() { public static ClusterService mockClusterService(Settings settings) { var clusterService = mock(ClusterService.class); - var registeredSettings = Stream.of( - HttpSettings.getSettingsDefinitions(), - HttpClientManager.getSettingsDefinitions(), - ThrottlerManager.getSettingsDefinitions(), - RetrySettings.getSettingsDefinitions(), - Truncator.getSettingsDefinitions(), - RequestExecutorServiceSettings.getSettingsDefinitions(), - ElasticInferenceServiceSettings.getSettingsDefinitions() - ).flatMap(Collection::stream).collect(Collectors.toSet()); + var registeredSettings = InferencePlugin.getInferenceSettings(); var cSettings = new ClusterSettings(settings, registeredSettings); when(clusterService.getClusterSettings()).thenReturn(cSettings); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 7457859a64603..c4977639dae81 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -21,6 +22,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; @@ -35,8 +37,10 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterService; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.mockito.Mockito.mock; @@ -103,7 +107,85 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep verifyNoMoreInteractions(sender); } - private static final class TestSenderService extends SenderService { + public void test_nullTimeoutUsesClusterSetting() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var configuredTimeout = TimeValue.timeValueSeconds(15); + var clusterService = mockClusterService( + Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), configuredTimeout).build() + ); + + var capturedTimeout = new AtomicReference(); + var testService = new TestSenderService(factory, createWithEmptySettings(threadPool), clusterService) { + // Override doInfer to capture the timeout value and return a mock response + @Override + protected void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + capturedTimeout.set(timeout); + listener.onResponse(mock(InferenceServiceResults.class)); + } + }; + + try (testService) { + var model = mock(Model.class); + when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); + + PlainActionFuture listener = new PlainActionFuture<>(); + + testService.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, null, listener); + + listener.actionGet(TIMEOUT); + assertEquals(configuredTimeout, capturedTimeout.get()); + } + } + + public void test_providedTimeoutPropagateProperly() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var providedTimeout = TimeValue.timeValueSeconds(45); + var clusterService = mockClusterService( + Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), TimeValue.timeValueSeconds(15)).build() + ); + + var capturedTimeout = new AtomicReference(); + var testService = new TestSenderService(factory, createWithEmptySettings(threadPool), clusterService) { + // Override doInfer to capture the timeout value and return a mock response + @Override + protected void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + capturedTimeout.set(timeout); + listener.onResponse(mock(InferenceServiceResults.class)); + } + }; + + try (testService) { + var model = mock(Model.class); + when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); + + PlainActionFuture listener = new PlainActionFuture<>(); + + testService.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, providedTimeout, listener); + + listener.actionGet(TIMEOUT); + assertEquals(providedTimeout, capturedTimeout.get()); + } + } + + private static class TestSenderService extends SenderService { TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index d90cb638709db..b0dbf5b67b811 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -10,18 +10,22 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Booleans; import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; +import org.elasticsearch.xpack.inference.InferencePlugin; import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; +import static org.elasticsearch.xpack.inference.Utils.mockClusterService; import static org.elasticsearch.xpack.inference.Utils.modifiableMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertMapStringsToSecureString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; @@ -1308,4 +1312,46 @@ public void testExtractOptionalListOfStringTuples_Exception_WhenTupleSecondEleme ) ); } + + public void testResolveInferenceTimeout_WithProvidedTimeout_ReturnsProvidedTimeout() { + var clusterService = mockClusterService(Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), "10s").build()); + var providedTimeout = TimeValue.timeValueSeconds(45); + + for (InputType inputType : InputType.values()) { + var result = ServiceUtils.resolveInferenceTimeout(providedTimeout, inputType, clusterService); + assertEquals("Input type " + inputType + " should return provided timeout", providedTimeout, result); + } + } + + public void testResolveInferenceTimeout_WithNullTimeout_ReturnsExpectedTimeoutByInputType() { + var configuredTimeout = TimeValue.timeValueSeconds(10); + var clusterService = mockClusterService( + Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), configuredTimeout).build() + ); + + Map expectedTimeouts = Map.of( + InputType.SEARCH, + configuredTimeout, + InputType.INTERNAL_SEARCH, + configuredTimeout, + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + InputType.CLASSIFICATION, + InferenceAction.Request.DEFAULT_TIMEOUT, + InputType.CLUSTERING, + InferenceAction.Request.DEFAULT_TIMEOUT, + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT + ); + + for (Map.Entry entry : expectedTimeouts.entrySet()) { + InputType inputType = entry.getKey(); + TimeValue expectedTimeout = entry.getValue(); + + var result = ServiceUtils.resolveInferenceTimeout(null, inputType, clusterService); + assertEquals("Input type " + inputType + " should return expected timeout", expectedTimeout, result); + } + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 6c01145701d92..4111cab05b7c2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -33,6 +33,7 @@ import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; @@ -108,6 +109,7 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response.RESULTS_FIELD; +import static org.elasticsearch.xpack.inference.Utils.mockClusterService; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86; @@ -1992,6 +1994,94 @@ private Client mockClientForStart(Consumer { + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(null); + return null; + }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); + + var configuredTimeout = TimeValue.timeValueSeconds(15); + var clusterService = mockClusterService( + Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), configuredTimeout).build() + ); + + var context = new InferenceServiceExtension.InferenceServiceFactoryContext( + client, + threadPool, + clusterService, + Settings.EMPTY, + inferenceStats + ); + var service = new ElasticsearchInternalService(context); + + var model = new MultilingualE5SmallModel( + "foo", + TaskType.TEXT_EMBEDDING, + "e5", + new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null), + null + ); + + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener<>(ActionListener.noop(), latch); + + service.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, null, latchedListener); + + assertTrue(latch.await(30, TimeUnit.SECONDS)); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(InferModelAction.Request.class); + verify(client).execute(same(InferModelAction.INSTANCE), requestCaptor.capture(), any(ActionListener.class)); + assertEquals(configuredTimeout, requestCaptor.getValue().getInferenceTimeout()); + } + + @SuppressWarnings("unchecked") + public void test_providedTimeoutPropagateProperly() throws InterruptedException { + Client client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(null); + return null; + }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); + + var providedTimeout = TimeValue.timeValueSeconds(45); + var clusterService = mockClusterService( + Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), TimeValue.timeValueSeconds(15)).build() + ); + + var context = new InferenceServiceExtension.InferenceServiceFactoryContext( + client, + threadPool, + clusterService, + Settings.EMPTY, + inferenceStats + ); + var service = new ElasticsearchInternalService(context); + + var model = new MultilingualE5SmallModel( + "foo", + TaskType.TEXT_EMBEDDING, + "e5", + new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null), + null + ); + + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener<>(ActionListener.noop(), latch); + + service.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, providedTimeout, latchedListener); + + assertTrue(latch.await(30, TimeUnit.SECONDS)); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(InferModelAction.Request.class); + verify(client).execute(same(InferModelAction.INSTANCE), requestCaptor.capture(), any(ActionListener.class)); + assertEquals(providedTimeout, requestCaptor.getValue().getInferenceTimeout()); + } + private ElasticsearchInternalService createService(Client client) { var cs = mock(ClusterService.class); var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java index bf883a6345398..18ac37e54c321 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java @@ -12,11 +12,14 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.LatchedActionListener; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; @@ -26,6 +29,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; @@ -40,6 +44,9 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Stream; import static org.elasticsearch.action.ActionListener.assertOnce; @@ -47,6 +54,7 @@ import static org.elasticsearch.action.support.ActionTestUtils.assertNoSuccessListener; import static org.elasticsearch.core.TimeValue.THIRTY_SECONDS; import static org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequestTests.randomUnifiedCompletionRequest; +import static org.elasticsearch.xpack.inference.Utils.mockClusterService; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; @@ -76,13 +84,14 @@ public class SageMakerServiceTests extends ESTestCase { private SageMakerClient client; private SageMakerSchemas schemas; private SageMakerService sageMakerService; + private ThreadPool threadPool; @Before public void init() { modelBuilder = mock(); client = mock(); schemas = mock(); - ThreadPool threadPool = mock(); + threadPool = mock(); when(threadPool.executor(anyString())).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of, mockClusterServiceEmpty()); @@ -179,6 +188,60 @@ public void testInfer() { verifyNoMoreInteractions(client, schemas, schema); } + @SuppressWarnings("unchecked") + public void test_nullTimeoutUsesClusterSetting() throws InterruptedException { + var model = mockModel(); + when(schemas.schemaFor(model)).thenReturn(mock()); + + var configuredTimeout = TimeValue.timeValueSeconds(15); + var clusterService = mockClusterService( + Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), configuredTimeout).build() + ); + + var service = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of, clusterService); + + var capturedTimeout = new AtomicReference(); + doAnswer(ans -> { + capturedTimeout.set(ans.getArgument(2)); + ((ActionListener) ans.getArgument(3)).onResponse(null); + return null; + }).when(client).invoke(any(), any(), any(), any()); + + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener<>(ActionListener.noop(), latch); + service.infer(model, QUERY, null, null, INPUT, false, null, InputType.SEARCH, null, latchedListener); + + assertTrue(latch.await(30, TimeUnit.SECONDS)); + assertEquals(configuredTimeout, capturedTimeout.get()); + } + + @SuppressWarnings("unchecked") + public void test_providedTimeoutPropagateProperly() throws InterruptedException { + var model = mockModel(); + when(schemas.schemaFor(model)).thenReturn(mock()); + + var providedTimeout = TimeValue.timeValueSeconds(45); + var clusterService = mockClusterService( + Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), TimeValue.timeValueSeconds(15)).build() + ); + + var service = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of, clusterService); + + var capturedTimeout = new AtomicReference(); + doAnswer(ans -> { + capturedTimeout.set(ans.getArgument(2)); + ((ActionListener) ans.getArgument(3)).onResponse(null); + return null; + }).when(client).invoke(any(), any(), any(), any()); + + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener<>(ActionListener.noop(), latch); + service.infer(model, QUERY, null, null, INPUT, false, null, InputType.SEARCH, providedTimeout, latchedListener); + + assertTrue(latch.await(30, TimeUnit.SECONDS)); + assertEquals(providedTimeout, capturedTimeout.get()); + } + private SageMakerModel mockModel() { SageMakerModel model = mock(); when(model.override(null)).thenReturn(model); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java index 70f361e94c85b..0f70ea8e0b212 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java @@ -41,7 +41,7 @@ protected void doAssertClientRequest(ActionRequest request, TextEmbeddingQueryVe assertThat(inferRequest.getInputs(), hasSize(1)); assertEquals(builder.getModelText(), inferRequest.getInputs().get(0)); assertEquals(builder.getModelId(), inferRequest.getModelId()); - assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, inferRequest.getInferenceTimeout()); + assertNull(inferRequest.getInferenceTimeout()); assertEquals(TrainedModelPrefixStrings.PrefixType.SEARCH, inferRequest.getPrefixType()); assertEquals(CoordinatedInferenceAction.Request.RequestModelType.NLP_MODEL, inferRequest.getRequestModelType()); }