Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7b1f1da
introducing timeout as cluster settings
Samiul-TheSoccerFan Jul 18, 2025
95066c7
forcing null to be send instead of default value
Samiul-TheSoccerFan Jul 18, 2025
e60c409
applying timeout in infer level
Samiul-TheSoccerFan Jul 18, 2025
846f6c2
removing unused variable
Samiul-TheSoccerFan Jul 18, 2025
74dcc03
adding unit tests for cluster timeout values
Samiul-TheSoccerFan Jul 18, 2025
5be1e11
fix linting issues
Samiul-TheSoccerFan Jul 18, 2025
29d5b7c
Update docs/changelog/131551.yaml
Samiul-TheSoccerFan Jul 18, 2025
d7b8116
update changelog
Samiul-TheSoccerFan Jul 18, 2025
bc67010
fix ml core SparseVectorQueryBuilder unit test
Samiul-TheSoccerFan Jul 18, 2025
2fe3f60
adding comment and Nullable annotation
Samiul-TheSoccerFan Jul 21, 2025
6b7a7a5
adding restriction to make sure the cluster setting is only read duri…
Samiul-TheSoccerFan Jul 21, 2025
c857710
Refactored timeout logic per input type and added unit tests
Samiul-TheSoccerFan Jul 22, 2025
013faf4
Merge branch 'main' into inference-timeout-as-cluster-settings
elasticmachine Jul 22, 2025
7f51a91
fix unit test failure due to missing inferenceStat varaible
Samiul-TheSoccerFan Jul 22, 2025
fdbb81f
update comment for timeout
Samiul-TheSoccerFan Jul 22, 2025
4b6cfac
remove the timeout util file
Samiul-TheSoccerFan Jul 23, 2025
e5e9c9a
resolve timeout from Service Utils and moved unit tests to service util
Samiul-TheSoccerFan Jul 23, 2025
dff7190
update comment for timeout
Samiul-TheSoccerFan Jul 23, 2025
62daced
removed duplicate setting
Samiul-TheSoccerFan Jul 23, 2025
f11f52a
update infernece plugin and utils streamline settings registration
Samiul-TheSoccerFan Jul 23, 2025
9b030ac
using mockClusterService in all services
Samiul-TheSoccerFan Jul 23, 2025
154aff6
adding min value
Samiul-TheSoccerFan Jul 23, 2025
2275b99
Merge branch 'main' into inference-timeout-as-cluster-settings
elasticmachine Jul 23, 2025
b9a907b
Adding tests for provided timeout to work as expected
Samiul-TheSoccerFan Jul 23, 2025
43eaf0d
simplify inference timeout settings
Samiul-TheSoccerFan Jul 23, 2025
0c80477
[CI] Auto commit changes from spotless
Jul 23, 2025
739b4fa
added better async handling in the test and simplify response
Samiul-TheSoccerFan Jul 24, 2025
e3d029a
revert back ingest timeout and simplify unit tests
Samiul-TheSoccerFan Jul 24, 2025
6c7b1fa
remove redundant code
Samiul-TheSoccerFan Jul 24, 2025
e54c7f6
Merge branch 'main' into inference-timeout-as-cluster-settings
elasticmachine Jul 24, 2025
1bb407c
fix unnecessary instance creation
Samiul-TheSoccerFan Jul 25, 2025
4f3d3ae
Merge branch 'main' into inference-timeout-as-cluster-settings
elasticmachine Jul 25, 2025
aa1240e
Merge branch 'main' into inference-timeout-as-cluster-settings
elasticmachine Jul 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/131551.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 131551
summary: Added support to configure query timeout for inference
area: Inference
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ default boolean hideFromConfigurationApi() {
* @param stream Stream inference results
* @param taskSettings Settings in the request to override the model's defaults
* @param inputType For search, ingest etc
* @param timeout The timeout for the request
* @param timeout The timeout for the request. Callers should normally pass in a timeout.
* Passing in null is specifically for query builders who do not have access to the cluster settings
* to determine the appropriate timeout value set by the user within semantic_text.
* @param listener Inference result listener
*/
void infer(
Expand All @@ -120,7 +122,7 @@ void infer(
boolean stream,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
Expand Down Expand Up @@ -279,7 +278,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
List.of(query),
TextExpansionConfigUpdate.EMPTY_UPDATE,
false,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API
null
);
inferRequest.setHighPriority(true);
inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
Expand Down Expand Up @@ -116,7 +115,7 @@ public void buildVector(Client client, ActionListener<float[]> listener) {
List.of(modelText),
TextEmbeddingConfigUpdate.EMPTY_INSTANCE,
false,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API
null
);

inferRequest.setHighPriority(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ protected boolean canSimulateMethod(Method method, Object[] args) throws NoSuchM
@Override
protected Object simulateMethod(Method method, Object[] args) {
CoordinatedInferenceAction.Request request = (CoordinatedInferenceAction.Request) args[1];
assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, request.getInferenceTimeout());
assertNull(request.getInferenceTimeout());
assertEquals(TrainedModelPrefixStrings.PrefixType.SEARCH, request.getPrefixType());
assertEquals(CoordinatedInferenceAction.Request.RequestModelType.NLP_MODEL, request.getRequestModelType());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;
import java.util.function.Supplier;

Expand Down Expand Up @@ -180,6 +181,12 @@ public class InferencePlugin extends Plugin
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
public static final Setting<TimeValue> INFERENCE_QUERY_TIMEOUT = Setting.timeSetting(
"xpack.inference.query_timeout",
TimeValue.timeValueSeconds(TimeUnit.SECONDS.toSeconds(10)),
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

public static final LicensedFeature.Momentary INFERENCE_API_FEATURE = LicensedFeature.momentary(
"inference",
Expand Down Expand Up @@ -499,6 +506,7 @@ public List<Setting<?>> getSettings() {
settings.addAll(RequestExecutorServiceSettings.getSettingsDefinitions());
settings.add(SKIP_VALIDATE_AND_START);
settings.add(INDICES_INFERENCE_BATCH_SIZE);
settings.add(INFERENCE_QUERY_TIMEOUT);
settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions());

return settings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
Expand Down Expand Up @@ -237,7 +236,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu
List.of(query),
Map.of(),
InputType.INTERNAL_SEARCH,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API,
null,
false
);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services;

import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.InferencePlugin;

public class InferenceTimeoutUtils {
public static TimeValue resolveInferenceTimeout(@Nullable TimeValue timeout, InputType inputType, ClusterService clusterService) {
if (timeout == null) {
if (inputType == InputType.SEARCH || inputType == InputType.INTERNAL_SEARCH) {
return clusterService.getClusterSettings().get(InferencePlugin.INFERENCE_QUERY_TIMEOUT);
} else {
return InferenceAction.Request.DEFAULT_TIMEOUT;
}
}
return timeout;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ public void infer(
boolean stream,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
timeout = InferenceTimeoutUtils.resolveInferenceTimeout(timeout, inputType, clusterService);
init();
var chunkInferenceInput = input.stream().map(i -> new ChunkInferenceInput(i, null)).toList();
var inferenceInput = createInput(this, model, chunkInferenceInput, inputType, query, returnDocuments, topN, stream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ private void preferredVariantFromPlatformArchitecture(ActionListener<PreferredMo
);
}

protected ClusterService getClusterService() {
return clusterService;
}

boolean isClusterInElasticCloud() {
// Use the ml lazy node count as a heuristic to determine if in Elastic cloud.
// A value > 0 means scaling should be available for ml nodes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.InferenceTimeoutUtils;
import org.elasticsearch.xpack.inference.services.ServiceUtils;

import java.util.ArrayList;
Expand Down Expand Up @@ -610,9 +611,10 @@ public void infer(
boolean stream,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
timeout = InferenceTimeoutUtils.resolveInferenceTimeout(timeout, inputType, getClusterService());
if (model instanceof ElasticsearchInternalModel esModel) {
var taskType = model.getConfigurations().getTaskType();
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.services.InferenceTimeoutUtils;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
Expand Down Expand Up @@ -160,7 +161,7 @@ public void infer(
listener.onFailure(createInvalidModelException(model));
return;
}

timeout = InferenceTimeoutUtils.resolveInferenceTimeout(timeout, inputType, clusterService);
var inferenceRequest = new SageMakerInferenceRequest(query, returnDocuments, topN, input, stream, inputType);

try {
Expand All @@ -173,7 +174,7 @@ public void infer(
client.invokeStream(
regionAndSecrets,
request,
timeout != null ? timeout : DEFAULT_TIMEOUT,
timeout,
ActionListener.wrap(
response -> listener.onResponse(schema.streamResponse(sageMakerModel, response)),
e -> listener.onFailure(schema.error(sageMakerModel, e))
Expand All @@ -185,7 +186,7 @@ public void infer(
client.invoke(
regionAndSecrets,
request,
timeout != null ? timeout : DEFAULT_TIMEOUT,
timeout,
ActionListener.wrap(
response -> listener.onResponse(schema.response(sageMakerModel, response, threadPool.getThreadContext())),
e -> listener.onFailure(schema.error(sageMakerModel, e))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ public static ClusterService mockClusterService(Settings settings) {
ElasticInferenceServiceSettings.getSettingsDefinitions()
).flatMap(Collection::stream).collect(Collectors.toSet());

registeredSettings.add(InferencePlugin.INFERENCE_QUERY_TIMEOUT);

var cSettings = new ClusterSettings(settings, registeredSettings);
when(clusterService.getClusterSettings()).thenReturn(cSettings);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services;

import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.junit.Before;

import static org.elasticsearch.xpack.inference.Utils.mockClusterService;

public class InferenceTimeoutUtilsTests extends ESTestCase {

private ClusterService clusterService;
private static final TimeValue configuredTimeout = TimeValue.timeValueSeconds(10);

@Before
public void setUp() throws Exception {
super.setUp();
var settings = Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), configuredTimeout).build();
clusterService = mockClusterService(settings);
}

public void testResolveInferenceTimeout_WithProvidedTimeout_ReturnsProvidedTimeout() {
var providedTimeout = TimeValue.timeValueSeconds(45);
{
var result = InferenceTimeoutUtils.resolveInferenceTimeout(providedTimeout, InputType.SEARCH, clusterService);
assertEquals(providedTimeout, result);
}
{
var result = InferenceTimeoutUtils.resolveInferenceTimeout(providedTimeout, InputType.INTERNAL_SEARCH, clusterService);
assertEquals(providedTimeout, result);
}
{
var result = InferenceTimeoutUtils.resolveInferenceTimeout(providedTimeout, InputType.INGEST, clusterService);
assertEquals(providedTimeout, result);
}
}

public void testResolveInferenceTimeout_WithNullTimeoutAndSearchInputType_ReturnsClusterSetting() {
{
var result = InferenceTimeoutUtils.resolveInferenceTimeout(null, InputType.SEARCH, clusterService);
assertEquals(configuredTimeout, result);
}
{
var result = InferenceTimeoutUtils.resolveInferenceTimeout(null, InputType.INTERNAL_SEARCH, clusterService);
assertEquals(configuredTimeout, result);
}
}

public void testResolveInferenceTimeout_WithNullTimeoutAndIngestInputType_ReturnsDefaultTimeout() {
var result = InferenceTimeoutUtils.resolveInferenceTimeout(null, InputType.INGEST, clusterService);
assertEquals(InferenceAction.Request.DEFAULT_TIMEOUT, result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
Expand All @@ -21,6 +23,7 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
Expand All @@ -34,7 +37,9 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
Expand Down Expand Up @@ -103,7 +108,49 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep
verifyNoMoreInteractions(sender);
}

private static final class TestSenderService extends SenderService {
public void test_nullTimeoutUsesClusterSetting() throws IOException {
var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender()).thenReturn(sender);

var configuredTimeout = TimeValue.timeValueSeconds(30);
var clusterSettings = new ClusterSettings(
Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), configuredTimeout).build(),
Set.of(InferencePlugin.INFERENCE_QUERY_TIMEOUT)
);
var clusterService = mock(ClusterService.class);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

var capturedTimeout = new AtomicReference<TimeValue>();
var testService = new TestSenderService(factory, createWithEmptySettings(threadPool), clusterService) {
// Override doInfer to capture the timeout value and return a mock response
@Override
protected void doInfer(
Model model,
InferenceInputs inputs,
Map<String, Object> taskSettings,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
capturedTimeout.set(timeout);
listener.onResponse(mock(InferenceServiceResults.class));
}
};

try (testService) {
var model = mock(Model.class);
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();

testService.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, null, listener);

listener.actionGet(TIMEOUT);
assertEquals(configuredTimeout, capturedTimeout.get());
}
}

private static class TestSenderService extends SenderService {
TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
super(factory, serviceComponents, clusterService);
}
Expand Down
Loading
Loading