1111import org .elasticsearch .action .ActionResponse ;
1212import org .elasticsearch .action .LegacyActionRequest ;
1313import org .elasticsearch .action .support .ActionTestUtils ;
14+ import org .elasticsearch .action .support .PlainActionFuture ;
1415import org .elasticsearch .cluster .service .ClusterService ;
1516import org .elasticsearch .common .io .stream .StreamInput ;
1617import org .elasticsearch .common .io .stream .StreamOutput ;
4041
4142import static org .elasticsearch .xpack .esql .core .async .AsyncTaskManagementService .addCompletionListener ;
4243import static org .hamcrest .Matchers .equalTo ;
44+ import static org .hamcrest .Matchers .greaterThanOrEqualTo ;
4345import static org .hamcrest .Matchers .notNullValue ;
4446import static org .hamcrest .Matchers .nullValue ;
4547
@@ -52,9 +54,11 @@ public class AsyncTaskManagementServiceTests extends ESSingleNodeTestCase {
5254
5355 public static class TestRequest extends LegacyActionRequest {
5456 private final String string ;
57+ private final TimeValue keepAlive ;
5558
56- public TestRequest (String string ) {
59+ public TestRequest (String string , TimeValue keepAlive ) {
5760 this .string = string ;
61+ this .keepAlive = keepAlive ;
5862 }
5963
6064 @ Override
@@ -129,7 +133,7 @@ public TestTask createTask(
129133 headers ,
130134 originHeaders ,
131135 asyncExecutionId ,
132- TimeValue . timeValueDays ( 5 )
136+ request . keepAlive
133137 );
134138 }
135139
@@ -172,7 +176,7 @@ public void setup() {
172176 );
173177 results = new AsyncResultsService <>(
174178 store ,
175- true ,
179+ false ,
176180 TestTask .class ,
177181 (task , listener , timeout ) -> addCompletionListener (transportService .getThreadPool (), task , listener , timeout ),
178182 transportService .getTaskManager (),
@@ -212,23 +216,17 @@ public void testReturnBeforeTimeout() throws Exception {
212216 boolean success = randomBoolean ();
213217 boolean keepOnCompletion = randomBoolean ();
214218 CountDownLatch latch = new CountDownLatch (1 );
215- TestRequest request = new TestRequest (success ? randomAlphaOfLength (10 ) : "die" );
216- service .asyncExecute (
217- request ,
218- TimeValue .timeValueMinutes (1 ),
219- TimeValue .timeValueMinutes (10 ),
220- keepOnCompletion ,
221- ActionListener .wrap (r -> {
222- assertThat (success , equalTo (true ));
223- assertThat (r .string , equalTo ("response for [" + request .string + "]" ));
224- assertThat (r .id , notNullValue ());
225- latch .countDown ();
226- }, e -> {
227- assertThat (success , equalTo (false ));
228- assertThat (e .getMessage (), equalTo ("test exception" ));
229- latch .countDown ();
230- })
231- );
219+ TestRequest request = new TestRequest (success ? randomAlphaOfLength (10 ) : "die" , TimeValue .timeValueDays (1 ));
220+ service .asyncExecute (request , TimeValue .timeValueMinutes (1 ), keepOnCompletion , ActionListener .wrap (r -> {
221+ assertThat (success , equalTo (true ));
222+ assertThat (r .string , equalTo ("response for [" + request .string + "]" ));
223+ assertThat (r .id , notNullValue ());
224+ latch .countDown ();
225+ }, e -> {
226+ assertThat (success , equalTo (false ));
227+ assertThat (e .getMessage (), equalTo ("test exception" ));
228+ latch .countDown ();
229+ }));
232230 assertThat (latch .await (10 , TimeUnit .SECONDS ), equalTo (true ));
233231 }
234232
@@ -252,20 +250,14 @@ public void execute(TestRequest request, TestTask task, ActionListener<TestRespo
252250 boolean timeoutOnFirstAttempt = randomBoolean ();
253251 boolean waitForCompletion = randomBoolean ();
254252 CountDownLatch latch = new CountDownLatch (1 );
255- TestRequest request = new TestRequest (success ? randomAlphaOfLength (10 ) : "die" );
253+ TestRequest request = new TestRequest (success ? randomAlphaOfLength (10 ) : "die" , TimeValue . timeValueDays ( 1 ) );
256254 AtomicReference <TestResponse > responseHolder = new AtomicReference <>();
257- service .asyncExecute (
258- request ,
259- TimeValue .timeValueMillis (1 ),
260- TimeValue .timeValueMinutes (10 ),
261- keepOnCompletion ,
262- ActionTestUtils .assertNoFailureListener (r -> {
263- assertThat (r .string , nullValue ());
264- assertThat (r .id , notNullValue ());
265- assertThat (responseHolder .getAndSet (r ), nullValue ());
266- latch .countDown ();
267- })
268- );
255+ service .asyncExecute (request , TimeValue .timeValueMillis (1 ), keepOnCompletion , ActionTestUtils .assertNoFailureListener (r -> {
256+ assertThat (r .string , nullValue ());
257+ assertThat (r .id , notNullValue ());
258+ assertThat (responseHolder .getAndSet (r ), nullValue ());
259+ latch .countDown ();
260+ }));
269261 assertThat (latch .await (20 , TimeUnit .SECONDS ), equalTo (true ));
270262
271263 if (timeoutOnFirstAttempt ) {
@@ -281,17 +273,11 @@ public void execute(TestRequest request, TestTask task, ActionListener<TestRespo
281273 if (waitForCompletion ) {
282274 // now we are waiting for the task to finish
283275 logger .trace ("Waiting for response to complete" );
284- AtomicReference <StoredAsyncResponse <TestResponse >> responseRef = new AtomicReference <>();
285- CountDownLatch getResponseCountDown = getResponse (
286- responseHolder .get ().id ,
287- TimeValue .timeValueSeconds (5 ),
288- ActionTestUtils .assertNoFailureListener (responseRef ::set )
289- );
276+ var getFuture = getResponse (responseHolder .get ().id , TimeValue .timeValueSeconds (5 ), TimeValue .MINUS_ONE );
290277
291278 executionLatch .countDown ();
292- assertThat ( getResponseCountDown . await ( 10 , TimeUnit . SECONDS ), equalTo ( true ) );
279+ var response = safeGet ( getFuture );
293280
294- StoredAsyncResponse <TestResponse > response = responseRef .get ();
295281 if (success ) {
296282 assertThat (response .getException (), nullValue ());
297283 assertThat (response .getResponse (), notNullValue ());
@@ -326,26 +312,46 @@ public void execute(TestRequest request, TestTask task, ActionListener<TestRespo
326312 }
327313 }
328314
315+ public void testUpdateKeepAliveToTask () throws Exception {
316+ long now = System .currentTimeMillis ();
317+ CountDownLatch executionLatch = new CountDownLatch (1 );
318+ AsyncTaskManagementService <TestRequest , TestResponse , TestTask > service = createManagementService (new TestOperation () {
319+ @ Override
320+ public void execute (TestRequest request , TestTask task , ActionListener <TestResponse > listener ) {
321+ executorService .submit (() -> {
322+ try {
323+ assertThat (executionLatch .await (10 , TimeUnit .SECONDS ), equalTo (true ));
324+ } catch (InterruptedException ex ) {
325+ throw new AssertionError (ex );
326+ }
327+ super .execute (request , task , listener );
328+ });
329+ }
330+ });
331+ TestRequest request = new TestRequest (randomAlphaOfLength (10 ), TimeValue .timeValueHours (1 ));
332+ PlainActionFuture <TestResponse > submitResp = new PlainActionFuture <>();
333+ try {
334+ service .asyncExecute (request , TimeValue .timeValueMillis (1 ), true , submitResp );
335+ String id = submitResp .get ().id ;
336+ assertThat (id , notNullValue ());
337+ TimeValue keepAlive = TimeValue .timeValueDays (between (1 , 10 ));
338+ var resp1 = safeGet (getResponse (id , TimeValue .ZERO , keepAlive ));
339+ assertThat (resp1 .getExpirationTime (), greaterThanOrEqualTo (now + keepAlive .millis ()));
340+ } finally {
341+ executionLatch .countDown ();
342+ }
343+ }
344+
329345 private StoredAsyncResponse <TestResponse > getResponse (String id , TimeValue timeout ) throws InterruptedException {
330- AtomicReference <StoredAsyncResponse <TestResponse >> response = new AtomicReference <>();
331- assertThat (
332- getResponse (id , timeout , ActionTestUtils .assertNoFailureListener (response ::set )).await (10 , TimeUnit .SECONDS ),
333- equalTo (true )
334- );
335- return response .get ();
346+ return safeGet (getResponse (id , timeout , TimeValue .MINUS_ONE ));
336347 }
337348
338- private CountDownLatch getResponse (String id , TimeValue timeout , ActionListener < StoredAsyncResponse < TestResponse >> listener ) {
339- CountDownLatch responseLatch = new CountDownLatch ( 1 );
349+ private PlainActionFuture < StoredAsyncResponse < TestResponse >> getResponse (String id , TimeValue timeout , TimeValue keepAlive ) {
350+ PlainActionFuture < StoredAsyncResponse < TestResponse >> future = new PlainActionFuture <>( );
340351 GetAsyncResultRequest getResultsRequest = new GetAsyncResultRequest (id ).setWaitForCompletionTimeout (timeout );
341- results .retrieveResult (getResultsRequest , ActionListener .wrap (r -> {
342- listener .onResponse (r );
343- responseLatch .countDown ();
344- }, e -> {
345- listener .onFailure (e );
346- responseLatch .countDown ();
347- }));
348- return responseLatch ;
352+ getResultsRequest .setKeepAlive (keepAlive );
353+ results .retrieveResult (getResultsRequest , future );
354+ return future ;
349355 }
350356
351357}
0 commit comments