diff --git a/.github/workflows/run-tck.yml b/.github/workflows/run-tck.yml index a36f5dc6c..545c1b495 100644 --- a/.github/workflows/run-tck.yml +++ b/.github/workflows/run-tck.yml @@ -17,6 +17,8 @@ env: UV_SYSTEM_PYTHON: 1 # SUT_JSONRPC_URL to use for the TCK and the server agent SUT_JSONRPC_URL: http://localhost:9999 + # Slow system on CI + TCK_STREAMING_TIMEOUT: 5.0 # Only run the latest job concurrency: @@ -55,7 +57,7 @@ jobs: - name: Build with Maven, skipping tests run: mvn -B install -DskipTests - name: Start SUT - run: SUT_GRPC_URL=${{ env.SUT_JSONRPC_URL }} mvn -B quarkus:dev & #SUT_JSONRPC_URL already set + run: SUT_GRPC_URL=${{ env.SUT_JSONRPC_URL }} SUT_REST_URL=${{ env.SUT_JSONRPC_URL }} mvn -B quarkus:dev & #SUT_JSONRPC_URL already set working-directory: tck - name: Wait for SUT to start run: | @@ -93,5 +95,5 @@ jobs: - name: Run TCK run: | - ./run_tck.py --sut-url ${{ env.SUT_JSONRPC_URL }} --category all --transports jsonrpc,grpc --compliance-report report.json + ./run_tck.py --sut-url ${{ env.SUT_JSONRPC_URL }} --category all --transports jsonrpc,grpc,rest --compliance-report report.json working-directory: tck/a2a-tck diff --git a/.gitignore b/.gitignore index d8908d914..04650afea 100644 --- a/.gitignore +++ b/.gitignore @@ -44,4 +44,3 @@ nb-configuration.xml # TLS Certificates .certs/ nbproject/ - diff --git a/README.md b/README.md index 0f0d7b48e..bb50eba9c 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,18 @@ To use the reference implementation with the gRPC protocol, add the following de Note that you can add more than one of the above dependencies to your project depending on the transports you'd like to support. -Support for the HTTP+JSON/REST transport will be coming soon. +To use the reference implementation with the HTTP+JSON/REST protocol, add the following dependency to your project: + +> *⚠️ The `io.github.a2asdk` `groupId` below is temporary and will likely change for future releases.* + +```xml + + io.github.a2asdk + a2a-java-sdk-reference-rest + + ${io.a2a.sdk.version} + +``` ### 2. Add a class that creates an A2A Agent Card @@ -117,7 +128,7 @@ public class WeatherAgentCardProducer { .tags(Collections.singletonList("weather")) .examples(List.of("weather in LA, CA")) .build())) - .protocolVersion("0.2.5") + .protocolVersion("0.3.0") .build(); } } @@ -247,7 +258,7 @@ By default, the sdk-client is coming with the JSONRPC transport dependency. Desp dependency is included by default, you still need to add the transport to the Client as described in [JSON-RPC Transport section](#json-rpc-transport-configuration). -If you want to use another transport (such as GRPC or HTTP+JSON), you'll need to add a relevant dependency: +If you want to use the gRPC transport, you'll need to add a relevant dependency: ---- > *⚠️ The `io.github.a2asdk` `groupId` below is temporary and will likely change for future releases.* @@ -262,7 +273,21 @@ If you want to use another transport (such as GRPC or HTTP+JSON), you'll need to ``` -Support for the HTTP+JSON/REST transport will be coming soon. + +If you want to use the HTTP+JSON/REST transport, you'll need to add a relevant dependency: + +---- +> *⚠️ The `io.github.a2asdk` `groupId` below is temporary and will likely change for future releases.* +---- + +```xml + + io.github.a2asdk + a2a-java-sdk-client-transport-rest + + ${io.a2a.sdk.version} + +``` ### Sample Usage @@ -360,6 +385,29 @@ Client client = Client .build(); ``` + +##### HTTP+JSON/REST Transport Configuration + +For the HTTP+JSON/REST transport, if you'd like to use the default `JdkA2AHttpClient`, provide a `RestTransportConfig` created with its default constructor. + +To use a custom HTTP client implementation, simply create a `RestTransportConfig` as follows: + +```java +// Create a custom HTTP client +A2AHttpClient customHttpClient = ... + +// Configure the client settings +ClientConfig clientConfig = new ClientConfig.Builder() + .setAcceptedOutputModes(List.of("text")) + .build(); + +Client client = Client + .builder(agentCard) + .clientConfig(clientConfig) + .withTransport(RestTransport.class, new RestTransportConfig(customHttpClient)) + .build(); +``` + ##### Multiple Transport Configurations You can specify configuration for multiple transports, the appropriate configuration @@ -371,6 +419,7 @@ Client client = Client .builder(agentCard) .withTransport(GrpcTransport.class, new GrpcTransportConfig(channelFactory)) .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfig()) + .withTransport(RestTransport.class, new RestTransportConfig()) .build(); ``` diff --git a/client/base/pom.xml b/client/base/pom.xml index 6df47c1b7..c5b45b583 100644 --- a/client/base/pom.xml +++ b/client/base/pom.xml @@ -35,6 +35,11 @@ a2a-java-sdk-client-transport-grpc test + + ${project.groupId} + a2a-java-sdk-client-transport-rest + test + ${project.groupId} a2a-java-sdk-common @@ -54,6 +59,11 @@ mockserver-netty test + + org.slf4j + slf4j-jdk14 + test + \ No newline at end of file diff --git a/client/transport/grpc/pom.xml b/client/transport/grpc/pom.xml index b910d6ac7..158ae7ea8 100644 --- a/client/transport/grpc/pom.xml +++ b/client/transport/grpc/pom.xml @@ -33,6 +33,14 @@ ${project.groupId} a2a-java-sdk-client-transport-spi + + io.grpc + grpc-protobuf + + + io.grpc + grpc-stub + org.junit.jupiter junit-jupiter-api diff --git a/client/transport/jsonrpc/src/test/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportTest.java b/client/transport/jsonrpc/src/test/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportTest.java index 99e5ef151..25de32947 100644 --- a/client/transport/jsonrpc/src/test/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportTest.java +++ b/client/transport/jsonrpc/src/test/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportTest.java @@ -2,7 +2,6 @@ import static io.a2a.client.transport.jsonrpc.JsonMessages.AGENT_CARD; import static io.a2a.client.transport.jsonrpc.JsonMessages.AGENT_CARD_SUPPORTS_EXTENDED; -import static io.a2a.client.transport.jsonrpc.JsonMessages.AUTHENTICATION_EXTENDED_AGENT_CARD; import static io.a2a.client.transport.jsonrpc.JsonMessages.CANCEL_TASK_TEST_REQUEST; import static io.a2a.client.transport.jsonrpc.JsonMessages.CANCEL_TASK_TEST_RESPONSE; import static io.a2a.client.transport.jsonrpc.JsonMessages.GET_AUTHENTICATED_EXTENDED_AGENT_CARD_REQUEST; @@ -50,7 +49,6 @@ import io.a2a.spec.FilePart; import io.a2a.spec.FileWithBytes; import io.a2a.spec.FileWithUri; -import io.a2a.spec.GetAuthenticatedExtendedCardResponse; import io.a2a.spec.GetTaskPushNotificationConfigParams; import io.a2a.spec.Message; import io.a2a.spec.MessageSendConfiguration; diff --git a/client/transport/rest/pom.xml b/client/transport/rest/pom.xml new file mode 100644 index 000000000..5d3be70ae --- /dev/null +++ b/client/transport/rest/pom.xml @@ -0,0 +1,62 @@ + + + 4.0.0 + + + io.github.a2asdk + a2a-java-sdk-parent + 0.3.0.Beta1-SNAPSHOT + ../../../pom.xml + + a2a-java-sdk-client-transport-rest + jar + + Java SDK A2A Client Transport: JSON+HTTP/REST + Java SDK for the Agent2Agent Protocol (A2A) - JSON+HTTP/REST Client Transport + + + + ${project.groupId} + a2a-java-sdk-common + + + ${project.groupId} + a2a-java-sdk-spec + + + ${project.groupId} + a2a-java-sdk-spec-grpc + + + ${project.groupId} + a2a-java-sdk-client-transport-spi + + + io.github.a2asdk + a2a-java-sdk-http-client + + + com.google.protobuf + protobuf-java-util + + + org.junit.jupiter + junit-jupiter-api + test + + + + org.mock-server + mockserver-netty + test + + + org.slf4j + slf4j-jdk14 + test + + + + \ No newline at end of file diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestErrorMapper.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestErrorMapper.java new file mode 100644 index 000000000..103f7bb0c --- /dev/null +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestErrorMapper.java @@ -0,0 +1,80 @@ +package io.a2a.client.transport.rest; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import io.a2a.client.http.A2AHttpResponse; +import io.a2a.spec.A2AClientException; +import io.a2a.spec.AuthenticatedExtendedCardNotConfiguredError; +import io.a2a.spec.ContentTypeNotSupportedError; +import io.a2a.spec.InternalError; +import io.a2a.spec.InvalidAgentResponseError; +import io.a2a.spec.InvalidParamsError; +import io.a2a.spec.InvalidRequestError; +import io.a2a.spec.JSONParseError; +import io.a2a.spec.MethodNotFoundError; +import io.a2a.spec.PushNotificationNotSupportedError; +import io.a2a.spec.TaskNotCancelableError; +import io.a2a.spec.TaskNotFoundError; +import io.a2a.spec.UnsupportedOperationError; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Utility class to A2AHttpResponse to appropriate A2A error types + */ +public class RestErrorMapper { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper().registerModule(new JavaTimeModule()); + + public static A2AClientException mapRestError(A2AHttpResponse response) { + return RestErrorMapper.mapRestError(response.body(), response.status()); + } + + public static A2AClientException mapRestError(String body, int code) { + try { + if (body != null && !body.isBlank()) { + JsonNode node = OBJECT_MAPPER.readTree(body); + String className = node.findValue("error").asText(); + String errorMessage = node.findValue("message").asText(); + return mapRestError(className, errorMessage, code); + } + return mapRestError("", "", code); + } catch (JsonProcessingException ex) { + Logger.getLogger(RestErrorMapper.class.getName()).log(Level.SEVERE, null, ex); + return new A2AClientException("Failed to parse error response: " + ex.getMessage()); + } + } + + public static A2AClientException mapRestError(String className, String errorMessage, int code) { + switch (className) { + case "io.a2a.spec.TaskNotFoundError": + return new A2AClientException(errorMessage, new TaskNotFoundError()); + case "io.a2a.spec.AuthenticatedExtendedCardNotConfiguredError": + return new A2AClientException(errorMessage, new AuthenticatedExtendedCardNotConfiguredError()); + case "io.a2a.spec.ContentTypeNotSupportedError": + return new A2AClientException(errorMessage, new ContentTypeNotSupportedError(null, null, errorMessage)); + case "io.a2a.spec.InternalError": + return new A2AClientException(errorMessage, new InternalError(errorMessage)); + case "io.a2a.spec.InvalidAgentResponseError": + return new A2AClientException(errorMessage, new InvalidAgentResponseError(null, null, errorMessage)); + case "io.a2a.spec.InvalidParamsError": + return new A2AClientException(errorMessage, new InvalidParamsError()); + case "io.a2a.spec.InvalidRequestError": + return new A2AClientException(errorMessage, new InvalidRequestError()); + case "io.a2a.spec.JSONParseError": + return new A2AClientException(errorMessage, new JSONParseError()); + case "io.a2a.spec.MethodNotFoundError": + return new A2AClientException(errorMessage, new MethodNotFoundError()); + case "io.a2a.spec.PushNotificationNotSupportedError": + return new A2AClientException(errorMessage, new PushNotificationNotSupportedError()); + case "io.a2a.spec.TaskNotCancelableError": + return new A2AClientException(errorMessage, new TaskNotCancelableError()); + case "io.a2a.spec.UnsupportedOperationError": + return new A2AClientException(errorMessage, new UnsupportedOperationError()); + default: + return new A2AClientException(errorMessage); + } + } +} diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransport.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransport.java new file mode 100644 index 000000000..5420600d4 --- /dev/null +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransport.java @@ -0,0 +1,389 @@ +package io.a2a.client.transport.rest; + +import static io.a2a.util.Assert.checkNotNullParam; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.MessageOrBuilder; +import com.google.protobuf.util.JsonFormat; +import io.a2a.client.http.A2ACardResolver; +import io.a2a.client.http.A2AHttpClient; +import io.a2a.client.http.A2AHttpResponse; +import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.transport.rest.sse.RestSSEEventListener; +import io.a2a.client.transport.spi.ClientTransport; +import io.a2a.client.transport.spi.interceptors.ClientCallContext; +import io.a2a.client.transport.spi.interceptors.ClientCallInterceptor; +import io.a2a.client.transport.spi.interceptors.PayloadAndHeaders; +import io.a2a.grpc.CancelTaskRequest; +import io.a2a.grpc.CreateTaskPushNotificationConfigRequest; +import io.a2a.grpc.GetTaskPushNotificationConfigRequest; +import io.a2a.grpc.GetTaskRequest; +import io.a2a.grpc.ListTaskPushNotificationConfigRequest; +import io.a2a.spec.TaskPushNotificationConfig; +import io.a2a.spec.A2AClientException; +import io.a2a.spec.AgentCard; +import io.a2a.spec.DeleteTaskPushNotificationConfigParams; +import io.a2a.spec.EventKind; +import io.a2a.spec.GetTaskPushNotificationConfigParams; +import io.a2a.spec.ListTaskPushNotificationConfigParams; +import io.a2a.spec.MessageSendParams; +import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.Task; +import io.a2a.spec.TaskIdParams; +import io.a2a.spec.TaskQueryParams; +import io.a2a.grpc.utils.ProtoUtils; +import io.a2a.spec.A2AClientError; +import io.a2a.spec.SendStreamingMessageRequest; +import io.a2a.spec.SetTaskPushNotificationConfigRequest; +import io.a2a.util.Utils; +import java.io.IOException; +import java.util.List; +import java.util.logging.Logger; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +public class RestTransport implements ClientTransport { + + private static final Logger log = Logger.getLogger(RestTransport.class.getName()); + private final A2AHttpClient httpClient; + private final String agentUrl; + private final List interceptors; + private AgentCard agentCard; + private boolean needsExtendedCard = false; + + public RestTransport(String agentUrl) { + this(null, null, agentUrl, null); + } + + public RestTransport(AgentCard agentCard) { + this(null, agentCard, agentCard.url(), null); + } + + public RestTransport(A2AHttpClient httpClient, AgentCard agentCard, + String agentUrl, List interceptors) { + this.httpClient = httpClient == null ? new JdkA2AHttpClient() : httpClient; + this.agentCard = agentCard; + this.agentUrl = agentUrl.endsWith("/") ? agentUrl.substring(0, agentUrl.length() - 1) : agentUrl; + this.interceptors = interceptors; + } + + @Override + public EventKind sendMessage(MessageSendParams messageSendParams, ClientCallContext context) throws A2AClientException { + checkNotNullParam("messageSendParams", messageSendParams); + io.a2a.grpc.SendMessageRequest.Builder builder = io.a2a.grpc.SendMessageRequest.newBuilder(ProtoUtils.ToProto.sendMessageRequest(messageSendParams)); + PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.SendMessageRequest.METHOD, builder, agentCard, context); + try { + String httpResponseBody = sendPostRequest(agentUrl + "/v1/message:send", payloadAndHeaders); + io.a2a.grpc.SendMessageResponse.Builder responseBuilder = io.a2a.grpc.SendMessageResponse.newBuilder(); + JsonFormat.parser().merge(httpResponseBody, responseBuilder); + if (responseBuilder.hasMsg()) { + return ProtoUtils.FromProto.message(responseBuilder.getMsg()); + } + if (responseBuilder.hasTask()) { + return ProtoUtils.FromProto.task(responseBuilder.getTask()); + } + throw new A2AClientException("Failed to send message, wrong response:" + httpResponseBody); + } catch (A2AClientException e) { + throw e; + } catch (IOException | InterruptedException e) { + throw new A2AClientException("Failed to send message: " + e, e); + } + } + + @Override + public void sendMessageStreaming(MessageSendParams messageSendParams, Consumer eventConsumer, Consumer errorConsumer, ClientCallContext context) throws A2AClientException { + checkNotNullParam("request", messageSendParams); + checkNotNullParam("eventConsumer", eventConsumer); + checkNotNullParam("messageSendParams", messageSendParams); + io.a2a.grpc.SendMessageRequest.Builder builder = io.a2a.grpc.SendMessageRequest.newBuilder(ProtoUtils.ToProto.sendMessageRequest(messageSendParams)); + PayloadAndHeaders payloadAndHeaders = applyInterceptors(SendStreamingMessageRequest.METHOD, + builder, agentCard, context); + AtomicReference> ref = new AtomicReference<>(); + RestSSEEventListener sseEventListener = new RestSSEEventListener(eventConsumer, errorConsumer); + try { + A2AHttpClient.PostBuilder postBuilder = createPostBuilder(agentUrl + "/v1/message:stream", payloadAndHeaders); + ref.set(postBuilder.postAsyncSSE( + msg -> sseEventListener.onMessage(msg, ref.get()), + throwable -> sseEventListener.onError(throwable, ref.get()), + () -> { + // We don't need to do anything special on completion + })); + } catch (IOException e) { + throw new A2AClientException("Failed to send streaming message request: " + e, e); + } catch (InterruptedException e) { + throw new A2AClientException("Send streaming message request timed out: " + e, e); + } + } + + @Override + public Task getTask(TaskQueryParams taskQueryParams, ClientCallContext context) throws A2AClientException { + checkNotNullParam("taskQueryParams", taskQueryParams); + GetTaskRequest.Builder builder = GetTaskRequest.newBuilder(); + builder.setName("tasks/" + taskQueryParams.id()); + PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskRequest.METHOD, builder, + agentCard, context); + try { + String url; + if (taskQueryParams.historyLength() != null) { + url = agentUrl + String.format("/v1/tasks/%1s?historyLength=%2d", taskQueryParams.id(), taskQueryParams.historyLength()); + } else { + url = agentUrl + String.format("/v1/tasks/%1s", taskQueryParams.id()); + } + A2AHttpClient.GetBuilder getBuilder = httpClient.createGet().url(url); + if (payloadAndHeaders.getHeaders() != null) { + for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { + getBuilder.addHeader(entry.getKey(), entry.getValue()); + } + } + A2AHttpResponse response = getBuilder.get(); + if (!response.success()) { + throw RestErrorMapper.mapRestError(response); + } + String httpResponseBody = response.body(); + io.a2a.grpc.Task.Builder responseBuilder = io.a2a.grpc.Task.newBuilder(); + JsonFormat.parser().merge(httpResponseBody, responseBuilder); + return ProtoUtils.FromProto.task(responseBuilder); + } catch (A2AClientException e) { + throw e; + } catch (IOException | InterruptedException e) { + throw new A2AClientException("Failed to get task: " + e, e); + } + } + + @Override + public Task cancelTask(TaskIdParams taskIdParams, ClientCallContext context) throws A2AClientException { + checkNotNullParam("taskIdParams", taskIdParams); + CancelTaskRequest.Builder builder = CancelTaskRequest.newBuilder(); + builder.setName("tasks/" + taskIdParams.id()); + PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.CancelTaskRequest.METHOD, builder, + agentCard, context); + try { + String httpResponseBody = sendPostRequest(agentUrl + String.format("/v1/tasks/%1s:cancel", taskIdParams.id()), payloadAndHeaders); + io.a2a.grpc.Task.Builder responseBuilder = io.a2a.grpc.Task.newBuilder(); + JsonFormat.parser().merge(httpResponseBody, responseBuilder); + return ProtoUtils.FromProto.task(responseBuilder); + } catch (A2AClientException e) { + throw e; + } catch (IOException | InterruptedException e) { + throw new A2AClientException("Failed to cancel task: " + e, e); + } + } + + @Override + public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushNotificationConfig request, ClientCallContext context) throws A2AClientException { + checkNotNullParam("request", request); + CreateTaskPushNotificationConfigRequest.Builder builder = CreateTaskPushNotificationConfigRequest.newBuilder(); + builder.setConfig(ProtoUtils.ToProto.taskPushNotificationConfig(request)) + .setParent("tasks/" + request.taskId()); + if (request.pushNotificationConfig().id() != null) { + builder.setConfigId(request.pushNotificationConfig().id()); + } + PayloadAndHeaders payloadAndHeaders = applyInterceptors(SetTaskPushNotificationConfigRequest.METHOD, builder, agentCard, context); + try { + String httpResponseBody = sendPostRequest(agentUrl + String.format("/v1/tasks/%1s/pushNotificationConfigs", request.taskId()), payloadAndHeaders); + io.a2a.grpc.TaskPushNotificationConfig.Builder responseBuilder = io.a2a.grpc.TaskPushNotificationConfig.newBuilder(); + JsonFormat.parser().merge(httpResponseBody, responseBuilder); + return ProtoUtils.FromProto.taskPushNotificationConfig(responseBuilder); + } catch (A2AClientException e) { + throw e; + } catch (IOException | InterruptedException e) { + throw new A2AClientException("Failed to set task push notification config: " + e, e); + } + } + + @Override + public TaskPushNotificationConfig getTaskPushNotificationConfiguration(GetTaskPushNotificationConfigParams request, ClientCallContext context) throws A2AClientException { + checkNotNullParam("request", request); + GetTaskPushNotificationConfigRequest.Builder builder = GetTaskPushNotificationConfigRequest.newBuilder(); + builder.setName(String.format("/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId())); + PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskPushNotificationConfigRequest.METHOD, builder, + agentCard, context); + try { + String url = agentUrl + String.format("/v1/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId()); + A2AHttpClient.GetBuilder getBuilder = httpClient.createGet().url(url); + if (payloadAndHeaders.getHeaders() != null) { + for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { + getBuilder.addHeader(entry.getKey(), entry.getValue()); + } + } + A2AHttpResponse response = getBuilder.get(); + if (!response.success()) { + throw RestErrorMapper.mapRestError(response); + } + String httpResponseBody = response.body(); + io.a2a.grpc.TaskPushNotificationConfig.Builder responseBuilder = io.a2a.grpc.TaskPushNotificationConfig.newBuilder(); + JsonFormat.parser().merge(httpResponseBody, responseBuilder); + return ProtoUtils.FromProto.taskPushNotificationConfig(responseBuilder); + } catch (A2AClientException e) { + throw e; + } catch (IOException | InterruptedException e) { + throw new A2AClientException("Failed to get push notifications: " + e, e); + } + } + + @Override + public List listTaskPushNotificationConfigurations(ListTaskPushNotificationConfigParams request, ClientCallContext context) throws A2AClientException { + checkNotNullParam("request", request); + ListTaskPushNotificationConfigRequest.Builder builder = ListTaskPushNotificationConfigRequest.newBuilder(); + builder.setParent(String.format("/tasks/%1s/pushNotificationConfigs", request.id())); + PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.ListTaskPushNotificationConfigRequest.METHOD, builder, + agentCard, context); + try { + String url = agentUrl + String.format("/v1/tasks/%1s/pushNotificationConfigs", request.id()); + A2AHttpClient.GetBuilder getBuilder = httpClient.createGet().url(url); + if (payloadAndHeaders.getHeaders() != null) { + for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { + getBuilder.addHeader(entry.getKey(), entry.getValue()); + } + } + A2AHttpResponse response = getBuilder.get(); + if (!response.success()) { + throw RestErrorMapper.mapRestError(response); + } + String httpResponseBody = response.body(); + io.a2a.grpc.ListTaskPushNotificationConfigResponse.Builder responseBuilder = io.a2a.grpc.ListTaskPushNotificationConfigResponse.newBuilder(); + JsonFormat.parser().merge(httpResponseBody, responseBuilder); + return ProtoUtils.FromProto.listTaskPushNotificationConfigParams(responseBuilder); + } catch (A2AClientException e) { + throw e; + } catch (IOException | InterruptedException e) { + throw new A2AClientException("Failed to list push notifications: " + e, e); + } + } + + @Override + public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationConfigParams request, ClientCallContext context) throws A2AClientException { + checkNotNullParam("request", request); + io.a2a.grpc.DeleteTaskPushNotificationConfigRequestOrBuilder builder = io.a2a.grpc.DeleteTaskPushNotificationConfigRequest.newBuilder(); + PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.DeleteTaskPushNotificationConfigRequest.METHOD, builder, + agentCard, context); + try { + String url = agentUrl + String.format("/v1/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId()); + A2AHttpClient.DeleteBuilder deleteBuilder = httpClient.createDelete().url(url); + if (payloadAndHeaders.getHeaders() != null) { + for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { + deleteBuilder.addHeader(entry.getKey(), entry.getValue()); + } + } + A2AHttpResponse response = deleteBuilder.delete(); + if (!response.success()) { + throw RestErrorMapper.mapRestError(response); + } + } catch (A2AClientException e) { + throw e; + } catch (IOException | InterruptedException e) { + throw new A2AClientException("Failed to delete push notification config: " + e, e); + } + } + + @Override + public void resubscribe(TaskIdParams request, Consumer eventConsumer, + Consumer errorConsumer, ClientCallContext context) throws A2AClientException { + checkNotNullParam("request", request); + io.a2a.grpc.TaskSubscriptionRequest.Builder builder = io.a2a.grpc.TaskSubscriptionRequest.newBuilder(); + builder.setName("tasks/" + request.id()); + PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.TaskResubscriptionRequest.METHOD, builder, + agentCard, context); + AtomicReference> ref = new AtomicReference<>(); + RestSSEEventListener sseEventListener = new RestSSEEventListener(eventConsumer, errorConsumer); + try { + String url = agentUrl + String.format("/v1/tasks/%1s:subscribe", request.id()); + A2AHttpClient.PostBuilder postBuilder = createPostBuilder(url, payloadAndHeaders); + ref.set(postBuilder.postAsyncSSE( + msg -> sseEventListener.onMessage(msg, ref.get()), + throwable -> sseEventListener.onError(throwable, ref.get()), + () -> { + // We don't need to do anything special on completion + })); + } catch (IOException e) { + throw new A2AClientException("Failed to send streaming message request: " + e, e); + } catch (InterruptedException e) { + throw new A2AClientException("Send streaming message request timed out: " + e, e); + } + } + + @Override + public AgentCard getAgentCard(ClientCallContext context) throws A2AClientException { + A2ACardResolver resolver; + try { + if (agentCard == null) { + resolver = new A2ACardResolver(httpClient, agentUrl, null, getHttpHeaders(context)); + agentCard = resolver.getAgentCard(); + needsExtendedCard = agentCard.supportsAuthenticatedExtendedCard(); + } + if (!needsExtendedCard) { + return agentCard; + } + PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskRequest.METHOD, null, + agentCard, context); + String url = agentUrl + String.format("/v1/card"); + A2AHttpClient.GetBuilder getBuilder = httpClient.createGet().url(url); + if (payloadAndHeaders.getHeaders() != null) { + for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { + getBuilder.addHeader(entry.getKey(), entry.getValue()); + } + } + A2AHttpResponse response = getBuilder.get(); + if (!response.success()) { + throw RestErrorMapper.mapRestError(response); + } + String httpResponseBody = response.body(); + agentCard = Utils.OBJECT_MAPPER.readValue(httpResponseBody, AgentCard.class); + needsExtendedCard = false; + return agentCard; + } catch (IOException | InterruptedException e) { + throw new A2AClientException("Failed to get authenticated extended agent card: " + e, e); + } catch (A2AClientError e) { + throw new A2AClientException("Failed to get agent card: " + e, e); + } + } + + @Override + public void close() { + // no-op + } + + private PayloadAndHeaders applyInterceptors(String methodName, MessageOrBuilder payload, + AgentCard agentCard, ClientCallContext clientCallContext) { + PayloadAndHeaders payloadAndHeaders = new PayloadAndHeaders(payload, getHttpHeaders(clientCallContext)); + if (interceptors != null && !interceptors.isEmpty()) { + for (ClientCallInterceptor interceptor : interceptors) { + payloadAndHeaders = interceptor.intercept(methodName, payloadAndHeaders.getPayload(), + payloadAndHeaders.getHeaders(), agentCard, clientCallContext); + } + } + return payloadAndHeaders; + } + + private String sendPostRequest(String url, PayloadAndHeaders payloadAndHeaders) throws IOException, InterruptedException { + A2AHttpClient.PostBuilder builder = createPostBuilder(url, payloadAndHeaders); + A2AHttpResponse response = builder.post(); + if (!response.success()) { + log.fine("Error on POST processing " + JsonFormat.printer().print((MessageOrBuilder) payloadAndHeaders.getPayload())); + throw RestErrorMapper.mapRestError(response); + } + return response.body(); + } + + private A2AHttpClient.PostBuilder createPostBuilder(String url, PayloadAndHeaders payloadAndHeaders) throws JsonProcessingException, InvalidProtocolBufferException { + log.fine(JsonFormat.printer().print((MessageOrBuilder) payloadAndHeaders.getPayload())); + A2AHttpClient.PostBuilder postBuilder = httpClient.createPost() + .url(url) + .addHeader("Content-Type", "application/json") + .body(JsonFormat.printer().print((MessageOrBuilder) payloadAndHeaders.getPayload())); + + if (payloadAndHeaders.getHeaders() != null) { + for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { + postBuilder.addHeader(entry.getKey(), entry.getValue()); + } + } + return postBuilder; + } + + private Map getHttpHeaders(ClientCallContext context) { + return context != null ? context.getHeaders() : null; + } +} diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfig.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfig.java new file mode 100644 index 000000000..bbb583a1b --- /dev/null +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfig.java @@ -0,0 +1,21 @@ +package io.a2a.client.transport.rest; + +import io.a2a.client.http.A2AHttpClient; +import io.a2a.client.transport.spi.ClientTransportConfig; + +public class RestTransportConfig extends ClientTransportConfig { + + private final A2AHttpClient httpClient; + + public RestTransportConfig() { + this.httpClient = null; + } + + public RestTransportConfig(A2AHttpClient httpClient) { + this.httpClient = httpClient; + } + + public A2AHttpClient getHttpClient() { + return httpClient; + } +} \ No newline at end of file diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfigBuilder.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfigBuilder.java new file mode 100644 index 000000000..30f8b412a --- /dev/null +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfigBuilder.java @@ -0,0 +1,28 @@ +package io.a2a.client.transport.rest; + +import io.a2a.client.http.A2AHttpClient; +import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.transport.spi.ClientTransportConfigBuilder; + +public class RestTransportConfigBuilder extends ClientTransportConfigBuilder { + + private A2AHttpClient httpClient; + + public RestTransportConfigBuilder httpClient(A2AHttpClient httpClient) { + this.httpClient = httpClient; + + return this; + } + + @Override + public RestTransportConfig build() { + // No HTTP client provided, fallback to the default one (JDK-based implementation) + if (httpClient == null) { + httpClient = new JdkA2AHttpClient(); + } + + RestTransportConfig config = new RestTransportConfig(httpClient); + config.setInterceptors(this.interceptors); + return config; + } +} \ No newline at end of file diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportProvider.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportProvider.java new file mode 100644 index 000000000..99d155968 --- /dev/null +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportProvider.java @@ -0,0 +1,29 @@ +package io.a2a.client.transport.rest; + +import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.transport.spi.ClientTransportProvider; +import io.a2a.spec.A2AClientException; +import io.a2a.spec.AgentCard; +import io.a2a.spec.TransportProtocol; + +public class RestTransportProvider implements ClientTransportProvider { + + @Override + public String getTransportProtocol() { + return TransportProtocol.HTTP_JSON.asString(); + } + + @Override + public RestTransport create(RestTransportConfig clientTransportConfig, AgentCard agentCard, String agentUrl) throws A2AClientException { + RestTransportConfig transportConfig = clientTransportConfig; + if (transportConfig == null) { + transportConfig = new RestTransportConfig(new JdkA2AHttpClient()); + } + return new RestTransport(clientTransportConfig.getHttpClient(), agentCard, agentUrl, transportConfig.getInterceptors()); + } + + @Override + public Class getTransportProtocolClass() { + return RestTransport.class; + } +} diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java new file mode 100644 index 000000000..c34d615cb --- /dev/null +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java @@ -0,0 +1,72 @@ +package io.a2a.client.transport.rest.sse; + +import static io.a2a.grpc.StreamResponse.PayloadCase.ARTIFACT_UPDATE; +import static io.a2a.grpc.StreamResponse.PayloadCase.MSG; +import static io.a2a.grpc.StreamResponse.PayloadCase.STATUS_UPDATE; +import static io.a2a.grpc.StreamResponse.PayloadCase.TASK; + +import java.util.concurrent.Future; +import java.util.function.Consumer; +import java.util.logging.Logger; + +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.util.JsonFormat; +import io.a2a.client.transport.rest.RestErrorMapper; +import io.a2a.grpc.StreamResponse; +import io.a2a.grpc.utils.ProtoUtils; +import io.a2a.spec.StreamingEventKind; + +public class RestSSEEventListener { + + private static final Logger log = Logger.getLogger(RestSSEEventListener.class.getName()); + private final Consumer eventHandler; + private final Consumer errorHandler; + + public RestSSEEventListener(Consumer eventHandler, + Consumer errorHandler) { + this.eventHandler = eventHandler; + this.errorHandler = errorHandler; + } + + public void onMessage(String message, Future completableFuture) { + try { + System.out.println("Streaming message received: " + message); + io.a2a.grpc.StreamResponse.Builder builder = io.a2a.grpc.StreamResponse.newBuilder(); + JsonFormat.parser().merge(message, builder); + handleMessage(builder.build(), completableFuture); + } catch (InvalidProtocolBufferException e) { + errorHandler.accept(RestErrorMapper.mapRestError(message, 500)); + } + } + + public void onError(Throwable throwable, Future future) { + if (errorHandler != null) { + errorHandler.accept(throwable); + } + future.cancel(true); // close SSE channel + } + + private void handleMessage(StreamResponse response, Future future) { + StreamingEventKind event; + switch (response.getPayloadCase()) { + case MSG: + event = ProtoUtils.FromProto.message(response.getMsg()); + break; + case TASK: + event = ProtoUtils.FromProto.task(response.getTask()); + break; + case STATUS_UPDATE: + event = ProtoUtils.FromProto.taskStatusUpdateEvent(response.getStatusUpdate()); + break; + case ARTIFACT_UPDATE: + event = ProtoUtils.FromProto.taskArtifactUpdateEvent(response.getArtifactUpdate()); + break; + default: + log.warning("Invalid stream response " + response.getPayloadCase()); + errorHandler.accept(new IllegalStateException("Invalid stream response from server: " + response.getPayloadCase())); + return; + } + eventHandler.accept(event); + } + +} diff --git a/client/transport/rest/src/main/resources/META-INF/services/io.a2a.client.transport.spi.ClientTransportProvider b/client/transport/rest/src/main/resources/META-INF/services/io.a2a.client.transport.spi.ClientTransportProvider new file mode 100644 index 000000000..894866aab --- /dev/null +++ b/client/transport/rest/src/main/resources/META-INF/services/io.a2a.client.transport.spi.ClientTransportProvider @@ -0,0 +1 @@ +io.a2a.client.transport.rest.RestTransportProvider \ No newline at end of file diff --git a/client/transport/rest/src/test/java/io/a2a/client/transport/rest/JsonRestMessages.java b/client/transport/rest/src/test/java/io/a2a/client/transport/rest/JsonRestMessages.java new file mode 100644 index 000000000..e94fcca1f --- /dev/null +++ b/client/transport/rest/src/test/java/io/a2a/client/transport/rest/JsonRestMessages.java @@ -0,0 +1,654 @@ +package io.a2a.client.transport.rest; + +/** + * Request and response messages used by the tests. These have been created following examples from + * the A2A sample messages. + */ +public class JsonRestMessages { + + static final String SEND_MESSAGE_TEST_REQUEST = """ + { + "message": + { + "messageId": "message-1234", + "contextId": "context-1234", + "role": "ROLE_USER", + "content": [{ + "text": "tell me a joke" + }], + "metadata": { + } + } + }"""; + + static final String SEND_MESSAGE_TEST_RESPONSE = """ + { + "task": { + "id": "9b511af4-b27c-47fa-aecf-2a93c08a44f8", + "contextId": "context-1234", + "status": { + "state": "TASK_STATE_SUBMITTED" + }, + "history": [ + { + "messageId": "message-1234", + "contextId": "context-1234", + "taskId": "9b511af4-b27c-47fa-aecf-2a93c08a44f8", + "role": "ROLE_USER", + "content": [ + { + "text": "tell me a joke" + } + ], + "metadata": {} + } + ] + } + }"""; + + static final String CANCEL_TASK_TEST_REQUEST = """ + { + "name": "tasks/de38c76d-d54c-436c-8b9f-4c2703648d64" + }"""; + + static final String CANCEL_TASK_TEST_RESPONSE = """ + { + "id": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "contextId": "c295ea44-7543-4f78-b524-7a38915ad6e4", + "status": { + "state": "TASK_STATE_CANCELLED" + }, + "metadata": {} + }"""; + + static final String GET_TASK_TEST_RESPONSE = """ + { + "id": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "contextId": "c295ea44-7543-4f78-b524-7a38915ad6e4", + "status": { + "state": "TASK_STATE_COMPLETED" + }, + "artifacts": [ + { + "artifactId": "artifact-1", + "parts": [ + { + "text": "Why did the chicken cross the road? To get to the other side!" + } + ] + } + ], + "history": [ + { + "role": "ROLE_USER", + "content": [ + { + "text": "tell me a joke" + }, + { + "file": { + "file_with_uri": "file:///path/to/file.txt", + "mimeType": "text/plain" + } + }, + { + "file": { + "file_with_bytes": "aGVsbG8=", + "mimeType": "text/plain" + } + } + ], + "messageId": "message-123" + } + ], + "metadata": {} + } + """; + + static final String AGENT_CARD = """ + { + "name": "GeoSpatial Route Planner Agent", + "description": "Provides advanced route planning, traffic analysis, and custom map generation services. This agent can calculate optimal routes, estimate travel times considering real-time traffic, and create personalized maps with points of interest.", + "url": "https://georoute-agent.example.com/a2a/v1", + "provider": { + "organization": "Example Geo Services Inc.", + "url": "https://www.examplegeoservices.com" + }, + "iconUrl": "https://georoute-agent.example.com/icon.png", + "version": "1.2.0", + "documentationUrl": "https://docs.examplegeoservices.com/georoute-agent/api", + "capabilities": { + "streaming": true, + "pushNotifications": true, + "stateTransitionHistory": false + }, + "securitySchemes": { + "google": { + "type": "openIdConnect", + "openIdConnectUrl": "https://accounts.google.com/.well-known/openid-configuration" + } + }, + "security": [{ "google": ["openid", "profile", "email"] }], + "defaultInputModes": ["application/json", "text/plain"], + "defaultOutputModes": ["application/json", "image/png"], + "skills": [ + { + "id": "route-optimizer-traffic", + "name": "Traffic-Aware Route Optimizer", + "description": "Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", + "tags": ["maps", "routing", "navigation", "directions", "traffic"], + "examples": [ + "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + "{\\"origin\\": {\\"lat\\": 37.422, \\"lng\\": -122.084}, \\"destination\\": {\\"lat\\": 37.7749, \\"lng\\": -122.4194}, \\"preferences\\": [\\"avoid_ferries\\"]}" + ], + "inputModes": ["application/json", "text/plain"], + "outputModes": [ + "application/json", + "application/vnd.geo+json", + "text/html" + ] + }, + { + "id": "custom-map-generator", + "name": "Personalized Map Generator", + "description": "Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", + "tags": ["maps", "customization", "visualization", "cartography"], + "examples": [ + "Generate a map of my upcoming road trip with all planned stops highlighted.", + "Show me a map visualizing all coffee shops within a 1-mile radius of my current location." + ], + "inputModes": ["application/json"], + "outputModes": [ + "image/png", + "image/jpeg", + "application/json", + "text/html" + ] + } + ], + "supportsAuthenticatedExtendedCard": false, + "protocolVersion": "0.2.5" + }"""; + + static final String AGENT_CARD_SUPPORTS_EXTENDED = """ + { + "name": "GeoSpatial Route Planner Agent", + "description": "Provides advanced route planning, traffic analysis, and custom map generation services. This agent can calculate optimal routes, estimate travel times considering real-time traffic, and create personalized maps with points of interest.", + "url": "https://georoute-agent.example.com/a2a/v1", + "provider": { + "organization": "Example Geo Services Inc.", + "url": "https://www.examplegeoservices.com" + }, + "iconUrl": "https://georoute-agent.example.com/icon.png", + "version": "1.2.0", + "documentationUrl": "https://docs.examplegeoservices.com/georoute-agent/api", + "capabilities": { + "streaming": true, + "pushNotifications": true, + "stateTransitionHistory": false + }, + "securitySchemes": { + "google": { + "type": "openIdConnect", + "openIdConnectUrl": "https://accounts.google.com/.well-known/openid-configuration" + } + }, + "security": [{ "google": ["openid", "profile", "email"] }], + "defaultInputModes": ["application/json", "text/plain"], + "defaultOutputModes": ["application/json", "image/png"], + "skills": [ + { + "id": "route-optimizer-traffic", + "name": "Traffic-Aware Route Optimizer", + "description": "Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", + "tags": ["maps", "routing", "navigation", "directions", "traffic"], + "examples": [ + "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + "{\\"origin\\": {\\"lat\\": 37.422, \\"lng\\": -122.084}, \\"destination\\": {\\"lat\\": 37.7749, \\"lng\\": -122.4194}, \\"preferences\\": [\\"avoid_ferries\\"]}" + ], + "inputModes": ["application/json", "text/plain"], + "outputModes": [ + "application/json", + "application/vnd.geo+json", + "text/html" + ] + }, + { + "id": "custom-map-generator", + "name": "Personalized Map Generator", + "description": "Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", + "tags": ["maps", "customization", "visualization", "cartography"], + "examples": [ + "Generate a map of my upcoming road trip with all planned stops highlighted.", + "Show me a map visualizing all coffee shops within a 1-mile radius of my current location." + ], + "inputModes": ["application/json"], + "outputModes": [ + "image/png", + "image/jpeg", + "application/json", + "text/html" + ] + } + ], + "supportsAuthenticatedExtendedCard": true, + "protocolVersion": "0.2.5" + }"""; + + static final String AUTHENTICATION_EXTENDED_AGENT_CARD = """ + { + "name": "GeoSpatial Route Planner Agent Extended", + "description": "Extended description", + "url": "https://georoute-agent.example.com/a2a/v1", + "provider": { + "organization": "Example Geo Services Inc.", + "url": "https://www.examplegeoservices.com" + }, + "iconUrl": "https://georoute-agent.example.com/icon.png", + "version": "1.2.0", + "documentationUrl": "https://docs.examplegeoservices.com/georoute-agent/api", + "capabilities": { + "streaming": true, + "pushNotifications": true, + "stateTransitionHistory": false + }, + "securitySchemes": { + "google": { + "type": "openIdConnect", + "openIdConnectUrl": "https://accounts.google.com/.well-known/openid-configuration" + } + }, + "security": [{ "google": ["openid", "profile", "email"] }], + "defaultInputModes": ["application/json", "text/plain"], + "defaultOutputModes": ["application/json", "image/png"], + "skills": [ + { + "id": "route-optimizer-traffic", + "name": "Traffic-Aware Route Optimizer", + "description": "Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", + "tags": ["maps", "routing", "navigation", "directions", "traffic"], + "examples": [ + "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + "{\\"origin\\": {\\"lat\\": 37.422, \\"lng\\": -122.084}, \\"destination\\": {\\"lat\\": 37.7749, \\"lng\\": -122.4194}, \\"preferences\\": [\\"avoid_ferries\\"]}" + ], + "inputModes": ["application/json", "text/plain"], + "outputModes": [ + "application/json", + "application/vnd.geo+json", + "text/html" + ] + }, + { + "id": "custom-map-generator", + "name": "Personalized Map Generator", + "description": "Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", + "tags": ["maps", "customization", "visualization", "cartography"], + "examples": [ + "Generate a map of my upcoming road trip with all planned stops highlighted.", + "Show me a map visualizing all coffee shops within a 1-mile radius of my current location." + ], + "inputModes": ["application/json"], + "outputModes": [ + "image/png", + "image/jpeg", + "application/json", + "text/html" + ] + }, + { + "id": "skill-extended", + "name": "Extended Skill", + "description": "This is an extended skill.", + "tags": ["extended"] + } + ], + "supportsAuthenticatedExtendedCard": true, + "protocolVersion": "0.2.5" + }"""; + + static final String SEND_MESSAGE_TEST_REQUEST_WITH_MESSAGE_RESPONSE = """ + { + "jsonrpc": "2.0", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [ + { + "kind": "text", + "text": "tell me a joke" + } + ], + "messageId": "message-1234", + "contextId": "context-1234", + "kind": "message" + }, + "configuration": { + "acceptedOutputModes": ["text"], + "blocking": true + }, + } + }"""; + + static final String SEND_MESSAGE_TEST_RESPONSE_WITH_MESSAGE_RESPONSE = """ + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "role": "agent", + "parts": [ + { + "kind": "text", + "text": "Why did the chicken cross the road? To get to the other side!" + } + ], + "messageId": "msg-456", + "kind": "message" + } + }"""; + + static final String SEND_MESSAGE_WITH_ERROR_TEST_REQUEST = """ + { + "jsonrpc": "2.0", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [ + { + "kind": "text", + "text": "tell me a joke" + } + ], + "messageId": "message-1234", + "contextId": "context-1234", + "kind": "message" + }, + "configuration": { + "acceptedOutputModes": ["text"], + "blocking": true + }, + } + }"""; + + static final String SEND_MESSAGE_ERROR_TEST_RESPONSE = """ + { + "jsonrpc": "2.0", + "error": { + "code": -32702, + "message": "Invalid parameters", + "data": "Hello world" + } + }"""; + + static final String GET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE = """ + { + "name": "tasks/de38c76d-d54c-436c-8b9f-4c2703648d64/pushNotificationConfigs/10", + "pushNotificationConfig": { + "url": "https://example.com/callback", + "authentication": { + "schemes": ["jwt"] + } + } + }"""; + static final String LIST_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE = """ + { + "configs":[ + { + "name": "tasks/de38c76d-d54c-436c-8b9f-4c2703648d64/pushNotificationConfigs/10", + "pushNotificationConfig": { + "url": "https://example.com/callback", + "authentication": { + "schemes": ["jwt"] + } + } + }, + { + "name": "tasks/de38c76d-d54c-436c-8b9f-4c2703648d64/pushNotificationConfigs/5", + "pushNotificationConfig": { + "url": "https://test.com/callback" + } + } + ] + }"""; + + + static final String SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_REQUEST = """ + { + "parent": "tasks/de38c76d-d54c-436c-8b9f-4c2703648d64", + "config": { + "name": "tasks/de38c76d-d54c-436c-8b9f-4c2703648d64/pushNotificationConfigs", + "pushNotificationConfig": { + "url": "https://example.com/callback", + "authentication": { + "schemes": [ "jwt" ] + } + } + } + }"""; + + static final String SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE = """ + { + "name": "tasks/de38c76d-d54c-436c-8b9f-4c2703648d64/pushNotificationConfigs/10", + "pushNotificationConfig": { + "url": "https://example.com/callback", + "authentication": { + "schemes": ["jwt"] + } + } + }"""; + + static final String SEND_MESSAGE_WITH_FILE_PART_TEST_REQUEST = """ + { + "jsonrpc": "2.0", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [ + { + "kind": "text", + "text": "analyze this image" + }, + { + "kind": "file", + "file": { + "uri": "file:///path/to/image.jpg", + "mimeType": "image/jpeg" + } + } + ], + "messageId": "message-1234-with-file", + "contextId": "context-1234", + "kind": "message" + }, + "configuration": { + "acceptedOutputModes": ["text"], + "blocking": true + } + } + }"""; + + static final String SEND_MESSAGE_WITH_FILE_PART_TEST_RESPONSE = """ + { + "jsonrpc": "2.0", + "result": { + "id": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "contextId": "c295ea44-7543-4f78-b524-7a38915ad6e4", + "status": { + "state": "completed" + }, + "artifacts": [ + { + "artifactId": "artifact-1", + "name": "image-analysis", + "parts": [ + { + "kind": "text", + "text": "This is an image of a cat sitting on a windowsill." + } + ] + } + ], + "metadata": {}, + "kind": "task" + } + }"""; + + static final String SEND_MESSAGE_WITH_DATA_PART_TEST_REQUEST = """ + { + "jsonrpc": "2.0", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [ + { + "kind": "text", + "text": "process this data" + }, + { + "kind": "data", + "data": { + "temperature": 25.5, + "humidity": 60.2, + "location": "San Francisco", + "timestamp": "2024-01-15T10:30:00Z" + } + } + ], + "messageId": "message-1234-with-data", + "contextId": "context-1234", + "kind": "message" + }, + "configuration": { + "acceptedOutputModes": ["text"], + "blocking": true + } + } + }"""; + + static final String SEND_MESSAGE_WITH_DATA_PART_TEST_RESPONSE = """ + { + "jsonrpc": "2.0", + "result": { + "id": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "contextId": "c295ea44-7543-4f78-b524-7a38915ad6e4", + "status": { + "state": "completed" + }, + "artifacts": [ + { + "artifactId": "artifact-1", + "name": "data-analysis", + "parts": [ + { + "kind": "text", + "text": "Processed weather data: Temperature is 25.5°C, humidity is 60.2% in San Francisco." + } + ] + } + ], + "metadata": {}, + "kind": "task" + } + }"""; + + static final String SEND_MESSAGE_WITH_MIXED_PARTS_TEST_REQUEST = """ + { + "jsonrpc": "2.0", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [ + { + "kind": "text", + "text": "analyze this data and image" + }, + { + "kind": "file", + "file": { + "bytes": "aGVsbG8=", + "name": "chart.png", + "mimeType": "image/png" + } + }, + { + "kind": "data", + "data": { + "chartType": "bar", + "dataPoints": [10, 20, 30, 40], + "labels": ["Q1", "Q2", "Q3", "Q4"] + } + } + ], + "messageId": "message-1234-with-mixed", + "contextId": "context-1234", + "kind": "message" + }, + "configuration": { + "acceptedOutputModes": ["text"], + "blocking": true + } + } + }"""; + + static final String SEND_MESSAGE_WITH_MIXED_PARTS_TEST_RESPONSE = """ + { + "jsonrpc": "2.0", + "result": { + "id": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "contextId": "c295ea44-7543-4f78-b524-7a38915ad6e4", + "status": { + "state": "completed" + }, + "artifacts": [ + { + "artifactId": "artifact-1", + "name": "mixed-analysis", + "parts": [ + { + "kind": "text", + "text": "Analyzed chart image and data: Bar chart showing quarterly data with values [10, 20, 30, 40]." + } + ] + } + ], + "metadata": {}, + "kind": "task" + } + }"""; + + public static final String SEND_MESSAGE_STREAMING_TEST_REQUEST = """ + { + "message": { + "role": "ROLE_USER", + "content": [ + { + "text": "tell me some jokes" + } + ], + "messageId": "message-1234", + "contextId": "context-1234" + }, + "configuration": { + "acceptedOutputModes": ["text"] + } + }"""; + static final String SEND_MESSAGE_STREAMING_TEST_RESPONSE + = "event: message\n" + + "data: {\"task\":{\"id\":\"2\",\"contextId\":\"context-1234\",\"status\":{\"state\":\"TASK_STATE_SUBMITTED\"},\"artifacts\":[{\"artifactId\":\"artifact-1\",\"name\":\"joke\",\"parts\":[{\"text\":\"Why did the chicken cross the road? To get to the other side!\"}]}],\"metadata\":{}}}\n\n"; + + static final String TASK_RESUBSCRIPTION_REQUEST_TEST_RESPONSE + = "event: message\n" + + "data: {\"task\":{\"id\":\"2\",\"contextId\":\"context-1234\",\"status\":{\"state\":\"TASK_STATE_COMPLETED\"},\"artifacts\":[{\"artifactId\":\"artifact-1\",\"name\":\"joke\",\"parts\":[{\"text\":\"Why did the chicken cross the road? To get to the other side!\"}]}],\"metadata\":{}}}\n\n"; + public static final String TASK_RESUBSCRIPTION_TEST_REQUEST = """ + { + "jsonrpc": "2.0", + "method": "tasks/resubscribe", + "params": { + "id": "task-1234" + } + }"""; +} diff --git a/client/transport/rest/src/test/java/io/a2a/client/transport/rest/RestTransportTest.java b/client/transport/rest/src/test/java/io/a2a/client/transport/rest/RestTransportTest.java new file mode 100644 index 000000000..fc6fc23a6 --- /dev/null +++ b/client/transport/rest/src/test/java/io/a2a/client/transport/rest/RestTransportTest.java @@ -0,0 +1,426 @@ +package io.a2a.client.transport.rest; + + +import static io.a2a.client.transport.rest.JsonRestMessages.CANCEL_TASK_TEST_REQUEST; +import static io.a2a.client.transport.rest.JsonRestMessages.CANCEL_TASK_TEST_RESPONSE; +import static io.a2a.client.transport.rest.JsonRestMessages.GET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE; +import static io.a2a.client.transport.rest.JsonRestMessages.GET_TASK_TEST_RESPONSE; +import static io.a2a.client.transport.rest.JsonRestMessages.LIST_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE; +import static io.a2a.client.transport.rest.JsonRestMessages.SEND_MESSAGE_STREAMING_TEST_RESPONSE; +import static io.a2a.client.transport.rest.JsonRestMessages.SEND_MESSAGE_TEST_REQUEST; +import static io.a2a.client.transport.rest.JsonRestMessages.SEND_MESSAGE_TEST_RESPONSE; +import static io.a2a.client.transport.rest.JsonRestMessages.SEND_MESSAGE_STREAMING_TEST_REQUEST; +import static io.a2a.client.transport.rest.JsonRestMessages.SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_REQUEST; +import static io.a2a.client.transport.rest.JsonRestMessages.SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE; +import static io.a2a.client.transport.rest.JsonRestMessages.TASK_RESUBSCRIPTION_REQUEST_TEST_RESPONSE; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockserver.model.HttpRequest.request; +import static org.mockserver.model.HttpResponse.response; + +import io.a2a.client.transport.spi.interceptors.ClientCallContext; +import io.a2a.spec.Artifact; +import io.a2a.spec.DeleteTaskPushNotificationConfigParams; +import io.a2a.spec.EventKind; +import io.a2a.spec.FilePart; +import io.a2a.spec.FileWithBytes; +import io.a2a.spec.FileWithUri; +import io.a2a.spec.GetTaskPushNotificationConfigParams; +import io.a2a.spec.ListTaskPushNotificationConfigParams; +import io.a2a.spec.Message; +import io.a2a.spec.MessageSendConfiguration; +import io.a2a.spec.MessageSendParams; +import io.a2a.spec.Part; +import io.a2a.spec.Part.Kind; +import io.a2a.spec.PushNotificationAuthenticationInfo; +import io.a2a.spec.PushNotificationConfig; +import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.Task; +import io.a2a.spec.TaskIdParams; +import io.a2a.spec.TaskPushNotificationConfig; +import io.a2a.spec.TaskQueryParams; +import io.a2a.spec.TaskState; +import io.a2a.spec.TextPart; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.logging.Logger; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockserver.integration.ClientAndServer; +import org.mockserver.matchers.MatchType; +import org.mockserver.model.JsonBody; + +public class RestTransportTest { + + private static final Logger log = Logger.getLogger(RestTransportTest.class.getName()); + private ClientAndServer server; + + @BeforeEach + public void setUp() throws IOException { + server = new ClientAndServer(4001); + } + + @AfterEach + public void tearDown() { + server.stop(); + } + + public RestTransportTest() { + } + + /** + * Test of sendMessage method, of class JSONRestTransport. + */ + @Test + public void testSendMessage() throws Exception { + Message message = new Message.Builder() + .role(Message.Role.USER) + .parts(Collections.singletonList(new TextPart("tell me a joke"))) + .contextId("context-1234") + .messageId("message-1234") + .taskId("") + .build(); + this.server.when( + request() + .withMethod("POST") + .withPath("/v1/message:send") + .withBody(JsonBody.json(SEND_MESSAGE_TEST_REQUEST, MatchType.ONLY_MATCHING_FIELDS)) + ) + .respond( + response() + .withStatusCode(200) + .withBody(SEND_MESSAGE_TEST_RESPONSE) + ); + MessageSendParams messageSendParams = new MessageSendParams(message, null, null); + ClientCallContext context = null; + RestTransport instance = new RestTransport("http://localhost:4001"); + EventKind result = instance.sendMessage(messageSendParams, context); + assertEquals("task", result.getKind()); + Task task = (Task) result; + assertEquals("9b511af4-b27c-47fa-aecf-2a93c08a44f8", task.getId()); + assertEquals("context-1234", task.getContextId()); + assertEquals(TaskState.SUBMITTED, task.getStatus().state()); + assertNull(task.getStatus().message()); + assertNull(task.getMetadata()); + assertEquals(true, task.getArtifacts().isEmpty()); + assertEquals(1, task.getHistory().size()); + Message history = task.getHistory().get(0); + assertEquals("message", history.getKind()); + assertEquals(Message.Role.USER, history.getRole()); + assertEquals("context-1234", history.getContextId()); + assertEquals("message-1234", history.getMessageId()); + assertEquals("9b511af4-b27c-47fa-aecf-2a93c08a44f8", history.getTaskId()); + assertEquals(1, history.getParts().size()); + assertEquals(Kind.TEXT, history.getParts().get(0).getKind()); + assertEquals("tell me a joke", ((TextPart) history.getParts().get(0)).getText()); + assertNull(history.getMetadata()); + assertNull(history.getReferenceTaskIds()); + } + + /** + * Test of cancelTask method, of class JSONRestTransport. + */ + @Test + public void testCancelTask() throws Exception { + this.server.when( + request() + .withMethod("POST") + .withPath("/v1/tasks/de38c76d-d54c-436c-8b9f-4c2703648d64:cancel") + .withBody(JsonBody.json(CANCEL_TASK_TEST_REQUEST, MatchType.ONLY_MATCHING_FIELDS)) + ) + .respond( + response() + .withStatusCode(200) + .withBody(CANCEL_TASK_TEST_RESPONSE) + ); + ClientCallContext context = null; + RestTransport instance = new RestTransport("http://localhost:4001"); + Task task = instance.cancelTask(new TaskIdParams("de38c76d-d54c-436c-8b9f-4c2703648d64", + new HashMap<>()), context); + assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); + assertEquals(TaskState.CANCELED, task.getStatus().state()); + assertNull(task.getStatus().message()); + assertNull(task.getMetadata()); + } + + /** + * Test of getTask method, of class JSONRestTransport. + */ + @Test + public void testGetTask() throws Exception { + this.server.when( + request() + .withMethod("GET") + .withPath("/v1/tasks/de38c76d-d54c-436c-8b9f-4c2703648d64") + ) + .respond( + response() + .withStatusCode(200) + .withBody(GET_TASK_TEST_RESPONSE) + ); + ClientCallContext context = null; + TaskQueryParams request = new TaskQueryParams("de38c76d-d54c-436c-8b9f-4c2703648d64", 10); + RestTransport instance = new RestTransport("http://localhost:4001"); + Task task = instance.getTask(request, context); + assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); + assertEquals(TaskState.COMPLETED, task.getStatus().state()); + assertNull(task.getStatus().message()); + assertNull(task.getMetadata()); + assertEquals(false, task.getArtifacts().isEmpty()); + assertEquals(1, task.getArtifacts().size()); + Artifact artifact = task.getArtifacts().get(0); + assertEquals("artifact-1", artifact.artifactId()); + assertEquals("", artifact.name()); + assertEquals(false, artifact.parts().isEmpty()); + assertEquals(Kind.TEXT, artifact.parts().get(0).getKind()); + assertEquals("Why did the chicken cross the road? To get to the other side!", ((TextPart) artifact.parts().get(0)).getText()); + assertEquals(1, task.getHistory().size()); + Message history = task.getHistory().get(0); + assertEquals("message", history.getKind()); + assertEquals(Message.Role.USER, history.getRole()); + assertEquals("message-123", history.getMessageId()); + assertEquals(3, history.getParts().size()); + assertEquals(Kind.TEXT, history.getParts().get(0).getKind()); + assertEquals("tell me a joke", ((TextPart) history.getParts().get(0)).getText()); + assertEquals(Kind.FILE, history.getParts().get(1).getKind()); + FilePart part = (FilePart) history.getParts().get(1); + assertEquals("text/plain", part.getFile().mimeType()); + assertEquals("file:///path/to/file.txt", ((FileWithUri) part.getFile()).uri()); + part = (FilePart) history.getParts().get(2); + assertEquals(Kind.FILE, part.getKind()); + assertEquals("text/plain", part.getFile().mimeType()); + assertEquals("hello", ((FileWithBytes) part.getFile()).bytes()); + assertNull(history.getMetadata()); + assertNull(history.getReferenceTaskIds()); + } + + /** + * Test of sendMessageStreaming method, of class JSONRestTransport. + */ + @Test + public void testSendMessageStreaming() throws Exception { + this.server.when( + request() + .withMethod("POST") + .withPath("/v1/message:stream") + .withBody(JsonBody.json(SEND_MESSAGE_STREAMING_TEST_REQUEST, MatchType.ONLY_MATCHING_FIELDS)) + ) + .respond( + response() + .withStatusCode(200) + .withHeader("Content-Type", "text/event-stream") + .withBody(SEND_MESSAGE_STREAMING_TEST_RESPONSE) + ); + + RestTransport client = new RestTransport("http://localhost:4001"); + Message message = new Message.Builder() + .role(Message.Role.USER) + .parts(Collections.singletonList(new TextPart("tell me some jokes"))) + .contextId("context-1234") + .messageId("message-1234") + .build(); + MessageSendConfiguration configuration = new MessageSendConfiguration.Builder() + .acceptedOutputModes(List.of("text")) + .blocking(false) + .build(); + MessageSendParams params = new MessageSendParams.Builder() + .message(message) + .configuration(configuration) + .build(); + AtomicReference receivedEvent = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + Consumer eventHandler = event -> { + receivedEvent.set(event); + latch.countDown(); + }; + Consumer errorHandler = error -> { + }; + client.sendMessageStreaming(params, eventHandler, errorHandler, null); + + boolean eventReceived = latch.await(10, TimeUnit.SECONDS); + assertTrue(eventReceived); + assertNotNull(receivedEvent.get()); + assertEquals("task", receivedEvent.get().getKind()); + Task task = (Task) receivedEvent.get(); + assertEquals("2", task.getId()); + } + + /** + * Test of setTaskPushNotificationConfiguration method, of class JSONRestTransport. + */ + @Test + public void testSetTaskPushNotificationConfiguration() throws Exception { + log.info("Testing setTaskPushNotificationConfiguration"); + this.server.when( + request() + .withMethod("POST") + .withPath("/v1/tasks/de38c76d-d54c-436c-8b9f-4c2703648d64/pushNotificationConfigs") + .withBody(JsonBody.json(SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_REQUEST, MatchType.ONLY_MATCHING_FIELDS)) + ) + .respond( + response() + .withStatusCode(200) + .withBody(SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE) + ); + RestTransport client = new RestTransport("http://localhost:4001"); + TaskPushNotificationConfig pushedConfig = new TaskPushNotificationConfig( + "de38c76d-d54c-436c-8b9f-4c2703648d64", + new PushNotificationConfig.Builder() + .url("https://example.com/callback") + .authenticationInfo( + new PushNotificationAuthenticationInfo(Collections.singletonList("jwt"), null)) + .build()); + TaskPushNotificationConfig taskPushNotificationConfig = client.setTaskPushNotificationConfiguration(pushedConfig, null); + PushNotificationConfig pushNotificationConfig = taskPushNotificationConfig.pushNotificationConfig(); + assertNotNull(pushNotificationConfig); + assertEquals("https://example.com/callback", pushNotificationConfig.url()); + PushNotificationAuthenticationInfo authenticationInfo = pushNotificationConfig.authentication(); + assertEquals(1, authenticationInfo.schemes().size()); + assertEquals("jwt", authenticationInfo.schemes().get(0)); + } + + /** + * Test of getTaskPushNotificationConfiguration method, of class JSONRestTransport. + */ + @Test + public void testGetTaskPushNotificationConfiguration() throws Exception { + this.server.when( + request() + .withMethod("GET") + .withPath("/v1/tasks/de38c76d-d54c-436c-8b9f-4c2703648d64/pushNotificationConfigs/10") + ) + .respond( + response() + .withStatusCode(200) + .withBody(GET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE) + ); + + RestTransport client = new RestTransport("http://localhost:4001"); + TaskPushNotificationConfig taskPushNotificationConfig = client.getTaskPushNotificationConfiguration( + new GetTaskPushNotificationConfigParams("de38c76d-d54c-436c-8b9f-4c2703648d64", "10", + new HashMap<>()), null); + PushNotificationConfig pushNotificationConfig = taskPushNotificationConfig.pushNotificationConfig(); + assertNotNull(pushNotificationConfig); + assertEquals("https://example.com/callback", pushNotificationConfig.url()); + PushNotificationAuthenticationInfo authenticationInfo = pushNotificationConfig.authentication(); + assertTrue(authenticationInfo.schemes().size() == 1); + assertEquals("jwt", authenticationInfo.schemes().get(0)); + } + + /** + * Test of listTaskPushNotificationConfigurations method, of class JSONRestTransport. + */ + @Test + public void testListTaskPushNotificationConfigurations() throws Exception { + this.server.when( + request() + .withMethod("GET") + .withPath("/v1/tasks/de38c76d-d54c-436c-8b9f-4c2703648d64/pushNotificationConfigs") + ) + .respond( + response() + .withStatusCode(200) + .withBody(LIST_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE) + ); + + RestTransport client = new RestTransport("http://localhost:4001"); + List taskPushNotificationConfigs = client.listTaskPushNotificationConfigurations( + new ListTaskPushNotificationConfigParams("de38c76d-d54c-436c-8b9f-4c2703648d64", new HashMap<>()), null); + assertEquals(2, taskPushNotificationConfigs.size()); + PushNotificationConfig pushNotificationConfig = taskPushNotificationConfigs.get(0).pushNotificationConfig(); + assertNotNull(pushNotificationConfig); + assertEquals("https://example.com/callback", pushNotificationConfig.url()); + assertEquals("10", pushNotificationConfig.id()); + PushNotificationAuthenticationInfo authenticationInfo = pushNotificationConfig.authentication(); + assertTrue(authenticationInfo.schemes().size() == 1); + assertEquals("jwt", authenticationInfo.schemes().get(0)); + assertEquals("", authenticationInfo.credentials()); + pushNotificationConfig = taskPushNotificationConfigs.get(1).pushNotificationConfig(); + assertNotNull(pushNotificationConfig); + assertEquals("https://test.com/callback", pushNotificationConfig.url()); + assertEquals("5", pushNotificationConfig.id()); + authenticationInfo = pushNotificationConfig.authentication(); + assertNull(authenticationInfo); + } + + /** + * Test of deleteTaskPushNotificationConfigurations method, of class JSONRestTransport. + */ + @Test + public void testDeleteTaskPushNotificationConfigurations() throws Exception { + log.info("Testing deleteTaskPushNotificationConfigurations"); + this.server.when( + request() + .withMethod("DELETE") + .withPath("/v1/tasks/de38c76d-d54c-436c-8b9f-4c2703648d64/pushNotificationConfigs/10") + ) + .respond( + response() + .withStatusCode(200) + ); + ClientCallContext context = null; + RestTransport instance = new RestTransport("http://localhost:4001"); + instance.deleteTaskPushNotificationConfigurations(new DeleteTaskPushNotificationConfigParams("de38c76d-d54c-436c-8b9f-4c2703648d64", "10"), context); + } + + /** + * Test of resubscribe method, of class JSONRestTransport. + */ + @Test + public void testResubscribe() throws Exception { + log.info("Testing resubscribe"); + + this.server.when( + request() + .withMethod("POST") + .withPath("/v1/tasks/task-1234:subscribe") + ) + .respond( + response() + .withStatusCode(200) + .withHeader("Content-Type", "text/event-stream") + .withBody(TASK_RESUBSCRIPTION_REQUEST_TEST_RESPONSE) + ); + + RestTransport client = new RestTransport("http://localhost:4001"); + TaskIdParams taskIdParams = new TaskIdParams("task-1234"); + + AtomicReference receivedEvent = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + Consumer eventHandler = event -> { + receivedEvent.set(event); + latch.countDown(); + }; + Consumer errorHandler = error -> {}; + client.resubscribe(taskIdParams, eventHandler, errorHandler, null); + + boolean eventReceived = latch.await(10, TimeUnit.SECONDS); + assertTrue(eventReceived); + + StreamingEventKind eventKind = receivedEvent.get();; + assertNotNull(eventKind); + assertInstanceOf(Task.class, eventKind); + Task task = (Task) eventKind; + assertEquals("2", task.getId()); + assertEquals("context-1234", task.getContextId()); + assertEquals(TaskState.COMPLETED, task.getStatus().state()); + List artifacts = task.getArtifacts(); + assertEquals(1, artifacts.size()); + Artifact artifact = artifacts.get(0); + assertEquals("artifact-1", artifact.artifactId()); + assertEquals("joke", artifact.name()); + Part part = artifact.parts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("Why did the chicken cross the road? To get to the other side!", ((TextPart) part).getText()); + } +} diff --git a/examples/helloworld/server/pom.xml b/examples/helloworld/server/pom.xml index b35ffa8a9..03f06158a 100644 --- a/examples/helloworld/server/pom.xml +++ b/examples/helloworld/server/pom.xml @@ -23,7 +23,6 @@ io.github.a2asdk a2a-java-sdk-reference-jsonrpc - ${project.version} io.quarkus diff --git a/http-client/src/main/java/io/a2a/client/http/A2AHttpClient.java b/http-client/src/main/java/io/a2a/client/http/A2AHttpClient.java index f59e079f2..52c252a8f 100644 --- a/http-client/src/main/java/io/a2a/client/http/A2AHttpClient.java +++ b/http-client/src/main/java/io/a2a/client/http/A2AHttpClient.java @@ -1,6 +1,7 @@ package io.a2a.client.http; import java.io.IOException; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; @@ -10,8 +11,11 @@ public interface A2AHttpClient { PostBuilder createPost(); + DeleteBuilder createDelete(); + interface Builder> { T url(String s); + T addHeaders(Map headers); T addHeader(String name, String value); } @@ -31,4 +35,8 @@ CompletableFuture postAsyncSSE( Consumer errorConsumer, Runnable completeRunnable) throws IOException, InterruptedException; } + + interface DeleteBuilder extends Builder { + A2AHttpResponse delete() throws IOException, InterruptedException; + } } diff --git a/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java b/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java index 2cdbb2d37..abcecc8ed 100644 --- a/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java +++ b/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java @@ -35,6 +35,11 @@ public PostBuilder createPost() { return new JdkPostBuilder(); } + @Override + public DeleteBuilder createDelete() { + return new JdkDeleteBuilder(); + } + private abstract class JdkBuilder> implements Builder { private String url; private Map headers = new HashMap<>(); @@ -51,6 +56,16 @@ public T addHeader(String name, String value) { return self(); } + @Override + public T addHeaders(Map headers) { + if(headers != null && ! headers.isEmpty()) { + for (Map.Entry entry : headers.entrySet()) { + addHeader(entry.getKey(), entry.getValue()); + } + } + return self(); + } + @SuppressWarnings("unchecked") T self() { return (T) this; @@ -145,6 +160,19 @@ public CompletableFuture getAsyncSSE( .build(); return super.asyncRequest(request, messageConsumer, errorConsumer, completeRunnable); } + + } + + private class JdkDeleteBuilder extends JdkBuilder implements A2AHttpClient.DeleteBuilder { + + @Override + public A2AHttpResponse delete() throws IOException, InterruptedException { + HttpRequest request = super.createRequestBuilder().DELETE().build(); + HttpResponse response = + httpClient.send(request, BodyHandlers.ofString(StandardCharsets.UTF_8)); + return new JdkHttpResponse(response); + } + } private class JdkPostBuilder extends JdkBuilder implements A2AHttpClient.PostBuilder { diff --git a/http-client/src/test/java/io/a2a/client/http/A2ACardResolverTest.java b/http-client/src/test/java/io/a2a/client/http/A2ACardResolverTest.java index 0b855007b..99d26adad 100644 --- a/http-client/src/test/java/io/a2a/client/http/A2ACardResolverTest.java +++ b/http-client/src/test/java/io/a2a/client/http/A2ACardResolverTest.java @@ -14,6 +14,7 @@ import io.a2a.spec.A2AClientError; import io.a2a.spec.A2AClientJSONError; import io.a2a.spec.AgentCard; +import java.util.Map; import org.junit.jupiter.api.Test; public class A2ACardResolverTest { @@ -126,6 +127,11 @@ public PostBuilder createPost() { return null; } + @Override + public DeleteBuilder createDelete() { + return null; + } + class TestGetBuilder implements A2AHttpClient.GetBuilder { @Override @@ -161,7 +167,11 @@ public GetBuilder url(String s) { @Override public GetBuilder addHeader(String name, String value) { + return this; + } + @Override + public GetBuilder addHeaders(Map headers) { return this; } } diff --git a/pom.xml b/pom.xml index d1677324e..cfe777bae 100644 --- a/pom.xml +++ b/pom.xml @@ -81,22 +81,22 @@ ${project.groupId} - a2a-java-sdk-client-config + a2a-java-sdk-client-transport-spi ${project.version} ${project.groupId} - a2a-java-sdk-client-transport-spi + a2a-java-sdk-client-transport-jsonrpc ${project.version} ${project.groupId} - a2a-java-sdk-client-transport-jsonrpc + a2a-java-sdk-client-transport-grpc ${project.version} ${project.groupId} - a2a-java-sdk-client-transport-grpc + a2a-java-sdk-client-transport-rest ${project.version} @@ -119,6 +119,46 @@ a2a-java-sdk-spec-grpc ${project.version} + + ${project.groupId} + a2a-java-sdk-server-common + ${project.version} + + + ${project.groupId} + a2a-java-sdk-transport-grpc + ${project.version} + + + ${project.groupId} + a2a-java-sdk-transport-jsonrpc + ${project.version} + + + ${project.groupId} + a2a-java-sdk-transport-rest + ${project.version} + + + ${project.groupId} + a2a-java-sdk-reference-common + ${project.version} + + + ${project.groupId} + a2a-java-sdk-reference-grpc + ${project.version} + + + ${project.groupId} + a2a-java-sdk-reference-jsonrpc + ${project.version} + + + ${project.groupId} + a2a-java-sdk-reference-rest + ${project.version} + io.grpc grpc-bom @@ -212,6 +252,25 @@ ${logback.version} test + + ${project.groupId} + a2a-java-sdk-tests-server-common + ${project.version} + + + ${project.groupId} + a2a-java-sdk-tests-server-common + test-jar + test + ${project.version} + + + ${project.groupId} + a2a-java-sdk-server-common + test-jar + test + ${project.version} + @@ -302,7 +361,7 @@ sign - + @@ -327,6 +386,7 @@ client/base client/transport/grpc client/transport/jsonrpc + client/transport/rest client/transport/spi common examples/helloworld @@ -334,6 +394,7 @@ reference/common reference/grpc reference/jsonrpc + reference/rest server-common spec spec-grpc @@ -341,6 +402,7 @@ tests/server-common transport/jsonrpc transport/grpc + transport/rest diff --git a/reference/common/pom.xml b/reference/common/pom.xml index 8b2af244f..2e149a343 100644 --- a/reference/common/pom.xml +++ b/reference/common/pom.xml @@ -21,12 +21,10 @@ ${project.groupId} a2a-java-sdk-server-common - ${project.version} ${project.groupId} a2a-java-sdk-tests-server-common - ${project.version} provided @@ -34,7 +32,6 @@ a2a-java-sdk-tests-server-common test-jar test - ${project.version} io.quarkus diff --git a/reference/grpc/pom.xml b/reference/grpc/pom.xml index 6a4ec4618..e11adc2b5 100644 --- a/reference/grpc/pom.xml +++ b/reference/grpc/pom.xml @@ -18,22 +18,18 @@ ${project.groupId} a2a-java-sdk-reference-common - ${project.version} ${project.groupId} a2a-java-sdk-transport-grpc - ${project.version} ${project.groupId} a2a-java-sdk-server-common - ${project.version} ${project.groupId} a2a-java-sdk-tests-server-common - ${project.version} provided @@ -41,12 +37,10 @@ a2a-java-sdk-tests-server-common test-jar test - ${project.version} ${project.groupId} a2a-java-sdk-client-transport-grpc - ${project.version} test diff --git a/reference/jsonrpc/pom.xml b/reference/jsonrpc/pom.xml index a4342f96c..de441f6b3 100644 --- a/reference/jsonrpc/pom.xml +++ b/reference/jsonrpc/pom.xml @@ -21,22 +21,18 @@ ${project.groupId} a2a-java-sdk-reference-common - ${project.version} ${project.groupId} a2a-java-sdk-transport-jsonrpc - ${project.version} ${project.groupId} a2a-java-sdk-server-common - ${project.version} ${project.groupId} a2a-java-sdk-tests-server-common - ${project.version} provided @@ -44,7 +40,6 @@ a2a-java-sdk-tests-server-common test-jar test - ${project.version} io.quarkus diff --git a/reference/rest/README.md b/reference/rest/README.md new file mode 100644 index 000000000..2a7f0f902 --- /dev/null +++ b/reference/rest/README.md @@ -0,0 +1,7 @@ +# A2A Java SDK Reference Server Integration + +This is a reference server for the A2A SDK for Java, that we use to run tests, as well as to demonstrate examples. + +It is based on [Quarkus](https://quarkus.io), and makes use of Quarkus's [Reactive Routes](https://quarkus.io/guides/reactive-routes). + +It is a great choice if you use Quarkus! \ No newline at end of file diff --git a/reference/rest/pom.xml b/reference/rest/pom.xml new file mode 100644 index 000000000..e96a16971 --- /dev/null +++ b/reference/rest/pom.xml @@ -0,0 +1,106 @@ + + + 4.0.0 + + + io.github.a2asdk + a2a-java-sdk-parent + 0.3.0.Beta1-SNAPSHOT + ../../pom.xml + + a2a-java-sdk-reference-rest + + jar + + Java A2A Reference Server: JSON+HTTP/REST + Java SDK for the Agent2Agent Protocol (A2A) - A2A JSON+HTTP/REST Reference Server (based on Quarkus) + + + + ${project.groupId} + a2a-java-sdk-reference-common + + + ${project.groupId} + a2a-java-sdk-transport-rest + + + ${project.groupId} + a2a-java-sdk-server-common + + + ${project.groupId} + a2a-java-sdk-client-transport-rest + test + + + ${project.groupId} + a2a-java-sdk-tests-server-common + provided + + + ${project.groupId} + a2a-java-sdk-tests-server-common + test-jar + test + + + com.google.protobuf + protobuf-java-util + test + + + io.quarkus + quarkus-reactive-routes + + + jakarta.enterprise + jakarta.enterprise.cdi-api + + + jakarta.inject + jakarta.inject-api + + + org.slf4j + slf4j-api + + + io.quarkus + quarkus-junit5 + test + + + io.quarkus + quarkus-rest-client-jackson + test + + + org.junit.jupiter + junit-jupiter-api + test + + + io.rest-assured + rest-assured + test + + + + + + maven-surefire-plugin + 3.5.3 + + + org.jboss.logmanager.LogManager + INFO + ${maven.home} + + + + + + diff --git a/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java new file mode 100644 index 000000000..1a90cda75 --- /dev/null +++ b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java @@ -0,0 +1,411 @@ +package io.a2a.server.rest.quarkus; + +import static io.vertx.core.http.HttpHeaders.CONTENT_TYPE; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; + +import java.util.concurrent.Executor; +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; + +import jakarta.enterprise.inject.Instance; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import io.a2a.server.ServerCallContext; +import io.a2a.server.auth.UnauthenticatedUser; +import io.a2a.server.auth.User; +import io.a2a.server.util.async.Internal; +import io.a2a.spec.InternalError; +import io.a2a.spec.InvalidParamsError; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.MethodNotFoundError; +import io.a2a.transport.rest.handler.RestHandler; +import io.a2a.transport.rest.handler.RestHandler.HTTPRestResponse; +import io.a2a.transport.rest.handler.RestHandler.HTTPRestStreamingResponse; +import io.quarkus.vertx.web.Body; +import io.quarkus.vertx.web.ReactiveRoutes; +import io.quarkus.vertx.web.Route; +import io.smallrye.mutiny.Multi; +import io.vertx.core.AsyncResult; +import io.vertx.core.Handler; +import io.vertx.core.MultiMap; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.HttpServerResponse; +import io.vertx.ext.web.RoutingContext; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +@Singleton +public class A2AServerRoutes { + + @Inject + RestHandler jsonRestHandler; + + // Hook so testing can wait until the MultiSseSupport is subscribed. + // Without this we get intermittent failures + private static volatile Runnable streamingMultiSseSupportSubscribedRunnable; + + @Inject + @Internal + Executor executor; + + @Inject + Instance callContextFactory; + + @Route(regex = "^/v1/message:send$", order = 1, methods = {Route.HttpMethod.POST}, consumes = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING) + public void sendMessage(@Body String body, RoutingContext rc) { + ServerCallContext context = createCallContext(rc); + HTTPRestResponse response = null; + try { + response = jsonRestHandler.sendMessage(body, context); + } catch (Throwable t) { + response = jsonRestHandler.createErrorResponse(new InternalError(t.getMessage())); + } finally { + rc.response() + .setStatusCode(response.getStatusCode()) + .putHeader(CONTENT_TYPE, response.getContentType()) + .end(response.getBody()); + } + } + + @Route(regex = "^/v1/message:stream$", order = 1, methods = {Route.HttpMethod.POST}, consumes = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING) + public void sendMessageStreaming(@Body String body, RoutingContext rc) { + ServerCallContext context = createCallContext(rc); + HTTPRestStreamingResponse streamingResponse = null; + HTTPRestResponse error = null; + try { + HTTPRestResponse response = jsonRestHandler.sendStreamingMessage(body, context); + if (response instanceof HTTPRestStreamingResponse) { + streamingResponse = (HTTPRestStreamingResponse) response; + } else { + error = response; + } + } finally { + if (error != null) { + rc.response() + .setStatusCode(error.getStatusCode()) + .putHeader(CONTENT_TYPE, APPLICATION_JSON) + .end(error.getBody()); + } else { + Multi events = Multi.createFrom().publisher(streamingResponse.getPublisher()); + executor.execute(() -> { + MultiSseSupport.subscribeObject( + events.map(i -> (Object) i), rc); + }); + } + } + } + + @Route(path = "/v1/tasks/:id", order = 1, methods = {Route.HttpMethod.GET}, type = Route.HandlerType.BLOCKING) + public void getTask(RoutingContext rc) { + String taskId = rc.pathParam("id"); + ServerCallContext context = createCallContext(rc); + HTTPRestResponse response = null; + try { + Integer historyLength = null; + if (rc.request().params().contains("history_length")) { + historyLength = Integer.valueOf(rc.request().params().get("history_length")); + } + response = jsonRestHandler.getTask(taskId, historyLength, context); + } catch (NumberFormatException e) { + response = jsonRestHandler.createErrorResponse(new InvalidParamsError("bad history_length")); + } catch (Throwable t) { + response = jsonRestHandler.createErrorResponse(new InternalError(t.getMessage())); + } finally { + rc.response() + .setStatusCode(response.getStatusCode()) + .putHeader(CONTENT_TYPE, response.getContentType()) + .end(response.getBody()); + } + } + + @Route(regex = "^/v1/tasks/([^/]+):cancel$", order = 1, methods = {Route.HttpMethod.POST}, type = Route.HandlerType.BLOCKING) + public void cancelTask(RoutingContext rc) { + String taskId = rc.pathParam("param0"); + ServerCallContext context = createCallContext(rc); + HTTPRestResponse response = null; + try { + response = jsonRestHandler.cancelTask(taskId, context); + } catch (Throwable t) { + if (t instanceof JSONRPCError error) { + response = jsonRestHandler.createErrorResponse(error); + } else { + response = jsonRestHandler.createErrorResponse(new InternalError(t.getMessage())); + } + } finally { + rc.response() + .setStatusCode(response.getStatusCode()) + .putHeader(CONTENT_TYPE, response.getContentType()) + .end(response.getBody()); + } + } + + @Route(regex = "^/v1/tasks/([^/]+):subscribe$", order = 1, methods = {Route.HttpMethod.POST}, type = Route.HandlerType.BLOCKING) + public void resubscribeTask(RoutingContext rc) { + String taskId = rc.pathParam("param0"); + ServerCallContext context = createCallContext(rc); + HTTPRestStreamingResponse streamingResponse = null; + HTTPRestResponse error = null; + try { + HTTPRestResponse response = jsonRestHandler.resubscribeTask(taskId, context); + if (response instanceof HTTPRestStreamingResponse) { + streamingResponse = (HTTPRestStreamingResponse) response; + } else { + error = response; + } + } finally { + if (error != null) { + rc.response() + .setStatusCode(error.getStatusCode()) + .putHeader(CONTENT_TYPE, APPLICATION_JSON) + .end(error.getBody()); + } else { + Multi events = Multi.createFrom().publisher(streamingResponse.getPublisher()); + executor.execute(() -> { + MultiSseSupport.subscribeObject( + events.map(i -> (Object) i), rc); + }); + } + } + } + + @Route(path = "/v1/tasks/:id/pushNotificationConfigs", order = 1, methods = {Route.HttpMethod.POST}, consumes = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING) + public void setTaskPushNotificationConfiguration(@Body String body, RoutingContext rc) { + String taskId = rc.pathParam("id"); + ServerCallContext context = createCallContext(rc); + HTTPRestResponse response = null; + try { + response = jsonRestHandler.setTaskPushNotificationConfiguration(taskId, body, context); + } catch (Throwable t) { + response = jsonRestHandler.createErrorResponse(new InternalError(t.getMessage())); + } finally { + rc.response() + .setStatusCode(response.getStatusCode()) + .putHeader(CONTENT_TYPE, response.getContentType()) + .end(response.getBody()); + } + } + + @Route(path = "/v1/tasks/:id/pushNotificationConfigs/:configId", order = 1, methods = {Route.HttpMethod.GET}, type = Route.HandlerType.BLOCKING) + public void getTaskPushNotificationConfiguration(RoutingContext rc) { + String taskId = rc.pathParam("id"); + String configId = rc.pathParam("configId"); + ServerCallContext context = createCallContext(rc); + HTTPRestResponse response = null; + try { + response = jsonRestHandler.getTaskPushNotificationConfiguration(taskId, configId, context); + } catch (Throwable t) { + response = jsonRestHandler.createErrorResponse(new InternalError(t.getMessage())); + } finally { + rc.response() + .setStatusCode(response.getStatusCode()) + .putHeader(CONTENT_TYPE, response.getContentType()) + .end(response.getBody()); + } + } + + @Route(path = "/v1/tasks/:id/pushNotificationConfigs", order = 1, methods = {Route.HttpMethod.GET}, type = Route.HandlerType.BLOCKING) + public void listTaskPushNotificationConfigurations(RoutingContext rc) { + String taskId = rc.pathParam("id"); + ServerCallContext context = createCallContext(rc); + HTTPRestResponse response = null; + try { + response = jsonRestHandler.listTaskPushNotificationConfigurations(taskId, context); + } catch (Throwable t) { + response = jsonRestHandler.createErrorResponse(new InternalError(t.getMessage())); + } finally { + rc.response() + .setStatusCode(response.getStatusCode()) + .putHeader(CONTENT_TYPE, response.getContentType()) + .end(response.getBody()); + } + } + + @Route(path = "/v1/tasks/:id/pushNotificationConfigs/:configId", order = 1, methods = {Route.HttpMethod.DELETE}, type = Route.HandlerType.BLOCKING) + public void deleteTaskPushNotificationConfiguration(RoutingContext rc) { + String taskId = rc.pathParam("id"); + String configId = rc.pathParam("configId"); + ServerCallContext context = createCallContext(rc); + HTTPRestResponse response = null; + try { + response = jsonRestHandler.deleteTaskPushNotificationConfiguration(taskId, configId, context); + } catch (Throwable t) { + response = jsonRestHandler.createErrorResponse(new InternalError(t.getMessage())); + } finally { + rc.response() + .setStatusCode(response.getStatusCode()) + .putHeader(CONTENT_TYPE, response.getContentType()) + .end(response.getBody()); + } + } + + /** + * /** + * Handles incoming GET requests to the agent card endpoint. + * Returns the agent card in JSON format. + * + * @param rc + */ + @Route(path = "/.well-known/agent-card.json", order = 1, methods = Route.HttpMethod.GET, produces = APPLICATION_JSON) + public void getAgentCard(RoutingContext rc) { + HTTPRestResponse response = jsonRestHandler.getAgentCard(); + rc.response() + .setStatusCode(response.getStatusCode()) + .putHeader(CONTENT_TYPE, response.getContentType()) + .end(response.getBody()); + } + + @Route(path = "/v1/card", order = 1, methods = Route.HttpMethod.GET, produces = APPLICATION_JSON) + public void getAuthenticatedExtendedCard(RoutingContext rc) { + HTTPRestResponse response = jsonRestHandler.getAuthenticatedExtendedCard(); + rc.response() + .setStatusCode(response.getStatusCode()) + .putHeader(CONTENT_TYPE, response.getContentType()) + .end(response.getBody()); + } + + @Route(path = "^/v1/.*", order = 100, methods = {Route.HttpMethod.DELETE, Route.HttpMethod.GET, Route.HttpMethod.HEAD, Route.HttpMethod.OPTIONS, Route.HttpMethod.POST, Route.HttpMethod.PUT}, produces = APPLICATION_JSON) + public void methodNotFoundMessage(RoutingContext rc) { + HTTPRestResponse response = jsonRestHandler.createErrorResponse(new MethodNotFoundError()); + rc.response() + .setStatusCode(response.getStatusCode()) + .putHeader(CONTENT_TYPE, response.getContentType()) + .end(response.getBody()); + } + + static void setStreamingMultiSseSupportSubscribedRunnable(Runnable runnable) { + streamingMultiSseSupportSubscribedRunnable = runnable; + } + + private ServerCallContext createCallContext(RoutingContext rc) { + + if (callContextFactory.isUnsatisfied()) { + User user; + if (rc.user() == null) { + user = UnauthenticatedUser.INSTANCE; + } else { + user = new User() { + @Override + public boolean isAuthenticated() { + return rc.userContext().authenticated(); + } + + @Override + public String getUsername() { + return rc.user().subject(); + } + }; + } + Map state = new HashMap<>(); + // TODO Python's impl has + // state['auth'] = request.auth + // in jsonrpc_app.py. Figure out what this maps to in what Vert.X gives us + + Map headers = new HashMap<>(); + Set headerNames = rc.request().headers().names(); + headerNames.forEach(name -> headers.put(name, rc.request().getHeader(name))); + state.put("headers", headers); + + return new ServerCallContext(user, state); + } else { + CallContextFactory builder = callContextFactory.get(); + return builder.build(rc); + } + } + + // Port of import io.quarkus.vertx.web.runtime.MultiSseSupport, which is considered internal API + private static class MultiSseSupport { + + private MultiSseSupport() { + // Avoid direct instantiation. + } + + private static void initialize(HttpServerResponse response) { + if (response.bytesWritten() == 0) { + MultiMap headers = response.headers(); + if (headers.get("content-type") == null) { + headers.set("content-type", "text/event-stream"); + } + response.setChunked(true); + } + } + + private static void onWriteDone(Flow.Subscription subscription, AsyncResult ar, RoutingContext rc) { + if (ar.failed()) { + rc.fail(ar.cause()); + } else { + subscription.request(1); + } + } + + public static void write(Multi multi, RoutingContext rc) { + HttpServerResponse response = rc.response(); + multi.subscribe().withSubscriber(new Flow.Subscriber() { + Flow.Subscription upstream; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.upstream = subscription; + this.upstream.request(1); + + // Notify tests that we are subscribed + Runnable runnable = streamingMultiSseSupportSubscribedRunnable; + if (runnable != null) { + runnable.run(); + } + } + + @Override + public void onNext(Buffer item) { + initialize(response); + response.write(item, new Handler>() { + @Override + public void handle(AsyncResult ar) { + onWriteDone(upstream, ar, rc); + } + }); + } + + @Override + public void onError(Throwable throwable) { + rc.fail(throwable); + } + + @Override + public void onComplete() { + endOfStream(response); + } + }); + } + + public static void subscribeObject(Multi multi, RoutingContext rc) { + AtomicLong count = new AtomicLong(); + write(multi.map(new Function() { + @Override + public Buffer apply(Object o) { + if (o instanceof ReactiveRoutes.ServerSentEvent) { + ReactiveRoutes.ServerSentEvent ev = (ReactiveRoutes.ServerSentEvent) o; + long id = ev.id() != -1 ? ev.id() : count.getAndIncrement(); + String e = ev.event() == null ? "" : "event: " + ev.event() + "\n"; + return Buffer.buffer(e + "data: " + ev.data() + "\nid: " + id + "\n\n"); + } else { + return Buffer.buffer("data: " + o + "\nid: " + count.getAndIncrement() + "\n\n"); + } + } + }), rc); + } + + private static void endOfStream(HttpServerResponse response) { + if (response.bytesWritten() == 0) { // No item + MultiMap headers = response.headers(); + if (headers.get("content-type") == null) { + headers.set("content-type", "text/event-stream"); + } + } + response.end(); + } + } + +} diff --git a/reference/rest/src/main/java/io/a2a/server/rest/quarkus/CallContextFactory.java b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/CallContextFactory.java new file mode 100644 index 000000000..7aa5caf5e --- /dev/null +++ b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/CallContextFactory.java @@ -0,0 +1,8 @@ +package io.a2a.server.rest.quarkus; + +import io.a2a.server.ServerCallContext; +import io.vertx.ext.web.RoutingContext; + +public interface CallContextFactory { + ServerCallContext build(RoutingContext rc); +} diff --git a/reference/rest/src/main/java/io/a2a/server/rest/quarkus/QuarkusRestTransportMetadata.java b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/QuarkusRestTransportMetadata.java new file mode 100644 index 000000000..ee9d3ae98 --- /dev/null +++ b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/QuarkusRestTransportMetadata.java @@ -0,0 +1,11 @@ +package io.a2a.server.rest.quarkus; + +import io.a2a.server.TransportMetadata; +import io.a2a.spec.TransportProtocol; + +public class QuarkusRestTransportMetadata implements TransportMetadata { + @Override + public String getTransportProtocol() { + return TransportProtocol.HTTP_JSON.asString(); + } +} diff --git a/reference/rest/src/main/resources/META-INF/beans.xml b/reference/rest/src/main/resources/META-INF/beans.xml new file mode 100644 index 000000000..e69de29bb diff --git a/reference/rest/src/main/resources/META-INF/services/io.a2a.server.TransportMetadata b/reference/rest/src/main/resources/META-INF/services/io.a2a.server.TransportMetadata new file mode 100644 index 000000000..cb50024df --- /dev/null +++ b/reference/rest/src/main/resources/META-INF/services/io.a2a.server.TransportMetadata @@ -0,0 +1 @@ +io.a2a.server.rest.quarkus.QuarkusRestTransportMetadata \ No newline at end of file diff --git a/reference/rest/src/test/java/io/a2a/server/rest/quarkus/A2ATestRoutes.java b/reference/rest/src/test/java/io/a2a/server/rest/quarkus/A2ATestRoutes.java new file mode 100644 index 000000000..24d49c9a0 --- /dev/null +++ b/reference/rest/src/test/java/io/a2a/server/rest/quarkus/A2ATestRoutes.java @@ -0,0 +1,190 @@ +package io.a2a.server.rest.quarkus; + +import io.a2a.server.rest.quarkus.A2AServerRoutes; + +import static io.vertx.core.http.HttpHeaders.CONTENT_TYPE; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; +import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN; + +import java.util.concurrent.atomic.AtomicInteger; + +import jakarta.annotation.PostConstruct; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import io.a2a.server.apps.common.TestUtilsBean; +import io.a2a.spec.PushNotificationConfig; +import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.a2a.util.Utils; +import io.quarkus.vertx.web.Body; +import io.quarkus.vertx.web.Param; +import io.quarkus.vertx.web.Route; +import io.vertx.ext.web.RoutingContext; + +/** + * Exposes the {@link TestUtilsBean} via REST using Quarkus Reactive Routes + */ +@Singleton +public class A2ATestRoutes { + @Inject + TestUtilsBean testUtilsBean; + + @Inject + A2AServerRoutes a2AServerRoutes; + + AtomicInteger streamingSubscribedCount = new AtomicInteger(0); + + @PostConstruct + public void init() { + A2AServerRoutes.setStreamingMultiSseSupportSubscribedRunnable(() -> streamingSubscribedCount.incrementAndGet()); + } + + + @Route(path = "/test/task", methods = {Route.HttpMethod.POST}, consumes = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING) + public void saveTask(@Body String body, RoutingContext rc) { + try { + Task task = Utils.OBJECT_MAPPER.readValue(body, Task.class); + testUtilsBean.saveTask(task); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/task/:taskId", methods = {Route.HttpMethod.GET}, produces = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING) + public void getTask(@Param String taskId, RoutingContext rc) { + try { + Task task = testUtilsBean.getTask(taskId); + if (task == null) { + rc.response() + .setStatusCode(404) + .end(); + return; + } + rc.response() + .setStatusCode(200) + .putHeader(CONTENT_TYPE, APPLICATION_JSON) + .end(Utils.OBJECT_MAPPER.writeValueAsString(task)); + + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/task/:taskId", methods = {Route.HttpMethod.DELETE}, type = Route.HandlerType.BLOCKING) + public void deleteTask(@Param String taskId, RoutingContext rc) { + try { + Task task = testUtilsBean.getTask(taskId); + if (task == null) { + rc.response() + .setStatusCode(404) + .end(); + return; + } + testUtilsBean.deleteTask(taskId); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/queue/ensure/:taskId", methods = {Route.HttpMethod.POST}) + public void ensureTaskQueue(@Param String taskId, RoutingContext rc) { + try { + testUtilsBean.ensureQueue(taskId); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/queue/enqueueTaskStatusUpdateEvent/:taskId", methods = {Route.HttpMethod.POST}) + public void enqueueTaskStatusUpdateEvent(@Param String taskId, @Body String body, RoutingContext rc) { + + try { + TaskStatusUpdateEvent event = Utils.OBJECT_MAPPER.readValue(body, TaskStatusUpdateEvent.class); + testUtilsBean.enqueueEvent(taskId, event); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/queue/enqueueTaskArtifactUpdateEvent/:taskId", methods = {Route.HttpMethod.POST}) + public void enqueueTaskArtifactUpdateEvent(@Param String taskId, @Body String body, RoutingContext rc) { + + try { + TaskArtifactUpdateEvent event = Utils.OBJECT_MAPPER.readValue(body, TaskArtifactUpdateEvent.class); + testUtilsBean.enqueueEvent(taskId, event); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/streamingSubscribedCount", methods = {Route.HttpMethod.GET}, produces = {TEXT_PLAIN}) + public void getStreamingSubscribedCount(RoutingContext rc) { + rc.response() + .setStatusCode(200) + .end(String.valueOf(streamingSubscribedCount.get())); + } + + @Route(path = "/test/task/:taskId/config/:configId", methods = {Route.HttpMethod.DELETE}, type = Route.HandlerType.BLOCKING) + public void deleteTaskPushNotificationConfig(@Param String taskId, @Param String configId, RoutingContext rc) { + try { + Task task = testUtilsBean.getTask(taskId); + if (task == null) { + rc.response() + .setStatusCode(404) + .end(); + return; + } + testUtilsBean.deleteTaskPushNotificationConfig(taskId, configId); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/task/:taskId", methods = {Route.HttpMethod.POST}, type = Route.HandlerType.BLOCKING) + public void saveTaskPushNotificationConfig(@Param String taskId, @Body String body, RoutingContext rc) { + try { + PushNotificationConfig notificationConfig = Utils.OBJECT_MAPPER.readValue(body, PushNotificationConfig.class); + if (notificationConfig == null) { + rc.response() + .setStatusCode(404) + .end(); + return; + } + testUtilsBean.saveTaskPushNotificationConfig(taskId, notificationConfig); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + private void errorResponse(Throwable t, RoutingContext rc) { + t.printStackTrace(); + rc.response() + .setStatusCode(500) + .putHeader(CONTENT_TYPE, TEXT_PLAIN) + .end(); + } + +} diff --git a/reference/rest/src/test/java/io/a2a/server/rest/quarkus/QuarkusA2ARestTest.java b/reference/rest/src/test/java/io/a2a/server/rest/quarkus/QuarkusA2ARestTest.java new file mode 100644 index 000000000..f2231ffc8 --- /dev/null +++ b/reference/rest/src/test/java/io/a2a/server/rest/quarkus/QuarkusA2ARestTest.java @@ -0,0 +1,60 @@ +package io.a2a.server.rest.quarkus; + +import static io.a2a.server.apps.common.AbstractA2AServerTest.APPLICATION_JSON; + +import io.a2a.client.ClientBuilder; +import io.a2a.client.transport.rest.RestTransport; +import io.a2a.client.transport.rest.RestTransportConfigBuilder; +import io.a2a.server.apps.common.AbstractA2AServerTest; +import io.a2a.spec.TransportProtocol; +import io.quarkus.test.junit.QuarkusTest; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +@QuarkusTest +public class QuarkusA2ARestTest extends AbstractA2AServerTest { + + public QuarkusA2ARestTest() { + super(8081); + } + + @Override + protected String getTransportProtocol() { + return TransportProtocol.HTTP_JSON.asString(); + } + + @Override + protected String getTransportUrl() { + return "http://localhost:8081"; + } + + @Override + protected void configureTransport(ClientBuilder builder) { + builder.withTransport(RestTransport.class, new RestTransportConfigBuilder()); + } + @Test + public void testMethodNotFound() throws Exception { + + // Create the client + HttpClient client = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .build(); + // Create the request + HttpRequest.Builder builder = HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + serverPort + "/v1/message:send")) + .PUT(HttpRequest.BodyPublishers.ofString("test")) + .header("Content-Type", APPLICATION_JSON); + HttpResponse response = client.send(builder.build(), HttpResponse.BodyHandlers.ofString()); + Assertions.assertEquals(405, response.statusCode()); + builder = HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + serverPort + "/v1/message:send")) + .DELETE() + .header("Content-Type", APPLICATION_JSON); + response = client.send(builder.build(), HttpResponse.BodyHandlers.ofString()); + Assertions.assertEquals(405, response.statusCode()); + } +} diff --git a/reference/rest/src/test/resources/application.properties b/reference/rest/src/test/resources/application.properties new file mode 100644 index 000000000..d3366bece --- /dev/null +++ b/reference/rest/src/test/resources/application.properties @@ -0,0 +1 @@ +quarkus.arc.selected-alternatives=io.a2a.server.apps.common.TestHttpClient \ No newline at end of file diff --git a/server-common/pom.xml b/server-common/pom.xml index 1bcc2b7b9..7f1d810f6 100644 --- a/server-common/pom.xml +++ b/server-common/pom.xml @@ -20,17 +20,14 @@ ${project.groupId} a2a-java-sdk-spec - ${project.version} ${project.groupId} a2a-java-sdk-http-client - ${project.version} ${project.groupId} a2a-java-sdk-client-transport-jsonrpc - ${project.version} com.fasterxml.jackson.core diff --git a/server-common/src/main/java/io/a2a/server/requesthandlers/RequestHandler.java b/server-common/src/main/java/io/a2a/server/requesthandlers/RequestHandler.java index e45bc3c62..50ba260c0 100644 --- a/server-common/src/main/java/io/a2a/server/requesthandlers/RequestHandler.java +++ b/server-common/src/main/java/io/a2a/server/requesthandlers/RequestHandler.java @@ -20,6 +20,9 @@ public interface RequestHandler { Task onGetTask( TaskQueryParams params, ServerCallContext context) throws JSONRPCError; +// +// List onListTask(ServerCallContext context) +// throws JSONRPCError; Task onCancelTask( TaskIdParams params, diff --git a/server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStore.java b/server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStore.java index 451b60451..4ae6520c6 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStore.java +++ b/server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStore.java @@ -75,4 +75,4 @@ public void deleteInfo(String taskId, String configId) { pushNotificationInfos.remove(taskId); } } -} +} \ No newline at end of file diff --git a/server-common/src/main/java/io/a2a/server/tasks/InMemoryTaskStore.java b/server-common/src/main/java/io/a2a/server/tasks/InMemoryTaskStore.java index 5cb42d8bb..78d0567fb 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/InMemoryTaskStore.java +++ b/server-common/src/main/java/io/a2a/server/tasks/InMemoryTaskStore.java @@ -1,8 +1,5 @@ package io.a2a.server.tasks; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; diff --git a/server-common/src/main/java/io/a2a/server/tasks/TaskStore.java b/server-common/src/main/java/io/a2a/server/tasks/TaskStore.java index 73ac8f38f..26f66f023 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/TaskStore.java +++ b/server-common/src/main/java/io/a2a/server/tasks/TaskStore.java @@ -1,5 +1,6 @@ package io.a2a.server.tasks; + import io.a2a.spec.Task; public interface TaskStore { diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java index acef022c3..9f12ee792 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java @@ -37,6 +37,7 @@ import io.a2a.spec.TextPart; import io.a2a.util.Utils; import io.quarkus.arc.profile.IfBuildProfile; +import java.util.Map; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; @@ -161,6 +162,11 @@ public PostBuilder createPost() { return new TestHttpClient.TestPostBuilder(); } + @Override + public DeleteBuilder createDelete() { + return null; + } + class TestPostBuilder implements A2AHttpClient.PostBuilder { private volatile String body; @Override @@ -209,6 +215,11 @@ public PostBuilder addHeader(String name, String value) { return this; } + @Override + public PostBuilder addHeaders(Map headers) { + return this; + } + } } } diff --git a/spec-grpc/pom.xml b/spec-grpc/pom.xml index d72663392..ea5ac3fba 100644 --- a/spec-grpc/pom.xml +++ b/spec-grpc/pom.xml @@ -20,7 +20,6 @@ ${project.groupId} a2a-java-sdk-spec - ${project.version} com.google.protobuf @@ -29,10 +28,12 @@ io.grpc grpc-protobuf + provided io.grpc grpc-stub + provided jakarta.enterprise diff --git a/spec-grpc/src/main/java/io/a2a/grpc/utils/ProtoUtils.java b/spec-grpc/src/main/java/io/a2a/grpc/utils/ProtoUtils.java index 1b5bc8f17..5572cf440 100644 --- a/spec-grpc/src/main/java/io/a2a/grpc/utils/ProtoUtils.java +++ b/spec-grpc/src/main/java/io/a2a/grpc/utils/ProtoUtils.java @@ -13,6 +13,7 @@ import com.google.protobuf.ByteString; import com.google.protobuf.Struct; import com.google.protobuf.Value; +import io.a2a.grpc.ListTaskPushNotificationConfigResponse; import io.a2a.grpc.StreamResponse; import io.a2a.spec.APIKeySecurityScheme; @@ -167,9 +168,9 @@ public static io.a2a.grpc.Message message(Message message) { } public static io.a2a.grpc.TaskPushNotificationConfig taskPushNotificationConfig(TaskPushNotificationConfig config) { + String id = config.pushNotificationConfig().id(); io.a2a.grpc.TaskPushNotificationConfig.Builder builder = io.a2a.grpc.TaskPushNotificationConfig.newBuilder(); - String configId = config.pushNotificationConfig().id(); - builder.setName("tasks/" + config.taskId() + "/pushNotificationConfigs/" + (configId != null ? configId : config.taskId())); + builder.setName("tasks/" + config.taskId() + "/pushNotificationConfigs" + (id == null ? "" : ('/' + id))); builder.setPushNotificationConfig(pushNotificationConfig(config.pushNotificationConfig())); return builder.build(); } @@ -367,6 +368,17 @@ private static io.a2a.grpc.AgentCapabilities agentCapabilities(AgentCapabilities return builder.build(); } + public static io.a2a.grpc.SendMessageRequest sendMessageRequest(MessageSendParams request) { + io.a2a.grpc.SendMessageRequest.Builder builder = io.a2a.grpc.SendMessageRequest.newBuilder(); + builder.setRequest(message(request.message())); + if (request.configuration() != null) { + builder.setConfiguration(messageSendConfiguration(request.configuration())); + } + if (request.metadata() != null && ! request.metadata().isEmpty()) { + builder.setMetadata(struct(request.metadata())); + } + return builder.build(); + } private static io.a2a.grpc.AgentExtension agentExtension(AgentExtension agentExtension) { io.a2a.grpc.AgentExtension.Builder builder = io.a2a.grpc.AgentExtension.newBuilder(); if (agentExtension.description() != null) { @@ -521,6 +533,15 @@ private static io.a2a.grpc.AuthorizationCodeOAuthFlow authorizationCodeOAuthFlow return builder.build(); } + public static io.a2a.grpc.ListTaskPushNotificationConfigResponse listTaskPushNotificationConfigResponse(List configs) { + List confs = new ArrayList<>(configs.size()); + ListTaskPushNotificationConfigResponse.Builder response = ListTaskPushNotificationConfigResponse.newBuilder(); + for(TaskPushNotificationConfig config: configs) { + confs.add(taskPushNotificationConfig(config)); + } + return io.a2a.grpc.ListTaskPushNotificationConfigResponse.newBuilder().addAllConfigs(confs).build(); + } + private static io.a2a.grpc.ClientCredentialsOAuthFlow clientCredentialsOAuthFlow(ClientCredentialsOAuthFlow clientCredentialsOAuthFlow) { io.a2a.grpc.ClientCredentialsOAuthFlow.Builder builder = io.a2a.grpc.ClientCredentialsOAuthFlow.newBuilder(); if (clientCredentialsOAuthFlow.refreshUrl() != null) { @@ -661,23 +682,45 @@ public static io.a2a.grpc.SendMessageResponse taskOrMessage(EventKind eventKind) } } + public static io.a2a.grpc.StreamResponse taskOrMessageStream(StreamingEventKind eventKind) { + if (eventKind instanceof Task task) { + return io.a2a.grpc.StreamResponse.newBuilder() + .setTask(task(task)) + .build(); + } else if (eventKind instanceof Message msg) { + return io.a2a.grpc.StreamResponse.newBuilder() + .setMsg(message(msg)) + .build(); + } else if (eventKind instanceof TaskArtifactUpdateEvent update) { + return io.a2a.grpc.StreamResponse.newBuilder() + .setArtifactUpdate(taskArtifactUpdateEvent(update)) + .build(); + } else if (eventKind instanceof TaskStatusUpdateEvent update) { + return io.a2a.grpc.StreamResponse.newBuilder() + .setStatusUpdate(taskStatusUpdateEvent(update)) + .build(); + } else { + throw new IllegalArgumentException("Unsupported event type: " + eventKind); + } + } + } public static class FromProto { - public static TaskQueryParams taskQueryParams(io.a2a.grpc.GetTaskRequest request) { + public static TaskQueryParams taskQueryParams(io.a2a.grpc.GetTaskRequestOrBuilder request) { String name = request.getName(); String id = name.substring(name.lastIndexOf('/') + 1); return new TaskQueryParams(id, request.getHistoryLength()); } - public static TaskIdParams taskIdParams(io.a2a.grpc.CancelTaskRequest request) { + public static TaskIdParams taskIdParams(io.a2a.grpc.CancelTaskRequestOrBuilder request) { String name = request.getName(); String id = name.substring(name.lastIndexOf('/') + 1); return new TaskIdParams(id); } - public static MessageSendParams messageSendParams(io.a2a.grpc.SendMessageRequest request) { + public static MessageSendParams messageSendParams(io.a2a.grpc.SendMessageRequestOrBuilder request) { MessageSendParams.Builder builder = new MessageSendParams.Builder(); builder.message(message(request.getRequest())); if (request.hasConfiguration()) { @@ -700,6 +743,7 @@ public static TaskPushNotificationConfig taskPushNotificationConfig(io.a2a.grpc. private static TaskPushNotificationConfig taskPushNotificationConfig(io.a2a.grpc.TaskPushNotificationConfigOrBuilder config, boolean create) { String name = config.getName(); // "tasks/{id}/pushNotificationConfigs/{push_id}" String[] parts = name.split("/"); + String taskId = parts[1]; String configId = ""; if (create) { if (parts.length < 3) { @@ -707,6 +751,8 @@ private static TaskPushNotificationConfig taskPushNotificationConfig(io.a2a.grpc } if (parts.length == 4) { configId = parts[3]; + } else { + configId = taskId; } } else { if (parts.length < 4) { @@ -714,12 +760,11 @@ private static TaskPushNotificationConfig taskPushNotificationConfig(io.a2a.grpc } configId = parts[3]; } - String taskId = parts[1]; PushNotificationConfig pnc = pushNotification(config.getPushNotificationConfig(), configId); return new TaskPushNotificationConfig(taskId, pnc); } - public static GetTaskPushNotificationConfigParams getTaskPushNotificationConfigParams(io.a2a.grpc.GetTaskPushNotificationConfigRequest request) { + public static GetTaskPushNotificationConfigParams getTaskPushNotificationConfigParams(io.a2a.grpc.GetTaskPushNotificationConfigRequestOrBuilder request) { String name = request.getName(); // "tasks/{id}/pushNotificationConfigs/{push_id}" String[] parts = name.split("/"); String taskId = parts[1]; @@ -734,19 +779,28 @@ public static GetTaskPushNotificationConfigParams getTaskPushNotificationConfigP return new GetTaskPushNotificationConfigParams(taskId, configId); } - public static TaskIdParams taskIdParams(io.a2a.grpc.TaskSubscriptionRequest request) { + public static TaskIdParams taskIdParams(io.a2a.grpc.TaskSubscriptionRequestOrBuilder request) { String name = request.getName(); String id = name.substring(name.lastIndexOf('/') + 1); return new TaskIdParams(id); } - public static ListTaskPushNotificationConfigParams listTaskPushNotificationConfigParams(io.a2a.grpc.ListTaskPushNotificationConfigRequest request) { + public static List listTaskPushNotificationConfigParams(io.a2a.grpc.ListTaskPushNotificationConfigResponseOrBuilder response) { + List configs = response.getConfigsList(); + List result = new ArrayList<>(configs.size()); + for(io.a2a.grpc.TaskPushNotificationConfig config : configs) { + result.add(taskPushNotificationConfig(config, false)); + } + return result; + } + + public static ListTaskPushNotificationConfigParams listTaskPushNotificationConfigParams(io.a2a.grpc.ListTaskPushNotificationConfigRequestOrBuilder request) { String parent = request.getParent(); String id = parent.substring(parent.lastIndexOf('/') + 1); return new ListTaskPushNotificationConfigParams(id); } - public static DeleteTaskPushNotificationConfigParams deleteTaskPushNotificationConfigParams(io.a2a.grpc.DeleteTaskPushNotificationConfigRequest request) { + public static DeleteTaskPushNotificationConfigParams deleteTaskPushNotificationConfigParams(io.a2a.grpc.DeleteTaskPushNotificationConfigRequestOrBuilder request) { String name = request.getName(); // "tasks/{id}/pushNotificationConfigs/{push_id}" String[] parts = name.split("/"); if (parts.length < 4) { @@ -757,13 +811,13 @@ public static DeleteTaskPushNotificationConfigParams deleteTaskPushNotificationC return new DeleteTaskPushNotificationConfigParams(taskId, configId); } - private static AgentCapabilities agentCapabilities(io.a2a.grpc.AgentCapabilities agentCapabilities) { + private static AgentCapabilities agentCapabilities(io.a2a.grpc.AgentCapabilitiesOrBuilder agentCapabilities) { return new AgentCapabilities(agentCapabilities.getStreaming(), agentCapabilities.getPushNotifications(), false, agentCapabilities.getExtensionsList().stream().map(item -> agentExtension(item)).collect(Collectors.toList()) ); } - private static AgentExtension agentExtension(io.a2a.grpc.AgentExtension agentExtension) { + private static AgentExtension agentExtension(io.a2a.grpc.AgentExtensionOrBuilder agentExtension) { return new AgentExtension( agentExtension.getDescription(), struct(agentExtension.getParams()), @@ -772,7 +826,7 @@ private static AgentExtension agentExtension(io.a2a.grpc.AgentExtension agentExt ); } - private static MessageSendConfiguration messageSendConfiguration(io.a2a.grpc.SendMessageConfiguration sendMessageConfiguration) { + private static MessageSendConfiguration messageSendConfiguration(io.a2a.grpc.SendMessageConfigurationOrBuilder sendMessageConfiguration) { return new MessageSendConfiguration( sendMessageConfiguration.getAcceptedOutputModesList().isEmpty() ? null : new ArrayList<>(sendMessageConfiguration.getAcceptedOutputModesList()), @@ -782,8 +836,8 @@ private static MessageSendConfiguration messageSendConfiguration(io.a2a.grpc.Sen ); } - private static PushNotificationConfig pushNotification(io.a2a.grpc.PushNotificationConfig pushNotification, String configId) { - if (pushNotification == null || pushNotification.getDefaultInstanceForType().equals(pushNotification)) { + private static PushNotificationConfig pushNotification(io.a2a.grpc.PushNotificationConfigOrBuilder pushNotification, String configId) { + if(pushNotification == null || pushNotification.getDefaultInstanceForType().equals(pushNotification)) { return null; } return new PushNotificationConfig( @@ -794,18 +848,18 @@ private static PushNotificationConfig pushNotification(io.a2a.grpc.PushNotificat ); } - private static PushNotificationConfig pushNotification(io.a2a.grpc.PushNotificationConfig pushNotification) { + private static PushNotificationConfig pushNotification(io.a2a.grpc.PushNotificationConfigOrBuilder pushNotification) { return pushNotification(pushNotification, pushNotification.getId()); } - private static PushNotificationAuthenticationInfo authenticationInfo(io.a2a.grpc.AuthenticationInfo authenticationInfo) { + private static PushNotificationAuthenticationInfo authenticationInfo(io.a2a.grpc.AuthenticationInfoOrBuilder authenticationInfo) { return new PushNotificationAuthenticationInfo( new ArrayList<>(authenticationInfo.getSchemesList()), authenticationInfo.getCredentials() ); } - public static Task task(io.a2a.grpc.Task task) { + public static Task task(io.a2a.grpc.TaskOrBuilder task) { return new Task( task.getId(), task.getContextId(), @@ -816,7 +870,7 @@ public static Task task(io.a2a.grpc.Task task) { ); } - public static Message message(io.a2a.grpc.Message message) { + public static Message message(io.a2a.grpc.MessageOrBuilder message) { if (message.getMessageId().isEmpty()) { throw new InvalidParamsError(); } @@ -832,7 +886,7 @@ public static Message message(io.a2a.grpc.Message message) { ); } - public static TaskStatusUpdateEvent taskStatusUpdateEvent(io.a2a.grpc.TaskStatusUpdateEvent taskStatusUpdateEvent) { + public static TaskStatusUpdateEvent taskStatusUpdateEvent(io.a2a.grpc.TaskStatusUpdateEventOrBuilder taskStatusUpdateEvent) { return new TaskStatusUpdateEvent.Builder() .taskId(taskStatusUpdateEvent.getTaskId()) .status(taskStatus(taskStatusUpdateEvent.getStatus())) @@ -842,7 +896,7 @@ public static TaskStatusUpdateEvent taskStatusUpdateEvent(io.a2a.grpc.TaskStatus .build(); } - public static TaskArtifactUpdateEvent taskArtifactUpdateEvent(io.a2a.grpc.TaskArtifactUpdateEvent taskArtifactUpdateEvent) { + public static TaskArtifactUpdateEvent taskArtifactUpdateEvent(io.a2a.grpc.TaskArtifactUpdateEventOrBuilder taskArtifactUpdateEvent) { return new TaskArtifactUpdateEvent.Builder() .taskId(taskArtifactUpdateEvent.getTaskId()) .append(taskArtifactUpdateEvent.getAppend()) @@ -853,7 +907,7 @@ public static TaskArtifactUpdateEvent taskArtifactUpdateEvent(io.a2a.grpc.TaskAr .build(); } - private static Artifact artifact(io.a2a.grpc.Artifact artifact) { + private static Artifact artifact(io.a2a.grpc.ArtifactOrBuilder artifact) { return new Artifact( artifact.getArtifactId(), artifact.getName(), @@ -863,7 +917,7 @@ private static Artifact artifact(io.a2a.grpc.Artifact artifact) { ); } - private static Part part(io.a2a.grpc.Part part) { + private static Part part(io.a2a.grpc.PartOrBuilder part) { if (part.hasText()) { return textPart(part.getText()); } else if (part.hasFile()) { @@ -878,7 +932,7 @@ private static TextPart textPart(String text) { return new TextPart(text); } - private static FilePart filePart(io.a2a.grpc.FilePart filePart) { + private static FilePart filePart(io.a2a.grpc.FilePartOrBuilder filePart) { if (filePart.hasFileWithBytes()) { return new FilePart(new FileWithBytes(filePart.getMimeType(), null, filePart.getFileWithBytes().toStringUtf8())); } else if (filePart.hasFileWithUri()) { @@ -887,17 +941,18 @@ private static FilePart filePart(io.a2a.grpc.FilePart filePart) { throw new InvalidRequestError(); } - private static DataPart dataPart(io.a2a.grpc.DataPart dataPart) { + private static DataPart dataPart(io.a2a.grpc.DataPartOrBuilder dataPart) { return new DataPart(struct(dataPart.getData())); } - private static TaskStatus taskStatus(io.a2a.grpc.TaskStatus taskStatus) { + private static TaskStatus taskStatus(io.a2a.grpc.TaskStatusOrBuilder taskStatus) { TaskState state = taskState(taskStatus.getState()); if (state == null) { return null; } - return new TaskStatus(state, - taskStatus.hasUpdate() ? message(taskStatus.getUpdate()) : null, + return new TaskStatus( + taskState(taskStatus.getState()), + taskStatus.hasUpdate() ? message(taskStatus.getUpdateOrBuilder()) : null, OffsetDateTime.ofInstant(Instant.ofEpochSecond(taskStatus.getTimestamp().getSeconds(), taskStatus.getTimestamp().getNanos()), ZoneOffset.UTC) ); } diff --git a/spec/pom.xml b/spec/pom.xml index 13a880aa0..bd06dd1df 100644 --- a/spec/pom.xml +++ b/spec/pom.xml @@ -20,7 +20,6 @@ ${project.groupId} a2a-java-sdk-common - ${project.version} diff --git a/spec/src/main/java/io/a2a/util/Utils.java b/spec/src/main/java/io/a2a/util/Utils.java index c9e982910..ec702130a 100644 --- a/spec/src/main/java/io/a2a/util/Utils.java +++ b/spec/src/main/java/io/a2a/util/Utils.java @@ -93,4 +93,6 @@ public static Task appendArtifactToTask(Task task, TaskArtifactUpdateEvent event .build(); } + + } diff --git a/tck/pom.xml b/tck/pom.xml index e03ef70da..85a4b215a 100644 --- a/tck/pom.xml +++ b/tck/pom.xml @@ -17,14 +17,16 @@ - io.github.a2asdk + ${project.groupId} a2a-java-sdk-reference-jsonrpc - ${project.version} io.github.a2asdk a2a-java-sdk-reference-grpc - ${project.version} + + + io.github.a2asdk + a2a-java-sdk-reference-rest io.quarkus diff --git a/tck/src/main/java/io/a2a/tck/server/AgentCardProducer.java b/tck/src/main/java/io/a2a/tck/server/AgentCardProducer.java index 68e6c7e26..2ab3d9a4e 100644 --- a/tck/src/main/java/io/a2a/tck/server/AgentCardProducer.java +++ b/tck/src/main/java/io/a2a/tck/server/AgentCardProducer.java @@ -1,6 +1,5 @@ package io.a2a.tck.server; -import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -22,8 +21,10 @@ public class AgentCardProducer { @Produces @PublicAgentCard public AgentCard agentCard() { + String sutJsonRpcUrl = getEnvOrDefault("SUT_JSONRPC_URL", DEFAULT_SUT_URL); String sutGrpcUrl = getEnvOrDefault("SUT_GRPC_URL", DEFAULT_SUT_URL); + String sutRestcUrl = getEnvOrDefault("SUT_REST_URL", DEFAULT_SUT_URL); return new AgentCard.Builder() .name("Hello World Agent") .description("Just a hello world agent") @@ -47,7 +48,8 @@ public AgentCard agentCard() { .protocolVersion("0.3.0") .additionalInterfaces(List.of( new AgentInterface(TransportProtocol.JSONRPC.asString(), sutJsonRpcUrl), - new AgentInterface(TransportProtocol.GRPC.asString(), sutGrpcUrl))) + new AgentInterface(TransportProtocol.GRPC.asString(), sutGrpcUrl), + new AgentInterface(TransportProtocol.HTTP_JSON.asString(), sutRestcUrl))) .build(); } diff --git a/tests/server-common/pom.xml b/tests/server-common/pom.xml index fb6cbc5db..1d5234b4e 100644 --- a/tests/server-common/pom.xml +++ b/tests/server-common/pom.xml @@ -21,17 +21,14 @@ ${project.groupId} a2a-java-sdk-spec - ${project.version} ${project.groupId} a2a-java-sdk-client - ${project.version} ${project.groupId} a2a-java-sdk-server-common - ${project.version} jakarta.ws.rs @@ -56,13 +53,16 @@ io.github.a2asdk a2a-java-sdk-client-transport-jsonrpc - ${project.version} test io.github.a2asdk a2a-java-sdk-client-transport-grpc - ${project.version} + test + + + io.github.a2asdk + a2a-java-sdk-client-transport-rest test diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java index 87d2e536b..f161307aa 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java @@ -15,6 +15,7 @@ import io.a2a.client.http.A2AHttpResponse; import io.a2a.spec.Task; import io.a2a.util.Utils; +import java.util.Map; @Dependent @Alternative @@ -32,6 +33,11 @@ public PostBuilder createPost() { return new TestPostBuilder(); } + @Override + public DeleteBuilder createDelete() { + return null; + } + class TestPostBuilder implements A2AHttpClient.PostBuilder { private volatile String body; @Override @@ -79,5 +85,10 @@ public PostBuilder url(String s) { public PostBuilder addHeader(String name, String value) { return this; } + + @Override + public PostBuilder addHeaders(Map headers) { + return this; + } } } \ No newline at end of file diff --git a/transport/grpc/pom.xml b/transport/grpc/pom.xml index 93ec2a095..76d72db24 100644 --- a/transport/grpc/pom.xml +++ b/transport/grpc/pom.xml @@ -21,19 +21,16 @@ io.github.a2asdk a2a-java-sdk-server-common - ${project.version} ${project.groupId} a2a-java-sdk-server-common - ${project.version} test-jar test ${project.groupId} a2a-java-sdk-spec-grpc - ${project.version} com.google.protobuf diff --git a/transport/jsonrpc/pom.xml b/transport/jsonrpc/pom.xml index 15eecd06e..3565d042f 100644 --- a/transport/jsonrpc/pom.xml +++ b/transport/jsonrpc/pom.xml @@ -19,14 +19,12 @@ - io.github.a2asdk + ${project.groupId} a2a-java-sdk-server-common - ${project.version} ${project.groupId} a2a-java-sdk-server-common - ${project.version} test-jar test diff --git a/transport/rest/pom.xml b/transport/rest/pom.xml new file mode 100644 index 000000000..712abbce4 --- /dev/null +++ b/transport/rest/pom.xml @@ -0,0 +1,70 @@ + + + 4.0.0 + + + io.github.a2asdk + a2a-java-sdk-parent + 0.3.0.Beta1-SNAPSHOT + ../../pom.xml + + a2a-java-sdk-transport-rest + + jar + + Java SDK A2A Transport: JSON+HTTP/REST + Java SDK for the Agent2Agent Protocol (A2A) - JSON+HTTP/REST Transport + + + + io.github.a2asdk + a2a-java-sdk-server-common + + + io.github.a2asdk + a2a-java-sdk-spec-grpc + + + io.github.a2asdk + a2a-java-sdk-spec + + + ${project.groupId} + a2a-java-sdk-server-common + test-jar + test + + + ch.qos.logback + logback-classic + test + + + org.junit.jupiter + junit-jupiter-api + test + + + org.mockito + mockito-core + test + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + + + com.google.protobuf + protobuf-java-util + + + org.slf4j + slf4j-jdk14 + test + + + + + \ No newline at end of file diff --git a/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java b/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java new file mode 100644 index 000000000..f6029b266 --- /dev/null +++ b/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java @@ -0,0 +1,427 @@ +package io.a2a.transport.rest.handler; + +import static io.a2a.server.util.async.AsyncUtils.createTubeConfig; + +import com.fasterxml.jackson.core.JacksonException; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.util.JsonFormat; +import io.a2a.grpc.utils.ProtoUtils; +import io.a2a.server.ExtendedAgentCard; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import java.util.List; +import java.util.concurrent.Flow; + +import io.a2a.server.PublicAgentCard; +import io.a2a.server.ServerCallContext; +import io.a2a.server.requesthandlers.RequestHandler; +import io.a2a.spec.AgentCard; +import io.a2a.spec.AuthenticatedExtendedCardNotConfiguredError; +import io.a2a.spec.ContentTypeNotSupportedError; +import io.a2a.spec.DeleteTaskPushNotificationConfigParams; +import io.a2a.spec.EventKind; +import io.a2a.spec.GetTaskPushNotificationConfigParams; +import io.a2a.spec.InternalError; +import io.a2a.spec.InvalidAgentResponseError; +import io.a2a.spec.InvalidParamsError; +import io.a2a.spec.InvalidRequestError; +import io.a2a.spec.JSONParseError; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.ListTaskPushNotificationConfigParams; +import io.a2a.spec.MethodNotFoundError; +import io.a2a.spec.PushNotificationNotSupportedError; +import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.Task; +import io.a2a.spec.TaskIdParams; +import io.a2a.spec.TaskNotCancelableError; +import io.a2a.spec.TaskNotFoundError; +import io.a2a.spec.TaskPushNotificationConfig; +import io.a2a.spec.TaskQueryParams; +import io.a2a.spec.UnsupportedOperationError; +import io.a2a.util.Utils; +import jakarta.enterprise.inject.Instance; +import java.util.concurrent.CompletableFuture; +import java.util.logging.Level; +import java.util.logging.Logger; +import mutiny.zero.ZeroPublisher; + +@ApplicationScoped +public class RestHandler { + + private static final Logger log = Logger.getLogger(RestHandler.class.getName()); + private AgentCard agentCard; + private Instance extendedAgentCard; + private RequestHandler requestHandler; + + protected RestHandler() { + // For CDI + } + + @Inject + public RestHandler(@PublicAgentCard AgentCard agentCard, @ExtendedAgentCard Instance extendedAgentCard, + RequestHandler requestHandler) { + this.agentCard = agentCard; + this.extendedAgentCard = extendedAgentCard; + this.requestHandler = requestHandler; + } + + public RestHandler(AgentCard agentCard, RequestHandler requestHandler) { + this.agentCard = agentCard; + this.requestHandler = requestHandler; + } + + public HTTPRestResponse sendMessage(String body, ServerCallContext context) { + try { + io.a2a.grpc.SendMessageRequest.Builder request = io.a2a.grpc.SendMessageRequest.newBuilder(); + parseRequestBody(body, request); + EventKind result = requestHandler.onMessageSend(ProtoUtils.FromProto.messageSendParams(request), context); + return createSuccessResponse(200, io.a2a.grpc.SendMessageResponse.newBuilder(ProtoUtils.ToProto.taskOrMessage(result))); + } catch (JSONRPCError e) { + return createErrorResponse(e); + } catch (Throwable throwable) { + return createErrorResponse(new InternalError(throwable.getMessage())); + } + } + + public HTTPRestResponse sendStreamingMessage(String body, ServerCallContext context) { + try { + if (!agentCard.capabilities().streaming()) { + return createErrorResponse(new InvalidRequestError("Streaming is not supported by the agent")); + } + io.a2a.grpc.SendMessageRequest.Builder request = io.a2a.grpc.SendMessageRequest.newBuilder(); + parseRequestBody(body, request); + Flow.Publisher publisher = requestHandler.onMessageSendStream(ProtoUtils.FromProto.messageSendParams(request), context); + return createStreamingResponse(publisher); + } catch (JSONRPCError e) { + return new HTTPRestStreamingResponse(ZeroPublisher.fromItems(new HTTPRestErrorResponse(e).toJson())); + } catch (Throwable throwable) { + return new HTTPRestStreamingResponse(ZeroPublisher.fromItems(new HTTPRestErrorResponse(new InternalError(throwable.getMessage())).toJson())); + } + } + + public HTTPRestResponse cancelTask(String taskId, ServerCallContext context) { + try { + if(taskId == null || taskId.isEmpty()) { + throw new InvalidParamsError(); + } + TaskIdParams params = new TaskIdParams(taskId); + Task task = requestHandler.onCancelTask(params, context); + if (task != null) { + return createSuccessResponse(200, io.a2a.grpc.Task.newBuilder(ProtoUtils.ToProto.task(task))); + } + throw new UnsupportedOperationError(); + } catch (JSONRPCError e) { + return createErrorResponse(e); + } catch (Throwable throwable) { + return createErrorResponse(new InternalError(throwable.getMessage())); + } + } + + public HTTPRestResponse setTaskPushNotificationConfiguration(String taskId, String body, ServerCallContext context) { + try { + if (!agentCard.capabilities().pushNotifications()) { + throw new PushNotificationNotSupportedError(); + } + io.a2a.grpc.CreateTaskPushNotificationConfigRequest.Builder builder = io.a2a.grpc.CreateTaskPushNotificationConfigRequest.newBuilder(); + parseRequestBody(body, builder); + TaskPushNotificationConfig result = requestHandler.onSetTaskPushNotificationConfig(ProtoUtils.FromProto.taskPushNotificationConfig(builder), context); + return createSuccessResponse(201, io.a2a.grpc.TaskPushNotificationConfig.newBuilder(ProtoUtils.ToProto.taskPushNotificationConfig(result))); + } catch (JSONRPCError e) { + return createErrorResponse(e); + } catch (Throwable throwable) { + return createErrorResponse(new InternalError(throwable.getMessage())); + } + } + + public HTTPRestResponse resubscribeTask(String taskId, ServerCallContext context) { + try { + if (!agentCard.capabilities().streaming()) { + return createErrorResponse(new InvalidRequestError("Streaming is not supported by the agent")); + } + TaskIdParams params = new TaskIdParams(taskId); + Flow.Publisher publisher = requestHandler.onResubscribeToTask(params, context); + return createStreamingResponse(publisher); + } catch (JSONRPCError e) { + return new HTTPRestStreamingResponse(ZeroPublisher.fromItems(new HTTPRestErrorResponse(e).toJson())); + } catch (Throwable throwable) { + return new HTTPRestStreamingResponse(ZeroPublisher.fromItems(new HTTPRestErrorResponse(new InternalError(throwable.getMessage())).toJson())); + } + } + + public HTTPRestResponse getTask(String taskId, Integer historyLength, ServerCallContext context) { + try { + TaskQueryParams params = new TaskQueryParams(taskId,historyLength); + Task task = requestHandler.onGetTask(params, context); + if (task != null) { + return createSuccessResponse(200, io.a2a.grpc.Task.newBuilder(ProtoUtils.ToProto.task(task))); + } + throw new TaskNotFoundError(); + } catch (JSONRPCError e) { + return createErrorResponse(e); + } catch (Throwable throwable) { + return createErrorResponse(new InternalError(throwable.getMessage())); + } + } + + public HTTPRestResponse getTaskPushNotificationConfiguration(String taskId, String configId, ServerCallContext context) { + try { + if (!agentCard.capabilities().pushNotifications()) { + throw new PushNotificationNotSupportedError(); + } + GetTaskPushNotificationConfigParams params = new GetTaskPushNotificationConfigParams(taskId, configId); + TaskPushNotificationConfig config = requestHandler.onGetTaskPushNotificationConfig(params, context); + return createSuccessResponse(200, io.a2a.grpc.TaskPushNotificationConfig.newBuilder(ProtoUtils.ToProto.taskPushNotificationConfig(config))); + } catch (JSONRPCError e) { + return createErrorResponse(e); + } catch (Throwable throwable) { + return createErrorResponse(new InternalError(throwable.getMessage())); + } + } + + public HTTPRestResponse listTaskPushNotificationConfigurations(String taskId, ServerCallContext context) { + try { + if (!agentCard.capabilities().pushNotifications()) { + throw new PushNotificationNotSupportedError(); + } + ListTaskPushNotificationConfigParams params = new ListTaskPushNotificationConfigParams(taskId); + List configs = requestHandler.onListTaskPushNotificationConfig(params, context); + return createSuccessResponse(200, io.a2a.grpc.ListTaskPushNotificationConfigResponse.newBuilder(ProtoUtils.ToProto.listTaskPushNotificationConfigResponse(configs))); + } catch (JSONRPCError e) { + return createErrorResponse(e); + } catch (Throwable throwable) { + return createErrorResponse(new InternalError(throwable.getMessage())); + } + } + + public HTTPRestResponse deleteTaskPushNotificationConfiguration(String taskId, String configId, ServerCallContext context) { + try { + if (!agentCard.capabilities().pushNotifications()) { + throw new PushNotificationNotSupportedError(); + } + DeleteTaskPushNotificationConfigParams params = new DeleteTaskPushNotificationConfigParams(taskId, configId); + requestHandler.onDeleteTaskPushNotificationConfig(params, context); + return new HTTPRestResponse(204, "application/json", ""); + } catch (JSONRPCError e) { + return createErrorResponse(e); + } catch (Throwable throwable) { + return createErrorResponse(new InternalError(throwable.getMessage())); + } + } + + private void parseRequestBody(String body, com.google.protobuf.Message.Builder builder) throws JSONRPCError { + try { + if (body == null || body.trim().isEmpty()) { + throw new InvalidRequestError("Request body is required"); + } + validate(body); + JsonFormat.parser().merge(body, builder); + } catch (InvalidProtocolBufferException e) { + log.log(Level.SEVERE, "Error parsing JSON request body: {0}", body); + log.log(Level.SEVERE, "Parse error details", e); + throw new InvalidParamsError("Failed to parse request body: " + e.getMessage()); + } + } + + private void validate(String json) { + try { + Utils.OBJECT_MAPPER.readTree(json); + } catch (JacksonException e) { + throw new JSONParseError(JSONParseError.DEFAULT_CODE, "Failed to parse json", e.getMessage()); + } + } + + private HTTPRestResponse createSuccessResponse(int statusCode, com.google.protobuf.Message.Builder builder) { + try { + String jsonBody = JsonFormat.printer().print(builder); + return new HTTPRestResponse(statusCode, "application/json", jsonBody); + } catch (InvalidProtocolBufferException e) { + return createErrorResponse(new InternalError("Failed to serialize response: " + e.getMessage())); + } + } + + public HTTPRestResponse createErrorResponse(JSONRPCError error) { + int statusCode = mapErrorToHttpStatus(error); + return createErrorResponse(statusCode, error); + } + + private HTTPRestResponse createErrorResponse(int statusCode, JSONRPCError error) { + String jsonBody = new HTTPRestErrorResponse(error).toJson(); + return new HTTPRestResponse(statusCode, "application/json", jsonBody); + } + + private HTTPRestStreamingResponse createStreamingResponse(Flow.Publisher publisher) { + return new HTTPRestStreamingResponse(convertToSendStreamingMessageResponse(publisher)); + } + + private Flow.Publisher convertToSendStreamingMessageResponse( + Flow.Publisher publisher) { + // We can't use the normal convertingProcessor since that propagates any errors as an error handled + // via Subscriber.onError() rather than as part of the SendStreamingResponse payload + return ZeroPublisher.create(createTubeConfig(), tube -> { + CompletableFuture.runAsync(() -> { + publisher.subscribe(new Flow.Subscriber() { + Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(StreamingEventKind item) { + try { + String payload = JsonFormat.printer().omittingInsignificantWhitespace().print(ProtoUtils.ToProto.taskOrMessageStream(item)); + tube.send(payload); + subscription.request(1); + } catch (InvalidProtocolBufferException ex) { + onError(ex); + } + } + + @Override + public void onError(Throwable throwable) { + if (throwable instanceof JSONRPCError jsonrpcError) { + tube.send(new HTTPRestErrorResponse(jsonrpcError).toJson()); + } else { + tube.send(new HTTPRestErrorResponse(new InternalError(throwable.getMessage())).toJson()); + } + onComplete(); + } + + @Override + public void onComplete() { + tube.complete(); + } + }); + }); + }); + } + + private int mapErrorToHttpStatus(JSONRPCError error) { + if (error instanceof InvalidRequestError || error instanceof JSONParseError) { + return 400; + } + if (error instanceof InvalidParamsError) { + return 422; + } + if (error instanceof MethodNotFoundError || error instanceof TaskNotFoundError || error instanceof AuthenticatedExtendedCardNotConfiguredError) { + return 404; + } + if (error instanceof TaskNotCancelableError) { + return 409; + } + if (error instanceof PushNotificationNotSupportedError || error instanceof UnsupportedOperationError) { + return 501; + } + if (error instanceof ContentTypeNotSupportedError) { + return 415; + } + if (error instanceof InvalidAgentResponseError) { + return 502; + } + if (error instanceof InternalError ) { + return 500; + } + return 500; + } + + public HTTPRestResponse getAuthenticatedExtendedCard() { + try { + if (!agentCard.supportsAuthenticatedExtendedCard() || !extendedAgentCard.isResolvable()) { + throw new AuthenticatedExtendedCardNotConfiguredError(); + } + return new HTTPRestResponse(200, "application/json", Utils.OBJECT_MAPPER.writeValueAsString(extendedAgentCard.get())); + } catch (JSONRPCError e) { + return createErrorResponse(e); + } catch (Throwable t) { + return createErrorResponse(500, new InternalError(t.getMessage())); + } + } + + public HTTPRestResponse getAgentCard() { + try { + return new HTTPRestResponse(200, "application/json", Utils.OBJECT_MAPPER.writeValueAsString(agentCard)); + } catch (Throwable t) { + return createErrorResponse(500, new InternalError(t.getMessage())); + } + } + + public static class HTTPRestResponse { + + private final int statusCode; + private final String contentType; + private final String body; + + public HTTPRestResponse(int statusCode, String contentType, String body) { + this.statusCode = statusCode; + this.contentType = contentType; + this.body = body; + } + + public int getStatusCode() { + return statusCode; + } + + public String getContentType() { + return contentType; + } + + public String getBody() { + return body; + } + + @Override + public String toString() { + return "HTTPRestResponse{" + "statusCode=" + statusCode + ", contentType=" + contentType + ", body=" + body + '}'; + } + } + + public static class HTTPRestStreamingResponse extends HTTPRestResponse { + + private final Flow.Publisher publisher; + + public HTTPRestStreamingResponse(Flow.Publisher publisher) { + super(200, "text/event-stream", null); + this.publisher = publisher; + } + + public Flow.Publisher getPublisher() { + return publisher; + } + } + + private static class HTTPRestErrorResponse { + + private final String error; + private final String message; + + public HTTPRestErrorResponse(String error, String message) { + this.error = error; + this.message = message; + } + + public HTTPRestErrorResponse(JSONRPCError jsonRpcError) { + this.error = jsonRpcError.getClass().getName(); + this.message = jsonRpcError.getMessage(); + } + + public String getError() { + return error; + } + + public String getMessage() { + return message; + } + + public String toJson() { + return "{\"error\": \"" + error + "\", \"message\": \"" + message + "\"}"; + } + + @Override + public String toString() { + return "HTTPRestErrorResponse{" + "error=" + error + ", message=" + message + '}'; + } + } +} diff --git a/transport/rest/src/main/resources/META-INF/beans.xml b/transport/rest/src/main/resources/META-INF/beans.xml new file mode 100644 index 000000000..0d9b1c17d --- /dev/null +++ b/transport/rest/src/main/resources/META-INF/beans.xml @@ -0,0 +1,7 @@ + + + diff --git a/transport/rest/src/main/resources/a2a-requesthandler-test.properties b/transport/rest/src/main/resources/a2a-requesthandler-test.properties new file mode 100644 index 000000000..723a7f87f --- /dev/null +++ b/transport/rest/src/main/resources/a2a-requesthandler-test.properties @@ -0,0 +1 @@ +preferred-transport=HTTP_JSON \ No newline at end of file diff --git a/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java b/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java new file mode 100644 index 000000000..d8b355dc3 --- /dev/null +++ b/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java @@ -0,0 +1,403 @@ +package io.a2a.transport.rest.handler; + +import com.google.protobuf.InvalidProtocolBufferException; +import java.util.Map; + +import io.a2a.server.ServerCallContext; +import io.a2a.server.auth.UnauthenticatedUser; +import io.a2a.server.requesthandlers.AbstractA2ARequestHandlerTest; +import io.a2a.spec.AgentCard; +import io.a2a.spec.Task; +import io.a2a.server.tasks.TaskUpdater; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Flow; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class RestHandlerTest extends AbstractA2ARequestHandlerTest { + + private final ServerCallContext callContext = new ServerCallContext(UnauthenticatedUser.INSTANCE, Map.of("foo", "bar")); + + @Test + public void testGetTaskSuccess() { + RestHandler handler = new RestHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + + RestHandler.HTTPRestResponse response = handler.getTask(MINIMAL_TASK.getId(),null, callContext); + + Assertions.assertEquals(200, response.getStatusCode()); + Assertions.assertEquals("application/json", response.getContentType()); + Assertions.assertTrue(response.getBody().contains(MINIMAL_TASK.getId())); + + response = handler.getTask(MINIMAL_TASK.getId(),2 , callContext); + + Assertions.assertEquals(200, response.getStatusCode()); + Assertions.assertEquals("application/json", response.getContentType()); + Assertions.assertTrue(response.getBody().contains(MINIMAL_TASK.getId())); + } + + @Test + public void testGetTaskNotFound() { + RestHandler handler = new RestHandler(CARD, requestHandler); + + RestHandler.HTTPRestResponse response = handler.getTask("nonexistent", null, callContext); + + Assertions.assertEquals(404, response.getStatusCode()); + Assertions.assertEquals("application/json", response.getContentType()); + Assertions.assertTrue(response.getBody().contains("TaskNotFoundError")); + } + + @Test + public void testSendMessage() throws InvalidProtocolBufferException { + RestHandler handler = new RestHandler(CARD, requestHandler); + agentExecutorExecute = (context, eventQueue) -> { + eventQueue.enqueueEvent(context.getMessage()); + }; + String requestBody = """ + { + "message": + { + "messageId": "message-1234", + "contextId": "context-1234", + "role": "ROLE_USER", + "content": [{ + "text": "tell me a joke" + }], + "metadata": { + } + }, + "configuration": + { + "blocking": true + } + }"""; + + RestHandler.HTTPRestResponse response = handler.sendMessage(requestBody, callContext); + Assertions.assertEquals(200, response.getStatusCode(), response.toString()); + Assertions.assertEquals("application/json", response.getContentType()); + Assertions.assertNotNull(response.getBody()); + } + + @Test + public void testSendMessageInvalidBody() { + RestHandler handler = new RestHandler(CARD, requestHandler); + + String invalidBody = "invalid json"; + RestHandler.HTTPRestResponse response = handler.sendMessage(invalidBody, callContext); + + Assertions.assertEquals(400, response.getStatusCode()); + Assertions.assertEquals("application/json", response.getContentType()); + Assertions.assertTrue(response.getBody().contains("JSONParseError"),response.getBody()); + } + + @Test + public void testSendMessageWrongValueBody() { + RestHandler handler = new RestHandler(CARD, requestHandler); + String requestBody = """ + { + "message": + { + "messageId": "message-1234", + "contextId": "context-1234", + "role": "user", + "content": [{ + "text": "tell me a joke" + }], + "metadata": { + } + } + }"""; + RestHandler.HTTPRestResponse response = handler.sendMessage(requestBody, callContext); + + Assertions.assertEquals(422, response.getStatusCode()); + Assertions.assertEquals("application/json", response.getContentType()); + Assertions.assertTrue(response.getBody().contains("InvalidParamsError")); + } + + @Test + public void testSendMessageEmptyBody() { + RestHandler handler = new RestHandler(CARD, requestHandler); + + RestHandler.HTTPRestResponse response = handler.sendMessage("", callContext); + + Assertions.assertEquals(400, response.getStatusCode()); + Assertions.assertEquals("application/json", response.getContentType()); + Assertions.assertTrue(response.getBody().contains("InvalidRequestError")); + } + + @Test + public void testCancelTaskSuccess() { + RestHandler handler = new RestHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + + agentExecutorCancel = (context, eventQueue) -> { + // We need to cancel the task or the EventConsumer never finds a 'final' event. + // Looking at the Python implementation, they typically use AgentExecutors that + // don't support cancellation. So my theory is the Agent updates the task to the CANCEL status + Task task = context.getTask(); + TaskUpdater taskUpdater = new TaskUpdater(context, eventQueue); + taskUpdater.cancel(); + }; + + RestHandler.HTTPRestResponse response = handler.cancelTask(MINIMAL_TASK.getId(), callContext); + + Assertions.assertEquals(200, response.getStatusCode()); + Assertions.assertEquals("application/json", response.getContentType()); + Assertions.assertTrue(response.getBody().contains(MINIMAL_TASK.getId())); + } + + @Test + public void testCancelTaskNotFound() { + RestHandler handler = new RestHandler(CARD, requestHandler); + + RestHandler.HTTPRestResponse response = handler.cancelTask("nonexistent", callContext); + + Assertions.assertEquals(404, response.getStatusCode()); + Assertions.assertEquals("application/json", response.getContentType()); + Assertions.assertTrue(response.getBody().contains("TaskNotFoundError")); + } + + @Test + public void testSendStreamingMessageSuccess() { + RestHandler handler = new RestHandler(CARD, requestHandler); + agentExecutorExecute = (context, eventQueue) -> { + eventQueue.enqueueEvent(context.getMessage()); + }; + String requestBody = """ + { + "message": { + "role": "ROLE_USER", + "content": [ + { + "text": "tell me some jokes" + } + ], + "messageId": "message-1234", + "contextId": "context-1234" + }, + "configuration": { + "acceptedOutputModes": ["text"] + } + }"""; + + RestHandler.HTTPRestResponse response = handler.sendStreamingMessage(requestBody, callContext); + Assertions.assertEquals(200, response.getStatusCode(), response.toString()); + Assertions.assertInstanceOf(RestHandler.HTTPRestStreamingResponse.class, response); + RestHandler.HTTPRestStreamingResponse streamingResponse = (RestHandler.HTTPRestStreamingResponse) response; + Assertions.assertNotNull(streamingResponse.getPublisher()); + Assertions.assertEquals("text/event-stream", streamingResponse.getContentType()); + } + + @Test + public void testSendStreamingMessageNotSupported() { + AgentCard card = createAgentCard(false, true, true); + RestHandler handler = new RestHandler(card, requestHandler); + + String requestBody = """ + { + "contextId": "ctx123", + "role": "ROLE_USER", + "content": [{ + "text": "Hello" + }] + } + """; + + RestHandler.HTTPRestResponse response = handler.sendStreamingMessage(requestBody, callContext); + + Assertions.assertEquals(400, response.getStatusCode()); + Assertions.assertTrue(response.getBody().contains("InvalidRequestError")); + } + + @Test + public void testPushNotificationConfigSuccess() { + RestHandler handler = new RestHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + + String requestBody = """ + { + "parent": "tasks/%s", + "config": { + "name": "tasks/%s/pushNotificationConfigs/", + "pushNotificationConfig": { + "url": "https://example.com/callback", + "authentication": { + "schemes": ["jwt"] + } + } + } + }""".formatted(MINIMAL_TASK.getId(), MINIMAL_TASK.getId()); + + RestHandler.HTTPRestResponse response = handler.setTaskPushNotificationConfiguration( MINIMAL_TASK.getId(), requestBody, callContext); + + Assertions.assertEquals(201, response.getStatusCode(), response.toString()); + Assertions.assertEquals("application/json", response.getContentType()); + Assertions.assertNotNull(response.getBody()); + } + + @Test + public void testPushNotificationConfigNotSupported() { + AgentCard card = createAgentCard(true, false, true); + RestHandler handler = new RestHandler(card, requestHandler); + + String requestBody = """ + { + "taskId": "%s", + "pushNotificationConfig": { + "url": "http://example.com" + } + } + """.formatted(MINIMAL_TASK.getId()); + + RestHandler.HTTPRestResponse response = handler.setTaskPushNotificationConfiguration(MINIMAL_TASK.getId(), requestBody, callContext); + + Assertions.assertEquals(501, response.getStatusCode()); + Assertions.assertTrue(response.getBody().contains("PushNotificationNotSupportedError")); + } + + @Test + public void testGetPushNotificationConfig() { + RestHandler handler = new RestHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + + // First, create a push notification config + String createRequestBody = """ + { + "parent": "tasks/%s", + "config": { + "name": "tasks/%s/pushNotificationConfigs/", + "pushNotificationConfig": { + "url": "https://example.com/callback", + "authentication": { + "schemes": ["jwt"] + } + } + } + }""".formatted(MINIMAL_TASK.getId(), MINIMAL_TASK.getId()); + RestHandler.HTTPRestResponse response = handler.setTaskPushNotificationConfiguration(MINIMAL_TASK.getId(), createRequestBody, callContext); + Assertions.assertEquals(201, response.getStatusCode(), response.toString()); + Assertions.assertEquals("application/json", response.getContentType()); + // Now get it + response = handler.getTaskPushNotificationConfiguration(MINIMAL_TASK.getId(), "default-config-id", callContext); + Assertions.assertEquals(200, response.getStatusCode(), response.toString()); + Assertions.assertEquals("application/json", response.getContentType()); + } + + @Test + public void testDeletePushNotificationConfig() { + RestHandler handler = new RestHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + RestHandler.HTTPRestResponse response = handler.deleteTaskPushNotificationConfiguration(MINIMAL_TASK.getId(), "default-config-id", callContext); + Assertions.assertEquals(204, response.getStatusCode()); + } + + @Test + public void testListPushNotificationConfigs() { + RestHandler handler = new RestHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + + RestHandler.HTTPRestResponse response = handler.listTaskPushNotificationConfigurations(MINIMAL_TASK.getId(), callContext); + + Assertions.assertEquals(200, response.getStatusCode()); + Assertions.assertEquals("application/json", response.getContentType()); + Assertions.assertNotNull(response.getBody()); + } + + @Test + public void testHttpStatusCodeMapping() { + RestHandler handler = new RestHandler(CARD, requestHandler); + + // Test 400 for invalid request + RestHandler.HTTPRestResponse response = handler.sendMessage("", callContext); + Assertions.assertEquals(400, response.getStatusCode()); + + // Test 404 for not found + response = handler.getTask("nonexistent", null, callContext); + Assertions.assertEquals(404, response.getStatusCode()); + } + + @Test + public void testStreamingDoesNotBlockMainThread() throws Exception { + RestHandler handler = new RestHandler(CARD, requestHandler); + + // Track if the main thread gets blocked during streaming + AtomicBoolean eventReceived = new AtomicBoolean(false); + CountDownLatch streamStarted = new CountDownLatch(1); + CountDownLatch eventProcessed = new CountDownLatch(1); + agentExecutorExecute = (context, eventQueue) -> { + // Wait a bit to ensure the main thread continues + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + eventQueue.enqueueEvent(context.getMessage()); + }; + + String requestBody = """ + { + "message": { + "role": "ROLE_USER", + "content": [ + { + "text": "tell me some jokes" + } + ], + "messageId": "message-1234", + "contextId": "context-1234" + }, + "configuration": { + "acceptedOutputModes": ["text"] + } + }"""; + + // Start streaming + RestHandler.HTTPRestResponse response = handler.sendStreamingMessage(requestBody, callContext); + + Assertions.assertEquals(200, response.getStatusCode()); + Assertions.assertInstanceOf(RestHandler.HTTPRestStreamingResponse.class, response); + + RestHandler.HTTPRestStreamingResponse streamingResponse = (RestHandler.HTTPRestStreamingResponse) response; + Flow.Publisher publisher = streamingResponse.getPublisher(); + publisher.subscribe(new Flow.Subscriber() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + streamStarted.countDown(); + subscription.request(1); + } + + @Override + public void onNext(String item) { + eventReceived.set(true); + eventProcessed.countDown(); + } + + @Override + public void onError(Throwable throwable) { + eventProcessed.countDown(); + } + + @Override + public void onComplete() { + eventProcessed.countDown(); + } + }); + + // The main thread should not be blocked - we should be able to continue immediately + Assertions.assertTrue(streamStarted.await(100, TimeUnit.MILLISECONDS), + "Streaming subscription should start quickly without blocking main thread"); + + // This proves the main thread is not blocked - we can do other work + // Simulate main thread doing other work + Thread.sleep(50); + + // Wait for the actual event processing to complete + Assertions.assertTrue(eventProcessed.await(2, TimeUnit.SECONDS), + "Event should be processed within reasonable time"); + + // Verify we received the event + Assertions.assertTrue(eventReceived.get(), "Should have received streaming event"); + } +} diff --git a/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestTestTransportMetadata.java b/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestTestTransportMetadata.java new file mode 100644 index 000000000..68aad41bb --- /dev/null +++ b/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestTestTransportMetadata.java @@ -0,0 +1,12 @@ +package io.a2a.transport.rest.handler; + +import io.a2a.server.TransportMetadata; +import io.a2a.spec.TransportProtocol; + +public class RestTestTransportMetadata implements TransportMetadata { + @Override + public String getTransportProtocol() { + return TransportProtocol.HTTP_JSON.asString(); + } + +} diff --git a/transport/rest/src/test/resources/META-INF/services/io.a2a.server.TransportMetadata b/transport/rest/src/test/resources/META-INF/services/io.a2a.server.TransportMetadata new file mode 100644 index 000000000..3604945b4 --- /dev/null +++ b/transport/rest/src/test/resources/META-INF/services/io.a2a.server.TransportMetadata @@ -0,0 +1,2 @@ +# Add a test TransportMetadata so we pass AgentCard validation +io.a2a.transport.rest.handler.RestTestTransportMetadata \ No newline at end of file