Skip to content

Commit 04b6cf1

Browse files
authored
[ML] Ensure queued AbstractRunnables are notified when executor stops (elastic#135966) (elastic#136122)
AbstractProcessWorkerExecutorService.notifyQueueRunnables() was making an incorrect assumption that all AbstractRunnables that were submitted for execution would be queued as AbstractRunnables. However, PriorityProcessWorkerExecutorService wraps AbstractRunnables in OrderedRunnable before queueing them, and since OrderedRunnable is not an AbstractRunnable, these were skipped when notifyQueueRunnables() drained the queue, leading to potential hangs. - Make OrderedRunnable extend AbstractRunnable and pass onFailure() and onRejection() calls to the AbstractRunnable it wraps - Ensure that notifyQueueRunnables() is called and the executor marked as shut down if an exception is thrown from start() - Add unit tests - Update docs/changelog/135966.yaml Closes elastic#134651
1 parent 52f85dc commit 04b6cf1

File tree

5 files changed

+256
-17
lines changed

5 files changed

+256
-17
lines changed

docs/changelog/135966.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 135966
2+
summary: Ensure queued `AbstractRunnables` are notified when executor stops
3+
area: Machine Learning
4+
type: bug
5+
issues:
6+
- 134651

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorService.java

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77

88
package org.elasticsearch.xpack.ml.inference.pytorch;
99

10+
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
1011
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
1112
import org.elasticsearch.common.util.concurrent.ThreadContext;
13+
import org.elasticsearch.common.util.concurrent.WrappedRunnable;
1214
import org.elasticsearch.core.SuppressForbidden;
1315
import org.elasticsearch.xpack.ml.job.process.AbstractInitializableRunnable;
1416
import org.elasticsearch.xpack.ml.job.process.AbstractProcessWorkerExecutorService;
1517

18+
import java.util.Objects;
1619
import java.util.concurrent.PriorityBlockingQueue;
1720

1821
/**
@@ -32,13 +35,20 @@ public enum RequestPriority {
3235
};
3336

3437
/**
35-
* A Runnable sorted first by RequestPriority then a tie breaker which in
36-
* most cases will be the insertion order
38+
* A wrapper around an {@link AbstractRunnable} which allows it to be sorted first by {@link RequestPriority}, then by a tiebreaker,
39+
* which in most cases will be the insertion order
3740
*/
38-
public record OrderedRunnable(RequestPriority priority, long tieBreaker, Runnable runnable)
39-
implements
40-
Comparable<OrderedRunnable>,
41-
Runnable {
41+
protected static final class OrderedRunnable extends AbstractRunnable implements Comparable<OrderedRunnable>, WrappedRunnable {
42+
private final RequestPriority priority;
43+
private final long tieBreaker;
44+
private final AbstractRunnable runnable;
45+
46+
protected OrderedRunnable(RequestPriority priority, long tieBreaker, AbstractRunnable runnable) {
47+
this.priority = priority;
48+
this.tieBreaker = tieBreaker;
49+
this.runnable = runnable;
50+
}
51+
4252
@Override
4353
public int compareTo(OrderedRunnable o) {
4454
int p = this.priority.compareTo(o.priority);
@@ -50,10 +60,55 @@ public int compareTo(OrderedRunnable o) {
5060
}
5161

5262
@Override
53-
public void run() {
63+
public void onFailure(Exception e) {
64+
runnable.onFailure(e);
65+
}
66+
67+
@Override
68+
public void onRejection(Exception e) {
69+
runnable.onRejection(e);
70+
}
71+
72+
@Override
73+
protected void doRun() throws Exception {
5474
runnable.run();
5575
}
56-
};
76+
77+
@Override
78+
public boolean isForceExecution() {
79+
return runnable.isForceExecution();
80+
}
81+
82+
@Override
83+
public void onAfter() {
84+
runnable.onAfter();
85+
}
86+
87+
@Override
88+
public Runnable unwrap() {
89+
return runnable;
90+
}
91+
92+
@Override
93+
public boolean equals(Object obj) {
94+
if (obj == this) return true;
95+
if (obj == null || obj.getClass() != this.getClass()) return false;
96+
var that = (OrderedRunnable) obj;
97+
return Objects.equals(this.priority, that.priority)
98+
&& this.tieBreaker == that.tieBreaker
99+
&& Objects.equals(this.runnable, that.runnable);
100+
}
101+
102+
@Override
103+
public int hashCode() {
104+
return Objects.hash(priority, tieBreaker, runnable);
105+
}
106+
107+
@Override
108+
public String toString() {
109+
return "OrderedRunnable[" + "priority=" + priority + ", " + "tieBreaker=" + tieBreaker + ", " + "runnable=" + runnable + ']';
110+
}
111+
}
57112

58113
private final int queueCapacity;
59114

@@ -93,7 +148,7 @@ public synchronized void executeWithPriority(AbstractInitializableRunnable comma
93148
}
94149

95150
// PriorityBlockingQueue::offer always returns true
96-
queue.offer(new OrderedRunnable(priority, tieBreaker, contextHolder.preserveContext(command)));
151+
queue.offer(new OrderedRunnable(priority, tieBreaker, (AbstractRunnable) contextHolder.preserveContext(command)));
97152
if (isShutdown()) {
98153
// the worker shutdown during this function
99154
notifyQueueRunnables();

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/AbstractProcessWorkerExecutorService.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,12 @@ public void start() {
125125
running.set(false);
126126
}
127127
}
128-
129-
notifyQueueRunnables();
130128
} catch (InterruptedException e) {
131129
Thread.currentThread().interrupt();
132130
} finally {
131+
// If we're throwing an exception, shutdown() may not have been called, so call it here
132+
shutdown();
133+
notifyQueueRunnables();
133134
Runnable onComplete = onCompletion.get();
134135
if (onComplete != null) {
135136
onComplete.run();
@@ -155,17 +156,18 @@ public synchronized void notifyQueueRunnables() {
155156
format("[%s] notifying [%d] queued requests that have not been processed before shutdown", processName, queue.size())
156157
);
157158

158-
List<Runnable> notExecuted = new ArrayList<>();
159+
List<T> notExecuted = new ArrayList<>();
159160
queue.drainTo(notExecuted);
160161

161-
String msg = "unable to process as " + processName + " worker service has shutdown";
162162
Exception ex = error.get();
163-
for (Runnable runnable : notExecuted) {
164-
if (runnable instanceof AbstractRunnable ar) {
163+
for (T runnable : notExecuted) {
164+
if (runnable instanceof AbstractRunnable abstractRunnable) {
165165
if (ex != null) {
166-
ar.onFailure(ex);
166+
abstractRunnable.onFailure(ex);
167167
} else {
168-
ar.onRejection(new EsRejectedExecutionException(msg, true));
168+
abstractRunnable.onRejection(
169+
new EsRejectedExecutionException("unable to process as " + processName + " worker service has shutdown", true)
170+
);
169171
}
170172
}
171173
}

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorServiceTests.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222

2323
import static org.elasticsearch.xpack.ml.inference.pytorch.PriorityProcessWorkerExecutorService.RequestPriority;
2424
import static org.hamcrest.Matchers.equalTo;
25+
import static org.hamcrest.Matchers.is;
2526
import static org.hamcrest.Matchers.lessThan;
27+
import static org.hamcrest.Matchers.not;
2628

2729
public class PriorityProcessWorkerExecutorServiceTests extends ESTestCase {
2830

@@ -177,6 +179,49 @@ public void testOrderedRunnables_MixedPriorities() {
177179
}
178180
}
179181

182+
public void testNotifyQueueRunnables_notifiesAllQueuedRunnables() throws InterruptedException {
183+
notifyQueueRunnables(false);
184+
}
185+
186+
public void testNotifyQueueRunnables_notifiesAllQueuedRunnables_withError() throws InterruptedException {
187+
notifyQueueRunnables(true);
188+
}
189+
190+
private void notifyQueueRunnables(boolean withError) {
191+
int queueSize = 10;
192+
var executor = createProcessWorkerExecutorService(queueSize);
193+
194+
List<QueueDrainingRunnable> runnables = new ArrayList<>(queueSize);
195+
// First fill the queue
196+
for (int i = 0; i < queueSize; ++i) {
197+
QueueDrainingRunnable runnable = new QueueDrainingRunnable();
198+
runnables.add(runnable);
199+
executor.executeWithPriority(runnable, RequestPriority.NORMAL, i);
200+
}
201+
202+
assertThat(executor.queueSize(), is(queueSize));
203+
204+
// Set the executor to be stopped
205+
if (withError) {
206+
executor.shutdownNowWithError(new Exception());
207+
} else {
208+
executor.shutdownNow();
209+
}
210+
211+
// Start the executor, which will cause notifyQueueRunnables() to be called immediately since the executor is already stopped
212+
executor.start();
213+
214+
// Confirm that all the runnables were notified
215+
for (QueueDrainingRunnable runnable : runnables) {
216+
assertThat(runnable.initialized, is(true));
217+
assertThat(runnable.hasBeenRun, is(false));
218+
assertThat(runnable.hasBeenRejected, not(withError));
219+
assertThat(runnable.hasBeenFailed, is(withError));
220+
}
221+
222+
assertThat(executor.queueSize(), is(0));
223+
}
224+
180225
private PriorityProcessWorkerExecutorService createProcessWorkerExecutorService(int queueSize) {
181226
return new PriorityProcessWorkerExecutorService(
182227
threadPool.getThreadContext(),
@@ -244,4 +289,32 @@ public void init() {
244289
// do nothing
245290
}
246291
}
292+
293+
private static class QueueDrainingRunnable extends AbstractInitializableRunnable {
294+
295+
private boolean initialized = false;
296+
private boolean hasBeenRun = false;
297+
private boolean hasBeenRejected = false;
298+
private boolean hasBeenFailed = false;
299+
300+
@Override
301+
public void init() {
302+
initialized = true;
303+
}
304+
305+
@Override
306+
public void onRejection(Exception e) {
307+
hasBeenRejected = true;
308+
}
309+
310+
@Override
311+
public void onFailure(Exception e) {
312+
hasBeenFailed = true;
313+
}
314+
315+
@Override
316+
protected void doRun() {
317+
hasBeenRun = true;
318+
}
319+
}
247320
}

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorServiceTests.java

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,20 @@
1414
import org.elasticsearch.threadpool.ThreadPool;
1515
import org.junit.After;
1616

17+
import java.util.ArrayList;
18+
import java.util.List;
1719
import java.util.concurrent.ConcurrentLinkedQueue;
1820
import java.util.concurrent.CountDownLatch;
1921
import java.util.concurrent.Future;
22+
import java.util.concurrent.FutureTask;
2023
import java.util.concurrent.atomic.AtomicBoolean;
2124
import java.util.concurrent.atomic.AtomicInteger;
2225

2326
import static org.hamcrest.Matchers.containsString;
2427
import static org.hamcrest.Matchers.hasSize;
28+
import static org.hamcrest.Matchers.is;
2529
import static org.hamcrest.Matchers.isA;
30+
import static org.hamcrest.Matchers.not;
2631

2732
public class ProcessWorkerExecutorServiceTests extends ESTestCase {
2833

@@ -137,7 +142,105 @@ public void testAutodetectWorkerExecutorServiceDoesNotSwallowErrors() {
137142
assertThat(e.getMessage(), containsString("future error"));
138143
}
139144

145+
public void testNotifyQueueRunnables_notifiesAllQueuedAbstractRunnables() throws InterruptedException {
146+
notifyQueueRunnables(false);
147+
}
148+
149+
public void testNotifyQueueRunnables_notifiesAllQueuedAbstractRunnables_withError() throws InterruptedException {
150+
notifyQueueRunnables(true);
151+
}
152+
153+
private void notifyQueueRunnables(boolean withError) {
154+
int entries = 10;
155+
var executor = createExecutorService();
156+
157+
List<QueueDrainingRunnable> abstractRunnables = new ArrayList<>();
158+
// First fill the queue with both AbstractRunnable and Runnable
159+
for (int i = 0; i < entries; ++i) {
160+
QueueDrainingRunnable abstractRunnable = new QueueDrainingRunnable();
161+
abstractRunnables.add(abstractRunnable);
162+
executor.execute(abstractRunnable);
163+
Runnable runnable = () -> fail("Should not be invoked");
164+
executor.execute(runnable);
165+
}
166+
167+
assertThat(executor.queueSize(), is(entries * 2));
168+
169+
// Set the executor to be stopped
170+
if (withError) {
171+
executor.shutdownNowWithError(new Exception());
172+
} else {
173+
executor.shutdownNow();
174+
}
175+
176+
// Start the executor, which will cause notifyQueueRunnables() to be called immediately since the executor is already stopped
177+
executor.start();
178+
179+
// Confirm that all the abstract runnables were notified
180+
for (QueueDrainingRunnable runnable : abstractRunnables) {
181+
assertThat(runnable.initialized, is(true));
182+
assertThat(runnable.hasBeenRun, is(false));
183+
assertThat(runnable.hasBeenRejected, not(withError));
184+
assertThat(runnable.hasBeenFailed, is(withError));
185+
}
186+
187+
assertThat(executor.queueSize(), is(0));
188+
}
189+
190+
public void testQueuedAbstractRunnablesAreNotified_whenRunnableFutureEncountersError() {
191+
var executor = createExecutorService();
192+
193+
// First queue a RunnableFuture that will stop the executor then throw an Exception wrapping an error when it completes
194+
Error expectedError = new Error("Expected");
195+
FutureTask<Void> runnableFuture = new FutureTask<>(() -> { throw new Exception(expectedError); });
196+
executor.execute(runnableFuture);
197+
198+
// Then queue an AbstractRunnable that should be notified if it's still in the queue when the start() method returns
199+
QueueDrainingRunnable abstractRunnable = new QueueDrainingRunnable();
200+
executor.execute(abstractRunnable);
201+
202+
// Start the executor and expect the error to be thrown and the executor to be marked as shut down
203+
Error error = assertThrows(Error.class, executor::start);
204+
assertThat(error, is(expectedError));
205+
assertThat(executor.isShutdown(), is(true));
206+
assertThat(executor.isTerminated(), is(true));
207+
208+
// Confirm that all the abstract runnable was notified
209+
assertThat(abstractRunnable.initialized, is(true));
210+
assertThat(abstractRunnable.hasBeenRun, is(false));
211+
assertThat(abstractRunnable.hasBeenRejected, is(true));
212+
assertThat(abstractRunnable.hasBeenFailed, is(false));
213+
}
214+
140215
private ProcessWorkerExecutorService createExecutorService() {
141216
return new ProcessWorkerExecutorService(threadPool.getThreadContext(), TEST_PROCESS, QUEUE_SIZE);
142217
}
218+
219+
private static class QueueDrainingRunnable extends AbstractInitializableRunnable {
220+
221+
private boolean initialized = false;
222+
private boolean hasBeenRun = false;
223+
private boolean hasBeenRejected = false;
224+
private boolean hasBeenFailed = false;
225+
226+
@Override
227+
public void init() {
228+
initialized = true;
229+
}
230+
231+
@Override
232+
public void onRejection(Exception e) {
233+
hasBeenRejected = true;
234+
}
235+
236+
@Override
237+
public void onFailure(Exception e) {
238+
hasBeenFailed = true;
239+
}
240+
241+
@Override
242+
protected void doRun() {
243+
hasBeenRun = true;
244+
}
245+
}
143246
}

0 commit comments

Comments
 (0)