Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7b1f1da
introducing timeout as cluster settings
Samiul-TheSoccerFan Jul 18, 2025
95066c7
forcing null to be send instead of default value
Samiul-TheSoccerFan Jul 18, 2025
e60c409
applying timeout in infer level
Samiul-TheSoccerFan Jul 18, 2025
846f6c2
removing unused variable
Samiul-TheSoccerFan Jul 18, 2025
74dcc03
adding unit tests for cluster timeout values
Samiul-TheSoccerFan Jul 18, 2025
5be1e11
fix linting issues
Samiul-TheSoccerFan Jul 18, 2025
29d5b7c
Update docs/changelog/131551.yaml
Samiul-TheSoccerFan Jul 18, 2025
d7b8116
update changelog
Samiul-TheSoccerFan Jul 18, 2025
bc67010
fix ml core SparseVectorQueryBuilder unit test
Samiul-TheSoccerFan Jul 18, 2025
2fe3f60
adding comment and Nullable annotation
Samiul-TheSoccerFan Jul 21, 2025
6b7a7a5
adding restriction to make sure the cluster setting is only read duri…
Samiul-TheSoccerFan Jul 21, 2025
c857710
Refactored timeout logic per input type and added unit tests
Samiul-TheSoccerFan Jul 22, 2025
013faf4
Merge branch 'main' into inference-timeout-as-cluster-settings
elasticmachine Jul 22, 2025
7f51a91
fix unit test failure due to missing inferenceStat varaible
Samiul-TheSoccerFan Jul 22, 2025
fdbb81f
update comment for timeout
Samiul-TheSoccerFan Jul 22, 2025
4b6cfac
remove the timeout util file
Samiul-TheSoccerFan Jul 23, 2025
e5e9c9a
resolve timeout from Service Utils and moved unit tests to service util
Samiul-TheSoccerFan Jul 23, 2025
dff7190
update comment for timeout
Samiul-TheSoccerFan Jul 23, 2025
62daced
removed duplicate setting
Samiul-TheSoccerFan Jul 23, 2025
f11f52a
update infernece plugin and utils streamline settings registration
Samiul-TheSoccerFan Jul 23, 2025
9b030ac
using mockClusterService in all services
Samiul-TheSoccerFan Jul 23, 2025
154aff6
adding min value
Samiul-TheSoccerFan Jul 23, 2025
2275b99
Merge branch 'main' into inference-timeout-as-cluster-settings
elasticmachine Jul 23, 2025
b9a907b
Adding tests for provided timeout to work as expected
Samiul-TheSoccerFan Jul 23, 2025
43eaf0d
simplify inference timeout settings
Samiul-TheSoccerFan Jul 23, 2025
0c80477
[CI] Auto commit changes from spotless
Jul 23, 2025
739b4fa
added better async handling in the test and simplify response
Samiul-TheSoccerFan Jul 24, 2025
e3d029a
revert back ingest timeout and simplify unit tests
Samiul-TheSoccerFan Jul 24, 2025
6c7b1fa
remove redundant code
Samiul-TheSoccerFan Jul 24, 2025
e54c7f6
Merge branch 'main' into inference-timeout-as-cluster-settings
elasticmachine Jul 24, 2025
1bb407c
fix unnecessary instance creation
Samiul-TheSoccerFan Jul 25, 2025
4f3d3ae
Merge branch 'main' into inference-timeout-as-cluster-settings
elasticmachine Jul 25, 2025
aa1240e
Merge branch 'main' into inference-timeout-as-cluster-settings
elasticmachine Jul 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/131551.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 131551
summary: Added support to configure query timeout for inference
area: Inference
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
Expand Down Expand Up @@ -279,7 +278,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
List.of(query),
TextExpansionConfigUpdate.EMPTY_UPDATE,
false,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API
null
);
inferRequest.setHighPriority(true);
inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
Expand Down Expand Up @@ -116,7 +115,7 @@ public void buildVector(Client client, ActionListener<float[]> listener) {
List.of(modelText),
TextEmbeddingConfigUpdate.EMPTY_INSTANCE,
false,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API
null
);

inferRequest.setHighPriority(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ protected boolean canSimulateMethod(Method method, Object[] args) throws NoSuchM
@Override
protected Object simulateMethod(Method method, Object[] args) {
CoordinatedInferenceAction.Request request = (CoordinatedInferenceAction.Request) args[1];
assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, request.getInferenceTimeout());
assertNull(request.getInferenceTimeout());
assertEquals(TrainedModelPrefixStrings.PrefixType.SEARCH, request.getPrefixType());
assertEquals(CoordinatedInferenceAction.Request.RequestModelType.NLP_MODEL, request.getRequestModelType());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;
import java.util.function.Supplier;

Expand Down Expand Up @@ -179,6 +180,12 @@ public class InferencePlugin extends Plugin
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
public static final Setting<TimeValue> INFERENCE_QUERY_TIMEOUT = Setting.timeSetting(
"xpack.inference.query_timeout",
TimeValue.timeValueSeconds(TimeUnit.SECONDS.toSeconds(10)),
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

public static final LicensedFeature.Momentary INFERENCE_API_FEATURE = LicensedFeature.momentary(
"inference",
Expand Down Expand Up @@ -496,6 +503,7 @@ public List<Setting<?>> getSettings() {
settings.addAll(RequestExecutorServiceSettings.getSettingsDefinitions());
settings.add(SKIP_VALIDATE_AND_START);
settings.add(INDICES_INFERENCE_BATCH_SIZE);
settings.add(INFERENCE_QUERY_TIMEOUT);
settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions());

return settings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
Expand Down Expand Up @@ -237,7 +236,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu
List.of(query),
Map.of(),
InputType.INTERNAL_SEARCH,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API,
null,
false
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand Down Expand Up @@ -73,6 +74,9 @@ public void infer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (timeout == null) {
timeout = clusterService.getClusterSettings().get(InferencePlugin.INFERENCE_QUERY_TIMEOUT);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only want to apply this timeout if the input type is SEARCH or INTERNAL_SEARCH. Which brings up another edge case: If we allow timeout to be null now, we need to set default timeouts for the other input types as well.

init();
var chunkInferenceInput = input.stream().map(i -> new ChunkInferenceInput(i, null)).toList();
var inferenceInput = createInput(this, model, chunkInferenceInput, inputType, query, returnDocuments, topN, stream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,10 @@ private void preferredVariantFromPlatformArchitecture(ActionListener<PreferredMo
);
}

protected TimeValue getConfiguredInferenceTimeout() {
return clusterService.getClusterSettings().get(InferencePlugin.INFERENCE_QUERY_TIMEOUT);
}

boolean isClusterInElasticCloud() {
// Use the ml lazy node count as a heuristic to determine if in Elastic cloud.
// A value > 0 means scaling should be available for ml nodes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ public void infer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (timeout == null) {
timeout = getConfiguredInferenceTimeout();
}
if (model instanceof ElasticsearchInternalModel esModel) {
var taskType = model.getConfigurations().getTaskType();
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
Expand Down Expand Up @@ -161,6 +162,10 @@ public void infer(
return;
}

if (timeout == null) {
timeout = clusterService.getClusterSettings().get(InferencePlugin.INFERENCE_QUERY_TIMEOUT);
}

var inferenceRequest = new SageMakerInferenceRequest(query, returnDocuments, topN, input, stream, inputType);

try {
Expand All @@ -173,7 +178,7 @@ public void infer(
client.invokeStream(
regionAndSecrets,
request,
timeout != null ? timeout : DEFAULT_TIMEOUT,
timeout,
ActionListener.wrap(
response -> listener.onResponse(schema.streamResponse(sageMakerModel, response)),
e -> listener.onFailure(schema.error(sageMakerModel, e))
Expand All @@ -185,7 +190,7 @@ public void infer(
client.invoke(
regionAndSecrets,
request,
timeout != null ? timeout : DEFAULT_TIMEOUT,
timeout,
ActionListener.wrap(
response -> listener.onResponse(schema.response(sageMakerModel, response, threadPool.getThreadContext())),
e -> listener.onFailure(schema.error(sageMakerModel, e))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
Expand All @@ -21,6 +23,7 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
Expand All @@ -34,7 +37,9 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
Expand Down Expand Up @@ -103,7 +108,49 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep
verifyNoMoreInteractions(sender);
}

private static final class TestSenderService extends SenderService {
public void test_nullTimeoutUsesClusterSetting() throws IOException {
var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender()).thenReturn(sender);

var configuredTimeout = TimeValue.timeValueSeconds(30);
var clusterSettings = new ClusterSettings(
Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), configuredTimeout).build(),
Set.of(InferencePlugin.INFERENCE_QUERY_TIMEOUT)
);
var clusterService = mock(ClusterService.class);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

var capturedTimeout = new AtomicReference<TimeValue>();
var testService = new TestSenderService(factory, createWithEmptySettings(threadPool), clusterService) {
// Override doInfer to capture the timeout value and return a mock response
@Override
protected void doInfer(
Model model,
InferenceInputs inputs,
Map<String, Object> taskSettings,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
capturedTimeout.set(timeout);
listener.onResponse(mock(InferenceServiceResults.class));
}
};

try (testService) {
var model = mock(Model.class);
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();

testService.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, null, listener);

listener.actionGet(TIMEOUT);
assertEquals(configuredTimeout, capturedTimeout.get());
}
}

private static class TestSenderService extends SenderService {
TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
super(factory, serviceComponents, clusterService);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
Expand Down Expand Up @@ -1911,6 +1912,58 @@ public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException {
}
}

@SuppressWarnings("unchecked")
public void test_nullTimeoutUsesClusterSetting() throws InterruptedException {
var mlTrainedModelResults = new ArrayList<InferenceResults>();
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true);

Client client = mock(Client.class);
when(client.threadPool()).thenReturn(threadPool);
doAnswer(invocationOnMock -> {
var listener = (ActionListener<InferModelAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(response);
return null;
}).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class));

var configuredTimeout = TimeValue.timeValueSeconds(30);
var clusterSettings = new ClusterSettings(
Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), configuredTimeout).build(),
Set.of(InferencePlugin.INFERENCE_QUERY_TIMEOUT)
);
var clusterService = mock(ClusterService.class);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool, clusterService, Settings.EMPTY);
var service = new ElasticsearchInternalService(context);

var model = new MultilingualE5SmallModel(
"foo",
TaskType.TEXT_EMBEDDING,
"e5",
new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null),
null
);

var gotResults = new AtomicBoolean();
var resultsListener = ActionListener.<InferenceServiceResults>wrap(serviceResponse -> {
assertThat(serviceResponse, instanceOf(TextEmbeddingFloatResults.class));
gotResults.set(true);
}, ESTestCase::fail);

var latch = new CountDownLatch(1);
var latchedListener = new LatchedActionListener<>(resultsListener, latch);

service.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, null, latchedListener);

latch.await();
assertTrue("Listener not called", gotResults.get());

ArgumentCaptor<InferModelAction.Request> requestCaptor = ArgumentCaptor.forClass(InferModelAction.Request.class);
verify(client).execute(same(InferModelAction.INSTANCE), requestCaptor.capture(), any(ActionListener.class));
assertEquals(configuredTimeout, requestCaptor.getValue().getInferenceTimeout());
}

private ElasticsearchInternalService createService(Client client) {
var cs = mock(ClusterService.class);
var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand All @@ -26,6 +29,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
Expand All @@ -40,6 +44,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Stream;

import static org.elasticsearch.action.ActionListener.assertOnce;
Expand Down Expand Up @@ -179,6 +184,33 @@ public void testInfer() {
verifyNoMoreInteractions(client, schemas, schema);
}

@SuppressWarnings("unchecked")
public void test_nullTimeoutUsesClusterSetting() {
var model = mockModel();
when(schemas.schemaFor(model)).thenReturn(mock());

var configuredTimeout = TimeValue.timeValueSeconds(30);
var clusterSettings = new ClusterSettings(
Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), configuredTimeout).build(),
Set.of(InferencePlugin.INFERENCE_QUERY_TIMEOUT)
);
var clusterService = mock(ClusterService.class);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

var service = new SageMakerService(modelBuilder, client, schemas, mock(ThreadPool.class), Map::of, clusterService);

var capturedTimeout = new AtomicReference<TimeValue>();
doAnswer(ans -> {
capturedTimeout.set(ans.getArgument(2));
((ActionListener<InvokeEndpointResponse>) ans.getArgument(3)).onResponse(InvokeEndpointResponse.builder().build());
return null;
}).when(client).invoke(any(), any(), any(), any());

service.infer(model, QUERY, null, null, INPUT, false, null, INPUT_TYPE, null, assertNoFailureListener(ignored -> {}));

assertEquals(configuredTimeout, capturedTimeout.get());
}

private SageMakerModel mockModel() {
SageMakerModel model = mock();
when(model.override(null)).thenReturn(model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ protected void doAssertClientRequest(ActionRequest request, TextEmbeddingQueryVe
assertThat(inferRequest.getInputs(), hasSize(1));
assertEquals(builder.getModelText(), inferRequest.getInputs().get(0));
assertEquals(builder.getModelId(), inferRequest.getModelId());
assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, inferRequest.getInferenceTimeout());
assertNull(inferRequest.getInferenceTimeout());
assertEquals(TrainedModelPrefixStrings.PrefixType.SEARCH, inferRequest.getPrefixType());
assertEquals(CoordinatedInferenceAction.Request.RequestModelType.NLP_MODEL, inferRequest.getRequestModelType());
}
Expand Down
Loading