Skip to content

Commit 6eaacdc

Browse files
committed
Make sure the ZeroPublisher subscription happens on a separate thread and test
1 parent 069adb9 commit 6eaacdc

File tree

2 files changed

+129
-36
lines changed

2 files changed

+129
-36
lines changed

transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import jakarta.inject.Inject;
1010

1111
import java.util.List;
12+
import java.util.concurrent.CompletableFuture;
1213
import java.util.concurrent.Flow;
1314
import java.util.regex.Matcher;
1415
import java.util.regex.Pattern;
@@ -143,7 +144,7 @@ private HTTPRestResponse handlePostRequest(String path, String body, ServerCallC
143144
public HTTPRestResponse sendMessage(String body, ServerCallContext context) {
144145
io.a2a.grpc.SendMessageRequest.Builder request = io.a2a.grpc.SendMessageRequest.newBuilder();
145146
parseRequestBody(body, request);
146-
EventKind result = requestHandler.onMessageSend(ProtoUtils.FromProto.messageSendParams(request), context);
147+
EventKind result = requestHandler.onMessageSend(ProtoUtils.FromProto.messageSendParams(request.build()), context);
147148
return createSuccessResponse(200, io.a2a.grpc.SendMessageResponse.newBuilder(ProtoUtils.ToProto.taskOrMessage(result)));
148149
}
149150

@@ -154,7 +155,7 @@ public HTTPRestStreamingResponse sendStreamingMessage(String body, ServerCallCon
154155
try {
155156
io.a2a.grpc.SendMessageRequest.Builder request = io.a2a.grpc.SendMessageRequest.newBuilder();
156157
parseRequestBody(body, request);
157-
Flow.Publisher<StreamingEventKind> publisher = requestHandler.onMessageSendStream(ProtoUtils.FromProto.messageSendParams(request), context);
158+
Flow.Publisher<StreamingEventKind> publisher = requestHandler.onMessageSendStream(ProtoUtils.FromProto.messageSendParams(request.build()), context);
158159
return createStreamingResponse(publisher);
159160
} catch (JSONRPCError e) {
160161
return new HTTPRestStreamingResponse(ZeroPublisher.fromItems(new HTTPRestErrorResponse(e).toJson()));
@@ -221,7 +222,7 @@ public HTTPRestResponse listTaskPushNotificationConfigurations(String taskId, Se
221222
}
222223
ListTaskPushNotificationConfigParams params = new ListTaskPushNotificationConfigParams(taskId);
223224
List<TaskPushNotificationConfig> configs = requestHandler.onListTaskPushNotificationConfig(params, context);
224-
return createSuccessResponse(200, io.a2a.grpc.ListTaskPushNotificationConfigResponse.newBuilder(ProtoUtils.ToProto.listTaskPushNotificationConfigResponse(configs)));
225+
return createSuccessResponse(200, ProtoUtils.ToProto.listTaskPushNotificationConfigResponse(configs).toBuilder());
225226
}
226227

227228
public HTTPRestResponse deleteTaskPushNotificationConfiguration(String taskId, String configId, ServerCallContext context) {
@@ -289,43 +290,45 @@ private Flow.Publisher<String> convertToSendStreamingMessageResponse(
289290
// We can't use the normal convertingProcessor since that propagates any errors as an error handled
290291
// via Subscriber.onError() rather than as part of the SendStreamingResponse payload
291292
return ZeroPublisher.create(createTubeConfig(), tube -> {
292-
publisher.subscribe(new Flow.Subscriber<StreamingEventKind>() {
293-
Flow.Subscription subscription;
294-
295-
@Override
296-
public void onSubscribe(Flow.Subscription subscription) {
297-
this.subscription = subscription;
298-
subscription.request(1);
299-
}
300-
301-
@Override
302-
public void onNext(StreamingEventKind item) {
303-
try {
304-
String payload = JsonFormat.printer().omittingInsignificantWhitespace().print(ProtoUtils.ToProto.taskOrMessageStream(item));
305-
System.out.println("############## Sending event " + payload);
306-
tube.send(payload);
293+
CompletableFuture.runAsync(() -> {
294+
publisher.subscribe(new Flow.Subscriber<StreamingEventKind>() {
295+
Flow.Subscription subscription;
296+
297+
@Override
298+
public void onSubscribe(Flow.Subscription subscription) {
299+
this.subscription = subscription;
307300
subscription.request(1);
308-
} catch (InvalidProtocolBufferException ex) {
309-
onError(ex);
310301
}
311-
}
312-
313-
@Override
314-
public void onError(Throwable throwable) {
315-
System.out.println("############## Sending error " + throwable);
316-
throwable.printStackTrace();
317-
if (throwable instanceof JSONRPCError jsonrpcError) {
318-
tube.send(new HTTPRestErrorResponse(jsonrpcError).toJson());
319-
} else {
320-
tube.send(new HTTPRestErrorResponse(new InternalError(throwable.getMessage())).toJson());
302+
303+
@Override
304+
public void onNext(StreamingEventKind item) {
305+
try {
306+
String payload = JsonFormat.printer().omittingInsignificantWhitespace().print(ProtoUtils.ToProto.taskOrMessageStream(item));
307+
System.out.println("############## Sending event " + payload);
308+
tube.send(payload);
309+
subscription.request(1);
310+
} catch (InvalidProtocolBufferException ex) {
311+
onError(ex);
312+
}
321313
}
322-
onComplete();
323-
}
324314

325-
@Override
326-
public void onComplete() {
327-
tube.complete();
328-
}
315+
@Override
316+
public void onError(Throwable throwable) {
317+
System.out.println("############## Sending error " + throwable);
318+
throwable.printStackTrace();
319+
if (throwable instanceof JSONRPCError jsonrpcError) {
320+
tube.send(new HTTPRestErrorResponse(jsonrpcError).toJson());
321+
} else {
322+
tube.send(new HTTPRestErrorResponse(new InternalError(throwable.getMessage())).toJson());
323+
}
324+
onComplete();
325+
}
326+
327+
@Override
328+
public void onComplete() {
329+
tube.complete();
330+
}
331+
});
329332
});
330333
});
331334
}

transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
import com.google.protobuf.InvalidProtocolBufferException;
44
import java.util.Map;
5+
import java.util.concurrent.CountDownLatch;
6+
import java.util.concurrent.Flow;
7+
import java.util.concurrent.TimeUnit;
8+
import java.util.concurrent.atomic.AtomicBoolean;
59

610
import io.a2a.server.ServerCallContext;
711
import io.a2a.server.auth.UnauthenticatedUser;
@@ -315,4 +319,90 @@ public void testHttpStatusCodeMapping() {
315319
response = handler.handleRequest("PATCH", "/v1/card", null, callContext);
316320
Assertions.assertEquals(405, response.getStatusCode());
317321
}
322+
323+
@Test
324+
public void testStreamingDoesNotBlockMainThread() throws Exception {
325+
RestHandler handler = new RestHandler(CARD, requestHandler);
326+
327+
// Track if the main thread gets blocked during streaming
328+
AtomicBoolean eventReceived = new AtomicBoolean(false);
329+
CountDownLatch streamStarted = new CountDownLatch(1);
330+
CountDownLatch eventProcessed = new CountDownLatch(1);
331+
agentExecutorExecute = (context, eventQueue) -> {
332+
// Wait a bit to ensure the main thread continues
333+
try {
334+
Thread.sleep(100);
335+
} catch (InterruptedException e) {
336+
Thread.currentThread().interrupt();
337+
}
338+
eventQueue.enqueueEvent(context.getMessage());
339+
};
340+
341+
String requestBody = """
342+
{
343+
"message": {
344+
"role": "ROLE_USER",
345+
"content": [
346+
{
347+
"text": "tell me some jokes"
348+
}
349+
],
350+
"messageId": "message-1234",
351+
"contextId": "context-1234"
352+
},
353+
"configuration": {
354+
"acceptedOutputModes": ["text"]
355+
}
356+
}""";
357+
358+
// Start streaming
359+
RestHandler.HTTPRestResponse response = handler.handleRequest("POST", "/v1/message:stream", requestBody, callContext);
360+
361+
Assertions.assertEquals(200, response.getStatusCode());
362+
Assertions.assertInstanceOf(RestHandler.HTTPRestStreamingResponse.class, response);
363+
364+
RestHandler.HTTPRestStreamingResponse streamingResponse = (RestHandler.HTTPRestStreamingResponse) response;
365+
Flow.Publisher<String> publisher = streamingResponse.getPublisher();
366+
publisher.subscribe(new Flow.Subscriber<String>() {
367+
@Override
368+
public void onSubscribe(Flow.Subscription subscription) {
369+
streamStarted.countDown();
370+
subscription.request(1);
371+
}
372+
373+
@Override
374+
public void onNext(String item) {
375+
eventReceived.set(true);
376+
eventProcessed.countDown();
377+
}
378+
379+
@Override
380+
public void onError(Throwable throwable) {
381+
eventProcessed.countDown();
382+
}
383+
384+
@Override
385+
public void onComplete() {
386+
eventProcessed.countDown();
387+
}
388+
});
389+
390+
// The main thread should not be blocked - we should be able to continue immediately
391+
Assertions.assertTrue(streamStarted.await(100, TimeUnit.MILLISECONDS),
392+
"Streaming subscription should start quickly without blocking main thread");
393+
394+
// This proves the main thread is not blocked - we can do other work
395+
long startTime = System.currentTimeMillis();
396+
while (System.currentTimeMillis() - startTime < 50) {
397+
// Simulate main thread doing other work
398+
Thread.sleep(1);
399+
}
400+
401+
// Wait for the actual event processing to complete
402+
Assertions.assertTrue(eventProcessed.await(2, TimeUnit.SECONDS),
403+
"Event should be processed within reasonable time");
404+
405+
// Verify we received the event
406+
Assertions.assertTrue(eventReceived.get(), "Should have received streaming event");
407+
}
318408
}

0 commit comments

Comments
 (0)