Skip to content

Commit c254a4a

Browse files
committed
feat: Support multiple payload types in push notifications
Expand PushNotificationSender to support all StreamingEventKind payload types as defined in the A2A specification, not just Task objects. Fixes: #490 Signed-off-by: Jeff Mesnil <[email protected]>
1 parent aa41de5 commit c254a4a

File tree

9 files changed

+268
-37
lines changed

9 files changed

+268
-37
lines changed

extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStoreIntegrationTest.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import io.a2a.spec.GetTaskPushNotificationConfigParams;
2626
import io.a2a.spec.Message;
2727
import io.a2a.spec.PushNotificationConfig;
28+
import io.a2a.spec.StreamingEventKind;
2829
import io.a2a.spec.Task;
2930
import io.a2a.spec.TaskPushNotificationConfig;
3031
import io.a2a.spec.TextPart;
@@ -88,9 +89,9 @@ public void testDirectNotificationTrigger() {
8889
mockPushNotificationSender.sendNotification(testTask);
8990

9091
// Verify it was captured
91-
Queue<Task> captured = mockPushNotificationSender.getCapturedTasks();
92+
Queue<StreamingEventKind> captured = mockPushNotificationSender.getCapturedEvents();
9293
assertEquals(1, captured.size());
93-
assertEquals("direct-test-task", captured.peek().getId());
94+
assertEquals("direct-test-task", ((Task)captured.peek()).getId());
9495
}
9596

9697
@Test
@@ -151,7 +152,7 @@ public void testJpaDatabasePushNotificationConfigStoreIntegration() throws Excep
151152
boolean notificationReceived = false;
152153

153154
while (System.currentTimeMillis() < end) {
154-
if (!mockPushNotificationSender.getCapturedTasks().isEmpty()) {
155+
if (!mockPushNotificationSender.getCapturedEvents().isEmpty()) {
155156
notificationReceived = true;
156157
break;
157158
}
@@ -161,10 +162,12 @@ public void testJpaDatabasePushNotificationConfigStoreIntegration() throws Excep
161162
assertTrue(notificationReceived, "Timeout waiting for push notification.");
162163

163164
// Step 6: Verify the captured notification
164-
Queue<Task> capturedTasks = mockPushNotificationSender.getCapturedTasks();
165+
Queue<StreamingEventKind> capturedTasks = mockPushNotificationSender.getCapturedEvents();
165166

166167
// Verify the notification contains the correct task with artifacts
167168
Task notifiedTaskWithArtifact = capturedTasks.stream()
169+
.filter(e -> Task.TASK.equals(e.getKind()))
170+
.map(e -> (Task)e)
168171
.filter(t -> taskId.equals(t.getId()) && t.getArtifacts() != null && t.getArtifacts().size() > 0)
169172
.findFirst()
170173
.orElse(null);

extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/MockPushNotificationSender.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import jakarta.enterprise.inject.Alternative;
99

1010
import io.a2a.server.tasks.PushNotificationSender;
11-
import io.a2a.spec.Task;
11+
import io.a2a.spec.StreamingEventKind;
1212

1313
/**
1414
* Mock implementation of PushNotificationSender for integration testing.
@@ -19,18 +19,18 @@
1919
@Priority(100)
2020
public class MockPushNotificationSender implements PushNotificationSender {
2121

22-
private final Queue<Task> capturedTasks = new ConcurrentLinkedQueue<>();
22+
private final Queue<StreamingEventKind> capturedEvents = new ConcurrentLinkedQueue<>();
2323

2424
@Override
25-
public void sendNotification(Task task) {
26-
capturedTasks.add(task);
25+
public void sendNotification(StreamingEventKind kind) {
26+
capturedEvents.add(kind);
2727
}
2828

29-
public Queue<Task> getCapturedTasks() {
30-
return capturedTasks;
29+
public Queue<StreamingEventKind> getCapturedEvents() {
30+
return capturedEvents;
3131
}
3232

3333
public void clear() {
34-
capturedTasks.clear();
34+
capturedEvents.clear();
3535
}
3636
}

server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
import static io.a2a.client.http.A2AHttpClient.APPLICATION_JSON;
44
import static io.a2a.client.http.A2AHttpClient.CONTENT_TYPE;
55
import static io.a2a.common.A2AHeaders.X_A2A_NOTIFICATION_TOKEN;
6+
import static io.a2a.spec.Message.MESSAGE;
7+
import static io.a2a.spec.Task.TASK;
8+
import static io.a2a.spec.TaskArtifactUpdateEvent.ARTIFACT_UPDATE;
9+
import static io.a2a.spec.TaskStatusUpdateEvent.STATUS_UPDATE;
610

711
import java.io.IOException;
812
import java.util.List;
@@ -18,8 +22,12 @@
1822

1923
import io.a2a.client.http.A2AHttpClient;
2024
import io.a2a.client.http.JdkA2AHttpClient;
25+
import io.a2a.spec.Message;
2126
import io.a2a.spec.PushNotificationConfig;
27+
import io.a2a.spec.StreamingEventKind;
2228
import io.a2a.spec.Task;
29+
import io.a2a.spec.TaskArtifactUpdateEvent;
30+
import io.a2a.spec.TaskStatusUpdateEvent;
2331
import io.a2a.util.Utils;
2432

2533
@ApplicationScoped
@@ -42,34 +50,44 @@ public BasePushNotificationSender(PushNotificationConfigStore configStore, A2AHt
4250
}
4351

4452
@Override
45-
public void sendNotification(Task task) {
46-
List<PushNotificationConfig> pushConfigs = configStore.getInfo(task.getId());
53+
public void sendNotification(StreamingEventKind kind) {
54+
String taskId = switch (kind.getKind()) {
55+
case TASK -> ((Task)kind).getId();
56+
case MESSAGE -> ((Message)kind).getTaskId();
57+
case STATUS_UPDATE -> ((TaskStatusUpdateEvent)kind).getTaskId();
58+
case ARTIFACT_UPDATE -> ((TaskArtifactUpdateEvent)kind).getTaskId();
59+
default -> null;
60+
};
61+
if (taskId == null) {
62+
return;
63+
}
64+
List<PushNotificationConfig> pushConfigs = configStore.getInfo(taskId);
4765
if (pushConfigs == null || pushConfigs.isEmpty()) {
4866
return;
4967
}
5068

5169
List<CompletableFuture<Boolean>> dispatchResults = pushConfigs
5270
.stream()
53-
.map(pushConfig -> dispatch(task, pushConfig))
71+
.map(pushConfig -> dispatch(kind, pushConfig))
5472
.toList();
5573
CompletableFuture<Void> allFutures = CompletableFuture.allOf(dispatchResults.toArray(new CompletableFuture[0]));
5674
CompletableFuture<Boolean> dispatchResult = allFutures.thenApply(v -> dispatchResults.stream()
5775
.allMatch(CompletableFuture::join));
5876
try {
5977
boolean allSent = dispatchResult.get();
6078
if (! allSent) {
61-
LOGGER.warn("Some push notifications failed to send for taskId: " + task.getId());
79+
LOGGER.warn("Some push notifications failed to send for taskId: " + taskId);
6280
}
6381
} catch (InterruptedException | ExecutionException e) {
64-
LOGGER.warn("Some push notifications failed to send for taskId " + task.getId() + ": {}", e.getMessage(), e);
82+
LOGGER.warn("Some push notifications failed to send for taskId " + taskId + ": {}", e.getMessage(), e);
6583
}
6684
}
6785

68-
private CompletableFuture<Boolean> dispatch(Task task, PushNotificationConfig pushInfo) {
69-
return CompletableFuture.supplyAsync(() -> dispatchNotification(task, pushInfo));
86+
private CompletableFuture<Boolean> dispatch(StreamingEventKind kind, PushNotificationConfig pushInfo) {
87+
return CompletableFuture.supplyAsync(() -> dispatchNotification(kind, pushInfo));
7088
}
7189

72-
private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo) {
90+
private boolean dispatchNotification(StreamingEventKind kind, PushNotificationConfig pushInfo) {
7391
String url = pushInfo.url();
7492
String token = pushInfo.token();
7593

@@ -80,7 +98,7 @@ private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo)
8098

8199
String body;
82100
try {
83-
body = Utils.marshalFrom(task);
101+
body = Utils.marshalFrom(kind);
84102
} catch (JsonProcessingException e) {
85103
LOGGER.debug("Error writing value as string: {}", e.getMessage(), e);
86104
return false;
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
package io.a2a.server.tasks;
22

3-
import io.a2a.spec.Task;
3+
import io.a2a.spec.StreamingEventKind;
44

55
/**
66
* Interface for sending push notifications for tasks.
77
*/
88
public interface PushNotificationSender {
99

1010
/**
11-
* Sends a push notification containing the latest task state.
12-
* @param task the task
11+
* Sends a push notification containing payload about a task.
12+
* @param kind the payload to push
1313
*/
14-
void sendNotification(Task task);
14+
void sendNotification(StreamingEventKind kind);
1515
}

server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import jakarta.enterprise.context.Dependent;
1818

19-
import com.fasterxml.jackson.databind.JsonNode;
2019
import io.quarkus.arc.profile.IfBuildProfile;
2120
import org.junit.jupiter.api.AfterEach;
2221
import org.junit.jupiter.api.Assertions;
@@ -200,10 +199,7 @@ public PostBuilder body(String body) {
200199

201200
@Override
202201
public A2AHttpResponse post() throws IOException, InterruptedException {
203-
JsonNode root = Utils.OBJECT_MAPPER.readTree(body);
204-
// This will need to be updated for #490 to unmarshall based on the kind of payload
205-
JsonNode taskNode = root.elements().next();
206-
Task task = Utils.OBJECT_MAPPER.treeToValue(taskNode, Task.TYPE_REFERENCE);
202+
Task task = Utils.unmarshalStreamingEventKindFrom(body);
207203
tasks.add(task);
208204
try {
209205
return new A2AHttpResponse() {

0 commit comments

Comments
 (0)