Skip to content

Commit 564418e

Browse files
committed
#490 WIP
Signed-off-by: Jeff Mesnil <[email protected]>
1 parent 686e40a commit 564418e

File tree

5 files changed

+92
-25
lines changed

5 files changed

+92
-25
lines changed

extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStoreIntegrationTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public void testDirectNotificationTrigger() {
8888
mockPushNotificationSender.sendNotification(testTask);
8989

9090
// Verify it was captured
91-
Queue<Task> captured = mockPushNotificationSender.getCapturedTasks();
91+
Queue<Task> captured = mockPushNotificationSender.getCapturedEvents();
9292
assertEquals(1, captured.size());
9393
assertEquals("direct-test-task", captured.peek().getId());
9494
}
@@ -151,7 +151,7 @@ public void testJpaDatabasePushNotificationConfigStoreIntegration() throws Excep
151151
boolean notificationReceived = false;
152152

153153
while (System.currentTimeMillis() < end) {
154-
if (!mockPushNotificationSender.getCapturedTasks().isEmpty()) {
154+
if (!mockPushNotificationSender.getCapturedEvents().isEmpty()) {
155155
notificationReceived = true;
156156
break;
157157
}
@@ -161,7 +161,7 @@ public void testJpaDatabasePushNotificationConfigStoreIntegration() throws Excep
161161
assertTrue(notificationReceived, "Timeout waiting for push notification.");
162162

163163
// Step 6: Verify the captured notification
164-
Queue<Task> capturedTasks = mockPushNotificationSender.getCapturedTasks();
164+
Queue<Task> capturedTasks = mockPushNotificationSender.getCapturedEvents();
165165

166166
// Verify the notification contains the correct task with artifacts
167167
Task notifiedTaskWithArtifact = capturedTasks.stream()

extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/MockPushNotificationSender.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
import java.util.Queue;
44
import java.util.concurrent.ConcurrentLinkedQueue;
55

6+
import io.a2a.spec.StreamingEventKind;
67
import jakarta.annotation.Priority;
78
import jakarta.enterprise.context.ApplicationScoped;
89
import jakarta.enterprise.inject.Alternative;
910

1011
import io.a2a.server.tasks.PushNotificationSender;
11-
import io.a2a.spec.Task;
1212

1313
/**
1414
* Mock implementation of PushNotificationSender for integration testing.
@@ -19,18 +19,18 @@
1919
@Priority(100)
2020
public class MockPushNotificationSender implements PushNotificationSender {
2121

22-
private final Queue<Task> capturedTasks = new ConcurrentLinkedQueue<>();
22+
private final Queue<StreamingEventKind> capturedEvents = new ConcurrentLinkedQueue<>();
2323

2424
@Override
25-
public void sendNotification(Task task) {
26-
capturedTasks.add(task);
25+
public void sendNotification(StreamingEventKind kind) {
26+
capturedEvents.add(kind);
2727
}
2828

29-
public Queue<Task> getCapturedTasks() {
30-
return capturedTasks;
29+
public Queue<StreamingEventKind> getCapturedEvents() {
30+
return capturedEvents;
3131
}
3232

3333
public void clear() {
34-
capturedTasks.clear();
34+
capturedEvents.clear();
3535
}
3636
}

server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33
import static io.a2a.client.http.A2AHttpClient.APPLICATION_JSON;
44
import static io.a2a.client.http.A2AHttpClient.CONTENT_TYPE;
55
import static io.a2a.common.A2AHeaders.X_A2A_NOTIFICATION_TOKEN;
6+
import static io.a2a.spec.Message.MESSAGE;
7+
import static io.a2a.spec.Task.TASK;
8+
import static io.a2a.spec.TaskArtifactUpdateEvent.ARTIFACT_UPDATE;
9+
import static io.a2a.spec.TaskStatusUpdateEvent.STATUS_UPDATE;
10+
11+
import io.a2a.spec.Message;
12+
import io.a2a.spec.StreamingEventKind;
13+
import io.a2a.spec.TaskArtifactUpdateEvent;
14+
import io.a2a.spec.TaskStatusUpdateEvent;
615
import jakarta.enterprise.context.ApplicationScoped;
716
import jakarta.inject.Inject;
817

@@ -42,34 +51,45 @@ public BasePushNotificationSender(PushNotificationConfigStore configStore, A2AHt
4251
}
4352

4453
@Override
45-
public void sendNotification(Task task) {
46-
List<PushNotificationConfig> pushConfigs = configStore.getInfo(task.getId());
54+
public void sendNotification(StreamingEventKind kind) {
55+
String taskId = switch (kind.getKind()) {
56+
case TASK -> ((Task) kind).getId();
57+
case MESSAGE -> ((Message)kind).getTaskId();
58+
case STATUS_UPDATE -> ((TaskStatusUpdateEvent)kind).getTaskId();
59+
case ARTIFACT_UPDATE -> ((TaskArtifactUpdateEvent)kind).getTaskId();
60+
default -> null;
61+
};
62+
if (taskId == null) {
63+
return;
64+
}
65+
66+
List<PushNotificationConfig> pushConfigs = configStore.getInfo(taskId);
4767
if (pushConfigs == null || pushConfigs.isEmpty()) {
4868
return;
4969
}
5070

5171
List<CompletableFuture<Boolean>> dispatchResults = pushConfigs
5272
.stream()
53-
.map(pushConfig -> dispatch(task, pushConfig))
73+
.map(pushConfig -> dispatch(kind, pushConfig))
5474
.toList();
5575
CompletableFuture<Void> allFutures = CompletableFuture.allOf(dispatchResults.toArray(new CompletableFuture[0]));
5676
CompletableFuture<Boolean> dispatchResult = allFutures.thenApply(v -> dispatchResults.stream()
5777
.allMatch(CompletableFuture::join));
5878
try {
5979
boolean allSent = dispatchResult.get();
60-
if (! allSent) {
61-
LOGGER.warn("Some push notifications failed to send for taskId: " + task.getId());
80+
if (!allSent) {
81+
LOGGER.warn("Some push notifications failed to send for taskId: " + taskId);
6282
}
6383
} catch (InterruptedException | ExecutionException e) {
64-
LOGGER.warn("Some push notifications failed to send for taskId " + task.getId() + ": {}", e.getMessage(), e);
84+
LOGGER.warn("Some push notifications failed to send for taskId " + taskId + ": {}", e.getMessage(), e);
6585
}
6686
}
6787

68-
private CompletableFuture<Boolean> dispatch(Task task, PushNotificationConfig pushInfo) {
69-
return CompletableFuture.supplyAsync(() -> dispatchNotification(task, pushInfo));
88+
private CompletableFuture<Boolean> dispatch(StreamingEventKind kind, PushNotificationConfig pushInfo) {
89+
return CompletableFuture.supplyAsync(() -> dispatchNotification(kind, pushInfo));
7090
}
7191

72-
private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo) {
92+
private boolean dispatchNotification(StreamingEventKind kind, PushNotificationConfig pushInfo) {
7393
String url = pushInfo.url();
7494
String token = pushInfo.token();
7595

@@ -80,7 +100,7 @@ private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo)
80100

81101
String body;
82102
try {
83-
body = Utils.OBJECT_MAPPER.writeValueAsString(task);
103+
body = Utils.OBJECT_MAPPER.writeValueAsString(kind);
84104
} catch (JsonProcessingException e) {
85105
LOGGER.debug("Error writing value as string: {}", e.getMessage(), e);
86106
return false;
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
package io.a2a.server.tasks;
22

3-
import io.a2a.spec.Task;
3+
import io.a2a.spec.StreamingEventKind;
44

55
/**
66
* Interface for sending push notifications for tasks.
77
*/
88
public interface PushNotificationSender {
99

1010
/**
11-
* Sends a push notification containing the latest task state.
12-
* @param task the task
11+
* Sends a push notification with a payload related to the task.
12+
* @param kind the payload to push
1313
*/
14-
void sendNotification(Task task);
14+
void sendNotification(StreamingEventKind kind);
1515
}

server-common/src/test/java/io/a2a/server/tasks/PushNotificationSenderTest.java

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
import java.util.concurrent.TimeUnit;
1717
import java.util.function.Consumer;
1818

19+
import io.a2a.spec.Message;
20+
import io.a2a.spec.Part;
21+
import io.a2a.spec.StreamingEventKind;
22+
import io.a2a.spec.TextPart;
1923
import org.junit.jupiter.api.BeforeEach;
2024
import org.junit.jupiter.api.Test;
2125

@@ -67,6 +71,7 @@ class TestPostBuilder implements A2AHttpClient.PostBuilder {
6771
@Override
6872
public PostBuilder body(String body) {
6973
this.body = body;
74+
System.out.println("body = " + body);
7075
return this;
7176
}
7277

@@ -80,6 +85,7 @@ public A2AHttpResponse post() throws IOException, InterruptedException {
8085
Task task = Utils.OBJECT_MAPPER.readValue(body, Task.TYPE_REFERENCE);
8186
tasks.add(task);
8287
urls.add(url);
88+
System.out.println(requestHeaders);
8389
headers.add(new java.util.HashMap<>(requestHeaders));
8490

8591
return new A2AHttpResponse() {
@@ -95,7 +101,7 @@ public boolean success() {
95101

96102
@Override
97103
public String body() {
98-
return "";
104+
return body;
99105
}
100106
};
101107
} finally {
@@ -316,4 +322,45 @@ public void testSendNotificationHttpError() {
316322
// Verify no tasks were successfully processed due to the error
317323
assertEquals(0, testHttpClient.tasks.size());
318324
}
325+
326+
@Test
327+
public void testSendNotificationWithMessage() throws InterruptedException {
328+
String taskId = "task_send_notification_with_message";
329+
Task taskData = createSampleTask(taskId, TaskState.COMPLETED);
330+
PushNotificationConfig config = createSamplePushConfig("http://notify.me/here", "cfg1", "unique_token");
331+
332+
// Set up the configuration in the store
333+
configStore.setInfo(taskId, config);
334+
335+
// Set up latch to wait for async completion
336+
testHttpClient.latch = new CountDownLatch(1);
337+
338+
Message message = new Message.Builder()
339+
.taskId(taskId)
340+
.messageId("task_push_notification_message")
341+
.parts(Collections.singletonList(new TextPart("Message for task " + taskId)))
342+
.role(Message.Role.USER)
343+
.build();
344+
sender.sendNotification(message);
345+
346+
// Wait for the async operation to complete
347+
assertTrue(testHttpClient.latch.await(5, TimeUnit.SECONDS), "HTTP call should complete within 5 seconds");
348+
349+
// Verify the task was sent via HTTP
350+
assertEquals(1, testHttpClient.tasks.size());
351+
Task sentTask = testHttpClient.tasks.get(0);
352+
assertEquals(taskData.getId(), sentTask.getId());
353+
354+
// Verify that the X-A2A-Notification-Token header is sent with the correct token
355+
assertEquals(1, testHttpClient.headers.size());
356+
Map<String, String> sentHeaders = testHttpClient.headers.get(0);
357+
assertEquals(2, sentHeaders.size());
358+
assertTrue(sentHeaders.containsKey(A2AHeaders.X_A2A_NOTIFICATION_TOKEN));
359+
assertEquals(config.token(), sentHeaders.get(A2AHeaders.X_A2A_NOTIFICATION_TOKEN));
360+
// Content-Type header should always be present
361+
assertTrue(sentHeaders.containsKey(CONTENT_TYPE));
362+
assertEquals(APPLICATION_JSON, sentHeaders.get(CONTENT_TYPE));
363+
364+
}
365+
319366
}

0 commit comments

Comments
 (0)