Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7b1f1da
introducing timeout as cluster settings
Samiul-TheSoccerFan Jul 18, 2025
95066c7
forcing null to be send instead of default value
Samiul-TheSoccerFan Jul 18, 2025
e60c409
applying timeout in infer level
Samiul-TheSoccerFan Jul 18, 2025
846f6c2
removing unused variable
Samiul-TheSoccerFan Jul 18, 2025
74dcc03
adding unit tests for cluster timeout values
Samiul-TheSoccerFan Jul 18, 2025
5be1e11
fix linting issues
Samiul-TheSoccerFan Jul 18, 2025
29d5b7c
Update docs/changelog/131551.yaml
Samiul-TheSoccerFan Jul 18, 2025
d7b8116
update changelog
Samiul-TheSoccerFan Jul 18, 2025
bc67010
fix ml core SparseVectorQueryBuilder unit test
Samiul-TheSoccerFan Jul 18, 2025
2fe3f60
adding comment and Nullable annotation
Samiul-TheSoccerFan Jul 21, 2025
6b7a7a5
adding restriction to make sure the cluster setting is only read duri…
Samiul-TheSoccerFan Jul 21, 2025
c857710
Refactored timeout logic per input type and added unit tests
Samiul-TheSoccerFan Jul 22, 2025
013faf4
Merge branch 'main' into inference-timeout-as-cluster-settings
elasticmachine Jul 22, 2025
7f51a91
fix unit test failure due to missing inferenceStat varaible
Samiul-TheSoccerFan Jul 22, 2025
fdbb81f
update comment for timeout
Samiul-TheSoccerFan Jul 22, 2025
4b6cfac
remove the timeout util file
Samiul-TheSoccerFan Jul 23, 2025
e5e9c9a
resolve timeout from Service Utils and moved unit tests to service util
Samiul-TheSoccerFan Jul 23, 2025
dff7190
update comment for timeout
Samiul-TheSoccerFan Jul 23, 2025
62daced
removed duplicate setting
Samiul-TheSoccerFan Jul 23, 2025
f11f52a
update infernece plugin and utils streamline settings registration
Samiul-TheSoccerFan Jul 23, 2025
9b030ac
using mockClusterService in all services
Samiul-TheSoccerFan Jul 23, 2025
154aff6
adding min value
Samiul-TheSoccerFan Jul 23, 2025
2275b99
Merge branch 'main' into inference-timeout-as-cluster-settings
elasticmachine Jul 23, 2025
b9a907b
Adding tests for provided timeout to work as expected
Samiul-TheSoccerFan Jul 23, 2025
43eaf0d
simplify inference timeout settings
Samiul-TheSoccerFan Jul 23, 2025
0c80477
[CI] Auto commit changes from spotless
Jul 23, 2025
739b4fa
added better async handling in the test and simplify response
Samiul-TheSoccerFan Jul 24, 2025
e3d029a
revert back ingest timeout and simplify unit tests
Samiul-TheSoccerFan Jul 24, 2025
6c7b1fa
remove redundant code
Samiul-TheSoccerFan Jul 24, 2025
e54c7f6
Merge branch 'main' into inference-timeout-as-cluster-settings
elasticmachine Jul 24, 2025
1bb407c
fix unnecessary instance creation
Samiul-TheSoccerFan Jul 25, 2025
4f3d3ae
Merge branch 'main' into inference-timeout-as-cluster-settings
elasticmachine Jul 25, 2025
aa1240e
Merge branch 'main' into inference-timeout-as-cluster-settings
elasticmachine Jul 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -498,15 +499,10 @@ public static ExecutorBuilder<?> inferenceUtilityExecutor(Settings settings) {

@Override
public List<Setting<?>> getSettings() {
return List.copyOf(getInferenceSettingsInternal());
return List.copyOf(new ArrayList<>(getInferenceSettings()));
}

// only used in tests
public static Set<Setting<?>> getInferenceSettings() {
return Set.copyOf(getInferenceSettingsInternal());
}

private static Set<Setting<?>> getInferenceSettingsInternal() {
Set<Setting<?>> settings = new HashSet<>();
settings.addAll(HttpSettings.getSettingsDefinitions());
settings.addAll(HttpClientManager.getSettingsDefinitions());
Expand All @@ -518,7 +514,7 @@ private static Set<Setting<?>> getInferenceSettingsInternal() {
settings.add(INDICES_INFERENCE_BATCH_SIZE);
settings.add(INFERENCE_QUERY_TIMEOUT);
settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions());
return settings;
return Collections.unmodifiableSet(settings);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
Expand Down Expand Up @@ -1113,10 +1113,8 @@ public static TimeValue resolveInferenceTimeout(@Nullable TimeValue timeout, Inp
if (timeout == null) {
if (inputType == InputType.SEARCH || inputType == InputType.INTERNAL_SEARCH) {
return clusterService.getClusterSettings().get(InferencePlugin.INFERENCE_QUERY_TIMEOUT);
} else if (inputType == InputType.INGEST || inputType == InputType.INTERNAL_INGEST) {
return InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST;
} else {
return InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API;
return InferenceAction.Request.DEFAULT_TIMEOUT;
}
}
return timeout;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
import org.elasticsearch.xpack.inference.InferencePlugin;

Expand Down Expand Up @@ -1317,62 +1317,41 @@ public void testResolveInferenceTimeout_WithProvidedTimeout_ReturnsProvidedTimeo
var clusterService = mockClusterService(Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), "10s").build());
var providedTimeout = TimeValue.timeValueSeconds(45);

InputType[] inputTypes = {
InputType.INGEST,
InputType.INTERNAL_INGEST,
InputType.SEARCH,
InputType.INTERNAL_SEARCH,
InputType.CLASSIFICATION,
InputType.CLUSTERING,
InputType.UNSPECIFIED };

for (InputType inputType : inputTypes) {
for (InputType inputType : InputType.values()) {
var result = ServiceUtils.resolveInferenceTimeout(providedTimeout, inputType, clusterService);
assertEquals("Input type " + inputType + " should return provided timeout", providedTimeout, result);
}
}

public void testResolveInferenceTimeout_WithNullTimeoutAndSearchInputType_ReturnsClusterSetting() {
public void testResolveInferenceTimeout_WithNullTimeout_ReturnsExpectedTimeoutByInputType() {
var configuredTimeout = TimeValue.timeValueSeconds(10);
var clusterService = mockClusterService(
Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), configuredTimeout).build()
);

{
var result = ServiceUtils.resolveInferenceTimeout(null, InputType.SEARCH, clusterService);
assertEquals(configuredTimeout, result);
}
{
var result = ServiceUtils.resolveInferenceTimeout(null, InputType.INTERNAL_SEARCH, clusterService);
assertEquals(configuredTimeout, result);
}
}

public void testResolveInferenceTimeout_WithNullTimeoutAndIngestInputType_ReturnsMaxValue() {
var clusterService = mockClusterService(Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), "10s").build());

{
var result = ServiceUtils.resolveInferenceTimeout(null, InputType.INGEST, clusterService);
assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST, result);
}
{
var result = ServiceUtils.resolveInferenceTimeout(null, InputType.INTERNAL_INGEST, clusterService);
assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST, result);
}
}

public void testResolveInferenceTimeout_WithNullTimeoutAndOtherInputTypes_ReturnsDefaultTimeout() {
var clusterService = mockClusterService(Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), "10s").build());
Map<InputType, TimeValue> expectedTimeouts = Map.of(
InputType.SEARCH,
configuredTimeout,
InputType.INTERNAL_SEARCH,
configuredTimeout,
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
InputType.INTERNAL_INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
InputType.CLASSIFICATION,
InferenceAction.Request.DEFAULT_TIMEOUT,
InputType.CLUSTERING,
InferenceAction.Request.DEFAULT_TIMEOUT,
InputType.UNSPECIFIED,
InferenceAction.Request.DEFAULT_TIMEOUT
);

InputType[] otherTypes = { InputType.CLASSIFICATION, InputType.CLUSTERING, InputType.UNSPECIFIED };
for (Map.Entry<InputType, TimeValue> entry : expectedTimeouts.entrySet()) {
InputType inputType = entry.getKey();
TimeValue expectedTimeout = entry.getValue();

for (InputType inputType : otherTypes) {
var result = ServiceUtils.resolveInferenceTimeout(null, inputType, clusterService);
assertEquals(
"Input type " + inputType + " should return DEFAULT_TIMEOUT",
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API,
result
);
assertEquals("Input type " + inputType + " should return expected timeout", expectedTimeout, result);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1996,15 +1996,11 @@ private Client mockClientForStart(Consumer<ActionListener<CreateTrainedModelAssi

@SuppressWarnings("unchecked")
public void test_nullTimeoutUsesClusterSetting() throws InterruptedException {
var mlTrainedModelResults = new ArrayList<InferenceResults>();
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true);

Client client = mock(Client.class);
when(client.threadPool()).thenReturn(threadPool);
doAnswer(invocationOnMock -> {
var listener = (ActionListener<InferModelAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(response);
listener.onResponse(null);
return null;
}).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class));

Expand All @@ -2030,19 +2026,12 @@ public void test_nullTimeoutUsesClusterSetting() throws InterruptedException {
null
);

var gotResults = new AtomicBoolean();
var resultsListener = ActionListener.<InferenceServiceResults>wrap(serviceResponse -> {
assertThat(serviceResponse, instanceOf(TextEmbeddingFloatResults.class));
gotResults.set(true);
}, ESTestCase::fail);

var latch = new CountDownLatch(1);
var latchedListener = new LatchedActionListener<>(resultsListener, latch);
var latchedListener = new LatchedActionListener<>(ActionListener.<InferenceServiceResults>noop(), latch);

service.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, null, latchedListener);

latch.await();
assertTrue("Listener not called", gotResults.get());
assertTrue(latch.await(30, TimeUnit.SECONDS));

ArgumentCaptor<InferModelAction.Request> requestCaptor = ArgumentCaptor.forClass(InferModelAction.Request.class);
verify(client).execute(same(InferModelAction.INSTANCE), requestCaptor.capture(), any(ActionListener.class));
Expand All @@ -2051,15 +2040,11 @@ public void test_nullTimeoutUsesClusterSetting() throws InterruptedException {

@SuppressWarnings("unchecked")
public void test_providedTimeoutPropagateProperly() throws InterruptedException {
var mlTrainedModelResults = new ArrayList<InferenceResults>();
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true);

Client client = mock(Client.class);
when(client.threadPool()).thenReturn(threadPool);
doAnswer(invocationOnMock -> {
var listener = (ActionListener<InferModelAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(response);
listener.onResponse(null);
return null;
}).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class));

Expand All @@ -2085,19 +2070,12 @@ public void test_providedTimeoutPropagateProperly() throws InterruptedException
null
);

var gotResults = new AtomicBoolean();
var resultsListener = ActionListener.<InferenceServiceResults>wrap(serviceResponse -> {
assertThat(serviceResponse, instanceOf(TextEmbeddingFloatResults.class));
gotResults.set(true);
}, ESTestCase::fail);

var latch = new CountDownLatch(1);
var latchedListener = new LatchedActionListener<>(resultsListener, latch);
var latchedListener = new LatchedActionListener<>(ActionListener.<InferenceServiceResults>noop(), latch);

service.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, providedTimeout, latchedListener);

latch.await();
assertTrue("Listener not called", gotResults.get());
assertTrue(latch.await(30, TimeUnit.SECONDS));

ArgumentCaptor<InferModelAction.Request> requestCaptor = ArgumentCaptor.forClass(InferModelAction.Request.class);
verify(client).execute(same(InferModelAction.INSTANCE), requestCaptor.capture(), any(ActionListener.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.LatchedActionListener;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
Expand All @@ -42,6 +44,8 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Stream;

Expand Down Expand Up @@ -80,13 +84,14 @@ public class SageMakerServiceTests extends ESTestCase {
private SageMakerClient client;
private SageMakerSchemas schemas;
private SageMakerService sageMakerService;
private ThreadPool threadPool;

@Before
public void init() {
modelBuilder = mock();
client = mock();
schemas = mock();
ThreadPool threadPool = mock();
threadPool = mock();
when(threadPool.executor(anyString())).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of, mockClusterServiceEmpty());
Expand Down Expand Up @@ -184,7 +189,7 @@ public void testInfer() {
}

@SuppressWarnings("unchecked")
public void test_nullTimeoutUsesClusterSetting() {
public void test_nullTimeoutUsesClusterSetting() throws InterruptedException {
var model = mockModel();
when(schemas.schemaFor(model)).thenReturn(mock());

Expand All @@ -193,22 +198,25 @@ public void test_nullTimeoutUsesClusterSetting() {
Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), configuredTimeout).build()
);

var service = new SageMakerService(modelBuilder, client, schemas, mock(ThreadPool.class), Map::of, clusterService);
var service = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of, clusterService);

var capturedTimeout = new AtomicReference<TimeValue>();
doAnswer(ans -> {
capturedTimeout.set(ans.getArgument(2));
((ActionListener<InvokeEndpointResponse>) ans.getArgument(3)).onResponse(InvokeEndpointResponse.builder().build());
((ActionListener<InvokeEndpointResponse>) ans.getArgument(3)).onResponse(null);
return null;
}).when(client).invoke(any(), any(), any(), any());

service.infer(model, QUERY, null, null, INPUT, false, null, InputType.SEARCH, null, assertNoFailureListener(ignored -> {}));
var latch = new CountDownLatch(1);
var latchedListener = new LatchedActionListener<>(ActionListener.<InferenceServiceResults>noop(), latch);
service.infer(model, QUERY, null, null, INPUT, false, null, InputType.SEARCH, null, latchedListener);

assertTrue(latch.await(30, TimeUnit.SECONDS));
assertEquals(configuredTimeout, capturedTimeout.get());
}

@SuppressWarnings("unchecked")
public void test_providedTimeoutPropagateProperly() {
public void test_providedTimeoutPropagateProperly() throws InterruptedException {
var model = mockModel();
when(schemas.schemaFor(model)).thenReturn(mock());

Expand All @@ -217,28 +225,20 @@ public void test_providedTimeoutPropagateProperly() {
Settings.builder().put(InferencePlugin.INFERENCE_QUERY_TIMEOUT.getKey(), TimeValue.timeValueSeconds(15)).build()
);

var service = new SageMakerService(modelBuilder, client, schemas, mock(ThreadPool.class), Map::of, clusterService);
var service = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of, clusterService);

var capturedTimeout = new AtomicReference<TimeValue>();
doAnswer(ans -> {
capturedTimeout.set(ans.getArgument(2));
((ActionListener<InvokeEndpointResponse>) ans.getArgument(3)).onResponse(InvokeEndpointResponse.builder().build());
((ActionListener<InvokeEndpointResponse>) ans.getArgument(3)).onResponse(null);
return null;
}).when(client).invoke(any(), any(), any(), any());

service.infer(
model,
QUERY,
null,
null,
INPUT,
false,
null,
InputType.SEARCH,
providedTimeout,
assertNoFailureListener(ignored -> {})
);
var latch = new CountDownLatch(1);
var latchedListener = new LatchedActionListener<>(ActionListener.<InferenceServiceResults>noop(), latch);
service.infer(model, QUERY, null, null, INPUT, false, null, InputType.SEARCH, providedTimeout, latchedListener);

assertTrue(latch.await(30, TimeUnit.SECONDS));
assertEquals(providedTimeout, capturedTimeout.get());
}

Expand Down
Loading