Skip to content

Commit 2e944f9

Browse files
authored
Merge pull request #114 from kabir/async-improvements
Async improvements
2 parents 954d41d + df163af commit 2e944f9

File tree

8 files changed

+390
-230
lines changed

8 files changed

+390
-230
lines changed

core/src/main/java/io/a2a/server/events/EventConsumer.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
public class EventConsumer {
1717
private final EventQueue queue;
1818
private Throwable error;
19-
private final Executor executor = Executors.newCachedThreadPool();
20-
2119

2220
private static final String ERROR_MSG = "Agent did not return any response";
2321
private static final int NO_WAIT = -1;

core/src/main/java/io/a2a/server/events/EventQueue.java

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import java.util.concurrent.CopyOnWriteArrayList;
77
import java.util.concurrent.CountDownLatch;
88
import java.util.concurrent.TimeUnit;
9+
import java.util.concurrent.atomic.AtomicBoolean;
910

1011
import io.a2a.util.TempLoggerWrapper;
1112
import org.slf4j.Logger;
@@ -35,7 +36,9 @@ public static EventQueue create() {
3536
return new MainQueue();
3637
}
3738

38-
abstract CountDownLatch getPollingStartedLatch();
39+
public abstract void awaitQueuePollerStart() throws InterruptedException ;
40+
41+
abstract void signalQueuePollerStarted();
3942

4043
public void enqueueEvent(Event event) {
4144
if (closed) {
@@ -71,21 +74,22 @@ public Event dequeueEvent(int waitMilliSeconds) throws EventQueueClosedException
7174
}
7275
return event;
7376
} catch (InterruptedException e) {
74-
log.debug("Interrupted {}", this);
77+
log.debug("Interrupted dequeue (waiting) {}", this);
7578
Thread.currentThread().interrupt();
7679
return null;
7780
}
7881
} finally {
79-
log.debug("Signalling that queue polling started {}", this);
80-
getPollingStartedLatch().countDown();
82+
signalQueuePollerStarted();
8183
}
8284
}
8385

8486
public void taskDone() {
8587
// TODO Not sure if needed yet. BlockingQueue.poll()/.take() remove the events.
8688
}
8789

88-
public void close() {
90+
public abstract void close();
91+
92+
public void doClose() {
8993
synchronized (this) {
9094
if (closed) {
9195
return;
@@ -103,7 +107,8 @@ public void close() {
103107

104108
static class MainQueue extends EventQueue {
105109
private final List<ChildQueue> children = new CopyOnWriteArrayList<>();
106-
private CountDownLatch pollingStartedLatch = new CountDownLatch(1);
110+
private final CountDownLatch pollingStartedLatch = new CountDownLatch(1);
111+
private final AtomicBoolean pollingStarted = new AtomicBoolean(false);
107112

108113
EventQueue tap() {
109114
ChildQueue child = new ChildQueue(this);
@@ -116,16 +121,27 @@ public void enqueueEvent(Event event) {
116121
children.forEach(eq -> eq.internalEnqueueEvent(event));
117122
}
118123

119-
CountDownLatch getPollingStartedLatch() {
120-
return pollingStartedLatch;
124+
@Override
125+
public void awaitQueuePollerStart() throws InterruptedException {
126+
log.debug("Waiting for queue poller to start on {}", this);
127+
pollingStartedLatch.await(10, TimeUnit.SECONDS);
128+
log.debug("Queue poller started on {}", this);
121129
}
122130

123-
131+
@Override
132+
void signalQueuePollerStarted() {
133+
if (pollingStarted.get()) {
134+
return;
135+
}
136+
log.debug("Signalling that queue polling started {}", this);
137+
pollingStartedLatch.countDown();
138+
pollingStarted.set(true);
139+
}
124140

125141
@Override
126142
public void close() {
127-
super.close();
128-
children.forEach(EventQueue::close);
143+
doClose();
144+
children.forEach(EventQueue::doClose);
129145
}
130146
}
131147

@@ -151,8 +167,18 @@ EventQueue tap() {
151167
}
152168

153169
@Override
154-
CountDownLatch getPollingStartedLatch() {
155-
return parent.getPollingStartedLatch();
170+
public void awaitQueuePollerStart() throws InterruptedException {
171+
parent.awaitQueuePollerStart();
172+
}
173+
174+
@Override
175+
void signalQueuePollerStarted() {
176+
parent.signalQueuePollerStarted();
177+
}
178+
179+
@Override
180+
public void close() {
181+
parent.close();
156182
}
157183
}
158184
}

core/src/main/java/io/a2a/server/events/InMemoryQueueManager.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77

88
import jakarta.enterprise.context.ApplicationScoped;
99

10+
import io.a2a.util.TempLoggerWrapper;
11+
import org.slf4j.Logger;
12+
import org.slf4j.LoggerFactory;
13+
1014
@ApplicationScoped
1115
public class InMemoryQueueManager implements QueueManager {
12-
13-
1416
private final Map<String, EventQueue> queues = Collections.synchronizedMap(new HashMap<>());
1517

1618
@Override
@@ -63,7 +65,7 @@ public EventQueue createOrTap(String taskId) {
6365
}
6466

6567
@Override
66-
public void signalPollingStarted(EventQueue eventQueue) throws InterruptedException {
67-
eventQueue.getPollingStartedLatch().await(10, TimeUnit.SECONDS);
68+
public void awaitQueuePollerStart(EventQueue eventQueue) throws InterruptedException {
69+
eventQueue.awaitQueuePollerStart();
6870
}
6971
}

core/src/main/java/io/a2a/server/events/QueueManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ public interface QueueManager {
1111

1212
EventQueue createOrTap(String taskId);
1313

14-
void signalPollingStarted(EventQueue eventQueue) throws InterruptedException;
14+
void awaitQueuePollerStart(EventQueue eventQueue) throws InterruptedException;
1515
}

core/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ private EnhancedRunnable registerAndExecuteAgentAsync(String taskId, RequestCont
355355
public void run() {
356356
agentExecutor.execute(requestContext, queue);
357357
try {
358-
queueManager.signalPollingStarted(queue);
358+
queueManager.awaitQueuePollerStart(queue);
359359
} catch (InterruptedException e) {
360360
Thread.currentThread().interrupt();
361361
}

core/src/main/java/io/a2a/server/requesthandlers/JSONRPCHandler.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package io.a2a.server.requesthandlers;
22

3-
import static io.a2a.util.AsyncUtils.convertingProcessor;
43
import static io.a2a.util.AsyncUtils.createTubeConfig;
54

65
import java.util.concurrent.Flow;

quarkus-sdk/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import static io.vertx.core.http.HttpHeaders.CONTENT_TYPE;
44
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;
55

6+
import java.util.concurrent.Executor;
7+
import java.util.concurrent.Executors;
68
import java.util.concurrent.Flow;
79
import java.util.concurrent.atomic.AtomicLong;
810
import java.util.function.Function;
@@ -66,6 +68,11 @@ public class A2AServerRoutes {
6668
@ExtendedAgentCard
6769
Instance<AgentCard> extendedAgentCard;
6870

71+
// Hook so testing can wait until the MultiSseSupport is subscribes.
72+
private static volatile Runnable streamingMultiSseSupportSubscribedRunnable;
73+
74+
private final Executor executor = Executors.newCachedThreadPool();
75+
6976
@Route(path = "/", methods = {Route.HttpMethod.POST}, consumes = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING)
7077
public void invokeJSONRPCHandler(@Body String body, RoutingContext rc) {
7178
boolean streaming = false;
@@ -93,7 +100,12 @@ public void invokeJSONRPCHandler(@Body String body, RoutingContext rc) {
93100
.putHeader(CONTENT_TYPE, APPLICATION_JSON)
94101
.end(Json.encodeToBuffer(error));
95102
} else if (streaming) {
96-
MultiSseSupport.subscribeObject(streamingResponse.map(i -> (Object)i), rc);
103+
final Multi<? extends JSONRPCResponse<?>> finalStreamingResponse = streamingResponse;
104+
executor.execute(() -> {
105+
MultiSseSupport.subscribeObject(
106+
finalStreamingResponse.map(i -> (Object)i), rc);
107+
});
108+
97109
} else {
98110
rc.response()
99111
.setStatusCode(200)
@@ -212,6 +224,10 @@ private static boolean isNonStreamingRequest(String requestBody) {
212224
requestBody.contains(A2A.GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD);
213225
}
214226

227+
static void setStreamingMultiSseSupportSubscribedRunnable(Runnable runnable) {
228+
streamingMultiSseSupportSubscribedRunnable = runnable;
229+
}
230+
215231
// Port of import io.quarkus.vertx.web.runtime.MultiSseSupport, which is considered internal API
216232
private static class MultiSseSupport {
217233

@@ -246,6 +262,12 @@ public static void write(Multi<Buffer> multi, RoutingContext rc) {
246262
public void onSubscribe(Flow.Subscription subscription) {
247263
this.upstream = subscription;
248264
this.upstream.request(1);
265+
266+
// Notify tests that we are subscribed
267+
Runnable runnable = streamingMultiSseSupportSubscribedRunnable;
268+
if (runnable != null) {
269+
runnable.run();
270+
}
249271
}
250272

251273
@Override

0 commit comments

Comments
 (0)