Skip to content

Commit b852721

Browse files
committed
Move push notification sending into processor. Hook for tests to wait
1 parent 32886d3 commit b852721

File tree

8 files changed

+173
-44
lines changed

8 files changed

+173
-44
lines changed

server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ public void awaitQueuePollerStart(EventQueue eventQueue) throws InterruptedExcep
112112
eventQueue.awaitQueuePollerStart();
113113
}
114114

115+
@Override
116+
public EventQueue.EventQueueBuilder getEventQueueBuilder(String taskId) {
117+
// Use the factory to ensure proper configuration (MainEventBus, callbacks, etc.)
118+
return factory.builder(taskId);
119+
}
120+
115121
@Override
116122
public int getActiveChildQueueCount(String taskId) {
117123
EventQueue queue = queues.get(taskId);

server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
package io.a2a.server.events;
22

3+
import jakarta.annotation.PostConstruct;
4+
import jakarta.annotation.PreDestroy;
5+
import jakarta.enterprise.context.ApplicationScoped;
6+
import jakarta.inject.Inject;
7+
8+
import io.a2a.server.tasks.PushNotificationSender;
39
import io.a2a.server.tasks.TaskManager;
410
import io.a2a.server.tasks.TaskStore;
511
import io.a2a.spec.A2AServerException;
612
import io.a2a.spec.Event;
713
import io.a2a.spec.Task;
814
import io.a2a.spec.TaskArtifactUpdateEvent;
915
import io.a2a.spec.TaskStatusUpdateEvent;
10-
import jakarta.annotation.PostConstruct;
11-
import jakarta.annotation.PreDestroy;
12-
import jakarta.enterprise.context.ApplicationScoped;
13-
import jakarta.inject.Inject;
1416
import org.slf4j.Logger;
1517
import org.slf4j.LoggerFactory;
1618

@@ -33,18 +35,40 @@
3335
public class MainEventBusProcessor implements Runnable {
3436
private static final Logger LOGGER = LoggerFactory.getLogger(MainEventBusProcessor.class);
3537

38+
/**
39+
* Callback for testing synchronization with async event processing.
40+
* Default is NOOP to avoid null checks in production code.
41+
* Tests can inject their own callback via setCallback().
42+
*/
43+
private static volatile MainEventBusProcessorCallback callback = MainEventBusProcessorCallback.NOOP;
3644

3745
private final MainEventBus eventBus;
3846

3947
private final TaskStore taskStore;
4048

49+
private final PushNotificationSender pushSender;
50+
4151
private volatile boolean running = true;
4252
private Thread processorThread;
4353

4454
@Inject
45-
public MainEventBusProcessor(MainEventBus eventBus, TaskStore taskStore) {
55+
public MainEventBusProcessor(MainEventBus eventBus, TaskStore taskStore, PushNotificationSender pushSender) {
4656
this.eventBus = eventBus;
4757
this.taskStore = taskStore;
58+
this.pushSender = pushSender;
59+
}
60+
61+
/**
62+
* Set a callback for testing synchronization with async event processing.
63+
* <p>
64+
* This is primarily intended for tests that need to wait for event processing to complete.
65+
* Pass null to reset to the default NOOP callback.
66+
* </p>
67+
*
68+
* @param callback the callback to invoke during event processing, or null for NOOP
69+
*/
70+
public static void setCallback(MainEventBusProcessorCallback callback) {
71+
MainEventBusProcessor.callback = callback != null ? callback : MainEventBusProcessorCallback.NOOP;
4872
}
4973

5074
@PostConstruct
@@ -99,7 +123,10 @@ private void processEvent(MainEventBusContext context) {
99123
// Step 1: Update TaskStore FIRST (persistence before clients see it)
100124
updateTaskStore(taskId, event);
101125

102-
// Step 2: Then distribute to ChildQueues (clients see it AFTER persistence)
126+
// Step 2: Send push notification AFTER persistence (ensures notification sees latest state)
127+
sendPushNotification(taskId);
128+
129+
// Step 3: Then distribute to ChildQueues (clients see it AFTER persistence + notification)
103130
if (eventQueue instanceof EventQueue.MainQueue mainQueue) {
104131
mainQueue.distributeToChildren(context.eventQueueItem());
105132
LOGGER.debug("Distributed event to children for task {}", taskId);
@@ -109,6 +136,14 @@ private void processEvent(MainEventBusContext context) {
109136
}
110137

111138
LOGGER.debug("Completed processing event for task {}", taskId);
139+
140+
// Step 4: Notify callback after all processing is complete
141+
callback.onEventProcessed(taskId, event);
142+
143+
// Step 5: If this is a final event, notify task finalization
144+
if (isFinalEvent(event)) {
145+
callback.onTaskFinalized(taskId);
146+
}
112147
}
113148

114149
/**
@@ -140,6 +175,28 @@ private void updateTaskStore(String taskId, Event event) {
140175
}
141176
}
142177

178+
/**
179+
* Sends push notification for the task AFTER persistence.
180+
* <p>
181+
* This is called after updateTaskStore() to ensure the notification contains
182+
* the latest persisted state, avoiding race conditions.
183+
* </p>
184+
*/
185+
private void sendPushNotification(String taskId) {
186+
try {
187+
Task task = taskStore.get(taskId);
188+
if (task != null) {
189+
LOGGER.debug("Sending push notification for task {}", taskId);
190+
pushSender.sendNotification(task);
191+
} else {
192+
LOGGER.debug("Skipping push notification - task {} not found in TaskStore", taskId);
193+
}
194+
} catch (Exception e) {
195+
LOGGER.error("Error sending push notification for task {}", taskId, e);
196+
// Don't rethrow - we still want to distribute to ChildQueues
197+
}
198+
}
199+
143200
/**
144201
* Extracts contextId from an event.
145202
* Returns null if the event type doesn't have a contextId (e.g., Message).
@@ -155,4 +212,20 @@ private String extractContextId(Event event) {
155212
// Message and other events don't have contextId
156213
return null;
157214
}
215+
216+
/**
217+
* Checks if an event represents a final task state.
218+
*
219+
* @param event the event to check
220+
* @return true if the event represents a final state (COMPLETED, FAILED, CANCELED, REJECTED, UNKNOWN)
221+
*/
222+
private boolean isFinalEvent(Event event) {
223+
if (event instanceof Task task) {
224+
return task.getStatus() != null && task.getStatus().state() != null
225+
&& task.getStatus().state().isFinal();
226+
} else if (event instanceof TaskStatusUpdateEvent statusUpdate) {
227+
return statusUpdate.isFinal();
228+
}
229+
return false;
230+
}
158231
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package io.a2a.server.events;
2+
3+
import io.a2a.spec.Event;
4+
5+
/**
6+
* Callback interface for MainEventBusProcessor events.
7+
* <p>
8+
* This interface is primarily intended for testing, allowing tests to synchronize
9+
* with the asynchronous MainEventBusProcessor. Production code should not rely on this.
10+
* </p>
11+
* <p>
12+
* Usage in tests:
13+
* <pre>
14+
* {@code
15+
* @BeforeEach
16+
* void setUp() {
17+
* CountDownLatch latch = new CountDownLatch(3);
18+
* MainEventBusProcessor.setCallback(new MainEventBusProcessorCallback() {
19+
* public void onEventProcessed(String taskId, Event event) {
20+
* latch.countDown();
21+
* }
22+
* });
23+
* }
24+
*
25+
* @AfterEach
26+
* void tearDown() {
27+
* MainEventBusProcessor.setCallback(null); // Reset to NOOP
28+
* }
29+
* }
30+
* </pre>
31+
* </p>
32+
*/
33+
public interface MainEventBusProcessorCallback {
34+
35+
/**
36+
* Called after an event has been fully processed (persisted, notification sent, distributed to children).
37+
*
38+
* @param taskId the task ID
39+
* @param event the event that was processed
40+
*/
41+
void onEventProcessed(String taskId, Event event);
42+
43+
/**
44+
* Called when a task reaches a final state (COMPLETED, FAILED, CANCELED, REJECTED).
45+
*
46+
* @param taskId the task ID that was finalized
47+
*/
48+
void onTaskFinalized(String taskId);
49+
50+
/**
51+
* No-op implementation that does nothing.
52+
* Used as the default callback to avoid null checks.
53+
*/
54+
MainEventBusProcessorCallback NOOP = new MainEventBusProcessorCallback() {
55+
@Override
56+
public void onEventProcessed(String taskId, Event event) {
57+
// No-op
58+
}
59+
60+
@Override
61+
public void onTaskFinalized(String taskId) {
62+
// No-op
63+
}
64+
};
65+
}

server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ public class DefaultRequestHandler implements RequestHandler {
105105
private final TaskStore taskStore;
106106
private final QueueManager queueManager;
107107
private final PushNotificationConfigStore pushConfigStore;
108-
private final PushNotificationSender pushSender;
109108
private final Supplier<RequestContext.Builder> requestContextBuilder;
110109

111110
private final ConcurrentMap<String, CompletableFuture<Void>> runningAgents = new ConcurrentHashMap<>();
@@ -116,12 +115,11 @@ public class DefaultRequestHandler implements RequestHandler {
116115
@Inject
117116
public DefaultRequestHandler(AgentExecutor agentExecutor, TaskStore taskStore,
118117
QueueManager queueManager, PushNotificationConfigStore pushConfigStore,
119-
PushNotificationSender pushSender, @Internal Executor executor) {
118+
@Internal Executor executor) {
120119
this.agentExecutor = agentExecutor;
121120
this.taskStore = taskStore;
122121
this.queueManager = queueManager;
123122
this.pushConfigStore = pushConfigStore;
124-
this.pushSender = pushSender;
125123
this.executor = executor;
126124
// TODO In Python this is also a constructor parameter defaulting to this SimpleRequestContextBuilder
127125
// implementation if the parameter is null. Skip that for now, since otherwise I get CDI errors, and
@@ -143,9 +141,9 @@ void initConfig() {
143141
*/
144142
public static DefaultRequestHandler create(AgentExecutor agentExecutor, TaskStore taskStore,
145143
QueueManager queueManager, PushNotificationConfigStore pushConfigStore,
146-
PushNotificationSender pushSender, Executor executor) {
144+
Executor executor) {
147145
DefaultRequestHandler handler =
148-
new DefaultRequestHandler(agentExecutor, taskStore, queueManager, pushConfigStore, pushSender, executor);
146+
new DefaultRequestHandler(agentExecutor, taskStore, queueManager, pushConfigStore, executor);
149147
handler.agentCompletionTimeoutSeconds = 5;
150148
handler.consumptionCompletionTimeoutSeconds = 2;
151149
return handler;
@@ -280,9 +278,6 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte
280278
ResultAggregator.EventTypeAndInterrupt etai = null;
281279
EventKind kind = null; // Declare outside try block so it's in scope for return
282280
try {
283-
// Create callback for push notifications during background event processing
284-
Runnable pushNotificationCallback = () -> sendPushNotification(taskId, resultAggregator);
285-
286281
EventConsumer consumer = new EventConsumer(queue);
287282

288283
// This callback must be added before we start consuming. Otherwise,
@@ -374,9 +369,6 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte
374369
if (kind instanceof Task taskResult && !taskId.equals(taskResult.getId())) {
375370
throw new InternalError("Task ID mismatch in agent response");
376371
}
377-
378-
// Send push notification after initial return (for both blocking and non-blocking)
379-
pushNotificationCallback.run();
380372
} finally {
381373
// Remove agent from map immediately to prevent accumulation
382374
CompletableFuture<Void> agentFuture = runningAgents.remove(taskId);
@@ -442,12 +434,7 @@ public Flow.Publisher<StreamingEventKind> onMessageSendStream(
442434
}
443435

444436
}
445-
if (pushSender != null && taskId.get() != null) {
446-
EventKind latest = resultAggregator.getCurrentResult();
447-
if (latest instanceof Task latestTask) {
448-
pushSender.sendNotification(latestTask);
449-
}
450-
}
437+
// Push notifications now sent by MainEventBusProcessor after persistence
451438

452439
return true;
453440
}));
@@ -820,15 +807,6 @@ private MessageSendSetup initMessageSend(MessageSendParams params, ServerCallCon
820807
return new MessageSendSetup(taskManager, task, requestContext);
821808
}
822809

823-
private void sendPushNotification(String taskId, ResultAggregator resultAggregator) {
824-
if (pushSender != null && taskId != null) {
825-
EventKind latest = resultAggregator.getCurrentResult();
826-
if (latest instanceof Task latestTask) {
827-
pushSender.sendNotification(latestTask);
828-
}
829-
}
830-
}
831-
832810
/**
833811
* Log current thread and resource statistics for debugging.
834812
* Only logs when DEBUG level is enabled. Call this from debugger or add strategic

server-common/src/test/java/io/a2a/server/events/EventQueueTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import java.util.List;
1313

14+
import io.a2a.server.tasks.PushNotificationSender;
1415
import io.a2a.spec.Artifact;
1516
import io.a2a.spec.Event;
1617
import io.a2a.spec.JSONRPCError;
@@ -52,13 +53,14 @@ public class EventQueueTest {
5253
}
5354
""";
5455

56+
private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {};
5557

5658
@BeforeEach
5759
public void init() {
5860
// Set up MainEventBus and processor for production-like test environment
5961
InMemoryTaskStore taskStore = new InMemoryTaskStore();
6062
mainEventBus = new MainEventBus();
61-
mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore);
63+
mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER);
6264
EventQueueUtil.start(mainEventBusProcessor);
6365

6466
eventQueue = EventQueue.builder()

server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import io.a2a.server.tasks.InMemoryTaskStore;
1818
import io.a2a.server.tasks.MockTaskStateProvider;
19+
import io.a2a.server.tasks.PushNotificationSender;
1920
import org.junit.jupiter.api.AfterEach;
2021
import org.junit.jupiter.api.BeforeEach;
2122
import org.junit.jupiter.api.Test;
@@ -27,13 +28,14 @@ public class InMemoryQueueManagerTest {
2728
private InMemoryTaskStore taskStore;
2829
private MainEventBus mainEventBus;
2930
private MainEventBusProcessor mainEventBusProcessor;
31+
private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {};
3032

3133
@BeforeEach
3234
public void setUp() {
3335
taskStateProvider = new MockTaskStateProvider();
3436
taskStore = new InMemoryTaskStore();
3537
mainEventBus = new MainEventBus();
36-
mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore);
38+
mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER);
3739
EventQueueUtil.start(mainEventBusProcessor);
3840

3941
queueManager = new InMemoryQueueManager(taskStateProvider, mainEventBus);

server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ public class AbstractA2ARequestHandlerTest {
6767
private static final String PREFERRED_TRANSPORT = "preferred-transport";
6868
private static final String A2A_REQUESTHANDLER_TEST_PROPERTIES = "/a2a-requesthandler-test.properties";
6969

70+
private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {};
71+
7072
protected AgentExecutor executor;
7173
protected TaskStore taskStore;
7274
protected RequestHandler requestHandler;
@@ -100,19 +102,20 @@ public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPC
100102
InMemoryTaskStore inMemoryTaskStore = new InMemoryTaskStore();
101103
taskStore = inMemoryTaskStore;
102104

105+
// Create push notification components BEFORE MainEventBusProcessor
106+
httpClient = new TestHttpClient();
107+
PushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore();
108+
PushNotificationSender pushSender = new BasePushNotificationSender(pushConfigStore, httpClient);
109+
103110
// Create MainEventBus and MainEventBusProcessor (production code path)
104111
mainEventBus = new MainEventBus();
105-
mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore);
112+
mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, pushSender);
106113
EventQueueUtil.start(mainEventBusProcessor);
107114

108115
queueManager = new InMemoryQueueManager(inMemoryTaskStore, mainEventBus);
109116

110-
httpClient = new TestHttpClient();
111-
PushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore();
112-
PushNotificationSender pushSender = new BasePushNotificationSender(pushConfigStore, httpClient);
113-
114117
requestHandler = DefaultRequestHandler.create(
115-
executor, taskStore, queueManager, pushConfigStore, pushSender, internalExecutor);
118+
executor, taskStore, queueManager, pushConfigStore, internalExecutor);
116119
}
117120

118121
@AfterEach

0 commit comments

Comments
 (0)