99import static org .wildfly .common .Assert .assertTrue ;
1010
1111import java .io .EOFException ;
12+ import java .io .IOException ;
1213import java .net .URI ;
1314import java .net .http .HttpClient ;
1415import java .net .http .HttpRequest ;
2021import java .util .concurrent .ExecutorService ;
2122import java .util .concurrent .Executors ;
2223import java .util .concurrent .TimeUnit ;
24+ import java .util .concurrent .atomic .AtomicInteger ;
2325import java .util .concurrent .atomic .AtomicReference ;
2426import java .util .stream .Stream ;
2527
2628import jakarta .ws .rs .core .MediaType ;
2729
2830import com .fasterxml .jackson .core .JsonProcessingException ;
29- import io .a2a .server .events .InMemoryQueueManager ;
3031import io .a2a .spec .AgentCard ;
3132import io .a2a .spec .Artifact ;
3233import io .a2a .spec .CancelTaskRequest ;
@@ -641,7 +642,8 @@ public void testResubscribeExistingTaskSuccess() throws Exception {
641642 TaskResubscriptionRequest taskResubscriptionRequest = new TaskResubscriptionRequest ("1" , new TaskIdParams (MINIMAL_TASK .getId ()));
642643
643644 // Count down the latch when the MultiSseSupport on the server has started subscribing
644- setStreamingSubscribedRunnable (taskResubscriptionRequestSent ::countDown );
645+ awaitStreamingSubscription ()
646+ .whenComplete ((unused , throwable ) -> taskResubscriptionRequestSent .countDown ());
645647
646648 CompletableFuture <HttpResponse <Stream <String >>> responseFuture = initialiseStreamingRequest (taskResubscriptionRequest , null );
647649
@@ -650,7 +652,6 @@ public void testResubscribeExistingTaskSuccess() throws Exception {
650652 responseFuture .thenAccept (response -> {
651653
652654 if (response .statusCode () != 200 ) {
653- //errorRef.set(new IllegalStateException("Status code was " + response.statusCode()));
654655 throw new IllegalStateException ("Status code was " + response .statusCode ());
655656 }
656657 try {
@@ -729,7 +730,7 @@ public void testResubscribeExistingTaskSuccess() throws Exception {
729730 assertEquals (TaskState .COMPLETED , taskStatusUpdateEvent .getStatus ().state ());
730731 assertNotNull (taskStatusUpdateEvent .getStatus ().timestamp ());
731732 } finally {
732- setStreamingSubscribedRunnable (null );
733+ // setStreamingSubscribedRunnable(null);
733734 deleteTaskInTaskStore (MINIMAL_TASK .getId ());
734735 executorService .shutdown ();
735736 if (!executorService .awaitTermination (10 , TimeUnit .SECONDS )) {
@@ -956,7 +957,7 @@ protected void ensureQueueForTask(String taskId) throws Exception {
956957 }
957958
958959 protected void enqueueEventOnServer (Event event ) throws Exception {
959- String path = null ;
960+ String path ;
960961 if (event instanceof TaskArtifactUpdateEvent e ) {
961962 path = "test/queue/enqueueTaskArtifactUpdateEvent/" + e .getTaskId ();
962963 } else if (event instanceof TaskStatusUpdateEvent e ) {
@@ -979,7 +980,42 @@ protected void enqueueEventOnServer(Event event) throws Exception {
979980 }
980981 }
981982
982- protected abstract void setStreamingSubscribedRunnable (Runnable runnable );
983+ private CompletableFuture <Void > awaitStreamingSubscription () {
984+ int cnt = getStreamingSubscribedCount ();
985+ AtomicInteger initialCount = new AtomicInteger (cnt );
986+
987+ return CompletableFuture .runAsync (() -> {
988+ try {
989+ while (true ) {
990+ int count = getStreamingSubscribedCount ();
991+ if (count > initialCount .get ()) {
992+ break ;
993+ }
994+ Thread .sleep (500 );
995+ }
996+ } catch (InterruptedException e ) {
997+ Thread .currentThread ().interrupt ();
998+ }
999+ });
1000+ }
1001+
1002+ private int getStreamingSubscribedCount () {
1003+ HttpClient client = HttpClient .newBuilder ()
1004+ .version (HttpClient .Version .HTTP_2 )
1005+ .build ();
1006+ HttpRequest request = HttpRequest .newBuilder ()
1007+ .uri (URI .create ("http://localhost:" + serverPort + "/test/streamingSubscribedCount" ))
1008+ .GET ()
1009+ .build ();
1010+ try {
1011+ HttpResponse <String > response = client .send (request , HttpResponse .BodyHandlers .ofString (StandardCharsets .UTF_8 ));
1012+ String body = response .body ().trim ();
1013+ System .out .println (body );
1014+ return Integer .valueOf (body );
1015+ } catch (IOException | InterruptedException e ) {
1016+ throw new RuntimeException (e );
1017+ }
1018+ }
9831019
9841020 private static class BreakException extends RuntimeException {
9851021
0 commit comments