Skip to content

Commit 8d40454

Browse files
Refactoring async start
1 parent ef179b3 commit 8d40454

File tree

2 files changed

+105
-42
lines changed

2 files changed

+105
-42
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12-
import org.elasticsearch.ExceptionsHelper;
1312
import org.elasticsearch.action.ActionListener;
1413
import org.elasticsearch.action.support.ContextPreservingActionListener;
1514
import org.elasticsearch.cluster.service.ClusterService;
@@ -44,6 +43,7 @@ public class HttpRequestSender implements Sender {
4443
*/
4544
public static class Factory {
4645
private final HttpRequestSender httpRequestSender;
46+
private static final TimeValue START_COMPLETED_WAIT_TIME = TimeValue.timeValueSeconds(5);
4747

4848
public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) {
4949
Objects.requireNonNull(serviceComponents);
@@ -70,7 +70,8 @@ public Factory(ServiceComponents serviceComponents, HttpClientManager httpClient
7070
httpClientManager,
7171
requestSender,
7272
service,
73-
startCompleted
73+
startCompleted,
74+
START_COMPLETED_WAIT_TIME
7475
);
7576
}
7677

@@ -80,27 +81,31 @@ public Sender createSender() {
8081
}
8182

8283
private static final Logger logger = LogManager.getLogger(HttpRequestSender.class);
83-
private static final TimeValue START_COMPLETED_WAIT_TIME = TimeValue.timeValueSeconds(5);
8484

8585
private final ThreadPool threadPool;
8686
private final HttpClientManager manager;
8787
private final AtomicBoolean started = new AtomicBoolean(false);
88+
private final AtomicBoolean startedCompleted = new AtomicBoolean(false);
8889
private final RequestSender requestSender;
8990
private final RequestExecutor service;
90-
private final CountDownLatch startCompleted;
91+
private final CountDownLatch startCompletedLatch;
92+
private final TimeValue startCompletedWaitTime;
9193

92-
private HttpRequestSender(
94+
// Visible for testing
95+
protected HttpRequestSender(
9396
ThreadPool threadPool,
9497
HttpClientManager httpClientManager,
9598
RequestSender requestSender,
9699
RequestExecutor service,
97-
CountDownLatch startCompleted
100+
CountDownLatch startCompletedLatch,
101+
TimeValue startCompletedWaitTime
98102
) {
99103
this.threadPool = Objects.requireNonNull(threadPool);
100104
this.manager = Objects.requireNonNull(httpClientManager);
101105
this.requestSender = Objects.requireNonNull(requestSender);
102106
this.service = Objects.requireNonNull(service);
103-
this.startCompleted = Objects.requireNonNull(startCompleted);
107+
this.startCompletedLatch = Objects.requireNonNull(startCompletedLatch);
108+
this.startCompletedWaitTime = Objects.requireNonNull(startCompletedWaitTime);
104109
}
105110

106111
/**
@@ -109,7 +114,12 @@ private HttpRequestSender(
109114
@Override
110115
public void startAsynchronously(ActionListener<Void> listener) {
111116
if (started.compareAndSet(false, true)) {
112-
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> startInternal(listener));
117+
var preservedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext());
118+
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> startInternal(preservedListener));
119+
} else if (startedCompleted.get() == false) {
120+
var preservedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext());
121+
// wait on another thread so we don't potential block a transport thread
122+
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> waitForStartToCompleteWithListener(preservedListener));
113123
} else {
114124
listener.onResponse(null);
115125
}
@@ -122,12 +132,22 @@ private void startInternal(ActionListener<Void> listener) {
122132
manager.start();
123133
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(service::start);
124134
waitForStartToComplete();
135+
startedCompleted.set(true);
125136
listener.onResponse(null);
126137
} catch (Exception ex) {
127138
listener.onFailure(ex);
128139
}
129140
}
130141

142+
private void waitForStartToCompleteWithListener(ActionListener<Void> listener) {
143+
try {
144+
waitForStartToComplete();
145+
listener.onResponse(null);
146+
} catch (Exception e) {
147+
listener.onFailure(e);
148+
}
149+
}
150+
131151
/**
132152
* Start various internal services. This is required before sending requests.
133153
*
@@ -136,10 +156,10 @@ private void startInternal(ActionListener<Void> listener) {
136156
@Override
137157
public void startSynchronously() {
138158
if (started.compareAndSet(false, true)) {
139-
ActionListener<Void> listener = ActionListener.wrap(unused -> {}, exception -> {
140-
logger.error("Http sender failed to start", exception);
141-
ExceptionsHelper.maybeDieOnAnotherThread(exception);
142-
});
159+
ActionListener<Void> listener = ActionListener.wrap(
160+
unused -> {},
161+
exception -> logger.error("Http sender failed to start", exception)
162+
);
143163
startInternal(listener);
144164
}
145165
// Handle the case where start*() was already called and this would return immediately because the started flag is already true
@@ -148,7 +168,7 @@ public void startSynchronously() {
148168

149169
private void waitForStartToComplete() {
150170
try {
151-
if (startCompleted.await(START_COMPLETED_WAIT_TIME.getSeconds(), TimeUnit.SECONDS) == false) {
171+
if (startCompletedLatch.await(startCompletedWaitTime.getMillis(), TimeUnit.MILLISECONDS) == false) {
152172
throw new IllegalStateException("Http sender startup did not complete in time");
153173
}
154174
} catch (InterruptedException e) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java

Lines changed: 72 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import org.elasticsearch.xcontent.XContentType;
2525
import org.elasticsearch.xpack.inference.external.http.HttpClient;
2626
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
27+
import org.elasticsearch.xpack.inference.external.http.RequestExecutor;
28+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
2729
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
2830
import org.elasticsearch.xpack.inference.external.request.Request;
2931
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
@@ -40,6 +42,7 @@
4042
import java.util.EnumSet;
4143
import java.util.List;
4244
import java.util.Locale;
45+
import java.util.concurrent.CountDownLatch;
4346
import java.util.concurrent.ExecutorService;
4447
import java.util.concurrent.TimeUnit;
4548
import java.util.concurrent.atomic.AtomicReference;
@@ -64,6 +67,8 @@
6467
import static org.mockito.Mockito.doAnswer;
6568
import static org.mockito.Mockito.doThrow;
6669
import static org.mockito.Mockito.mock;
70+
import static org.mockito.Mockito.times;
71+
import static org.mockito.Mockito.verify;
6772
import static org.mockito.Mockito.when;
6873

6974
public class HttpRequestSenderTests extends ESTestCase {
@@ -104,18 +109,28 @@ public void testCreateSender_ReturnsSameRequestExecutorInstance() {
104109
}
105110

106111
public void testCreateSender_CanCallStartMultipleTimes() throws Exception {
107-
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
112+
var mockManager = createMockHttpClientManager();
113+
114+
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), mockManager, mockClusterServiceEmpty());
108115

109116
try (var sender = createSender(senderFactory)) {
110117
sender.startSynchronously();
111118
sender.startSynchronously();
112119
sender.startSynchronously();
113120
}
121+
122+
verify(mockManager, times(1)).start();
114123
}
115124

116-
public void testStart_ThrowsException_WhenAnErrorOccurs() throws IOException {
125+
private HttpClientManager createMockHttpClientManager() {
117126
var mockManager = mock(HttpClientManager.class);
118127
when(mockManager.getHttpClient()).thenReturn(mock(HttpClient.class));
128+
129+
return mockManager;
130+
}
131+
132+
public void testStart_ThrowsExceptionWaitingForStartToComplete_WhenAnErrorOccurs() throws IOException {
133+
var mockManager = createMockHttpClientManager();
119134
doThrow(new Error("failed")).when(mockManager).start();
120135

121136
var senderFactory = new HttpRequestSender.Factory(
@@ -125,41 +140,66 @@ public void testStart_ThrowsException_WhenAnErrorOccurs() throws IOException {
125140
);
126141

127142
try (var sender = senderFactory.createSender()) {
128-
// Checking for both exception types because there's a race condition between the Error being thrown on a separate thread
129-
// and the startCompleted latch timing out waiting for the start to complete
130-
var exception = expectThrowsAnyOf(List.of(Error.class, IllegalStateException.class), sender::startSynchronously);
131-
132-
if (exception instanceof Error) {
133-
assertThat(exception.getMessage(), is("failed"));
134-
} else {
135-
// IllegalStateException can be thrown if the startCompleted latch times out waiting for the start to complete
136-
assertThat(exception.getMessage(), is("Http sender startup did not complete in time"));
137-
}
143+
var exception = expectThrows(Error.class, sender::startSynchronously);
144+
145+
assertThat(exception.getMessage(), is("failed"));
138146
}
139147
}
140148

141-
public void testStart_ThrowsExceptionWaitingForStartToComplete() throws IOException {
142-
var mockManager = mock(HttpClientManager.class);
143-
when(mockManager.getHttpClient()).thenReturn(mock(HttpClient.class));
144-
// This won't get rethrown because it is not an Error
149+
public void testStart_ThrowsExceptionWaitingForStartToComplete() {
150+
var mockManager = createMockHttpClientManager();
145151
doThrow(new IllegalArgumentException("failed")).when(mockManager).start();
146152

147-
var senderFactory = new HttpRequestSender.Factory(
148-
ServiceComponentsTests.createWithEmptySettings(threadPool),
153+
// Force the startup to never complete
154+
var latch = new CountDownLatch(1);
155+
var sender = new HttpRequestSender(
156+
threadPool,
149157
mockManager,
150-
mockClusterServiceEmpty()
158+
mock(RequestSender.class),
159+
mock(RequestExecutor.class),
160+
latch,
161+
// Override the wait time so we don't block the test for too long
162+
TimeValue.timeValueMillis(1)
151163
);
152164

153-
try (var sender = senderFactory.createSender()) {
154-
var exception = expectThrows(IllegalStateException.class, sender::startSynchronously);
165+
var exception = expectThrows(IllegalStateException.class, sender::startSynchronously);
155166

156-
assertThat(exception.getMessage(), is("Http sender startup did not complete in time"));
157-
}
167+
assertThat(exception.getMessage(), is("Http sender startup did not complete in time"));
168+
}
169+
170+
public void testStartAsync_WaitsAsyncForStartToComplete_ThrowsWhenItTimesOut_ThenSucceeds() {
171+
var mockManager = createMockHttpClientManager();
172+
var latch = new CountDownLatch(1);
173+
var sender = new HttpRequestSender(
174+
threadPool,
175+
mockManager,
176+
mock(RequestSender.class),
177+
mock(RequestExecutor.class),
178+
latch,
179+
// Override the wait time so we don't block the test for too long
180+
TimeValue.timeValueMillis(1)
181+
);
182+
183+
var listener = new PlainActionFuture<Void>();
184+
sender.startAsynchronously(listener);
185+
186+
var exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT));
187+
assertThat(exception.getMessage(), is("Http sender startup did not complete in time"));
188+
189+
// simulate the start completing
190+
latch.countDown();
191+
192+
var listenerCompleted = new PlainActionFuture<Void>();
193+
sender.startAsynchronously(listenerCompleted);
194+
assertNull(listenerCompleted.actionGet(TIMEOUT));
195+
196+
verify(mockManager, times(1)).start();
158197
}
159198

160199
public void testCreateSender_CanCallStartAsyncMultipleTimes() throws Exception {
200+
var mockManager = createMockHttpClientManager();
161201
var asyncCalls = 3;
162-
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
202+
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), mockManager, mockClusterServiceEmpty());
163203

164204
try (var sender = createSender(senderFactory)) {
165205
var listenerList = new ArrayList<PlainActionFuture<Void>>();
@@ -175,11 +215,14 @@ public void testCreateSender_CanCallStartAsyncMultipleTimes() throws Exception {
175215
assertNull(listener.actionGet(TIMEOUT));
176216
}
177217
}
218+
219+
verify(mockManager, times(1)).start();
178220
}
179221

180222
public void testCreateSender_CanCallStartAsyncAndSyncMultipleTimes() throws Exception {
223+
var mockManager = createMockHttpClientManager();
181224
var asyncCalls = 3;
182-
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
225+
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), mockManager, mockClusterServiceEmpty());
183226

184227
try (var sender = createSender(senderFactory)) {
185228
var listenerList = new ArrayList<PlainActionFuture<Void>>();
@@ -196,6 +239,8 @@ public void testCreateSender_CanCallStartAsyncAndSyncMultipleTimes() throws Exce
196239
assertNull(listener.actionGet(TIMEOUT));
197240
}
198241
}
242+
243+
verify(mockManager, times(1)).start();
199244
}
200245

201246
public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
@@ -340,8 +385,7 @@ public void testHttpRequestSender_Throws_WhenATimeoutOccurs() throws Exception {
340385
}
341386

342387
public void testHttpRequestSenderWithTimeout_Throws_WhenATimeoutOccurs() throws Exception {
343-
var mockManager = mock(HttpClientManager.class);
344-
when(mockManager.getHttpClient()).thenReturn(mock(HttpClient.class));
388+
var mockManager = createMockHttpClientManager();
345389

346390
var senderFactory = new HttpRequestSender.Factory(
347391
ServiceComponentsTests.createWithEmptySettings(threadPool),
@@ -363,8 +407,7 @@ public void testHttpRequestSenderWithTimeout_Throws_WhenATimeoutOccurs() throws
363407
}
364408

365409
public void testSendWithoutQueuingWithTimeout_Throws_WhenATimeoutOccurs() throws Exception {
366-
var mockManager = mock(HttpClientManager.class);
367-
when(mockManager.getHttpClient()).thenReturn(mock(HttpClient.class));
410+
var mockManager = createMockHttpClientManager();
368411

369412
var senderFactory = new HttpRequestSender.Factory(
370413
ServiceComponentsTests.createWithEmptySettings(threadPool),

0 commit comments

Comments
 (0)