|
7 | 7 |
|
8 | 8 | package org.elasticsearch.xpack.inference.external.http; |
9 | 9 |
|
| 10 | +import org.elasticsearch.common.util.concurrent.EsExecutors; |
10 | 11 | import org.elasticsearch.test.ESTestCase; |
11 | 12 | import org.elasticsearch.threadpool.ThreadPool; |
12 | | -import org.junit.After; |
13 | 13 | import org.junit.Before; |
14 | 14 |
|
15 | | -import java.util.concurrent.CountDownLatch; |
16 | 15 | import java.util.concurrent.atomic.AtomicInteger; |
17 | | -import java.util.concurrent.locks.ReentrantLock; |
| 16 | +import java.util.concurrent.atomic.AtomicReference; |
18 | 17 |
|
19 | 18 | import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; |
20 | | -import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; |
21 | 19 | import static org.hamcrest.Matchers.equalTo; |
22 | | -import static org.mockito.Mockito.spy; |
| 20 | +import static org.mockito.ArgumentMatchers.eq; |
| 21 | +import static org.mockito.Mockito.mock; |
23 | 22 | import static org.mockito.Mockito.times; |
24 | 23 | import static org.mockito.Mockito.verify; |
| 24 | +import static org.mockito.Mockito.verifyNoInteractions; |
| 25 | +import static org.mockito.Mockito.verifyNoMoreInteractions; |
| 26 | +import static org.mockito.Mockito.when; |
25 | 27 |
|
26 | 28 | public class RequestBasedTaskRunnerTests extends ESTestCase { |
27 | 29 | private ThreadPool threadPool; |
28 | 30 |
|
29 | 31 | @Before |
30 | 32 | public void setUp() throws Exception { |
31 | 33 | super.setUp(); |
32 | | - threadPool = spy(createThreadPool(inferenceUtilityPool())); |
| 34 | + threadPool = mock(); |
| 35 | + when(threadPool.executor(UTILITY_THREAD_POOL_NAME)).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); |
33 | 36 | } |
34 | 37 |
|
35 | | - @After |
36 | | - public void tearDown() throws Exception { |
37 | | - terminate(threadPool); |
38 | | - super.tearDown(); |
39 | | - } |
| 38 | + public void testRequestWhileLoopingWillRerunCommand() { |
| 39 | + var expectedTimesRerun = randomInt(5); |
| 40 | + AtomicInteger counter = new AtomicInteger(0); |
40 | 41 |
|
41 | | - public void testLoopOneAtATime() throws Exception { |
42 | | - // count the number of times the runnable is called |
43 | | - var counter = new AtomicInteger(0); |
44 | | - |
45 | | - // block the runnable and wait for the test thread to take an action |
46 | | - var lock = new ReentrantLock(); |
47 | | - var condition = lock.newCondition(); |
48 | | - Runnable block = () -> { |
49 | | - try { |
50 | | - try { |
51 | | - lock.lock(); |
52 | | - condition.await(); |
53 | | - } finally { |
54 | | - lock.unlock(); |
55 | | - } |
56 | | - } catch (InterruptedException e) { |
57 | | - fail(e, "did not unblock the thread in time, likely during threadpool terminate"); |
58 | | - } |
59 | | - }; |
60 | | - Runnable unblock = () -> { |
61 | | - try { |
62 | | - lock.lock(); |
63 | | - condition.signalAll(); |
64 | | - } finally { |
65 | | - lock.unlock(); |
| 42 | + var requestNextRun = new AtomicReference<Runnable>(); |
| 43 | + Runnable command = () -> { |
| 44 | + if (counter.getAndIncrement() < expectedTimesRerun) { |
| 45 | + requestNextRun.get().run(); |
66 | 46 | } |
67 | 47 | }; |
68 | | - |
69 | | - var runner = new RequestBasedTaskRunner(() -> { |
70 | | - counter.incrementAndGet(); |
71 | | - block.run(); |
72 | | - }, threadPool, UTILITY_THREAD_POOL_NAME); |
73 | | - |
74 | | - // given we have not called requestNextRun, then no thread should have started |
75 | | - assertThat(counter.get(), equalTo(0)); |
76 | | - verify(threadPool, times(0)).executor(UTILITY_THREAD_POOL_NAME); |
77 | | - |
| 48 | + var runner = new RequestBasedTaskRunner(command, threadPool, UTILITY_THREAD_POOL_NAME); |
| 49 | + requestNextRun.set(runner::requestNextRun); |
78 | 50 | runner.requestNextRun(); |
79 | 51 |
|
80 | | - // given that we have called requestNextRun, then 1 thread should run once |
81 | | - assertBusy(() -> { |
82 | | - verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME); |
83 | | - assertThat(counter.get(), equalTo(1)); |
84 | | - }); |
85 | | - |
86 | | - // given that we have called requestNextRun while a thread was running, and the thread was blocked |
87 | | - runner.requestNextRun(); |
88 | | - // then 1 thread should run once |
89 | | - verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME); |
90 | | - assertThat(counter.get(), equalTo(1)); |
| 52 | + verify(threadPool, times(1)).executor(eq(UTILITY_THREAD_POOL_NAME)); |
| 53 | + verifyNoMoreInteractions(threadPool); |
| 54 | + assertThat(counter.get(), equalTo(expectedTimesRerun + 1)); |
| 55 | + } |
91 | 56 |
|
92 | | - // given the thread is unblocked |
93 | | - unblock.run(); |
94 | | - // then 1 thread should run twice |
95 | | - verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME); |
96 | | - assertBusy(() -> assertThat(counter.get(), equalTo(2))); |
| 57 | + public void testRequestWhileNotLoopingWillQueueCommand() { |
| 58 | + AtomicInteger counter = new AtomicInteger(0); |
97 | 59 |
|
98 | | - // given the thread is unblocked again, but there were only two calls to requestNextRun |
99 | | - unblock.run(); |
100 | | - // then 1 thread should run twice |
101 | | - verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME); |
102 | | - assertBusy(() -> assertThat(counter.get(), equalTo(2))); |
| 60 | + var runner = new RequestBasedTaskRunner(counter::incrementAndGet, threadPool, UTILITY_THREAD_POOL_NAME); |
103 | 61 |
|
104 | | - // given no thread is running, when we call requestNextRun |
105 | | - runner.requestNextRun(); |
106 | | - // then a second thread should start for the third run |
107 | | - assertBusy(() -> { |
108 | | - verify(threadPool, times(2)).executor(UTILITY_THREAD_POOL_NAME); |
109 | | - assertThat(counter.get(), equalTo(3)); |
110 | | - }); |
111 | | - |
112 | | - // given the thread is unblocked, then it should exit and rejoin the threadpool |
113 | | - unblock.run(); |
114 | | - assertTrue("Test thread should unblock after all runs complete", terminate(threadPool)); |
115 | | - |
116 | | - // final check - we ran three times on two threads |
117 | | - verify(threadPool, times(2)).executor(UTILITY_THREAD_POOL_NAME); |
118 | | - assertThat(counter.get(), equalTo(3)); |
| 62 | + for (int i = 1; i < randomInt(10); i++) { |
| 63 | + runner.requestNextRun(); |
| 64 | + verify(threadPool, times(i)).executor(eq(UTILITY_THREAD_POOL_NAME)); |
| 65 | + assertThat(counter.get(), equalTo(i)); |
| 66 | + } |
| 67 | + ; |
119 | 68 | } |
120 | 69 |
|
121 | | - public void testCancel() throws Exception { |
122 | | - // count the number of times the runnable is called |
123 | | - var counter = new AtomicInteger(0); |
124 | | - var latch = new CountDownLatch(1); |
125 | | - var runner = new RequestBasedTaskRunner(() -> { |
126 | | - counter.incrementAndGet(); |
127 | | - try { |
128 | | - latch.await(); |
129 | | - } catch (InterruptedException e) { |
130 | | - fail(e, "did not unblock the thread in time, likely during threadpool terminate"); |
131 | | - } |
132 | | - }, threadPool, UTILITY_THREAD_POOL_NAME); |
| 70 | + public void testCancelBeforeRunning() { |
| 71 | + AtomicInteger counter = new AtomicInteger(0); |
133 | 72 |
|
134 | | - // given that we have called requestNextRun, then 1 thread should run once |
| 73 | + var runner = new RequestBasedTaskRunner(counter::incrementAndGet, threadPool, UTILITY_THREAD_POOL_NAME); |
| 74 | + runner.cancel(); |
135 | 75 | runner.requestNextRun(); |
136 | | - assertBusy(() -> { |
137 | | - verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME); |
138 | | - assertThat(counter.get(), equalTo(1)); |
139 | | - }); |
140 | 76 |
|
141 | | - // given that a thread is running, three more calls will be queued |
142 | | - runner.requestNextRun(); |
143 | | - runner.requestNextRun(); |
| 77 | + verifyNoInteractions(threadPool); |
| 78 | + assertThat(counter.get(), equalTo(0)); |
| 79 | + } |
| 80 | + |
| 81 | + public void testCancelWhileRunning() { |
| 82 | + var expectedTimesRerun = randomInt(5); |
| 83 | + AtomicInteger counter = new AtomicInteger(0); |
| 84 | + |
| 85 | + var runnerRef = new AtomicReference<RequestBasedTaskRunner>(); |
| 86 | + Runnable command = () -> { |
| 87 | + if (counter.getAndIncrement() < expectedTimesRerun) { |
| 88 | + runnerRef.get().requestNextRun(); |
| 89 | + } |
| 90 | + runnerRef.get().cancel(); |
| 91 | + }; |
| 92 | + var runner = new RequestBasedTaskRunner(command, threadPool, UTILITY_THREAD_POOL_NAME); |
| 93 | + runnerRef.set(runner); |
144 | 94 | runner.requestNextRun(); |
145 | 95 |
|
146 | | - // when we cancel the thread, then the thread should immediately exit and rejoin |
147 | | - runner.cancel(); |
148 | | - latch.countDown(); |
149 | | - assertTrue("Test thread should unblock after all runs complete", terminate(threadPool)); |
| 96 | + verify(threadPool, times(1)).executor(eq(UTILITY_THREAD_POOL_NAME)); |
| 97 | + verifyNoMoreInteractions(threadPool); |
| 98 | + assertThat(counter.get(), equalTo(1)); |
150 | 99 |
|
151 | | - // given that we called cancel, when we call requestNextRun then no thread should start |
152 | 100 | runner.requestNextRun(); |
153 | | - verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME); |
| 101 | + verifyNoMoreInteractions(threadPool); |
154 | 102 | assertThat(counter.get(), equalTo(1)); |
155 | 103 | } |
156 | 104 |
|
|
0 commit comments