Skip to content

Commit f1b48de

Browse files
authored
[ML] fixes inference timeout handling bug that throws unexpected NullPointerException (#87533)
The timeout thread was being scheduled to be executed before all the requisite fields were populated. This would (usually only in extreme circumstances), would throw a null pointer exception that looked like: ``` 2022-06-06T23:40:42,534][ERROR][o.e.b.ElasticsearchUncaughtExceptionHandler] [javaRestTest-2] uncaught exception in thread [elasticsearch[javaRestTest-2][ml_utility][T#4]] java.lang.NullPointerException: Cannot invoke "org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager$ProcessContext.getTimeoutCount()" because "this.processContext" is null at org.elasticsearch.xpack.ml.inference.deployment.AbstractPyTorchAction.onTimeout(AbstractPyTorchAction.java:58) ~[?:?] at org.elasticsearch.common.util.concurrent.ThreadContext$ContextPreservingRunnable.run(ThreadContext.java:709) ~[elasticsearch-8.4.0-SNAPSHOT.jar:?] at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136) ~[?:?] at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635) ~[?:?] at java.lang.Thread.run(Thread.java:833) [?:?] ``` This PR fixes this bug and an additional (more minor) bug where inference call rejections were not accurately counted. closes: #87457
1 parent af9c0da commit f1b48de

File tree

10 files changed

+113
-35
lines changed

10 files changed

+113
-35
lines changed

docs/changelog/87533.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 87533
2+
summary: Fixes inference timeout handling bug that throws unexpected `NullPointerException`
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/AbstractPyTorchAction.java

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,28 @@
1010
import org.apache.logging.log4j.Logger;
1111
import org.elasticsearch.ElasticsearchStatusException;
1212
import org.elasticsearch.action.ActionListener;
13-
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
1413
import org.elasticsearch.core.TimeValue;
1514
import org.elasticsearch.rest.RestStatus;
1615
import org.elasticsearch.threadpool.Scheduler;
1716
import org.elasticsearch.threadpool.ThreadPool;
1817
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1918
import org.elasticsearch.xpack.ml.MachineLearning;
19+
import org.elasticsearch.xpack.ml.job.process.AbstractInitializableRunnable;
2020

2121
import java.util.concurrent.atomic.AtomicBoolean;
2222

2323
import static org.elasticsearch.core.Strings.format;
2424

25-
abstract class AbstractPyTorchAction<T> extends AbstractRunnable {
25+
abstract class AbstractPyTorchAction<T> extends AbstractInitializableRunnable {
2626

2727
private final String modelId;
2828
private final long requestId;
2929
private final TimeValue timeout;
30-
private final Scheduler.Cancellable timeoutHandler;
30+
private Scheduler.Cancellable timeoutHandler;
3131
private final DeploymentManager.ProcessContext processContext;
3232
private final AtomicBoolean notified = new AtomicBoolean();
33-
3433
private final ActionListener<T> listener;
34+
private final ThreadPool threadPool;
3535

3636
protected AbstractPyTorchAction(
3737
String modelId,
@@ -41,16 +41,23 @@ protected AbstractPyTorchAction(
4141
ThreadPool threadPool,
4242
ActionListener<T> listener
4343
) {
44-
this.modelId = modelId;
44+
this.modelId = ExceptionsHelper.requireNonNull(modelId, "modelId");
4545
this.requestId = requestId;
46-
this.timeout = timeout;
47-
this.timeoutHandler = threadPool.schedule(
48-
this::onTimeout,
49-
ExceptionsHelper.requireNonNull(timeout, "timeout"),
50-
MachineLearning.UTILITY_THREAD_POOL_NAME
51-
);
52-
this.processContext = processContext;
53-
this.listener = listener;
46+
this.timeout = ExceptionsHelper.requireNonNull(timeout, "timeout");
47+
this.processContext = ExceptionsHelper.requireNonNull(processContext, "processContext");
48+
this.listener = ExceptionsHelper.requireNonNull(listener, "listener");
49+
this.threadPool = ExceptionsHelper.requireNonNull(threadPool, "threadPool");
50+
}
51+
52+
/**
53+
* Needs to be called after construction. This init starts the timeout handler and needs to be called before added to the executor for
54+
* scheduled work.
55+
*/
56+
@Override
57+
public final void init() {
58+
if (this.timeoutHandler == null) {
59+
this.timeoutHandler = threadPool.schedule(this::onTimeout, timeout, MachineLearning.UTILITY_THREAD_POOL_NAME);
60+
}
5461
}
5562

5663
void onTimeout() {
@@ -66,17 +73,31 @@ void onTimeout() {
6673
}
6774

6875
void onSuccess(T result) {
69-
timeoutHandler.cancel();
76+
if (timeoutHandler != null) {
77+
timeoutHandler.cancel();
78+
} else {
79+
assert false : "init() not called, timeout handler unexpectedly null";
80+
}
7081
if (notified.compareAndSet(false, true)) {
7182
listener.onResponse(result);
7283
return;
7384
}
7485
getLogger().debug("[{}] request [{}] received inference response but listener already notified", modelId, requestId);
7586
}
7687

88+
@Override
89+
public void onRejection(Exception e) {
90+
super.onRejection(e);
91+
processContext.getRejectedExecutionCount().incrementAndGet();
92+
}
93+
7794
@Override
7895
public void onFailure(Exception e) {
79-
timeoutHandler.cancel();
96+
if (timeoutHandler != null) {
97+
timeoutHandler.cancel();
98+
} else {
99+
assert false : "init() not called, timeout handler unexpectedly null";
100+
}
80101
if (notified.compareAndSet(false, true)) {
81102
processContext.getResultProcessor().ignoreResponseWithoutNotifying(String.valueOf(requestId));
82103
listener.onFailure(e);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorService.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
package org.elasticsearch.xpack.ml.inference.pytorch;
99

10-
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
1110
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
1211
import org.elasticsearch.common.util.concurrent.ThreadContext;
1312
import org.elasticsearch.core.SuppressForbidden;
13+
import org.elasticsearch.xpack.ml.job.process.AbstractInitializableRunnable;
1414
import org.elasticsearch.xpack.ml.job.process.AbstractProcessWorkerExecutorService;
1515

1616
import java.util.concurrent.PriorityBlockingQueue;
@@ -35,7 +35,7 @@ public enum RequestPriority {
3535
* A Runnable sorted first by RequestPriority then a tie breaker which in
3636
* most cases will be the insertion order
3737
*/
38-
public static record OrderedRunnable(RequestPriority priority, long tieBreaker, Runnable runnable)
38+
public record OrderedRunnable(RequestPriority priority, long tieBreaker, Runnable runnable)
3939
implements
4040
Comparable<OrderedRunnable>,
4141
Runnable {
@@ -76,7 +76,8 @@ public PriorityProcessWorkerExecutorService(ThreadContext contextHolder, String
7676
* @param priority Request priority
7777
* @param tieBreaker For sorting requests of equal priority
7878
*/
79-
public synchronized void executeWithPriority(AbstractRunnable command, RequestPriority priority, long tieBreaker) {
79+
public synchronized void executeWithPriority(AbstractInitializableRunnable command, RequestPriority priority, long tieBreaker) {
80+
command.init();
8081
if (isShutdown()) {
8182
EsRejectedExecutionException rejected = new EsRejectedExecutionException(processName + " worker service has shutdown", true);
8283
command.onRejection(rejected);
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.ml.job.process;
9+
10+
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
11+
12+
/**
13+
* Abstract runnable that has an `init` function that can be called when passed to an executor
14+
*/
15+
public abstract class AbstractInitializableRunnable extends AbstractRunnable {
16+
public abstract void init();
17+
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ public ProcessWorkerExecutorService(ThreadContext contextHolder, String processN
3333

3434
@Override
3535
public synchronized void execute(Runnable command) {
36+
if (command instanceof AbstractInitializableRunnable initializableRunnable) {
37+
initializableRunnable.init();
38+
}
3639
if (isShutdown()) {
3740
EsRejectedExecutionException rejected = new EsRejectedExecutionException(processName + " worker service has shutdown", true);
3841
if (command instanceof AbstractRunnable runnable) {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/ControlMessagePyTorchActionTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public void testRunNotCalledAfterNotified() {
6565
tp,
6666
listener
6767
);
68-
68+
action.init();
6969
action.onTimeout();
7070
action.run();
7171
verify(resultProcessor, times(1)).ignoreResponseWithoutNotifying("1");
@@ -84,6 +84,7 @@ public void testRunNotCalledAfterNotified() {
8484
tp,
8585
listener
8686
);
87+
action.init();
8788

8889
action.onFailure(new IllegalStateException());
8990
action.run();
@@ -122,6 +123,7 @@ public void testDoRun() throws IOException {
122123
tp,
123124
listener
124125
);
126+
action.init();
125127

126128
action.run();
127129

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.client.internal.Client;
12-
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
1312
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
1413
import org.elasticsearch.core.TimeValue;
1514
import org.elasticsearch.test.ESTestCase;
@@ -31,9 +30,6 @@
3130
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
3231
import static org.hamcrest.Matchers.equalTo;
3332
import static org.hamcrest.Matchers.instanceOf;
34-
import static org.mockito.ArgumentMatchers.any;
35-
import static org.mockito.ArgumentMatchers.anyLong;
36-
import static org.mockito.Mockito.doThrow;
3733
import static org.mockito.Mockito.mock;
3834
import static org.mockito.Mockito.when;
3935

@@ -74,6 +70,7 @@ public void testRejectedExecution() {
7470
Long taskId = 1L;
7571
when(task.getId()).thenReturn(taskId);
7672
when(task.isStopped()).thenReturn(Boolean.FALSE);
73+
when(task.getModelId()).thenReturn("test-rejected");
7774

7875
DeploymentManager deploymentManager = new DeploymentManager(
7976
mock(Client.class),
@@ -82,9 +79,12 @@ public void testRejectedExecution() {
8279
mock(PyTorchProcessFactory.class)
8380
);
8481

85-
PriorityProcessWorkerExecutorService executorService = mock(PriorityProcessWorkerExecutorService.class);
86-
doThrow(new EsRejectedExecutionException("mock executor rejection")).when(executorService)
87-
.executeWithPriority(any(AbstractRunnable.class), any(), anyLong());
82+
PriorityProcessWorkerExecutorService executorService = new PriorityProcessWorkerExecutorService(
83+
tp.getThreadContext(),
84+
"test reject",
85+
10
86+
);
87+
executorService.shutdown();
8888

8989
AtomicInteger rejectedCount = new AtomicInteger();
9090

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public void testInferListenerOnlyCalledOnce() {
7575
tp,
7676
listener
7777
);
78-
78+
action.init();
7979
action.onSuccess(new WarningInferenceResults("foo"));
8080
for (int i = 0; i < 10; i++) {
8181
action.onSuccess(new WarningInferenceResults("foo"));
@@ -95,7 +95,7 @@ public void testInferListenerOnlyCalledOnce() {
9595
tp,
9696
listener
9797
);
98-
98+
action.init();
9999
action.onTimeout();
100100
for (int i = 0; i < 10; i++) {
101101
action.onSuccess(new WarningInferenceResults("foo"));
@@ -116,7 +116,7 @@ public void testInferListenerOnlyCalledOnce() {
116116
tp,
117117
listener
118118
);
119-
119+
action.init();
120120
action.onFailure(new Exception("bar"));
121121
for (int i = 0; i < 10; i++) {
122122
action.onSuccess(new WarningInferenceResults("foo"));
@@ -146,7 +146,7 @@ public void testRunNotCalledAfterNotified() {
146146
tp,
147147
listener
148148
);
149-
149+
action.init();
150150
action.onTimeout();
151151
action.run();
152152
verify(resultProcessor, times(1)).ignoreResponseWithoutNotifying("1");
@@ -163,7 +163,7 @@ public void testRunNotCalledAfterNotified() {
163163
tp,
164164
listener
165165
);
166-
166+
action.init();
167167
action.onFailure(new IllegalStateException());
168168
action.run();
169169
verify(resultProcessor, never()).registerRequest(anyString(), any());

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorServiceTests.java

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
package org.elasticsearch.xpack.ml.inference.pytorch;
99

10-
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
1110
import org.elasticsearch.test.ESTestCase;
1211
import org.elasticsearch.threadpool.TestThreadPool;
1312
import org.elasticsearch.threadpool.ThreadPool;
13+
import org.elasticsearch.xpack.ml.job.process.AbstractInitializableRunnable;
1414
import org.junit.After;
1515

1616
import java.util.concurrent.atomic.AtomicInteger;
@@ -39,7 +39,12 @@ public void testQueueCapacityReached() {
3939
var r3 = new RunOrderValidator(3, counter);
4040
executor.executeWithPriority(r3, RequestPriority.NORMAL, 101L);
4141

42+
assertTrue(r1.initialized);
43+
assertTrue(r2.initialized);
44+
assertTrue(r3.initialized);
4245
assertTrue(r3.hasBeenRejected);
46+
assertFalse(r1.hasBeenRejected);
47+
assertFalse(r2.hasBeenRejected);
4348
}
4449

4550
public void testQueueCapacityReached_HighestPriority() {
@@ -56,8 +61,11 @@ public void testQueueCapacityReached_HighestPriority() {
5661
var r5 = new RunOrderValidator(5, counter);
5762
executor.executeWithPriority(r5, RequestPriority.NORMAL, 105L);
5863

64+
assertTrue(r3.initialized);
5965
assertTrue(r3.hasBeenRejected);
66+
assertTrue(highestPriorityAlwaysAccepted.initialized);
6067
assertFalse(highestPriorityAlwaysAccepted.hasBeenRejected);
68+
assertTrue(r5.initialized);
6169
assertTrue(r5.hasBeenRejected);
6270
}
6371

@@ -78,6 +86,10 @@ public void testOrderedRunnables_NormalPriority() {
7886

7987
executor.start();
8088

89+
assertTrue(r1.initialized);
90+
assertTrue(r2.initialized);
91+
assertTrue(r3.initialized);
92+
8193
assertTrue(r1.hasBeenRun);
8294
assertTrue(r2.hasBeenRun);
8395
assertTrue(r3.hasBeenRun);
@@ -114,18 +126,24 @@ private PriorityProcessWorkerExecutorService createProcessWorkerExecutorService(
114126
);
115127
}
116128

117-
private static class RunOrderValidator extends AbstractRunnable {
129+
private static class RunOrderValidator extends AbstractInitializableRunnable {
118130

119131
private boolean hasBeenRun = false;
120132
private boolean hasBeenRejected = false;
121133
private final int expectedOrder;
122134
private final AtomicInteger counter;
135+
private boolean initialized = false;
123136

124137
RunOrderValidator(int expectedOrder, AtomicInteger counter) {
125138
this.expectedOrder = expectedOrder;
126139
this.counter = counter;
127140
}
128141

142+
@Override
143+
public void init() {
144+
initialized = true;
145+
}
146+
129147
@Override
130148
public void onRejection(Exception e) {
131149
hasBeenRejected = true;
@@ -143,7 +161,7 @@ protected void doRun() {
143161
}
144162
}
145163

146-
private static class ShutdownExecutorRunnable extends AbstractRunnable {
164+
private static class ShutdownExecutorRunnable extends AbstractInitializableRunnable {
147165

148166
PriorityProcessWorkerExecutorService executor;
149167

@@ -162,5 +180,9 @@ protected void doRun() {
162180
executor.shutdown();
163181
}
164182

183+
@Override
184+
public void init() {
185+
// do nothing
186+
}
165187
}
166188
}

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorServiceTests.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,19 @@ public void testAutodetectWorkerExecutorService_SubmitAfterShutdown() {
4242
threadPool.generic().execute(executor::start);
4343
executor.shutdown();
4444
AtomicBoolean rejected = new AtomicBoolean(false);
45-
executor.execute(new AbstractRunnable() {
45+
AtomicBoolean initialized = new AtomicBoolean(false);
46+
executor.execute(new AbstractInitializableRunnable() {
4647
@Override
4748
public void onRejection(Exception e) {
4849
assertThat(e, isA(EsRejectedExecutionException.class));
4950
rejected.set(true);
5051
}
5152

53+
@Override
54+
public void init() {
55+
initialized.set(true);
56+
}
57+
5258
@Override
5359
public void onFailure(Exception e) {
5460
fail("onFailure should not be called after the worker is shutdown");
@@ -61,6 +67,7 @@ protected void doRun() throws Exception {
6167
});
6268

6369
assertTrue(rejected.get());
70+
assertTrue(initialized.get());
6471
}
6572

6673
public void testAutodetectWorkerExecutorService_TasksNotExecutedCallHandlerOnShutdown() throws Exception {

0 commit comments

Comments
 (0)