2424import org .elasticsearch .xcontent .XContentType ;
2525import org .elasticsearch .xpack .inference .external .http .HttpClient ;
2626import org .elasticsearch .xpack .inference .external .http .HttpClientManager ;
27+ import org .elasticsearch .xpack .inference .external .http .RequestExecutor ;
28+ import org .elasticsearch .xpack .inference .external .http .retry .RequestSender ;
2729import org .elasticsearch .xpack .inference .external .http .retry .ResponseHandler ;
2830import org .elasticsearch .xpack .inference .external .request .Request ;
2931import org .elasticsearch .xpack .inference .logging .ThrottlerManager ;
4042import java .util .EnumSet ;
4143import java .util .List ;
4244import java .util .Locale ;
45+ import java .util .concurrent .CountDownLatch ;
4346import java .util .concurrent .ExecutorService ;
4447import java .util .concurrent .TimeUnit ;
4548import java .util .concurrent .atomic .AtomicReference ;
6467import static org .mockito .Mockito .doAnswer ;
6568import static org .mockito .Mockito .doThrow ;
6669import static org .mockito .Mockito .mock ;
70+ import static org .mockito .Mockito .times ;
71+ import static org .mockito .Mockito .verify ;
6772import static org .mockito .Mockito .when ;
6873
6974public class HttpRequestSenderTests extends ESTestCase {
@@ -104,18 +109,28 @@ public void testCreateSender_ReturnsSameRequestExecutorInstance() {
104109 }
105110
106111 public void testCreateSender_CanCallStartMultipleTimes () throws Exception {
107- var senderFactory = new HttpRequestSender .Factory (createWithEmptySettings (threadPool ), clientManager , mockClusterServiceEmpty ());
112+ var mockManager = createMockHttpClientManager ();
113+
114+ var senderFactory = new HttpRequestSender .Factory (createWithEmptySettings (threadPool ), mockManager , mockClusterServiceEmpty ());
108115
109116 try (var sender = createSender (senderFactory )) {
110117 sender .startSynchronously ();
111118 sender .startSynchronously ();
112119 sender .startSynchronously ();
113120 }
121+
122+ verify (mockManager , times (1 )).start ();
114123 }
115124
116- public void testStart_ThrowsException_WhenAnErrorOccurs () throws IOException {
125+ private HttpClientManager createMockHttpClientManager () {
117126 var mockManager = mock (HttpClientManager .class );
118127 when (mockManager .getHttpClient ()).thenReturn (mock (HttpClient .class ));
128+
129+ return mockManager ;
130+ }
131+
132+ public void testStart_ThrowsExceptionWaitingForStartToComplete_WhenAnErrorOccurs () throws IOException {
133+ var mockManager = createMockHttpClientManager ();
119134 doThrow (new Error ("failed" )).when (mockManager ).start ();
120135
121136 var senderFactory = new HttpRequestSender .Factory (
@@ -125,41 +140,66 @@ public void testStart_ThrowsException_WhenAnErrorOccurs() throws IOException {
125140 );
126141
127142 try (var sender = senderFactory .createSender ()) {
128- // Checking for both exception types because there's a race condition between the Error being thrown on a separate thread
129- // and the startCompleted latch timing out waiting for the start to complete
130- var exception = expectThrowsAnyOf (List .of (Error .class , IllegalStateException .class ), sender ::startSynchronously );
131-
132- if (exception instanceof Error ) {
133- assertThat (exception .getMessage (), is ("failed" ));
134- } else {
135- // IllegalStateException can be thrown if the startCompleted latch times out waiting for the start to complete
136- assertThat (exception .getMessage (), is ("Http sender startup did not complete in time" ));
137- }
143+ var exception = expectThrows (Error .class , sender ::startSynchronously );
144+
145+ assertThat (exception .getMessage (), is ("failed" ));
138146 }
139147 }
140148
141- public void testStart_ThrowsExceptionWaitingForStartToComplete () throws IOException {
142- var mockManager = mock (HttpClientManager .class );
143- when (mockManager .getHttpClient ()).thenReturn (mock (HttpClient .class ));
144- // This won't get rethrown because it is not an Error
149+ public void testStart_ThrowsExceptionWaitingForStartToComplete () {
150+ var mockManager = createMockHttpClientManager ();
145151 doThrow (new IllegalArgumentException ("failed" )).when (mockManager ).start ();
146152
147- var senderFactory = new HttpRequestSender .Factory (
148- ServiceComponentsTests .createWithEmptySettings (threadPool ),
153+ // Force the startup to never complete
154+ var latch = new CountDownLatch (1 );
155+ var sender = new HttpRequestSender (
156+ threadPool ,
149157 mockManager ,
150- mockClusterServiceEmpty ()
158+ mock (RequestSender .class ),
159+ mock (RequestExecutor .class ),
160+ latch ,
161+ // Override the wait time so we don't block the test for too long
162+ TimeValue .timeValueMillis (1 )
151163 );
152164
153- try (var sender = senderFactory .createSender ()) {
154- var exception = expectThrows (IllegalStateException .class , sender ::startSynchronously );
165+ var exception = expectThrows (IllegalStateException .class , sender ::startSynchronously );
155166
156- assertThat (exception .getMessage (), is ("Http sender startup did not complete in time" ));
157- }
167+ assertThat (exception .getMessage (), is ("Http sender startup did not complete in time" ));
168+ }
169+
170+ public void testStartAsync_WaitsAsyncForStartToComplete_ThrowsWhenItTimesOut_ThenSucceeds () {
171+ var mockManager = createMockHttpClientManager ();
172+ var latch = new CountDownLatch (1 );
173+ var sender = new HttpRequestSender (
174+ threadPool ,
175+ mockManager ,
176+ mock (RequestSender .class ),
177+ mock (RequestExecutor .class ),
178+ latch ,
179+ // Override the wait time so we don't block the test for too long
180+ TimeValue .timeValueMillis (1 )
181+ );
182+
183+ var listener = new PlainActionFuture <Void >();
184+ sender .startAsynchronously (listener );
185+
186+ var exception = expectThrows (IllegalStateException .class , () -> listener .actionGet (TIMEOUT ));
187+ assertThat (exception .getMessage (), is ("Http sender startup did not complete in time" ));
188+
189+ // simulate the start completing
190+ latch .countDown ();
191+
192+ var listenerCompleted = new PlainActionFuture <Void >();
193+ sender .startAsynchronously (listenerCompleted );
194+ assertNull (listenerCompleted .actionGet (TIMEOUT ));
195+
196+ verify (mockManager , times (1 )).start ();
158197 }
159198
160199 public void testCreateSender_CanCallStartAsyncMultipleTimes () throws Exception {
200+ var mockManager = createMockHttpClientManager ();
161201 var asyncCalls = 3 ;
162- var senderFactory = new HttpRequestSender .Factory (createWithEmptySettings (threadPool ), clientManager , mockClusterServiceEmpty ());
202+ var senderFactory = new HttpRequestSender .Factory (createWithEmptySettings (threadPool ), mockManager , mockClusterServiceEmpty ());
163203
164204 try (var sender = createSender (senderFactory )) {
165205 var listenerList = new ArrayList <PlainActionFuture <Void >>();
@@ -175,11 +215,14 @@ public void testCreateSender_CanCallStartAsyncMultipleTimes() throws Exception {
175215 assertNull (listener .actionGet (TIMEOUT ));
176216 }
177217 }
218+
219+ verify (mockManager , times (1 )).start ();
178220 }
179221
180222 public void testCreateSender_CanCallStartAsyncAndSyncMultipleTimes () throws Exception {
223+ var mockManager = createMockHttpClientManager ();
181224 var asyncCalls = 3 ;
182- var senderFactory = new HttpRequestSender .Factory (createWithEmptySettings (threadPool ), clientManager , mockClusterServiceEmpty ());
225+ var senderFactory = new HttpRequestSender .Factory (createWithEmptySettings (threadPool ), mockManager , mockClusterServiceEmpty ());
183226
184227 try (var sender = createSender (senderFactory )) {
185228 var listenerList = new ArrayList <PlainActionFuture <Void >>();
@@ -196,6 +239,8 @@ public void testCreateSender_CanCallStartAsyncAndSyncMultipleTimes() throws Exce
196239 assertNull (listener .actionGet (TIMEOUT ));
197240 }
198241 }
242+
243+ verify (mockManager , times (1 )).start ();
199244 }
200245
201246 public void testCreateSender_SendsRequestAndReceivesResponse () throws Exception {
@@ -340,8 +385,7 @@ public void testHttpRequestSender_Throws_WhenATimeoutOccurs() throws Exception {
340385 }
341386
342387 public void testHttpRequestSenderWithTimeout_Throws_WhenATimeoutOccurs () throws Exception {
343- var mockManager = mock (HttpClientManager .class );
344- when (mockManager .getHttpClient ()).thenReturn (mock (HttpClient .class ));
388+ var mockManager = createMockHttpClientManager ();
345389
346390 var senderFactory = new HttpRequestSender .Factory (
347391 ServiceComponentsTests .createWithEmptySettings (threadPool ),
@@ -363,8 +407,7 @@ public void testHttpRequestSenderWithTimeout_Throws_WhenATimeoutOccurs() throws
363407 }
364408
365409 public void testSendWithoutQueuingWithTimeout_Throws_WhenATimeoutOccurs () throws Exception {
366- var mockManager = mock (HttpClientManager .class );
367- when (mockManager .getHttpClient ()).thenReturn (mock (HttpClient .class ));
410+ var mockManager = createMockHttpClientManager ();
368411
369412 var senderFactory = new HttpRequestSender .Factory (
370413 ServiceComponentsTests .createWithEmptySettings (threadPool ),
0 commit comments