Skip to content

Commit 12f749c

Browse files
committed
Rework the streaming subscribed check so it can work from a client
1 parent 26b771e commit 12f749c

File tree

4 files changed

+67
-18
lines changed

4 files changed

+67
-18
lines changed

reference-impl/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import java.util.concurrent.atomic.AtomicLong;
99
import java.util.function.Function;
1010

11-
import com.fasterxml.jackson.databind.JsonNode;
1211
import jakarta.enterprise.inject.Instance;
1312
import jakarta.inject.Inject;
1413
import jakarta.inject.Singleton;
@@ -17,8 +16,10 @@
1716
import com.fasterxml.jackson.core.JsonParseException;
1817
import com.fasterxml.jackson.core.JsonProcessingException;
1918
import com.fasterxml.jackson.core.io.JsonEOFException;
19+
import com.fasterxml.jackson.databind.JsonNode;
2020
import io.a2a.server.ExtendedAgentCard;
2121
import io.a2a.server.requesthandlers.JSONRPCHandler;
22+
import io.a2a.server.util.async.Internal;
2223
import io.a2a.spec.AgentCard;
2324
import io.a2a.spec.CancelTaskRequest;
2425
import io.a2a.spec.GetTaskPushNotificationConfigRequest;
@@ -44,7 +45,6 @@
4445
import io.a2a.spec.TaskResubscriptionRequest;
4546
import io.a2a.spec.UnsupportedOperationError;
4647
import io.a2a.util.Utils;
47-
import io.a2a.server.util.async.Internal;
4848
import io.quarkus.vertx.web.Body;
4949
import io.quarkus.vertx.web.ReactiveRoutes;
5050
import io.quarkus.vertx.web.Route;
@@ -69,6 +69,7 @@ public class A2AServerRoutes {
6969
Instance<AgentCard> extendedAgentCard;
7070

7171
// Hook so testing can wait until the MultiSseSupport is subscribed.
72+
// Without this we get intermittent failures
7273
private static volatile Runnable streamingMultiSseSupportSubscribedRunnable;
7374

7475
@Inject

reference-impl/src/test/java/io/a2a/server/apps/quarkus/A2ATestRoutes.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;
55
import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN;
66

7+
import java.util.concurrent.atomic.AtomicInteger;
8+
9+
import jakarta.annotation.PostConstruct;
710
import jakarta.inject.Inject;
811
import jakarta.inject.Singleton;
912

@@ -25,6 +28,17 @@ public class A2ATestRoutes {
2528
@Inject
2629
TestUtilsBean testUtilsBean;
2730

31+
@Inject
32+
A2AServerRoutes a2AServerRoutes;
33+
34+
AtomicInteger streamingSubscribedCount = new AtomicInteger(0);
35+
36+
@PostConstruct
37+
public void init() {
38+
A2AServerRoutes.setStreamingMultiSseSupportSubscribedRunnable(() -> streamingSubscribedCount.incrementAndGet());
39+
}
40+
41+
2842
@Route(path = "/test/task", methods = {Route.HttpMethod.POST}, consumes = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING)
2943
public void saveTask(@Body String body, RoutingContext rc) {
3044
try {
@@ -33,7 +47,7 @@ public void saveTask(@Body String body, RoutingContext rc) {
3347
rc.response()
3448
.setStatusCode(200)
3549
.end();
36-
} catch (Throwable t) {
50+
} catch (Throwable t) {
3751
errorResponse(t, rc);
3852
}
3953
}
@@ -118,6 +132,13 @@ public void enqueueTaskArtifactUpdateEvent(@Param String taskId, @Body String bo
118132
}
119133
}
120134

135+
@Route(path = "test/streamingSubscribedCount", methods = {Route.HttpMethod.GET}, produces = {TEXT_PLAIN})
136+
public void getStreamingSubscribedCount(RoutingContext rc) {
137+
rc.response()
138+
.setStatusCode(200)
139+
.end(String.valueOf(streamingSubscribedCount.get()));
140+
}
141+
121142
private void errorResponse(Throwable t, RoutingContext rc) {
122143
rc.response()
123144
.setStatusCode(200)
Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
package io.a2a.server.apps.quarkus;
22

3-
import jakarta.inject.Inject;
4-
53
import io.a2a.server.apps.common.AbstractA2AServerTest;
6-
import io.a2a.server.events.InMemoryQueueManager;
7-
import io.a2a.server.tasks.TaskStore;
84
import io.quarkus.test.junit.QuarkusTest;
95

106
@QuarkusTest
@@ -13,9 +9,4 @@ public class QuarkusA2AServerTest extends AbstractA2AServerTest {
139
public QuarkusA2AServerTest() {
1410
super(8081);
1511
}
16-
17-
@Override
18-
protected void setStreamingSubscribedRunnable(Runnable runnable) {
19-
A2AServerRoutes.setStreamingMultiSseSupportSubscribedRunnable(runnable);
20-
}
2112
}

tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import static org.wildfly.common.Assert.assertTrue;
1010

1111
import java.io.EOFException;
12+
import java.io.IOException;
1213
import java.net.URI;
1314
import java.net.http.HttpClient;
1415
import java.net.http.HttpRequest;
@@ -20,13 +21,13 @@
2021
import java.util.concurrent.ExecutorService;
2122
import java.util.concurrent.Executors;
2223
import java.util.concurrent.TimeUnit;
24+
import java.util.concurrent.atomic.AtomicInteger;
2325
import java.util.concurrent.atomic.AtomicReference;
2426
import java.util.stream.Stream;
2527

2628
import jakarta.ws.rs.core.MediaType;
2729

2830
import com.fasterxml.jackson.core.JsonProcessingException;
29-
import io.a2a.server.events.InMemoryQueueManager;
3031
import io.a2a.spec.AgentCard;
3132
import io.a2a.spec.Artifact;
3233
import io.a2a.spec.CancelTaskRequest;
@@ -641,7 +642,8 @@ public void testResubscribeExistingTaskSuccess() throws Exception {
641642
TaskResubscriptionRequest taskResubscriptionRequest = new TaskResubscriptionRequest("1", new TaskIdParams(MINIMAL_TASK.getId()));
642643

643644
// Count down the latch when the MultiSseSupport on the server has started subscribing
644-
setStreamingSubscribedRunnable(taskResubscriptionRequestSent::countDown);
645+
awaitStreamingSubscription()
646+
.whenComplete((unused, throwable) -> taskResubscriptionRequestSent.countDown());
645647

646648
CompletableFuture<HttpResponse<Stream<String>>> responseFuture = initialiseStreamingRequest(taskResubscriptionRequest, null);
647649

@@ -650,7 +652,6 @@ public void testResubscribeExistingTaskSuccess() throws Exception {
650652
responseFuture.thenAccept(response -> {
651653

652654
if (response.statusCode() != 200) {
653-
//errorRef.set(new IllegalStateException("Status code was " + response.statusCode()));
654655
throw new IllegalStateException("Status code was " + response.statusCode());
655656
}
656657
try {
@@ -729,7 +730,7 @@ public void testResubscribeExistingTaskSuccess() throws Exception {
729730
assertEquals(TaskState.COMPLETED, taskStatusUpdateEvent.getStatus().state());
730731
assertNotNull(taskStatusUpdateEvent.getStatus().timestamp());
731732
} finally {
732-
setStreamingSubscribedRunnable(null);
733+
//setStreamingSubscribedRunnable(null);
733734
deleteTaskInTaskStore(MINIMAL_TASK.getId());
734735
executorService.shutdown();
735736
if (!executorService.awaitTermination(10, TimeUnit.SECONDS)) {
@@ -956,7 +957,7 @@ protected void ensureQueueForTask(String taskId) throws Exception {
956957
}
957958

958959
protected void enqueueEventOnServer(Event event) throws Exception {
959-
String path = null;
960+
String path;
960961
if (event instanceof TaskArtifactUpdateEvent e) {
961962
path = "test/queue/enqueueTaskArtifactUpdateEvent/" + e.getTaskId();
962963
} else if (event instanceof TaskStatusUpdateEvent e) {
@@ -979,7 +980,42 @@ protected void enqueueEventOnServer(Event event) throws Exception {
979980
}
980981
}
981982

982-
protected abstract void setStreamingSubscribedRunnable(Runnable runnable);
983+
private CompletableFuture<Void> awaitStreamingSubscription() {
984+
int cnt = getStreamingSubscribedCount();
985+
AtomicInteger initialCount = new AtomicInteger(cnt);
986+
987+
return CompletableFuture.runAsync(() -> {
988+
try {
989+
while (true) {
990+
int count = getStreamingSubscribedCount();
991+
if (count > initialCount.get()) {
992+
break;
993+
}
994+
Thread.sleep(500);
995+
}
996+
} catch (InterruptedException e) {
997+
Thread.currentThread().interrupt();
998+
}
999+
});
1000+
}
1001+
1002+
private int getStreamingSubscribedCount() {
1003+
HttpClient client = HttpClient.newBuilder()
1004+
.version(HttpClient.Version.HTTP_2)
1005+
.build();
1006+
HttpRequest request = HttpRequest.newBuilder()
1007+
.uri(URI.create("http://localhost:" + serverPort + "/test/streamingSubscribedCount"))
1008+
.GET()
1009+
.build();
1010+
try {
1011+
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8));
1012+
String body = response.body().trim();
1013+
System.out.println(body);
1014+
return Integer.valueOf(body);
1015+
} catch (IOException | InterruptedException e) {
1016+
throw new RuntimeException(e);
1017+
}
1018+
}
9831019

9841020
private static class BreakException extends RuntimeException {
9851021

0 commit comments

Comments
 (0)