Skip to content

Commit cbe1bc0

Browse files
Merge branch 'main' into fix-knn-scroll
2 parents f7514b5 + 7a0f63c commit cbe1bc0

File tree

8 files changed

+112
-75
lines changed

8 files changed

+112
-75
lines changed

docs/changelog/126858.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 126858
2+
summary: Leverage threadpool schedule for inference api to avoid long running thread
3+
area: Machine Learning
4+
type: bug
5+
issues:
6+
- 126853

docs/changelog/126930.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 126930
2+
summary: Adding missing `onFailure` call for Inference API start model request
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,6 @@
5757
*/
5858
public class RequestExecutorService implements RequestExecutor {
5959

60-
/**
61-
* Provides dependency injection mainly for testing
62-
*/
63-
interface Sleeper {
64-
void sleep(TimeValue sleepTime) throws InterruptedException;
65-
}
66-
67-
// default for tests
68-
static final Sleeper DEFAULT_SLEEPER = sleepTime -> sleepTime.timeUnit().sleep(sleepTime.duration());
6960
// default for tests
7061
static final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> DEFAULT_QUEUE_CREATOR =
7162
new AdjustableCapacityBlockingQueue.QueueCreator<>() {
@@ -118,7 +109,6 @@ interface RateLimiterCreator {
118109
private final Clock clock;
119110
private final AtomicBoolean shutdown = new AtomicBoolean(false);
120111
private final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> queueCreator;
121-
private final Sleeper sleeper;
122112
private final RateLimiterCreator rateLimiterCreator;
123113
private final AtomicReference<Scheduler.Cancellable> cancellableCleanupTask = new AtomicReference<>();
124114
private final AtomicBoolean started = new AtomicBoolean(false);
@@ -129,16 +119,7 @@ public RequestExecutorService(
129119
RequestExecutorServiceSettings settings,
130120
RequestSender requestSender
131121
) {
132-
this(
133-
threadPool,
134-
DEFAULT_QUEUE_CREATOR,
135-
startupLatch,
136-
settings,
137-
requestSender,
138-
Clock.systemUTC(),
139-
DEFAULT_SLEEPER,
140-
DEFAULT_RATE_LIMIT_CREATOR
141-
);
122+
this(threadPool, DEFAULT_QUEUE_CREATOR, startupLatch, settings, requestSender, Clock.systemUTC(), DEFAULT_RATE_LIMIT_CREATOR);
142123
}
143124

144125
public RequestExecutorService(
@@ -148,7 +129,6 @@ public RequestExecutorService(
148129
RequestExecutorServiceSettings settings,
149130
RequestSender requestSender,
150131
Clock clock,
151-
Sleeper sleeper,
152132
RateLimiterCreator rateLimiterCreator
153133
) {
154134
this.threadPool = Objects.requireNonNull(threadPool);
@@ -157,7 +137,6 @@ public RequestExecutorService(
157137
this.requestSender = Objects.requireNonNull(requestSender);
158138
this.settings = Objects.requireNonNull(settings);
159139
this.clock = Objects.requireNonNull(clock);
160-
this.sleeper = Objects.requireNonNull(sleeper);
161140
this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator);
162141
}
163142

@@ -213,15 +192,10 @@ public void start() {
213192
startCleanupTask();
214193
signalStartInitiated();
215194

216-
while (isShutdown() == false) {
217-
handleTasks();
218-
}
219-
} catch (InterruptedException e) {
220-
Thread.currentThread().interrupt();
221-
} finally {
222-
shutdown();
223-
notifyRequestsOfShutdown();
224-
terminationLatch.countDown();
195+
handleTasks();
196+
} catch (Exception e) {
197+
logger.warn("Failed to start request executor", e);
198+
cleanup();
225199
}
226200
}
227201

@@ -256,13 +230,44 @@ void removeStaleGroupings() {
256230
}
257231
}
258232

259-
private void handleTasks() throws InterruptedException {
260-
var timeToWait = settings.getTaskPollFrequency();
261-
for (var endpoint : rateLimitGroupings.values()) {
262-
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
233+
private void scheduleNextHandleTasks(TimeValue timeToWait) {
234+
if (shutdown.get()) {
235+
logger.debug("Shutdown requested while scheduling next handle task call, cleaning up");
236+
cleanup();
237+
return;
238+
}
239+
240+
threadPool.schedule(this::handleTasks, timeToWait, threadPool.executor(UTILITY_THREAD_POOL_NAME));
241+
}
242+
243+
private void cleanup() {
244+
try {
245+
shutdown();
246+
notifyRequestsOfShutdown();
247+
terminationLatch.countDown();
248+
} catch (Exception e) {
249+
logger.warn("Encountered an error while cleaning up", e);
263250
}
251+
}
264252

265-
sleeper.sleep(timeToWait);
253+
private void handleTasks() {
254+
try {
255+
if (shutdown.get()) {
256+
logger.debug("Shutdown requested while handling tasks, cleaning up");
257+
cleanup();
258+
return;
259+
}
260+
261+
var timeToWait = settings.getTaskPollFrequency();
262+
for (var endpoint : rateLimitGroupings.values()) {
263+
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
264+
}
265+
266+
scheduleNextHandleTasks(timeToWait);
267+
} catch (Exception e) {
268+
logger.warn("Encountered an error while handling tasks", e);
269+
cleanup();
270+
}
266271
}
267272

268273
private void notifyRequestsOfShutdown() {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalL
106106
})
107107
.<Boolean>andThen((l2, modelDidPut) -> {
108108
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
109-
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, finalListener);
109+
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
110110
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
111111
})
112112
.addListener(finalListener);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ public void onFailure(Exception e) {
105105
&& statusException.getRootCause() instanceof ResourceAlreadyExistsException) {
106106
// Deployment is already started
107107
listener.onResponse(Boolean.TRUE);
108+
} else {
109+
listener.onFailure(e);
108110
}
109111
return;
110112
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
5151
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
5252
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
53+
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
5354
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
5455
import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata;
5556
import static org.elasticsearch.xpack.inference.services.openai.OpenAiUtils.ORGANIZATION_HEADER;
@@ -90,7 +91,7 @@ public void shutdown() throws IOException, InterruptedException {
9091
}
9192

9293
public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
93-
var senderFactory = createSenderFactory(clientManager, threadRef);
94+
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
9495

9596
try (var sender = createSender(senderFactory)) {
9697
sender.start();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java

Lines changed: 8 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
import static org.mockito.ArgumentMatchers.any;
5252
import static org.mockito.ArgumentMatchers.anyInt;
5353
import static org.mockito.Mockito.doAnswer;
54-
import static org.mockito.Mockito.doThrow;
5554
import static org.mockito.Mockito.mock;
5655
import static org.mockito.Mockito.times;
5756
import static org.mockito.Mockito.verify;
@@ -206,7 +205,7 @@ public void testExecute_Throws_WhenQueueIsFull() {
206205
assertFalse(thrownException.isExecutorShutdown());
207206
}
208207

209-
public void testTaskThrowsError_CallsOnFailure() {
208+
public void testTaskThrowsError_CallsOnFailure() throws InterruptedException {
210209
var requestSender = mock(RetryingHttpSender.class);
211210

212211
var service = createRequestExecutorService(null, requestSender);
@@ -229,6 +228,8 @@ public void testTaskThrowsError_CallsOnFailure() {
229228
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
230229
assertThat(thrownException.getMessage(), is(format("Failed to send request from inference entity id [%s]", "id")));
231230
assertThat(thrownException.getCause(), instanceOf(IllegalArgumentException.class));
231+
service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
232+
232233
assertTrue(service.isTerminated());
233234
}
234235

@@ -361,7 +362,6 @@ public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws I
361362
createRequestExecutorServiceSettingsEmpty(),
362363
requestSender,
363364
Clock.systemUTC(),
364-
RequestExecutorService.DEFAULT_SLEEPER,
365365
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
366366
);
367367

@@ -375,36 +375,7 @@ public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws I
375375
});
376376
service.start();
377377

378-
assertTrue(service.isTerminated());
379-
}
380-
381-
public void testSleep_ThrowingInterruptedException_TerminatesService() throws Exception {
382-
@SuppressWarnings("unchecked")
383-
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
384-
var sleeper = mock(RequestExecutorService.Sleeper.class);
385-
doThrow(new InterruptedException("failed")).when(sleeper).sleep(any());
386-
387-
var service = new RequestExecutorService(
388-
threadPool,
389-
mockQueueCreator(queue),
390-
null,
391-
createRequestExecutorServiceSettingsEmpty(),
392-
mock(RetryingHttpSender.class),
393-
Clock.systemUTC(),
394-
sleeper,
395-
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
396-
);
397-
398-
Future<?> executorTermination = threadPool.generic().submit(() -> {
399-
try {
400-
service.start();
401-
} catch (Exception e) {
402-
fail(Strings.format("Failed to shutdown executor: %s", e));
403-
}
404-
});
405-
406-
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
407-
378+
service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
408379
assertTrue(service.isTerminated());
409380
}
410381

@@ -581,7 +552,6 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens() {
581552
settings,
582553
requestSender,
583554
Clock.systemUTC(),
584-
RequestExecutorService.DEFAULT_SLEEPER,
585555
rateLimiterCreator
586556
);
587557
var requestManager = RequestManagerTests.createMock(requestSender);
@@ -614,7 +584,6 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_And
614584
settings,
615585
requestSender,
616586
Clock.systemUTC(),
617-
RequestExecutorService.DEFAULT_SLEEPER,
618587
rateLimiterCreator
619588
);
620589
var requestManager = RequestManagerTests.createMock(requestSender);
@@ -626,11 +595,15 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_And
626595

627596
doAnswer(invocation -> {
628597
service.shutdown();
598+
ActionListener<InferenceServiceResults> passedListener = invocation.getArgument(4);
599+
passedListener.onResponse(null);
600+
629601
return Void.TYPE;
630602
}).when(requestSender).send(any(), any(), any(), any(), any());
631603

632604
service.start();
633605

606+
listener.actionGet(TIMEOUT);
634607
verify(requestSender, times(1)).send(any(), any(), any(), any(), any());
635608
}
636609

@@ -648,7 +621,6 @@ public void testRemovesRateLimitGroup_AfterStaleDuration() {
648621
settings,
649622
requestSender,
650623
clock,
651-
RequestExecutorService.DEFAULT_SLEEPER,
652624
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
653625
);
654626
var requestManager = RequestManagerTests.createMock(requestSender, "id1");
@@ -682,7 +654,6 @@ public void testStartsCleanupThread() {
682654
settings,
683655
requestSender,
684656
Clock.systemUTC(),
685-
RequestExecutorService.DEFAULT_SLEEPER,
686657
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
687658
);
688659

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.elasticsearch.inference.ModelConfigurations;
3838
import org.elasticsearch.inference.SimilarityMeasure;
3939
import org.elasticsearch.inference.TaskType;
40+
import org.elasticsearch.rest.RestStatus;
4041
import org.elasticsearch.test.ESTestCase;
4142
import org.elasticsearch.threadpool.ThreadPool;
4243
import org.elasticsearch.xcontent.ParseField;
@@ -49,13 +50,16 @@
4950
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
5051
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
5152
import org.elasticsearch.xpack.core.ml.MachineLearningField;
53+
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
5254
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
5355
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
5456
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
5557
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
5658
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
59+
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
5760
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
5861
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
62+
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
5963
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
6064
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
6165
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
@@ -1858,6 +1862,49 @@ public void testUpdateWithMlEnabled() throws IOException, InterruptedException {
18581862
}
18591863
}
18601864

1865+
public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException {
1866+
var model = new ElserInternalModel(
1867+
"inference_id",
1868+
TaskType.SPARSE_EMBEDDING,
1869+
"elasticsearch",
1870+
new ElserInternalServiceSettings(
1871+
new ElasticsearchInternalServiceSettings(1, 1, "id", new AdaptiveAllocationsSettings(false, 0, 0), null)
1872+
),
1873+
new ElserMlNodeTaskSettings(),
1874+
null
1875+
);
1876+
1877+
var client = mock(Client.class);
1878+
when(client.threadPool()).thenReturn(threadPool);
1879+
1880+
doAnswer(invocationOnMock -> {
1881+
ActionListener<GetTrainedModelsAction.Response> listener = invocationOnMock.getArgument(2);
1882+
var builder = GetTrainedModelsAction.Response.builder();
1883+
builder.setModels(List.of(mock(TrainedModelConfig.class)));
1884+
builder.setTotalCount(1);
1885+
1886+
listener.onResponse(builder.build());
1887+
return Void.TYPE;
1888+
}).when(client).execute(eq(GetTrainedModelsAction.INSTANCE), any(), any());
1889+
1890+
doAnswer(invocationOnMock -> {
1891+
ActionListener<CreateTrainedModelAssignmentAction.Response> listener = invocationOnMock.getArgument(2);
1892+
listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT));
1893+
return Void.TYPE;
1894+
}).when(client).execute(eq(StartTrainedModelDeploymentAction.INSTANCE), any(), any());
1895+
1896+
try (var service = createService(client)) {
1897+
var actionListener = new PlainActionFuture<Boolean>();
1898+
service.start(model, TimeValue.timeValueSeconds(30), actionListener);
1899+
var exception = expectThrows(
1900+
ElasticsearchStatusException.class,
1901+
() -> actionListener.actionGet(TimeValue.timeValueSeconds(30))
1902+
);
1903+
1904+
assertThat(exception.getMessage(), is("failed"));
1905+
}
1906+
}
1907+
18611908
private ElasticsearchInternalService createService(Client client) {
18621909
var cs = mock(ClusterService.class);
18631910
var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));

0 commit comments

Comments
 (0)