diff --git a/reference-impl/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java b/reference-impl/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java index acc363868..1500dfe4c 100644 --- a/reference-impl/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java +++ b/reference-impl/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java @@ -8,7 +8,6 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; -import com.fasterxml.jackson.databind.JsonNode; import jakarta.enterprise.inject.Instance; import jakarta.inject.Inject; import jakarta.inject.Singleton; @@ -17,8 +16,10 @@ import com.fasterxml.jackson.core.JsonParseException; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.io.JsonEOFException; +import com.fasterxml.jackson.databind.JsonNode; import io.a2a.server.ExtendedAgentCard; import io.a2a.server.requesthandlers.JSONRPCHandler; +import io.a2a.server.util.async.Internal; import io.a2a.spec.AgentCard; import io.a2a.spec.CancelTaskRequest; import io.a2a.spec.GetTaskPushNotificationConfigRequest; @@ -44,7 +45,6 @@ import io.a2a.spec.TaskResubscriptionRequest; import io.a2a.spec.UnsupportedOperationError; import io.a2a.util.Utils; -import io.a2a.server.util.async.Internal; import io.quarkus.vertx.web.Body; import io.quarkus.vertx.web.ReactiveRoutes; import io.quarkus.vertx.web.Route; @@ -69,6 +69,7 @@ public class A2AServerRoutes { Instance extendedAgentCard; // Hook so testing can wait until the MultiSseSupport is subscribed. + // Without this we get intermittent failures private static volatile Runnable streamingMultiSseSupportSubscribedRunnable; @Inject diff --git a/reference-impl/src/test/java/io/a2a/server/apps/quarkus/A2ATestRoutes.java b/reference-impl/src/test/java/io/a2a/server/apps/quarkus/A2ATestRoutes.java new file mode 100644 index 000000000..bc6a342a6 --- /dev/null +++ b/reference-impl/src/test/java/io/a2a/server/apps/quarkus/A2ATestRoutes.java @@ -0,0 +1,149 @@ +package io.a2a.server.apps.quarkus; + +import static io.vertx.core.http.HttpHeaders.CONTENT_TYPE; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; +import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN; + +import java.util.concurrent.atomic.AtomicInteger; + +import jakarta.annotation.PostConstruct; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import io.a2a.server.apps.common.TestUtilsBean; +import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.a2a.util.Utils; +import io.quarkus.vertx.web.Body; +import io.quarkus.vertx.web.Param; +import io.quarkus.vertx.web.Route; +import io.vertx.ext.web.RoutingContext; + +/** + * Exposes the {@link TestUtilsBean} via REST using Quarkus Reactive Routes + */ +@Singleton +public class A2ATestRoutes { + @Inject + TestUtilsBean testUtilsBean; + + @Inject + A2AServerRoutes a2AServerRoutes; + + AtomicInteger streamingSubscribedCount = new AtomicInteger(0); + + @PostConstruct + public void init() { + A2AServerRoutes.setStreamingMultiSseSupportSubscribedRunnable(() -> streamingSubscribedCount.incrementAndGet()); + } + + + @Route(path = "/test/task", methods = {Route.HttpMethod.POST}, consumes = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING) + public void saveTask(@Body String body, RoutingContext rc) { + try { + Task task = Utils.OBJECT_MAPPER.readValue(body, Task.class); + testUtilsBean.saveTask(task); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/task/:taskId", methods = {Route.HttpMethod.GET}, produces = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING) + public void getTask(@Param String taskId, RoutingContext rc) { + try { + Task task = testUtilsBean.getTask(taskId); + if (task == null) { + rc.response() + .setStatusCode(404) + .end(); + return; + } + rc.response() + .setStatusCode(200) + .putHeader(CONTENT_TYPE, APPLICATION_JSON) + .end(Utils.OBJECT_MAPPER.writeValueAsString(task)); + + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/task/:taskId", methods = {Route.HttpMethod.DELETE}, type = Route.HandlerType.BLOCKING) + public void deleteTask(@Param String taskId, RoutingContext rc) { + try { + Task task = testUtilsBean.getTask(taskId); + if (task == null) { + rc.response() + .setStatusCode(404) + .end(); + return; + } + testUtilsBean.deleteTask(taskId); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/queue/ensure/:taskId", methods = {Route.HttpMethod.POST}) + public void ensureTaskQueue(@Param String taskId, RoutingContext rc) { + try { + testUtilsBean.ensureQueue(taskId); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/queue/enqueueTaskStatusUpdateEvent/:taskId", methods = {Route.HttpMethod.POST}) + public void enqueueTaskStatusUpdateEvent(@Param String taskId, @Body String body, RoutingContext rc) { + + try { + TaskStatusUpdateEvent event = Utils.OBJECT_MAPPER.readValue(body, TaskStatusUpdateEvent.class); + testUtilsBean.enqueueEvent(taskId, event); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/queue/enqueueTaskArtifactUpdateEvent/:taskId", methods = {Route.HttpMethod.POST}) + public void enqueueTaskArtifactUpdateEvent(@Param String taskId, @Body String body, RoutingContext rc) { + + try { + TaskArtifactUpdateEvent event = Utils.OBJECT_MAPPER.readValue(body, TaskArtifactUpdateEvent.class); + testUtilsBean.enqueueEvent(taskId, event); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/streamingSubscribedCount", methods = {Route.HttpMethod.GET}, produces = {TEXT_PLAIN}) + public void getStreamingSubscribedCount(RoutingContext rc) { + rc.response() + .setStatusCode(200) + .end(String.valueOf(streamingSubscribedCount.get())); + } + + private void errorResponse(Throwable t, RoutingContext rc) { + t.printStackTrace(); + rc.response() + .setStatusCode(500) + .putHeader(CONTENT_TYPE, TEXT_PLAIN) + .end(); + } + +} diff --git a/reference-impl/src/test/java/io/a2a/server/apps/quarkus/QuarkusA2AServerTest.java b/reference-impl/src/test/java/io/a2a/server/apps/quarkus/QuarkusA2AServerTest.java index dab1954ba..f9ed48643 100644 --- a/reference-impl/src/test/java/io/a2a/server/apps/quarkus/QuarkusA2AServerTest.java +++ b/reference-impl/src/test/java/io/a2a/server/apps/quarkus/QuarkusA2AServerTest.java @@ -1,37 +1,12 @@ package io.a2a.server.apps.quarkus; -import jakarta.inject.Inject; - import io.a2a.server.apps.common.AbstractA2AServerTest; -import io.a2a.server.events.InMemoryQueueManager; -import io.a2a.server.tasks.TaskStore; import io.quarkus.test.junit.QuarkusTest; @QuarkusTest public class QuarkusA2AServerTest extends AbstractA2AServerTest { - @Inject - TaskStore taskStore; - - @Inject - InMemoryQueueManager queueManager; - public QuarkusA2AServerTest() { super(8081); } - - @Override - protected TaskStore getTaskStore() { - return taskStore; - } - - @Override - protected InMemoryQueueManager getQueueManager() { - return queueManager; - } - - @Override - protected void setStreamingSubscribedRunnable(Runnable runnable) { - A2AServerRoutes.setStreamingMultiSseSupportSubscribedRunnable(runnable); - } } diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java index 0992a157f..3a20660e8 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java @@ -9,24 +9,25 @@ import static org.wildfly.common.Assert.assertTrue; import java.io.EOFException; +import java.io.IOException; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Stream; import jakarta.ws.rs.core.MediaType; import com.fasterxml.jackson.core.JsonProcessingException; -import io.a2a.server.events.InMemoryQueueManager; -import io.a2a.server.tasks.TaskStore; import io.a2a.spec.AgentCard; import io.a2a.spec.Artifact; import io.a2a.spec.CancelTaskRequest; @@ -71,6 +72,10 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +/** + * This test requires doing some work on the server to add/get/delete tasks, and enqueue events. This is exposed via REST, + * which delegates to {@link TestUtilsBean}. + */ public abstract class AbstractA2AServerTest { private static final Task MINIMAL_TASK = new Task.Builder() @@ -102,6 +107,7 @@ public abstract class AbstractA2AServerTest { .role(Message.Role.AGENT) .parts(new TextPart("test message")) .build(); + public static final String APPLICATION_JSON = "application/json"; private final int serverPort; @@ -110,16 +116,30 @@ protected AbstractA2AServerTest(int serverPort) { } @Test - public void testGetTaskSuccess() { + public void testTaskStoreMethodsSanityTest() throws Exception { + Task task = new Task.Builder(MINIMAL_TASK).id("abcde").build(); + saveTaskInTaskStore(task); + Task saved = getTaskFromTaskStore(task.getId()); + assertEquals(task.getId(), saved.getId()); + assertEquals(task.getContextId(), saved.getContextId()); + assertEquals(task.getStatus().state(), saved.getStatus().state()); + + deleteTaskInTaskStore(task.getId()); + Task saved2 = getTaskFromTaskStore(task.getId()); + assertNull(saved2); + } + + @Test + public void testGetTaskSuccess() throws Exception { testGetTask(); } - private void testGetTask() { + private void testGetTask() throws Exception { testGetTask(null); } - private void testGetTask(String mediaType) { - getTaskStore().save(MINIMAL_TASK); + private void testGetTask(String mediaType) throws Exception { + saveTaskInTaskStore(MINIMAL_TASK); try { GetTaskRequest request = new GetTaskRequest("1", new TaskQueryParams(MINIMAL_TASK.getId())); RequestSpecification requestSpecification = RestAssured.given() @@ -140,15 +160,14 @@ private void testGetTask(String mediaType) { assertEquals("session-xyz", response.getResult().getContextId()); assertEquals(TaskState.SUBMITTED, response.getResult().getStatus().state()); assertNull(response.getError()); - } catch (Exception e) { } finally { - getTaskStore().delete(MINIMAL_TASK.getId()); + deleteTaskInTaskStore(MINIMAL_TASK.getId()); } } @Test - public void testGetTaskNotFound() { - assertTrue(getTaskStore().get("non-existent-task") == null); + public void testGetTaskNotFound() throws Exception { + assertTrue(getTaskFromTaskStore("non-existent-task") == null); GetTaskRequest request = new GetTaskRequest("1", new TaskQueryParams("non-existent-task")); GetTaskResponse response = given() .contentType(MediaType.APPLICATION_JSON) @@ -167,8 +186,8 @@ public void testGetTaskNotFound() { } @Test - public void testCancelTaskSuccess() { - getTaskStore().save(CANCEL_TASK); + public void testCancelTaskSuccess() throws Exception { + saveTaskInTaskStore(CANCEL_TASK); try { CancelTaskRequest request = new CancelTaskRequest("1", new TaskIdParams(CANCEL_TASK.getId())); CancelTaskResponse response = given() @@ -188,13 +207,13 @@ public void testCancelTaskSuccess() { assertEquals(TaskState.CANCELED, task.getStatus().state()); } catch (Exception e) { } finally { - getTaskStore().delete(CANCEL_TASK.getId()); + deleteTaskInTaskStore(CANCEL_TASK.getId()); } } @Test - public void testCancelTaskNotSupported() { - getTaskStore().save(CANCEL_TASK_NOT_SUPPORTED); + public void testCancelTaskNotSupported() throws Exception { + saveTaskInTaskStore(CANCEL_TASK_NOT_SUPPORTED); try { CancelTaskRequest request = new CancelTaskRequest("1", new TaskIdParams(CANCEL_TASK_NOT_SUPPORTED.getId())); CancelTaskResponse response = given() @@ -213,7 +232,7 @@ public void testCancelTaskNotSupported() { assertEquals(new UnsupportedOperationError().getCode(), response.getError().getCode()); } catch (Exception e) { } finally { - getTaskStore().delete(CANCEL_TASK_NOT_SUPPORTED.getId()); + deleteTaskInTaskStore(CANCEL_TASK_NOT_SUPPORTED.getId()); } } @@ -238,8 +257,8 @@ public void testCancelTaskNotFound() { } @Test - public void testSendMessageNewMessageSuccess() { - assertTrue(getTaskStore().get(MINIMAL_TASK.getId()) == null); + public void testSendMessageNewMessageSuccess() throws Exception { + assertTrue(getTaskFromTaskStore(MINIMAL_TASK.getId()) == null); Message message = new Message.Builder(MESSAGE) .taskId(MINIMAL_TASK.getId()) .contextId(MINIMAL_TASK.getContextId()) @@ -264,8 +283,8 @@ public void testSendMessageNewMessageSuccess() { } @Test - public void testSendMessageExistingTaskSuccess() { - getTaskStore().save(MINIMAL_TASK); + public void testSendMessageExistingTaskSuccess() throws Exception { + saveTaskInTaskStore(MINIMAL_TASK); try { Message message = new Message.Builder(MESSAGE) .taskId(MINIMAL_TASK.getId()) @@ -290,13 +309,13 @@ public void testSendMessageExistingTaskSuccess() { assertEquals("test message", ((TextPart) part).getText()); } catch (Exception e) { } finally { - getTaskStore().delete(MINIMAL_TASK.getId()); + deleteTaskInTaskStore(MINIMAL_TASK.getId()); } } @Test - public void testSetPushNotificationSuccess() { - getTaskStore().save(MINIMAL_TASK); + public void testSetPushNotificationSuccess() throws Exception { + saveTaskInTaskStore(MINIMAL_TASK); try { TaskPushNotificationConfig taskPushConfig = new TaskPushNotificationConfig( @@ -318,13 +337,13 @@ public void testSetPushNotificationSuccess() { assertEquals("http://example.com", config.pushNotificationConfig().url()); } catch (Exception e) { } finally { - getTaskStore().delete(MINIMAL_TASK.getId()); + deleteTaskInTaskStore(MINIMAL_TASK.getId()); } } @Test - public void testGetPushNotificationSuccess() { - getTaskStore().save(MINIMAL_TASK); + public void testGetPushNotificationSuccess() throws Exception { + saveTaskInTaskStore(MINIMAL_TASK); try { TaskPushNotificationConfig taskPushConfig = new TaskPushNotificationConfig( @@ -360,7 +379,7 @@ public void testGetPushNotificationSuccess() { assertEquals("http://example.com", config.pushNotificationConfig().url()); } catch (Exception e) { } finally { - getTaskStore().delete(MINIMAL_TASK.getId()); + deleteTaskInTaskStore(MINIMAL_TASK.getId()); } } @@ -542,14 +561,14 @@ public void testInvalidJSONRPCRequestNonExistentMethod() { } @Test - public void testNonStreamingMethodWithAcceptHeader() { + public void testNonStreamingMethodWithAcceptHeader() throws Exception { testGetTask(MediaType.APPLICATION_JSON); } @Test - public void testSendMessageStreamExistingTaskSuccess() { - getTaskStore().save(MINIMAL_TASK); + public void testSendMessageStreamExistingTaskSuccess() throws Exception { + saveTaskInTaskStore(MINIMAL_TASK); try { Message message = new Message.Builder(MESSAGE) .taskId(MINIMAL_TASK.getId()) @@ -598,20 +617,20 @@ public void testSendMessageStreamExistingTaskSuccess() { Assertions.assertNull(errorRef.get()); } catch (Exception e) { } finally { - getTaskStore().delete(MINIMAL_TASK.getId()); + deleteTaskInTaskStore(MINIMAL_TASK.getId()); } } @Test public void testResubscribeExistingTaskSuccess() throws Exception { ExecutorService executorService = Executors.newSingleThreadExecutor(); - getTaskStore().save(MINIMAL_TASK); + saveTaskInTaskStore(MINIMAL_TASK); try { // attempting to send a streaming message instead of explicitly calling queueManager#createOrTap // does not work because after the message is sent, the queue becomes null but task resubscription // requires the queue to still be active - getQueueManager().createOrTap(MINIMAL_TASK.getId()); + ensureQueueForTask(MINIMAL_TASK.getId()); CountDownLatch taskResubscriptionRequestSent = new CountDownLatch(1); CountDownLatch taskResubscriptionResponseReceived = new CountDownLatch(2); @@ -622,7 +641,8 @@ public void testResubscribeExistingTaskSuccess() throws Exception { TaskResubscriptionRequest taskResubscriptionRequest = new TaskResubscriptionRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); // Count down the latch when the MultiSseSupport on the server has started subscribing - setStreamingSubscribedRunnable(taskResubscriptionRequestSent::countDown); + awaitStreamingSubscription() + .whenComplete((unused, throwable) -> taskResubscriptionRequestSent.countDown()); CompletableFuture>> responseFuture = initialiseStreamingRequest(taskResubscriptionRequest, null); @@ -631,7 +651,6 @@ public void testResubscribeExistingTaskSuccess() throws Exception { responseFuture.thenAccept(response -> { if (response.statusCode() != 200) { - //errorRef.set(new IllegalStateException("Status code was " + response.statusCode())); throw new IllegalStateException("Status code was " + response.statusCode()); } try { @@ -682,7 +701,7 @@ public void testResubscribeExistingTaskSuccess() throws Exception { .build()); for (Event event : events) { - getQueueManager().get(MINIMAL_TASK.getId()).enqueueEvent(event); + enqueueEventOnServer(event); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); @@ -710,8 +729,8 @@ public void testResubscribeExistingTaskSuccess() throws Exception { assertEquals(TaskState.COMPLETED, taskStatusUpdateEvent.getStatus().state()); assertNotNull(taskStatusUpdateEvent.getStatus().timestamp()); } finally { - setStreamingSubscribedRunnable(null); - getTaskStore().delete(MINIMAL_TASK.getId()); + //setStreamingSubscribedRunnable(null); + deleteTaskInTaskStore(MINIMAL_TASK.getId()); executorService.shutdown(); if (!executorService.awaitTermination(10, TimeUnit.SECONDS)) { executorService.shutdownNow(); @@ -860,10 +879,9 @@ private CompletableFuture>> initialiseStreamingReque // Create the request HttpRequest.Builder builder = HttpRequest.newBuilder() - .uri(URI.create("http://localhost:" + serverPort + - "/")) + .uri(URI.create("http://localhost:" + serverPort + "/")) .POST(HttpRequest.BodyPublishers.ofString(Utils.OBJECT_MAPPER.writeValueAsString(request))) - .header("Content-Type", "application/json"); + .header("Content-Type", APPLICATION_JSON); if (mediaType != null) { builder.header("Accept", mediaType); } @@ -874,11 +892,136 @@ private CompletableFuture>> initialiseStreamingReque return client.sendAsync(httpRequest, HttpResponse.BodyHandlers.ofLines()); } - protected abstract TaskStore getTaskStore(); + protected void saveTaskInTaskStore(Task task) throws Exception { + HttpClient client = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .build(); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + serverPort + "/test/task")) + .POST(HttpRequest.BodyPublishers.ofString(Utils.OBJECT_MAPPER.writeValueAsString(task))) + .header("Content-Type", APPLICATION_JSON) + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + if (response.statusCode() != 200) { + throw new RuntimeException(response.statusCode() + ": Creating task failed! " + response.body()); + } + } + + protected Task getTaskFromTaskStore(String taskId) throws Exception { + HttpClient client = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .build(); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + serverPort + "/test/task/" + taskId)) + .GET() + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + if (response.statusCode() == 404) { + return null; + } + if (response.statusCode() != 200) { + throw new RuntimeException(response.statusCode() + ": Creating task failed! " + response.body()); + } + return Utils.OBJECT_MAPPER.readValue(response.body(), Task.TYPE_REFERENCE); + } + + protected void deleteTaskInTaskStore(String taskId) throws Exception { + HttpClient client = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .build(); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(("http://localhost:" + serverPort + "/test/task/" + taskId))) + .DELETE() + .build(); + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + if (response.statusCode() != 200) { + throw new RuntimeException(response.statusCode() + ": Deleting task failed!" + response.body()); + } + } + + protected void ensureQueueForTask(String taskId) throws Exception { + HttpClient client = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .build(); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + serverPort + "/test/queue/ensure/" + taskId)) + .POST(HttpRequest.BodyPublishers.noBody()) + .build(); + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + if (response.statusCode() != 200) { + throw new RuntimeException(response.statusCode() + ": Deleting task failed!" + response.body()); + } + } + + protected void enqueueEventOnServer(Event event) throws Exception { + String path; + if (event instanceof TaskArtifactUpdateEvent e) { + path = "test/queue/enqueueTaskArtifactUpdateEvent/" + e.getTaskId(); + } else if (event instanceof TaskStatusUpdateEvent e) { + path = "test/queue/enqueueTaskStatusUpdateEvent/" + e.getTaskId(); + } else { + throw new RuntimeException("Unknown event type " + event.getClass() + ". If you need the ability to" + + " handle more types, please add the REST endpoints."); + } + HttpClient client = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .build(); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + serverPort + "/" + path)) + .header("Content-Type", APPLICATION_JSON) + .POST(HttpRequest.BodyPublishers.ofString(Utils.OBJECT_MAPPER.writeValueAsString(event))) + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + if (response.statusCode() != 200) { + throw new RuntimeException(response.statusCode() + ": Queueing event failed!" + response.body()); + } + } - protected abstract InMemoryQueueManager getQueueManager(); + private CompletableFuture awaitStreamingSubscription() { + int cnt = getStreamingSubscribedCount(); + AtomicInteger initialCount = new AtomicInteger(cnt); - protected abstract void setStreamingSubscribedRunnable(Runnable runnable); + return CompletableFuture.runAsync(() -> { + try { + boolean done = false; + long end = System.currentTimeMillis() + 15000; + while (System.currentTimeMillis() < end) { + int count = getStreamingSubscribedCount(); + if (count > initialCount.get()) { + done = true; + break; + } + Thread.sleep(500); + } + if (!done) { + throw new RuntimeException("Timed out waiting for subscription"); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted"); + } + }); + } + + private int getStreamingSubscribedCount() { + HttpClient client = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .build(); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + serverPort + "/test/streamingSubscribedCount")) + .GET() + .build(); + try { + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + String body = response.body().trim(); + return Integer.parseInt(body); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + } private static class BreakException extends RuntimeException { diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/TestUtilsBean.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestUtilsBean.java new file mode 100644 index 000000000..8cd333a55 --- /dev/null +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestUtilsBean.java @@ -0,0 +1,50 @@ +package io.a2a.server.apps.common; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.core.Response; + +import io.a2a.server.events.QueueManager; +import io.a2a.server.tasks.TaskStore; +import io.a2a.spec.Event; +import io.a2a.spec.Task; +import io.quarkus.arc.profile.IfBuildProfile; + +/** + * Contains utilities to interact with the server side for the tests. + * The intent for this bean is to be exposed via REST. + * + *

There is a Quarkus implementation in {@code A2ATestRoutes} which shows the contract for how to + * expose it via REST. For other REST frameworks, you will need to provide an implementation that works in a similar + * way to {@code A2ATestRoutes}.

+ */ +@ApplicationScoped +public class TestUtilsBean { + + @Inject + TaskStore taskStore; + + @Inject + QueueManager queueManager; + + public void saveTask(Task task) { + taskStore.save(task); + } + + public Task getTask(String taskId) { + return taskStore.get(taskId); + } + + public void deleteTask(String taskId) { + taskStore.delete(taskId); + } + + public void ensureQueue(String taskId) { + queueManager.createOrTap(taskId); + } + + public void enqueueEvent(String taskId, Event event) { + queueManager.get(taskId).enqueueEvent(event); + } +}