diff --git a/client/src/main/java/io/a2a/A2A.java b/client/src/main/java/io/a2a/A2A.java index c8dca80e4..cf422abce 100644 --- a/client/src/main/java/io/a2a/A2A.java +++ b/client/src/main/java/io/a2a/A2A.java @@ -4,13 +4,13 @@ import java.util.Map; import io.a2a.client.A2ACardResolver; -import io.a2a.http.A2AHttpClient; -import io.a2a.http.JdkA2AHttpClient; import io.a2a.spec.A2AClientError; import io.a2a.spec.A2AClientJSONError; import io.a2a.spec.AgentCard; import io.a2a.spec.Message; import io.a2a.spec.TextPart; +import io.a2a.transport.A2ATransport; +import io.a2a.transport.http.JdkA2AHttpTransport; /** @@ -80,20 +80,20 @@ private static Message toMessage(String text, Message.Role role, String messageI * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema */ public static AgentCard getAgentCard(String agentUrl) throws A2AClientError, A2AClientJSONError { - return getAgentCard(new JdkA2AHttpClient(), agentUrl); + return getAgentCard(new JdkA2AHttpTransport(), agentUrl); } /** * Get the agent card for an A2A agent. * - * @param httpClient the http client to use + * @param transport the transport to use * @param agentUrl the base URL for the agent whose agent card we want to retrieve * @return the agent card * @throws A2AClientError If an HTTP error occurs fetching the card * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema */ - public static AgentCard getAgentCard(A2AHttpClient httpClient, String agentUrl) throws A2AClientError, A2AClientJSONError { - return getAgentCard(httpClient, agentUrl, null, null); + public static AgentCard getAgentCard(A2ATransport transport, String agentUrl) throws A2AClientError, A2AClientJSONError { + return getAgentCard(transport, agentUrl, null, null); } /** @@ -108,13 +108,13 @@ public static AgentCard getAgentCard(A2AHttpClient httpClient, String agentUrl) * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema */ public static AgentCard getAgentCard(String agentUrl, String relativeCardPath, Map authHeaders) throws A2AClientError, A2AClientJSONError { - return getAgentCard(new JdkA2AHttpClient(), agentUrl, relativeCardPath, authHeaders); + return getAgentCard(new JdkA2AHttpTransport(), agentUrl, relativeCardPath, authHeaders); } /** * Get the agent card for an A2A agent. * - * @param httpClient the http client to use + * @param transport the transport to use * @param agentUrl the base URL for the agent whose agent card we want to retrieve * @param relativeCardPath optional path to the agent card endpoint relative to the base * agent URL, defaults to ".well-known/agent.json" @@ -123,8 +123,8 @@ public static AgentCard getAgentCard(String agentUrl, String relativeCardPath, M * @throws A2AClientError If an HTTP error occurs fetching the card * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema */ - public static AgentCard getAgentCard(A2AHttpClient httpClient, String agentUrl, String relativeCardPath, Map authHeaders) throws A2AClientError, A2AClientJSONError { - A2ACardResolver resolver = new A2ACardResolver(httpClient, agentUrl, relativeCardPath, authHeaders); + public static AgentCard getAgentCard(A2ATransport transport, String agentUrl, String relativeCardPath, Map authHeaders) throws A2AClientError, A2AClientJSONError { + A2ACardResolver resolver = new A2ACardResolver(transport, agentUrl, relativeCardPath, authHeaders); return resolver.getAgentCard(); } } diff --git a/client/src/main/java/io/a2a/client/A2ACardResolver.java b/client/src/main/java/io/a2a/client/A2ACardResolver.java index 1266f7219..740a33e92 100644 --- a/client/src/main/java/io/a2a/client/A2ACardResolver.java +++ b/client/src/main/java/io/a2a/client/A2ACardResolver.java @@ -1,20 +1,15 @@ package io.a2a.client; -import static io.a2a.util.Utils.unmarshalFrom; - -import java.io.IOException; import java.util.Map; -import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; -import io.a2a.http.A2AHttpClient; -import io.a2a.http.A2AHttpResponse; +import io.a2a.transport.A2ATransport; import io.a2a.spec.A2AClientError; import io.a2a.spec.A2AClientJSONError; import io.a2a.spec.AgentCard; public class A2ACardResolver { - private final A2AHttpClient httpClient; + private final A2ATransport transport; private final String url; private final Map authHeaders; @@ -22,32 +17,32 @@ public class A2ACardResolver { static final TypeReference AGENT_CARD_TYPE_REFERENCE = new TypeReference<>() {}; /** - * @param httpClient the http client to use + * @param transport the transport to use * @param baseUrl the base URL for the agent whose agent card we want to retrieve */ - public A2ACardResolver(A2AHttpClient httpClient, String baseUrl) { - this(httpClient, baseUrl, null, null); + public A2ACardResolver(A2ATransport transport, String baseUrl) { + this(transport, baseUrl, null, null); } /** - * @param httpClient the http client to use + * @param transport the transport to use * @param baseUrl the base URL for the agent whose agent card we want to retrieve * @param agentCardPath optional path to the agent card endpoint relative to the base * agent URL, defaults to ".well-known/agent.json" */ - public A2ACardResolver(A2AHttpClient httpClient, String baseUrl, String agentCardPath) { - this(httpClient, baseUrl, agentCardPath, null); + public A2ACardResolver(A2ATransport transport, String baseUrl, String agentCardPath) { + this(transport, baseUrl, agentCardPath, null); } /** - * @param httpClient the http client to use + * @param transport the transport to use * @param baseUrl the base URL for the agent whose agent card we want to retrieve * @param agentCardPath optional path to the agent card endpoint relative to the base * agent URL, defaults to ".well-known/agent.json" * @param authHeaders the HTTP authentication headers to use. May be {@code null} */ - public A2ACardResolver(A2AHttpClient httpClient, String baseUrl, String agentCardPath, Map authHeaders) { - this.httpClient = httpClient; + public A2ACardResolver(A2ATransport transport, String baseUrl, String agentCardPath, Map authHeaders) { + this.transport = transport; if (!baseUrl.endsWith("/")) { baseUrl += "/"; } @@ -67,33 +62,7 @@ public A2ACardResolver(A2AHttpClient httpClient, String baseUrl, String agentCar * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema */ public AgentCard getAgentCard() throws A2AClientError, A2AClientJSONError { - A2AHttpClient.GetBuilder builder = httpClient.createGet() - .url(url) - .addHeader("Content-Type", "application/json"); - - if (authHeaders != null) { - for (Map.Entry entry : authHeaders.entrySet()) { - builder.addHeader(entry.getKey(), entry.getValue()); - } - } - - String body; - try { - A2AHttpResponse response = builder.get(); - if (!response.success()) { - throw new A2AClientError("Failed to obtain agent card: " + response.status()); - } - body = response.body(); - } catch (IOException | InterruptedException e) { - throw new A2AClientError("Failed to obtain agent card", e); - } - - try { - return unmarshalFrom(body, AGENT_CARD_TYPE_REFERENCE); - } catch (JsonProcessingException e) { - throw new A2AClientJSONError("Could not unmarshal agent card response", e); - } - + return transport.getAgentCard(url, authHeaders); } diff --git a/client/src/main/java/io/a2a/client/A2AClient.java b/client/src/main/java/io/a2a/client/A2AClient.java index ea08baea4..40ed34479 100644 --- a/client/src/main/java/io/a2a/client/A2AClient.java +++ b/client/src/main/java/io/a2a/client/A2AClient.java @@ -1,6 +1,7 @@ package io.a2a.client; import static io.a2a.util.Assert.checkNotNullParam; +import static io.a2a.util.Utils.OBJECT_MAPPER; import java.io.IOException; import java.util.Map; @@ -11,9 +12,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import io.a2a.client.sse.SSEEventListener; -import io.a2a.http.A2AHttpClient; -import io.a2a.http.A2AHttpResponse; -import io.a2a.http.JdkA2AHttpClient; +import io.a2a.transport.A2ATransport; import io.a2a.A2A; import io.a2a.spec.A2AClientError; import io.a2a.spec.A2AClientJSONError; @@ -27,7 +26,6 @@ import io.a2a.spec.GetTaskResponse; import io.a2a.spec.JSONRPCError; import io.a2a.spec.JSONRPCMessage; -import io.a2a.spec.JSONRPCResponse; import io.a2a.spec.MessageSendParams; import io.a2a.spec.PushNotificationConfig; import io.a2a.spec.SendMessageRequest; @@ -40,7 +38,7 @@ import io.a2a.spec.TaskPushNotificationConfig; import io.a2a.spec.TaskQueryParams; import io.a2a.spec.TaskResubscriptionRequest; -import io.a2a.util.Utils; +import io.a2a.transport.http.JdkA2AHttpTransport; /** * An A2A client. @@ -52,7 +50,7 @@ public class A2AClient { private static final TypeReference CANCEL_TASK_RESPONSE_REFERENCE = new TypeReference<>() {}; private static final TypeReference GET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE = new TypeReference<>() {}; private static final TypeReference SET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE = new TypeReference<>() {}; - private final A2AHttpClient httpClient; + private final A2ATransport transport; private final String agentUrl; private AgentCard agentCard; @@ -66,7 +64,7 @@ public A2AClient(AgentCard agentCard) { checkNotNullParam("agentCard", agentCard); this.agentCard = agentCard; this.agentUrl = agentCard.url(); - this.httpClient = new JdkA2AHttpClient(); + this.transport = new JdkA2AHttpTransport(); } /** @@ -77,13 +75,25 @@ public A2AClient(AgentCard agentCard) { public A2AClient(String agentUrl) { checkNotNullParam("agentUrl", agentUrl); this.agentUrl = agentUrl; - this.httpClient = new JdkA2AHttpClient(); + this.transport = new JdkA2AHttpTransport(); + } + + /** + * Create a new A2AClient. + * + * @param agentUrl the URL for the A2A server this client will be communicating with + * @param transport the transport to use + */ + public A2AClient(String agentUrl, A2ATransport transport) { + checkNotNullParam("agentUrl", agentUrl); + this.agentUrl = agentUrl; + this.transport = transport; } /** * Fetches the agent card and initialises an A2A client. * - * @param httpClient the {@link A2AHttpClient} to use + * @param transport the {@link A2ATransport} to use * @param baseUrl the base URL of the agent's host * @param agentCardPath the path to the agent card endpoint, relative to the {@code baseUrl}. If {@code null}, the * value {@link A2ACardResolver#DEFAULT_AGENT_CARD_PATH} will be used @@ -91,9 +101,9 @@ public A2AClient(String agentUrl) { * @throws A2AClientError If an HTTP error occurs fetching the card * @throws A2AClientJSONError if the agent card response is invalid */ - public static A2AClient getClientFromAgentCardUrl(A2AHttpClient httpClient, String baseUrl, + public static A2AClient getClientFromAgentCardUrl(A2ATransport transport, String baseUrl, String agentCardPath) throws A2AClientError, A2AClientJSONError { - A2ACardResolver resolver = new A2ACardResolver(httpClient, baseUrl, agentCardPath); + A2ACardResolver resolver = new A2ACardResolver(transport, baseUrl, agentCardPath); AgentCard card = resolver.getAgentCard(); return new A2AClient(card); } @@ -108,7 +118,7 @@ public static A2AClient getClientFromAgentCardUrl(A2AHttpClient httpClient, Stri */ public AgentCard getAgentCard() throws A2AClientError, A2AClientJSONError { if (this.agentCard == null) { - this.agentCard = A2A.getAgentCard(this.httpClient, this.agentUrl); + this.agentCard = A2A.getAgentCard(this.transport, this.agentUrl); } return this.agentCard; } @@ -124,7 +134,7 @@ public AgentCard getAgentCard() throws A2AClientError, A2AClientJSONError { */ public AgentCard getAgentCard(String relativeCardPath, Map authHeaders) throws A2AClientError, A2AClientJSONError { if (this.agentCard == null) { - this.agentCard = A2A.getAgentCard(this.httpClient, this.agentUrl, relativeCardPath, authHeaders); + this.agentCard = A2A.getAgentCard(this.transport, this.agentUrl, relativeCardPath, authHeaders); } return this.agentCard; } @@ -161,8 +171,7 @@ public SendMessageResponse sendMessage(String requestId, MessageSendParams messa SendMessageRequest sendMessageRequest = sendMessageRequestBuilder.build(); try { - String httpResponseBody = sendPostRequest(sendMessageRequest); - return unmarshalResponse(httpResponseBody, SEND_MESSAGE_RESPONSE_REFERENCE); + return transport.sendMessage(sendMessageRequest, agentUrl, SEND_MESSAGE_RESPONSE_REFERENCE); } catch (IOException | InterruptedException e) { throw new A2AServerException("Failed to send message: " + e); } @@ -213,8 +222,7 @@ public GetTaskResponse getTask(String requestId, TaskQueryParams taskQueryParams GetTaskRequest getTaskRequest = getTaskRequestBuilder.build(); try { - String httpResponseBody = sendPostRequest(getTaskRequest); - return unmarshalResponse(httpResponseBody, GET_TASK_RESPONSE_REFERENCE); + return transport.sendMessage(getTaskRequest, agentUrl, GET_TASK_RESPONSE_REFERENCE); } catch (IOException | InterruptedException e) { throw new A2AServerException("Failed to get task: " + e); } @@ -263,8 +271,7 @@ public CancelTaskResponse cancelTask(String requestId, TaskIdParams taskIdParams CancelTaskRequest cancelTaskRequest = cancelTaskRequestBuilder.build(); try { - String httpResponseBody = sendPostRequest(cancelTaskRequest); - return unmarshalResponse(httpResponseBody, CANCEL_TASK_RESPONSE_REFERENCE); + return transport.sendMessage(cancelTaskRequest, agentUrl, CANCEL_TASK_RESPONSE_REFERENCE); } catch (IOException | InterruptedException e) { throw new A2AServerException("Failed to cancel task: " + e); } @@ -313,8 +320,7 @@ public GetTaskPushNotificationConfigResponse getTaskPushNotificationConfig(Strin GetTaskPushNotificationConfigRequest getTaskPushNotificationRequest = getTaskPushNotificationRequestBuilder.build(); try { - String httpResponseBody = sendPostRequest(getTaskPushNotificationRequest); - return unmarshalResponse(httpResponseBody, GET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE); + return transport.sendMessage(getTaskPushNotificationRequest, agentUrl, GET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE); } catch (IOException | InterruptedException e) { throw new A2AServerException("Failed to get task push notification config: " + e); } @@ -356,8 +362,7 @@ public SetTaskPushNotificationConfigResponse setTaskPushNotificationConfig(Strin SetTaskPushNotificationConfigRequest setTaskPushNotificationRequest = setTaskPushNotificationRequestBuilder.build(); try { - String httpResponseBody = sendPostRequest(setTaskPushNotificationRequest); - return unmarshalResponse(httpResponseBody, SET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE); + return transport.sendMessage(setTaskPushNotificationRequest, agentUrl, SET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE); } catch (IOException | InterruptedException e) { throw new A2AServerException("Failed to set task push notification config: " + e); } @@ -388,7 +393,7 @@ public void sendStreamingMessage(MessageSendParams messageSendParams, Consumer eventHandler, - Consumer errorHandler, Runnable failureHandler) throws A2AServerException { + Consumer errorHandler, Runnable failureHandler) throws A2AServerException { checkNotNullParam("messageSendParams", messageSendParams); checkNotNullParam("eventHandler", eventHandler); checkNotNullParam("errorHandler", errorHandler); @@ -407,14 +412,12 @@ public void sendStreamingMessage(String requestId, MessageSendParams messageSend SSEEventListener sseEventListener = new SSEEventListener(eventHandler, errorHandler, failureHandler); SendStreamingMessageRequest sendStreamingMessageRequest = sendStreamingMessageRequestBuilder.build(); try { - A2AHttpClient.PostBuilder builder = createPostBuilder(sendStreamingMessageRequest); - ref.set(builder.postAsyncSSE( - msg -> sseEventListener.onMessage(msg, ref.get()), + transport.sendMessageStreaming(sendStreamingMessageRequest, agentUrl, SEND_MESSAGE_RESPONSE_REFERENCE, + response -> sseEventListener.onMessage(response, ref.get()), throwable -> sseEventListener.onError(throwable, ref.get()), () -> { // We don't need to do anything special on completion - })); - + }); } catch (IOException e) { throw new A2AServerException("Failed to send streaming message request: " + e); } catch (InterruptedException e) { @@ -466,45 +469,16 @@ public void resubscribeToTask(String requestId, TaskIdParams taskIdParams, Consu SSEEventListener sseEventListener = new SSEEventListener(eventHandler, errorHandler, failureHandler); TaskResubscriptionRequest taskResubscriptionRequest = taskResubscriptionRequestBuilder.build(); try { - A2AHttpClient.PostBuilder builder = createPostBuilder(taskResubscriptionRequest); - ref.set(builder.postAsyncSSE( - msg -> sseEventListener.onMessage(msg, ref.get()), + transport.sendMessageStreaming(taskResubscriptionRequest, agentUrl, GET_TASK_RESPONSE_REFERENCE, + response -> sseEventListener.onMessage(response, ref.get()), throwable -> sseEventListener.onError(throwable, ref.get()), () -> { // We don't need to do anything special on completion - })); - + }); } catch (IOException e) { throw new A2AServerException("Failed to send task resubscription request: " + e); } catch (InterruptedException e) { throw new A2AServerException("Task resubscription request timed out: " + e); } } - - private String sendPostRequest(Object value) throws IOException, InterruptedException { - A2AHttpClient.PostBuilder builder = createPostBuilder(value); - A2AHttpResponse response = builder.post(); - if (!response.success()) { - throw new IOException("Request failed " + response.status()); - } - return response.body(); - } - - private A2AHttpClient.PostBuilder createPostBuilder(Object value) throws JsonProcessingException { - return httpClient.createPost() - .url(agentUrl) - .addHeader("Content-Type", "application/json") - .body(Utils.OBJECT_MAPPER.writeValueAsString(value)); - - } - - private T unmarshalResponse(String response, TypeReference typeReference) - throws A2AServerException, JsonProcessingException { - T value = Utils.unmarshalFrom(response, typeReference); - JSONRPCError error = value.getError(); - if (error != null) { - throw new A2AServerException(error.getMessage() + (error.getData() != null ? ": " + error.getData() : "")); - } - return value; - } } diff --git a/client/src/main/java/io/a2a/client/sse/SSEEventListener.java b/client/src/main/java/io/a2a/client/sse/SSEEventListener.java index 8ed0e9aa3..8485d2da4 100644 --- a/client/src/main/java/io/a2a/client/sse/SSEEventListener.java +++ b/client/src/main/java/io/a2a/client/sse/SSEEventListener.java @@ -7,8 +7,9 @@ import java.util.logging.Logger; import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.JsonNode; import io.a2a.spec.JSONRPCError; +import io.a2a.spec.JSONRPCResponse; +import io.a2a.spec.SendStreamingMessageResponse; import io.a2a.spec.StreamingEventKind; import io.a2a.spec.TaskStatusUpdateEvent; @@ -26,36 +27,34 @@ public SSEEventListener(Consumer eventHandler, Consumer completableFuture) { try { - handleMessage(OBJECT_MAPPER.readTree(message),completableFuture); + SendStreamingMessageResponse sendStreamingMessageResponse = OBJECT_MAPPER.readValue(message, SendStreamingMessageResponse.class); + handleMessage(sendStreamingMessageResponse,completableFuture); } catch (JsonProcessingException e) { log.warning("Failed to parse JSON message: " + message); } } + public void onMessage(JSONRPCResponse response, Future completableFuture) { + handleMessage(response,completableFuture); + } + public void onError(Throwable throwable, Future future) { failureHandler.run(); future.cancel(true); // close SSE channel } - private void handleMessage(JsonNode jsonNode, Future future) { - try { - if (jsonNode.has("error")) { - JSONRPCError error = OBJECT_MAPPER.treeToValue(jsonNode.get("error"), JSONRPCError.class); - errorHandler.accept(error); - } else if (jsonNode.has("result")) { - // result can be a Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent - JsonNode result = jsonNode.path("result"); - StreamingEventKind event = OBJECT_MAPPER.treeToValue(result, StreamingEventKind.class); - eventHandler.accept(event); - if (event instanceof TaskStatusUpdateEvent && ((TaskStatusUpdateEvent) event).isFinal()) { - future.cancel(true); // close SSE channel - } - } else { - throw new IllegalArgumentException("Unknown message type"); + private void handleMessage(JSONRPCResponse response, Future future) { + if (null != response.getError()) { + errorHandler.accept(response.getError()); + } else if (null != response.getResult()) { + // result can be a Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent + StreamingEventKind event = (StreamingEventKind) response.getResult(); + eventHandler.accept(event); + if (event instanceof TaskStatusUpdateEvent && ((TaskStatusUpdateEvent) event).isFinal()) { + future.cancel(true); // close SSE channel } - } catch (JsonProcessingException e) { - throw new RuntimeException(e); + } else { + throw new IllegalArgumentException("Unknown message type"); } } - } diff --git a/client/src/main/java/io/a2a/transport/A2ATransport.java b/client/src/main/java/io/a2a/transport/A2ATransport.java new file mode 100644 index 000000000..e5d154ad8 --- /dev/null +++ b/client/src/main/java/io/a2a/transport/A2ATransport.java @@ -0,0 +1,31 @@ +package io.a2a.transport; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.a2a.spec.A2AClientError; +import io.a2a.spec.AgentCard; +import io.a2a.spec.Event; +import io.a2a.spec.JSONRPCRequest; +import io.a2a.spec.JSONRPCResponse; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +public interface A2ATransport { + + AgentCard getAgentCard(String method, Map authInfo) throws A2AClientError; + + void sendEvent(Event event, String method) throws IOException, InterruptedException; + + > T sendMessage( + JSONRPCRequest request, String operation, TypeReference responseTypeRef) throws IOException, InterruptedException; + + > CompletableFuture sendMessageStreaming( + JSONRPCRequest request, + String operation, + TypeReference responseTypeRef, + Consumer responseConsumer, + Consumer errorConsumer, + Runnable completeRunnable) throws IOException, InterruptedException; +} diff --git a/client/src/main/java/io/a2a/http/A2AHttpResponse.java b/client/src/main/java/io/a2a/transport/http/A2AHttpResponse.java similarity index 76% rename from client/src/main/java/io/a2a/http/A2AHttpResponse.java rename to client/src/main/java/io/a2a/transport/http/A2AHttpResponse.java index d6973a5dc..f813c2e8c 100644 --- a/client/src/main/java/io/a2a/http/A2AHttpResponse.java +++ b/client/src/main/java/io/a2a/transport/http/A2AHttpResponse.java @@ -1,4 +1,4 @@ -package io.a2a.http; +package io.a2a.transport.http; public interface A2AHttpResponse { int status(); diff --git a/client/src/main/java/io/a2a/http/A2AHttpClient.java b/client/src/main/java/io/a2a/transport/http/A2AHttpTransport.java similarity index 89% rename from client/src/main/java/io/a2a/http/A2AHttpClient.java rename to client/src/main/java/io/a2a/transport/http/A2AHttpTransport.java index 7a246843a..ef930d359 100644 --- a/client/src/main/java/io/a2a/http/A2AHttpClient.java +++ b/client/src/main/java/io/a2a/transport/http/A2AHttpTransport.java @@ -1,10 +1,12 @@ -package io.a2a.http; +package io.a2a.transport.http; + +import io.a2a.transport.A2ATransport; import java.io.IOException; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; -public interface A2AHttpClient { +public interface A2AHttpTransport extends A2ATransport { GetBuilder createGet(); diff --git a/client/src/main/java/io/a2a/http/JdkA2AHttpClient.java b/client/src/main/java/io/a2a/transport/http/JdkA2AHttpTransport.java similarity index 59% rename from client/src/main/java/io/a2a/http/JdkA2AHttpClient.java rename to client/src/main/java/io/a2a/transport/http/JdkA2AHttpTransport.java index e3b5c0c66..3ac6883e5 100644 --- a/client/src/main/java/io/a2a/http/JdkA2AHttpClient.java +++ b/client/src/main/java/io/a2a/transport/http/JdkA2AHttpTransport.java @@ -1,12 +1,22 @@ -package io.a2a.http; +package io.a2a.transport.http; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import io.a2a.spec.A2AClientError; +import io.a2a.spec.A2AClientJSONError; +import io.a2a.spec.A2AServerException; +import io.a2a.spec.AgentCard; +import io.a2a.spec.Event; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.JSONRPCRequest; +import io.a2a.spec.JSONRPCResponse; +import io.a2a.util.Utils; import java.io.IOException; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; -import java.net.http.HttpResponse.BodyHandler; -import java.net.http.HttpResponse.BodyHandlers; import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; @@ -14,17 +24,89 @@ import java.util.concurrent.Flow; import java.util.function.Consumer; -public class JdkA2AHttpClient implements A2AHttpClient { +import static io.a2a.util.Utils.OBJECT_MAPPER; +import static io.a2a.util.Utils.unmarshalFrom; + +public class JdkA2AHttpTransport implements A2AHttpTransport { + private static final TypeReference AGENT_CARD_TYPE_REFERENCE = new TypeReference<>() { }; + private final HttpClient httpClient; - public JdkA2AHttpClient() { + public JdkA2AHttpTransport() { httpClient = HttpClient.newBuilder() .version(HttpClient.Version.HTTP_2) .followRedirects(HttpClient.Redirect.NORMAL) .build(); } + @Override + public AgentCard getAgentCard(String method, Map authInfo) throws A2AClientError { + GetBuilder builder = createGet() + .url(method) + .addHeader("Content-Type", "application/json"); + + if (authInfo != null) { + for (Map.Entry entry : authInfo.entrySet()) { + builder.addHeader(entry.getKey(), entry.getValue()); + } + } + + String body; + try { + A2AHttpResponse response = builder.get(); + if (!response.success()) { + throw new A2AClientError("Failed to obtain agent card: " + response.status()); + } + body = response.body(); + } catch (IOException | InterruptedException e) { + throw new A2AClientError("Failed to obtain agent card", e); + } + + try { + return unmarshalFrom(body, AGENT_CARD_TYPE_REFERENCE); + } catch (JsonProcessingException e) { + throw new A2AClientJSONError("Could not unmarshal agent card response", e); + } + } + + @Override + public void sendEvent(Event event, String method) throws IOException, InterruptedException { + String body = Utils.OBJECT_MAPPER.writeValueAsString(event); + createPost().url(method).body(body).post(); + } + + @Override + public > T sendMessage( + JSONRPCRequest request, String operation, TypeReference responseTypeRef) throws IOException, InterruptedException { + + PostBuilder postBuilder = createPostBuilder(request, operation); + A2AHttpResponse response = postBuilder.post(); + + if (!response.success()) { + throw new IOException("Request failed " + response.status()); + } + + return unmarshalResponse(response.body(), responseTypeRef); + } + + @Override + public > CompletableFuture sendMessageStreaming( + JSONRPCRequest request, String operation, TypeReference responseTypeRef, + Consumer responseConsumer, Consumer errorConsumer, Runnable completeRunnable) throws IOException, InterruptedException { + PostBuilder postBuilder = createPostBuilder(request, operation); + + return postBuilder.postAsyncSSE(message -> { + try { + T response = unmarshalResponse(message, responseTypeRef); + responseConsumer.accept(response); + } catch (Exception e) { + throw new RuntimeException(e); + } + }, errorConsumer, completeRunnable); + } + + @Override public GetBuilder createGet() { return new JdkGetBuilder(); @@ -35,9 +117,27 @@ public PostBuilder createPost() { return new JdkPostBuilder(); } + + private PostBuilder createPostBuilder(JSONRPCRequest request, String method) throws JsonProcessingException { + return createPost() + .url(method) + .addHeader("Content-Type", "application/json") + .body(OBJECT_MAPPER.writeValueAsString(request)); + } + + private > T unmarshalResponse(String response, TypeReference typeReference) + throws A2AServerException, JsonProcessingException { + T value = unmarshalFrom(response, typeReference); + JSONRPCError error = value.getError(); + if (error != null) { + throw new A2AServerException(error.getMessage() + (error.getData() != null ? ": " + error.getData() : "")); + } + return value; + } + private abstract class JdkBuilder> implements Builder { private String url; - private Map headers = new HashMap<>(); + private final Map headers = new HashMap<>(); @Override public T url(String url) { @@ -105,7 +205,7 @@ public void onComplete() { } }; - BodyHandler bodyHandler = BodyHandlers.fromLineSubscriber(subscriber); + HttpResponse.BodyHandler bodyHandler = HttpResponse.BodyHandlers.fromLineSubscriber(subscriber); // Send the response async, and let the subscriber handle the lines. return httpClient.sendAsync(request, bodyHandler) @@ -117,7 +217,7 @@ public void onComplete() { } } - private class JdkGetBuilder extends JdkBuilder implements A2AHttpClient.GetBuilder { + private class JdkGetBuilder extends JdkBuilder implements A2AHttpTransport.GetBuilder { private HttpRequest.Builder createRequestBuilder(boolean SSE) throws IOException { HttpRequest.Builder builder = super.createRequestBuilder().GET(); @@ -132,7 +232,7 @@ public A2AHttpResponse get() throws IOException, InterruptedException { HttpRequest request = createRequestBuilder(false) .build(); HttpResponse response = - httpClient.send(request, BodyHandlers.ofString(StandardCharsets.UTF_8)); + httpClient.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); return new JdkHttpResponse(response); } @@ -147,7 +247,7 @@ public CompletableFuture getAsyncSSE( } } - private class JdkPostBuilder extends JdkBuilder implements A2AHttpClient.PostBuilder { + private class JdkPostBuilder extends JdkBuilder implements A2AHttpTransport.PostBuilder { String body = ""; @Override @@ -171,7 +271,7 @@ public A2AHttpResponse post() throws IOException, InterruptedException { .POST(HttpRequest.BodyPublishers.ofString(body, StandardCharsets.UTF_8)) .build(); HttpResponse response = - httpClient.send(request, BodyHandlers.ofString(StandardCharsets.UTF_8)); + httpClient.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); return new JdkHttpResponse(response); } diff --git a/client/src/test/java/io/a2a/client/A2ACardResolverTest.java b/client/src/test/java/io/a2a/client/A2ACardResolverTest.java index 8265b9514..d0d43b15e 100644 --- a/client/src/test/java/io/a2a/client/A2ACardResolverTest.java +++ b/client/src/test/java/io/a2a/client/A2ACardResolverTest.java @@ -11,55 +11,55 @@ import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; -import io.a2a.http.A2AHttpClient; -import io.a2a.http.A2AHttpResponse; +import io.a2a.transport.http.A2AHttpResponse; import io.a2a.spec.A2AClientError; import io.a2a.spec.A2AClientJSONError; import io.a2a.spec.AgentCard; +import io.a2a.transport.http.JdkA2AHttpTransport; import org.junit.jupiter.api.Test; public class A2ACardResolverTest { @Test public void testConstructorStripsSlashes() throws Exception { - TestHttpClient client = new TestHttpClient(); - client.body = JsonMessages.AGENT_CARD; + TestHttpTransport transport = new TestHttpTransport(); + transport.body = JsonMessages.AGENT_CARD; - A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); + A2ACardResolver resolver = new A2ACardResolver(transport, "http://example.com/"); AgentCard card = resolver.getAgentCard(); - assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, transport.url); - resolver = new A2ACardResolver(client, "http://example.com"); + resolver = new A2ACardResolver(transport, "http://example.com"); card = resolver.getAgentCard(); - assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, transport.url); - resolver = new A2ACardResolver(client, "http://example.com/", A2ACardResolver.DEFAULT_AGENT_CARD_PATH); + resolver = new A2ACardResolver(transport, "http://example.com/", A2ACardResolver.DEFAULT_AGENT_CARD_PATH); card = resolver.getAgentCard(); - assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, transport.url); - resolver = new A2ACardResolver(client, "http://example.com", A2ACardResolver.DEFAULT_AGENT_CARD_PATH); + resolver = new A2ACardResolver(transport, "http://example.com", A2ACardResolver.DEFAULT_AGENT_CARD_PATH); card = resolver.getAgentCard(); - assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, transport.url); - resolver = new A2ACardResolver(client, "http://example.com/", A2ACardResolver.DEFAULT_AGENT_CARD_PATH.substring(0)); + resolver = new A2ACardResolver(transport, "http://example.com/", A2ACardResolver.DEFAULT_AGENT_CARD_PATH.substring(0)); card = resolver.getAgentCard(); - assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, transport.url); - resolver = new A2ACardResolver(client, "http://example.com", A2ACardResolver.DEFAULT_AGENT_CARD_PATH.substring(0)); + resolver = new A2ACardResolver(transport, "http://example.com", A2ACardResolver.DEFAULT_AGENT_CARD_PATH.substring(0)); card = resolver.getAgentCard(); - assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, transport.url); } @Test public void testGetAgentCardSuccess() throws Exception { - TestHttpClient client = new TestHttpClient(); + TestHttpTransport client = new TestHttpTransport(); client.body = JsonMessages.AGENT_CARD; A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); @@ -74,7 +74,7 @@ public void testGetAgentCardSuccess() throws Exception { @Test public void testGetAgentCardJsonDecodeError() throws Exception { - TestHttpClient client = new TestHttpClient(); + TestHttpTransport client = new TestHttpTransport(); client.body = "X" + JsonMessages.AGENT_CARD; A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); @@ -91,7 +91,7 @@ public void testGetAgentCardJsonDecodeError() throws Exception { @Test public void testGetAgentCardRequestError() throws Exception { - TestHttpClient client = new TestHttpClient(); + TestHttpTransport client = new TestHttpTransport(); client.status = 503; A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); @@ -105,7 +105,7 @@ public void testGetAgentCardRequestError() throws Exception { assertTrue(msg.contains("503")); } - private static class TestHttpClient implements A2AHttpClient { + private static class TestHttpTransport extends JdkA2AHttpTransport { int status = 200; String body; String url; @@ -120,7 +120,7 @@ public PostBuilder createPost() { return null; } - class TestGetBuilder implements A2AHttpClient.GetBuilder { + class TestGetBuilder implements GetBuilder { @Override public A2AHttpResponse get() throws IOException, InterruptedException { @@ -161,4 +161,4 @@ public GetBuilder addHeader(String name, String value) { } } -} +} \ No newline at end of file diff --git a/sdk-quarkus/src/test/resources/application.properties b/sdk-quarkus/src/test/resources/application.properties index d3366bece..7fe85eb85 100644 --- a/sdk-quarkus/src/test/resources/application.properties +++ b/sdk-quarkus/src/test/resources/application.properties @@ -1 +1 @@ -quarkus.arc.selected-alternatives=io.a2a.server.apps.common.TestHttpClient \ No newline at end of file +quarkus.arc.selected-alternatives=io.a2a.server.apps.common.TestHttpTransport \ No newline at end of file diff --git a/sdk-server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotifier.java b/sdk-server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotifier.java index 6fb1fb39a..994e09342 100644 --- a/sdk-server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotifier.java +++ b/sdk-server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotifier.java @@ -5,28 +5,26 @@ import java.util.HashMap; import java.util.Map; +import io.a2a.transport.A2ATransport; +import io.a2a.transport.http.JdkA2AHttpTransport; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; -import com.fasterxml.jackson.core.JsonProcessingException; -import io.a2a.http.A2AHttpClient; -import io.a2a.http.JdkA2AHttpClient; import io.a2a.spec.PushNotificationConfig; import io.a2a.spec.Task; -import io.a2a.util.Utils; @ApplicationScoped public class InMemoryPushNotifier implements PushNotifier { - private final A2AHttpClient httpClient; + private final A2ATransport transport; private final Map pushNotificationInfos = Collections.synchronizedMap(new HashMap<>()); @Inject public InMemoryPushNotifier() { - this.httpClient = new JdkA2AHttpClient(); + this.transport = new JdkA2AHttpTransport(); } - public InMemoryPushNotifier(A2AHttpClient httpClient) { - this.httpClient = httpClient; + public InMemoryPushNotifier(A2ATransport transport) { + this.transport = transport; } @Override @@ -54,22 +52,8 @@ public void sendNotification(Task task) { // TODO auth - String body; try { - body = Utils.OBJECT_MAPPER.writeValueAsString(task); - } catch (JsonProcessingException e) { - e.printStackTrace(); - throw new RuntimeException("Error writing value as string: " + e.getMessage(), e); - } catch (Throwable throwable) { - throwable.printStackTrace(); - throw new RuntimeException("Error writing value as string: " + throwable.getMessage(), throwable); - } - - try { - httpClient.createPost() - .url(url) - .body(body) - .post(); + transport.sendEvent(task, url); } catch (IOException | InterruptedException e) { throw new RuntimeException("Error pushing data to " + url + ": " + e.getMessage(), e); } diff --git a/sdk-server-common/src/test/java/io/a2a/server/requesthandlers/JSONRPCHandlerTest.java b/sdk-server-common/src/test/java/io/a2a/server/requesthandlers/JSONRPCHandlerTest.java index 49913c0b6..10b367fe8 100644 --- a/sdk-server-common/src/test/java/io/a2a/server/requesthandlers/JSONRPCHandlerTest.java +++ b/sdk-server-common/src/test/java/io/a2a/server/requesthandlers/JSONRPCHandlerTest.java @@ -22,10 +22,10 @@ import java.util.function.Consumer; import io.a2a.spec.InternalError; +import io.a2a.transport.http.JdkA2AHttpTransport; import jakarta.enterprise.context.Dependent; -import io.a2a.http.A2AHttpClient; -import io.a2a.http.A2AHttpResponse; +import io.a2a.transport.http.A2AHttpResponse; import io.a2a.server.agentexecution.AgentExecutor; import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventConsumer; @@ -47,7 +47,6 @@ import io.a2a.spec.GetTaskPushNotificationConfigResponse; import io.a2a.spec.GetTaskRequest; import io.a2a.spec.GetTaskResponse; -import io.a2a.spec.InternalError; import io.a2a.spec.InvalidRequestError; import io.a2a.spec.JSONRPCError; import io.a2a.spec.Message; @@ -104,7 +103,7 @@ public class JSONRPCHandlerTest { AgentExecutorMethod agentExecutorExecute; AgentExecutorMethod agentExecutorCancel; private InMemoryQueueManager queueManager; - private TestHttpClient httpClient; + private TestHttpTransport transport; private final Executor internalExecutor = Executors.newCachedThreadPool(); @@ -129,8 +128,8 @@ public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPC taskStore = new InMemoryTaskStore(); queueManager = new InMemoryQueueManager(); - httpClient = new TestHttpClient(); - PushNotifier pushNotifier = new InMemoryPushNotifier(httpClient); + transport = new TestHttpTransport(); + PushNotifier pushNotifier = new InMemoryPushNotifier(transport); requestHandler = new DefaultRequestHandler(executor, taskStore, queueManager, pushNotifier, internalExecutor); } @@ -693,7 +692,7 @@ public void testOnMessageStreamNewMessageSendPushNotificationSuccess() throws Ex final List results = Collections.synchronizedList(new ArrayList<>()); final AtomicReference subscriptionRef = new AtomicReference<>(); final CountDownLatch latch = new CountDownLatch(6); - httpClient.latch = latch; + transport.latch = latch; Executors.newSingleThreadExecutor().execute(() -> { response.subscribe(new Flow.Subscriber<>() { @@ -727,15 +726,15 @@ public void onComplete() { assertTrue(latch.await(5, TimeUnit.SECONDS)); subscriptionRef.get().cancel(); assertEquals(3, results.size()); - assertEquals(3, httpClient.tasks.size()); + assertEquals(3, transport.tasks.size()); - Task curr = httpClient.tasks.get(0); + Task curr = transport.tasks.get(0); assertEquals(MINIMAL_TASK.getId(), curr.getId()); assertEquals(MINIMAL_TASK.getContextId(), curr.getContextId()); assertEquals(MINIMAL_TASK.getStatus().state(), curr.getStatus().state()); assertEquals(0, curr.getArtifacts() == null ? 0 : curr.getArtifacts().size()); - curr = httpClient.tasks.get(1); + curr = transport.tasks.get(1); assertEquals(MINIMAL_TASK.getId(), curr.getId()); assertEquals(MINIMAL_TASK.getContextId(), curr.getContextId()); assertEquals(MINIMAL_TASK.getStatus().state(), curr.getStatus().state()); @@ -743,7 +742,7 @@ public void onComplete() { assertEquals(1, curr.getArtifacts().get(0).parts().size()); assertEquals("text", ((TextPart)curr.getArtifacts().get(0).parts().get(0)).getText()); - curr = httpClient.tasks.get(2); + curr = transport.tasks.get(2); assertEquals(MINIMAL_TASK.getId(), curr.getId()); assertEquals(MINIMAL_TASK.getContextId(), curr.getContextId()); assertEquals(TaskState.COMPLETED, curr.getStatus().state()); @@ -1265,7 +1264,7 @@ private interface AgentExecutorMethod { @Dependent @IfBuildProfile("test") - private static class TestHttpClient implements A2AHttpClient { + private static class TestHttpTransport extends JdkA2AHttpTransport { final List tasks = Collections.synchronizedList(new ArrayList<>()); volatile CountDownLatch latch; @@ -1279,7 +1278,7 @@ public PostBuilder createPost() { return new TestPostBuilder(); } - class TestPostBuilder implements A2AHttpClient.PostBuilder { + class TestPostBuilder implements PostBuilder { private volatile String body; @Override public PostBuilder body(String body) { diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpTransport.java similarity index 91% rename from tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java rename to tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpTransport.java index c5deef68f..2200e40a9 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpTransport.java @@ -8,17 +8,17 @@ import java.util.concurrent.CountDownLatch; import java.util.function.Consumer; +import io.a2a.transport.http.JdkA2AHttpTransport; import jakarta.enterprise.context.Dependent; import jakarta.enterprise.inject.Alternative; -import io.a2a.http.A2AHttpClient; -import io.a2a.http.A2AHttpResponse; +import io.a2a.transport.http.A2AHttpResponse; import io.a2a.spec.Task; import io.a2a.util.Utils; @Dependent @Alternative -public class TestHttpClient implements A2AHttpClient { +public class TestHttpTransport extends JdkA2AHttpTransport { final List tasks = Collections.synchronizedList(new ArrayList<>()); volatile CountDownLatch latch; @@ -32,7 +32,7 @@ public PostBuilder createPost() { return new TestPostBuilder(); } - class TestPostBuilder implements A2AHttpClient.PostBuilder { + class TestPostBuilder implements PostBuilder { private volatile String body; @Override public PostBuilder body(String body) {