diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 9dec62a96..bdc097791 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -4,6 +4,7 @@ on: # Handle all branches for now push: pull_request: + workflow_dispatch: # Only run the latest job concurrency: diff --git a/.github/workflows/release-to-maven-central.yml b/.github/workflows/release-to-maven-central.yml new file mode 100644 index 000000000..c0f5a9724 --- /dev/null +++ b/.github/workflows/release-to-maven-central.yml @@ -0,0 +1,52 @@ +name: Publish release to Maven Central + +on: + push: + tags: + - 'v?[0-9]+.[0-9]+.[0-9]+*' # Trigger on tags like v1.0.0, 1.2.3, v1.2.3.Alpha1 etc. + +jobs: + publish: + # Only run this job for the main repository, not for forks + if: github.repository == 'a2aproject/a2a-java' + runs-on: ubuntu-latest + permissions: + contents: read + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: maven + + # Use secrets to import GPG key + - name: Import GPG key + uses: crazy-max/ghaction-import-gpg@v6 + with: + gpg_private_key: ${{ secrets.GPG_SIGNING_KEY }} + passphrase: ${{ secrets.GPG_SIGNING_PASSPHRASE }} + + # Create settings.xml for Maven since it needs the 'central-a2asdk-temp' server. + # Populate wqith username and password from secrets + - name: Create settings.xml + run: | + mkdir -p ~/.m2 + echo "central-a2asdk-temp${{ secrets.CENTRAL_TOKEN_USERNAME }}${{ secrets.CENTRAL_TOKEN_PASSWORD }}" > ~/.m2/settings.xml + + # Deploy to Maven Central + # -s uses the settings file we created. + - name: Publish to Maven Central + run: > + mvn -B deploy + -s ~/.m2/settings.xml + -P release + -DskipTests + -Drelease.auto.publish=true + env: + # GPG passphrase is set as an environment variable for the gpg plugin to use + GPG_PASSPHRASE: ${{ secrets.GPG_SIGNING_PASSPHRASE }} diff --git a/.github/workflows/run-tck.yml b/.github/workflows/run-tck.yml index d2cba0890..7793157dc 100644 --- a/.github/workflows/run-tck.yml +++ b/.github/workflows/run-tck.yml @@ -8,10 +8,11 @@ on: pull_request: branches: - main + workflow_dispatch: env: # Tag of the TCK - TCK_VERSION: v0.2.3 + TCK_VERSION: v0.2.5 # Tells uv to not need a venv, and instead use system UV_SYSTEM_PYTHON: 1 diff --git a/CONTRIBUTING_INTEGRATIONS.md b/CONTRIBUTING_INTEGRATIONS.md new file mode 100644 index 000000000..9fc06a520 --- /dev/null +++ b/CONTRIBUTING_INTEGRATIONS.md @@ -0,0 +1,15 @@ +# Contributing A2A SDK Integrations + +To add your A2A SDK Integration for your chosen runtime to the list of integrations in the [README](README.md#server-integrations), open a pull request adding it to the list. + +The pull request should contain a link to your project page. + +Then the project page itself needs to contain the following information as a minimum: + +* How to use the integration. + * Ideally there should be a sample demonstrating how to use it +* The integration should have tests, extending [AbstractA2AServerTest](tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java) +* The integration should pass the [TCK](https://github.com/a2aproject/a2a-tck), and make it obvious how to see that it has passed. +* Ideally, the integration should be deployed in Maven Central. If that is not possible, provide clear instructions for how to build it. + +If some of the above points are problematic, feel free to point that out in your pull request, and we can discuss further. For example, AbstractA2AServerTest is currently written with only the initial runtimes in mind, and might need some tweaking. \ No newline at end of file diff --git a/README.md b/README.md index 5460f910b..2186dde3b 100644 --- a/README.md +++ b/README.md @@ -33,32 +33,23 @@ The A2A Java SDK provides a Java server implementation of the [Agent2Agent (A2A) - [Add a class that creates an A2A Agent Card](#2-add-a-class-that-creates-an-a2a-agent-card) - [Add a class that creates an A2A Agent Executor](#3-add-a-class-that-creates-an-a2a-agent-executor) -### 1. Add an A2A Java SDK Server Maven dependency to your project +### 1. Add the A2A Java SDK Server Maven dependency to your project Adding a dependency on an A2A Java SDK Server will provide access to the core classes that make up the A2A specification and allow you to run your agentic Java application as an A2A server agent. -The A2A Java SDK provides two A2A server endpoint implementations, one based on Jakarta REST (`a2a-java-sdk-server-jakarta`) and one based on Quarkus Reactive Routes (`a2a-java-sdk-server-quarkus`). You can choose the one that best fits your application. +The A2A Java SDK provides a [reference A2A server implementation](reference-impl/README.md) based on [Quarkus](https://quarkus.io) for use with our tests and examples. However, the project is designed in such a way that it is trivial to integrate with various Java runtimes. -Add **one** of the following dependencies to your project: +[Server Integrations](#server-integrations) contains a list of community contributed integrations of the server with various runtimes. You might be able to use one of these for your target runtime, or you can use them as inspiration to create your own. -> *⚠️ The `io.github.a2asdk` `groupId` below is temporary and will likely change for future releases.* - -```xml - - io.github.a2asdk - a2a-java-sdk-server-jakarta - - ${io.a2a.sdk.version} - -``` +To use the reference implementation add the following dependency to your project: -OR +> *⚠️ The `io.github.a2asdk` `groupId` below is temporary and will likely change for future releases.* ```xml io.github.a2asdk - a2a-java-sdk-server-quarkus + a2a-java-reference-server ${io.a2a.sdk.version} @@ -98,6 +89,7 @@ public class WeatherAgentCardProducer { .tags(Collections.singletonList("weather")) .examples(List.of("weather in LA, CA")) .build())) + .protocolVersion("0.2.5") .build(); } } @@ -264,18 +256,21 @@ Map metadata = ... CancelTaskResponse response = client.cancelTask(new TaskIdParams("task-1234", metadata)); ``` -#### Get the push notification configuration for a task +#### Get a push notification configuration for a task ```java // Get task push notification configuration GetTaskPushNotificationConfigResponse response = client.getTaskPushNotificationConfig("task-1234"); -// You can also specify additional properties using a map +// The push notification configuration ID can also be optionally specified +GetTaskPushNotificationConfigResponse response = client.getTaskPushNotificationConfig("task-1234", "config-4567"); + +// Additional properties can be specified using a map Map metadata = ... -GetTaskPushNotificationConfigResponse response = client.getTaskPushNotificationConfig(new TaskIdParams("task-1234", metadata)); +GetTaskPushNotificationConfigResponse response = client.getTaskPushNotificationConfig(new GetTaskPushNotificationConfigParams("task-1234", "config-1234", metadata)); ``` -#### Set the push notification configuration for a task +#### Set a push notification configuration for a task ```java // Set task push notification configuration @@ -286,6 +281,26 @@ PushNotificationConfig pushNotificationConfig = new PushNotificationConfig.Build SetTaskPushNotificationResponse response = client.setTaskPushNotificationConfig("task-1234", pushNotificationConfig); ``` +#### List the push notification configurations for a task + +```java +ListTaskPushNotificationConfigResponse response = client.listTaskPushNotificationConfig("task-1234"); + +// Additional properties can be specified using a map +Map metadata = ... +ListTaskPushNotificationConfigResponse response = client.listTaskPushNotificationConfig(new ListTaskPushNotificationConfigParams("task-123", metadata)); +``` + +#### Delete a push notification configuration for a task + +```java +DeleteTaskPushNotificationConfigResponse response = client.deleteTaskPushNotificationConfig("task-1234", "config-4567"); + +// Additional properties can be specified using a map +Map metadata = ... +DeleteTaskPushNotificationConfigResponse response = client.deleteTaskPushNotificationConfig(new DeleteTaskPushNotificationConfigParams("task-1234", "config-4567", metadata)); +``` + #### Send a streaming message ```java @@ -372,5 +387,13 @@ This project is licensed under the terms of the [Apache 2.0 License](LICENSE). See [CONTRIBUTING.md](CONTRIBUTING.md) for contribution guidelines. +## Server Integrations +The following list contains community contributed integrations with various Java Runtimes. + +To contribute an integration, please see [CONTRIBUTING_INTEGRATIONS.md](CONTRIBUTING_INTEGRATIONS.md). + +* [reference-impl/README.md](reference-impl/README.md) - Reference implementation, based on Quarkus. +* https://github.com/wildfly-extras/a2a-java-sdk-server-jakarta - This integration is based on Jakarta EE, and should work in all runtimes supporting the [Jakarta EE Web Profile](https://jakarta.ee/specifications/webprofile/). + diff --git a/client/pom.xml b/client/pom.xml index 64f983e42..acae086f4 100644 --- a/client/pom.xml +++ b/client/pom.xml @@ -7,7 +7,7 @@ io.github.a2asdk a2a-java-sdk-parent - 0.2.3.Beta2-SNAPSHOT + 0.2.6.Beta1-SNAPSHOT a2a-java-sdk-client diff --git a/client/src/main/java/io/a2a/client/A2ACardResolver.java b/client/src/main/java/io/a2a/client/A2ACardResolver.java index 1266f7219..88d1e351f 100644 --- a/client/src/main/java/io/a2a/client/A2ACardResolver.java +++ b/client/src/main/java/io/a2a/client/A2ACardResolver.java @@ -3,6 +3,8 @@ import static io.a2a.util.Utils.unmarshalFrom; import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; import java.util.Map; import com.fasterxml.jackson.core.JsonProcessingException; @@ -18,14 +20,16 @@ public class A2ACardResolver { private final String url; private final Map authHeaders; - static String DEFAULT_AGENT_CARD_PATH = "/.well-known/agent.json"; + private static final String DEFAULT_AGENT_CARD_PATH = "/.well-known/agent.json"; + + private static final TypeReference AGENT_CARD_TYPE_REFERENCE = new TypeReference<>() {}; - static final TypeReference AGENT_CARD_TYPE_REFERENCE = new TypeReference<>() {}; /** * @param httpClient the http client to use * @param baseUrl the base URL for the agent whose agent card we want to retrieve + * @throws A2AClientError if the URL for the agent is invalid */ - public A2ACardResolver(A2AHttpClient httpClient, String baseUrl) { + public A2ACardResolver(A2AHttpClient httpClient, String baseUrl) throws A2AClientError { this(httpClient, baseUrl, null, null); } @@ -34,8 +38,9 @@ public A2ACardResolver(A2AHttpClient httpClient, String baseUrl) { * @param baseUrl the base URL for the agent whose agent card we want to retrieve * @param agentCardPath optional path to the agent card endpoint relative to the base * agent URL, defaults to ".well-known/agent.json" + * @throws A2AClientError if the URL for the agent is invalid */ - public A2ACardResolver(A2AHttpClient httpClient, String baseUrl, String agentCardPath) { + public A2ACardResolver(A2AHttpClient httpClient, String baseUrl, String agentCardPath) throws A2AClientError { this(httpClient, baseUrl, agentCardPath, null); } @@ -45,17 +50,17 @@ public A2ACardResolver(A2AHttpClient httpClient, String baseUrl, String agentCar * @param agentCardPath optional path to the agent card endpoint relative to the base * agent URL, defaults to ".well-known/agent.json" * @param authHeaders the HTTP authentication headers to use. May be {@code null} + * @throws A2AClientError if the URL for the agent is invalid */ - public A2ACardResolver(A2AHttpClient httpClient, String baseUrl, String agentCardPath, Map authHeaders) { + public A2ACardResolver(A2AHttpClient httpClient, String baseUrl, String agentCardPath, + Map authHeaders) throws A2AClientError { this.httpClient = httpClient; - if (!baseUrl.endsWith("/")) { - baseUrl += "/"; - } agentCardPath = agentCardPath == null || agentCardPath.isEmpty() ? DEFAULT_AGENT_CARD_PATH : agentCardPath; - if (agentCardPath.startsWith("/")) { - agentCardPath = agentCardPath.substring(1); + try { + this.url = new URI(baseUrl).resolve(agentCardPath).toString(); + } catch (URISyntaxException e) { + throw new A2AClientError("Invalid agent URL", e); } - this.url = baseUrl + agentCardPath; this.authHeaders = authHeaders; } diff --git a/client/src/main/java/io/a2a/client/A2AClient.java b/client/src/main/java/io/a2a/client/A2AClient.java index ea08baea4..80f8c401c 100644 --- a/client/src/main/java/io/a2a/client/A2AClient.java +++ b/client/src/main/java/io/a2a/client/A2AClient.java @@ -21,6 +21,10 @@ import io.a2a.spec.AgentCard; import io.a2a.spec.CancelTaskRequest; import io.a2a.spec.CancelTaskResponse; +import io.a2a.spec.DeleteTaskPushNotificationConfigParams; +import io.a2a.spec.DeleteTaskPushNotificationConfigRequest; +import io.a2a.spec.DeleteTaskPushNotificationConfigResponse; +import io.a2a.spec.GetTaskPushNotificationConfigParams; import io.a2a.spec.GetTaskPushNotificationConfigRequest; import io.a2a.spec.GetTaskPushNotificationConfigResponse; import io.a2a.spec.GetTaskRequest; @@ -28,6 +32,9 @@ import io.a2a.spec.JSONRPCError; import io.a2a.spec.JSONRPCMessage; import io.a2a.spec.JSONRPCResponse; +import io.a2a.spec.ListTaskPushNotificationConfigParams; +import io.a2a.spec.ListTaskPushNotificationConfigRequest; +import io.a2a.spec.ListTaskPushNotificationConfigResponse; import io.a2a.spec.MessageSendParams; import io.a2a.spec.PushNotificationConfig; import io.a2a.spec.SendMessageRequest; @@ -52,6 +59,8 @@ public class A2AClient { private static final TypeReference CANCEL_TASK_RESPONSE_REFERENCE = new TypeReference<>() {}; private static final TypeReference GET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE = new TypeReference<>() {}; private static final TypeReference SET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE = new TypeReference<>() {}; + private static final TypeReference LIST_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE = new TypeReference<>() {}; + private static final TypeReference DELETE_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE = new TypeReference<>() {}; private final A2AHttpClient httpClient; private final String agentUrl; private AgentCard agentCard; @@ -164,7 +173,7 @@ public SendMessageResponse sendMessage(String requestId, MessageSendParams messa String httpResponseBody = sendPostRequest(sendMessageRequest); return unmarshalResponse(httpResponseBody, SEND_MESSAGE_RESPONSE_REFERENCE); } catch (IOException | InterruptedException e) { - throw new A2AServerException("Failed to send message: " + e); + throw new A2AServerException("Failed to send message: " + e, e.getCause()); } } @@ -216,7 +225,7 @@ public GetTaskResponse getTask(String requestId, TaskQueryParams taskQueryParams String httpResponseBody = sendPostRequest(getTaskRequest); return unmarshalResponse(httpResponseBody, GET_TASK_RESPONSE_REFERENCE); } catch (IOException | InterruptedException e) { - throw new A2AServerException("Failed to get task: " + e); + throw new A2AServerException("Failed to get task: " + e, e.getCause()); } } @@ -266,45 +275,57 @@ public CancelTaskResponse cancelTask(String requestId, TaskIdParams taskIdParams String httpResponseBody = sendPostRequest(cancelTaskRequest); return unmarshalResponse(httpResponseBody, CANCEL_TASK_RESPONSE_REFERENCE); } catch (IOException | InterruptedException e) { - throw new A2AServerException("Failed to cancel task: " + e); + throw new A2AServerException("Failed to cancel task: " + e, e.getCause()); } } /** * Get the push notification configuration for a task. * - * @param id the task ID + * @param taskId the task ID + * @return the response containing the push notification configuration + * @throws A2AServerException if getting the push notification configuration fails for any reason + */ + public GetTaskPushNotificationConfigResponse getTaskPushNotificationConfig(String taskId) throws A2AServerException { + return getTaskPushNotificationConfig(null, new GetTaskPushNotificationConfigParams(taskId)); + } + + /** + * Get the push notification configuration for a task. + * + * @param taskId the task ID + * @param pushNotificationConfigId the push notification configuration ID * @return the response containing the push notification configuration * @throws A2AServerException if getting the push notification configuration fails for any reason */ - public GetTaskPushNotificationConfigResponse getTaskPushNotificationConfig(String id) throws A2AServerException { - return getTaskPushNotificationConfig(null, new TaskIdParams(id)); + public GetTaskPushNotificationConfigResponse getTaskPushNotificationConfig(String taskId, String pushNotificationConfigId) throws A2AServerException { + return getTaskPushNotificationConfig(null, new GetTaskPushNotificationConfigParams(taskId, pushNotificationConfigId)); } /** * Get the push notification configuration for a task. * - * @param taskIdParams the params for the task + * @param getTaskPushNotificationConfigParams the params for the task * @return the response containing the push notification configuration * @throws A2AServerException if getting the push notification configuration fails for any reason */ - public GetTaskPushNotificationConfigResponse getTaskPushNotificationConfig(TaskIdParams taskIdParams) throws A2AServerException { - return getTaskPushNotificationConfig(null, taskIdParams); + public GetTaskPushNotificationConfigResponse getTaskPushNotificationConfig(GetTaskPushNotificationConfigParams getTaskPushNotificationConfigParams) throws A2AServerException { + return getTaskPushNotificationConfig(null, getTaskPushNotificationConfigParams); } /** * Get the push notification configuration for a task. * * @param requestId the request ID to use - * @param taskIdParams the params for the task + * @param getTaskPushNotificationConfigParams the params for the task * @return the response containing the push notification configuration * @throws A2AServerException if getting the push notification configuration fails for any reason */ - public GetTaskPushNotificationConfigResponse getTaskPushNotificationConfig(String requestId, TaskIdParams taskIdParams) throws A2AServerException { + public GetTaskPushNotificationConfigResponse getTaskPushNotificationConfig(String requestId, GetTaskPushNotificationConfigParams getTaskPushNotificationConfigParams) throws A2AServerException { GetTaskPushNotificationConfigRequest.Builder getTaskPushNotificationRequestBuilder = new GetTaskPushNotificationConfigRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) .method(GetTaskPushNotificationConfigRequest.METHOD) - .params(taskIdParams); + .params(getTaskPushNotificationConfigParams); if (requestId != null) { getTaskPushNotificationRequestBuilder.id(requestId); @@ -316,7 +337,7 @@ public GetTaskPushNotificationConfigResponse getTaskPushNotificationConfig(Strin String httpResponseBody = sendPostRequest(getTaskPushNotificationRequest); return unmarshalResponse(httpResponseBody, GET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE); } catch (IOException | InterruptedException e) { - throw new A2AServerException("Failed to get task push notification config: " + e); + throw new A2AServerException("Failed to get task push notification config: " + e, e.getCause()); } } @@ -359,7 +380,137 @@ public SetTaskPushNotificationConfigResponse setTaskPushNotificationConfig(Strin String httpResponseBody = sendPostRequest(setTaskPushNotificationRequest); return unmarshalResponse(httpResponseBody, SET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE); } catch (IOException | InterruptedException e) { - throw new A2AServerException("Failed to set task push notification config: " + e); + throw new A2AServerException("Failed to set task push notification config: " + e, e.getCause()); + } + } + + /** + * Retrieves the push notification configurations for a specified task. + * + * @param requestId the request ID to use + * @param taskId the task ID to use + * @return the response containing the push notification configuration + * @throws A2AServerException if getting the push notification configuration fails for any reason + */ + public ListTaskPushNotificationConfigResponse listTaskPushNotificationConfig(String requestId, String taskId) throws A2AServerException { + return listTaskPushNotificationConfig(requestId, new ListTaskPushNotificationConfigParams(taskId)); + } + + /** + * Retrieves the push notification configurations for a specified task. + * + * @param taskId the task ID to use + * @return the response containing the push notification configuration + * @throws A2AServerException if getting the push notification configuration fails for any reason + */ + public ListTaskPushNotificationConfigResponse listTaskPushNotificationConfig(String taskId) throws A2AServerException { + return listTaskPushNotificationConfig(null, new ListTaskPushNotificationConfigParams(taskId)); + } + + /** + * Retrieves the push notification configurations for a specified task. + * + * @param listTaskPushNotificationConfigParams the params for retrieving the push notification configuration + * @return the response containing the push notification configuration + * @throws A2AServerException if getting the push notification configuration fails for any reason + */ + public ListTaskPushNotificationConfigResponse listTaskPushNotificationConfig(ListTaskPushNotificationConfigParams listTaskPushNotificationConfigParams) throws A2AServerException { + return listTaskPushNotificationConfig(null, listTaskPushNotificationConfigParams); + } + + /** + * Retrieves the push notification configurations for a specified task. + * + * @param requestId the request ID to use + * @param listTaskPushNotificationConfigParams the params for retrieving the push notification configuration + * @return the response containing the push notification configuration + * @throws A2AServerException if getting the push notification configuration fails for any reason + */ + public ListTaskPushNotificationConfigResponse listTaskPushNotificationConfig(String requestId, + ListTaskPushNotificationConfigParams listTaskPushNotificationConfigParams) throws A2AServerException { + ListTaskPushNotificationConfigRequest.Builder listTaskPushNotificationRequestBuilder = new ListTaskPushNotificationConfigRequest.Builder() + .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) + .method(ListTaskPushNotificationConfigRequest.METHOD) + .params(listTaskPushNotificationConfigParams); + + if (requestId != null) { + listTaskPushNotificationRequestBuilder.id(requestId); + } + + ListTaskPushNotificationConfigRequest listTaskPushNotificationRequest = listTaskPushNotificationRequestBuilder.build(); + + try { + String httpResponseBody = sendPostRequest(listTaskPushNotificationRequest); + return unmarshalResponse(httpResponseBody, LIST_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE); + } catch (IOException | InterruptedException e) { + throw new A2AServerException("Failed to list task push notification config: " + e, e.getCause()); + } + } + + /** + * Delete the push notification configuration for a specified task. + * + * @param requestId the request ID to use + * @param taskId the task ID + * @param pushNotificationConfigId the push notification config ID + * @return the response + * @throws A2AServerException if deleting the push notification configuration fails for any reason + */ + public DeleteTaskPushNotificationConfigResponse deleteTaskPushNotificationConfig(String requestId, String taskId, + String pushNotificationConfigId) throws A2AServerException { + return deleteTaskPushNotificationConfig(requestId, new DeleteTaskPushNotificationConfigParams(taskId, pushNotificationConfigId)); + } + + /** + * Delete the push notification configuration for a specified task. + * + * @param taskId the task ID + * @param pushNotificationConfigId the push notification config ID + * @return the response + * @throws A2AServerException if deleting the push notification configuration fails for any reason + */ + public DeleteTaskPushNotificationConfigResponse deleteTaskPushNotificationConfig(String taskId, + String pushNotificationConfigId) throws A2AServerException { + return deleteTaskPushNotificationConfig(null, new DeleteTaskPushNotificationConfigParams(taskId, pushNotificationConfigId)); + } + + /** + * Delete the push notification configuration for a specified task. + * + * @param deleteTaskPushNotificationConfigParams the params for deleting the push notification configuration + * @return the response + * @throws A2AServerException if deleting the push notification configuration fails for any reason + */ + public DeleteTaskPushNotificationConfigResponse deleteTaskPushNotificationConfig(DeleteTaskPushNotificationConfigParams deleteTaskPushNotificationConfigParams) throws A2AServerException { + return deleteTaskPushNotificationConfig(null, deleteTaskPushNotificationConfigParams); + } + + /** + * Delete the push notification configuration for a specified task. + * + * @param requestId the request ID to use + * @param deleteTaskPushNotificationConfigParams the params for deleting the push notification configuration + * @return the response + * @throws A2AServerException if deleting the push notification configuration fails for any reason + */ + public DeleteTaskPushNotificationConfigResponse deleteTaskPushNotificationConfig(String requestId, + DeleteTaskPushNotificationConfigParams deleteTaskPushNotificationConfigParams) throws A2AServerException { + DeleteTaskPushNotificationConfigRequest.Builder deleteTaskPushNotificationRequestBuilder = new DeleteTaskPushNotificationConfigRequest.Builder() + .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) + .method(DeleteTaskPushNotificationConfigRequest.METHOD) + .params(deleteTaskPushNotificationConfigParams); + + if (requestId != null) { + deleteTaskPushNotificationRequestBuilder.id(requestId); + } + + DeleteTaskPushNotificationConfigRequest deleteTaskPushNotificationRequest = deleteTaskPushNotificationRequestBuilder.build(); + + try { + String httpResponseBody = sendPostRequest(deleteTaskPushNotificationRequest); + return unmarshalResponse(httpResponseBody, DELETE_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE); + } catch (IOException | InterruptedException e) { + throw new A2AServerException("Failed to delete task push notification config: " + e, e.getCause()); } } @@ -416,9 +567,9 @@ public void sendStreamingMessage(String requestId, MessageSendParams messageSend })); } catch (IOException e) { - throw new A2AServerException("Failed to send streaming message request: " + e); + throw new A2AServerException("Failed to send streaming message request: " + e, e.getCause()); } catch (InterruptedException e) { - throw new A2AServerException("Send streaming message request timed out: " + e); + throw new A2AServerException("Send streaming message request timed out: " + e, e.getCause()); } } @@ -475,9 +626,9 @@ public void resubscribeToTask(String requestId, TaskIdParams taskIdParams, Consu })); } catch (IOException e) { - throw new A2AServerException("Failed to send task resubscription request: " + e); + throw new A2AServerException("Failed to send task resubscription request: " + e, e.getCause()); } catch (InterruptedException e) { - throw new A2AServerException("Task resubscription request timed out: " + e); + throw new A2AServerException("Task resubscription request timed out: " + e, e.getCause()); } } @@ -503,7 +654,7 @@ private T unmarshalResponse(String response, TypeRef T value = Utils.unmarshalFrom(response, typeReference); JSONRPCError error = value.getError(); if (error != null) { - throw new A2AServerException(error.getMessage() + (error.getData() != null ? ": " + error.getData() : "")); + throw new A2AServerException(error.getMessage() + (error.getData() != null ? ": " + error.getData() : ""), error); } return value; } diff --git a/client/src/main/java/io/a2a/http/JdkA2AHttpClient.java b/client/src/main/java/io/a2a/http/JdkA2AHttpClient.java index e3b5c0c66..c3d5907a2 100644 --- a/client/src/main/java/io/a2a/http/JdkA2AHttpClient.java +++ b/client/src/main/java/io/a2a/http/JdkA2AHttpClient.java @@ -141,7 +141,7 @@ public CompletableFuture getAsyncSSE( Consumer messageConsumer, Consumer errorConsumer, Runnable completeRunnable) throws IOException, InterruptedException { - HttpRequest request = createRequestBuilder(false) + HttpRequest request = createRequestBuilder(true) .build(); return super.asyncRequest(request, messageConsumer, errorConsumer, completeRunnable); } @@ -180,7 +180,7 @@ public CompletableFuture postAsyncSSE( Consumer messageConsumer, Consumer errorConsumer, Runnable completeRunnable) throws IOException, InterruptedException { - HttpRequest request = createRequestBuilder(false) + HttpRequest request = createRequestBuilder(true) .build(); return super.asyncRequest(request, messageConsumer, errorConsumer, completeRunnable); } diff --git a/client/src/test/java/io/a2a/client/A2ACardResolverTest.java b/client/src/test/java/io/a2a/client/A2ACardResolverTest.java index 8265b9514..8d9ff0f5b 100644 --- a/client/src/test/java/io/a2a/client/A2ACardResolverTest.java +++ b/client/src/test/java/io/a2a/client/A2ACardResolverTest.java @@ -1,6 +1,5 @@ package io.a2a.client; -import static io.a2a.client.A2ACardResolver.AGENT_CARD_TYPE_REFERENCE; import static io.a2a.util.Utils.OBJECT_MAPPER; import static io.a2a.util.Utils.unmarshalFrom; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -11,6 +10,7 @@ import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; +import com.fasterxml.jackson.core.type.TypeReference; import io.a2a.http.A2AHttpClient; import io.a2a.http.A2AHttpResponse; import io.a2a.spec.A2AClientError; @@ -19,6 +19,10 @@ import org.junit.jupiter.api.Test; public class A2ACardResolverTest { + + private static final String AGENT_CARD_PATH = "/.well-known/agent.json"; + private static final TypeReference AGENT_CARD_TYPE_REFERENCE = new TypeReference<>() {}; + @Test public void testConstructorStripsSlashes() throws Exception { TestHttpClient client = new TestHttpClient(); @@ -27,33 +31,37 @@ public void testConstructorStripsSlashes() throws Exception { A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); AgentCard card = resolver.getAgentCard(); - assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); resolver = new A2ACardResolver(client, "http://example.com"); card = resolver.getAgentCard(); - assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); - resolver = new A2ACardResolver(client, "http://example.com/", A2ACardResolver.DEFAULT_AGENT_CARD_PATH); + // baseUrl with trailing slash, agentCardParth with leading slash + resolver = new A2ACardResolver(client, "http://example.com/", AGENT_CARD_PATH); card = resolver.getAgentCard(); - assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); - resolver = new A2ACardResolver(client, "http://example.com", A2ACardResolver.DEFAULT_AGENT_CARD_PATH); + // baseUrl without trailing slash, agentCardPath with leading slash + resolver = new A2ACardResolver(client, "http://example.com", AGENT_CARD_PATH); card = resolver.getAgentCard(); - assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); - resolver = new A2ACardResolver(client, "http://example.com/", A2ACardResolver.DEFAULT_AGENT_CARD_PATH.substring(0)); + // baseUrl with trailing slash, agentCardPath without leading slash + resolver = new A2ACardResolver(client, "http://example.com/", AGENT_CARD_PATH.substring(1)); card = resolver.getAgentCard(); - assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); - resolver = new A2ACardResolver(client, "http://example.com", A2ACardResolver.DEFAULT_AGENT_CARD_PATH.substring(0)); + // baseUrl without trailing slash, agentCardPath without leading slash + resolver = new A2ACardResolver(client, "http://example.com", AGENT_CARD_PATH.substring(1)); card = resolver.getAgentCard(); - assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); } diff --git a/client/src/test/java/io/a2a/client/A2AClientTest.java b/client/src/test/java/io/a2a/client/A2AClientTest.java index e99734ff1..eda03e02a 100644 --- a/client/src/test/java/io/a2a/client/A2AClientTest.java +++ b/client/src/test/java/io/a2a/client/A2AClientTest.java @@ -46,6 +46,7 @@ import io.a2a.spec.FilePart; import io.a2a.spec.FileWithBytes; import io.a2a.spec.FileWithUri; +import io.a2a.spec.GetTaskPushNotificationConfigParams; import io.a2a.spec.GetTaskPushNotificationConfigResponse; import io.a2a.spec.GetTaskResponse; import io.a2a.spec.Message; @@ -328,7 +329,7 @@ public void testA2AClientGetTaskPushNotificationConfig() throws Exception { A2AClient client = new A2AClient("http://localhost:4001"); GetTaskPushNotificationConfigResponse response = client.getTaskPushNotificationConfig("1", - new TaskIdParams("de38c76d-d54c-436c-8b9f-4c2703648d64", new HashMap<>())); + new GetTaskPushNotificationConfigParams("de38c76d-d54c-436c-8b9f-4c2703648d64", null, new HashMap<>())); assertEquals("2.0", response.getJsonrpc()); assertEquals(1, response.getId()); assertInstanceOf(TaskPushNotificationConfig.class, response.getResult()); @@ -442,6 +443,7 @@ public void testA2AClientGetAgentCard() throws Exception { assertEquals(outputModes, skills.get(1).outputModes()); assertTrue(agentCard.supportsAuthenticatedExtendedCard()); assertEquals("https://georoute-agent.example.com/icon.png", agentCard.iconUrl()); + assertEquals("0.2.5", agentCard.protocolVersion()); } @Test @@ -514,6 +516,7 @@ public void testA2AClientGetAuthenticatedExtendedAgentCard() throws Exception { assertEquals(List.of("extended"), skills.get(2).tags()); assertTrue(agentCard.supportsAuthenticatedExtendedCard()); assertEquals("https://georoute-agent.example.com/icon.png", agentCard.iconUrl()); + assertEquals("0.2.5", agentCard.protocolVersion()); } @Test diff --git a/client/src/test/java/io/a2a/client/JsonMessages.java b/client/src/test/java/io/a2a/client/JsonMessages.java index c7ebd7780..fecf216d0 100644 --- a/client/src/test/java/io/a2a/client/JsonMessages.java +++ b/client/src/test/java/io/a2a/client/JsonMessages.java @@ -67,7 +67,8 @@ public class JsonMessages { ] } ], - "supportsAuthenticatedExtendedCard": true + "supportsAuthenticatedExtendedCard": true, + "protocolVersion": "0.2.5" }"""; static final String AUTHENTICATION_EXTENDED_AGENT_CARD = """ @@ -137,7 +138,8 @@ public class JsonMessages { "tags": ["extended"] } ], - "supportsAuthenticatedExtendedCard": true + "supportsAuthenticatedExtendedCard": true, + "protocolVersion": "0.2.5" }"""; diff --git a/common/pom.xml b/common/pom.xml index 9b3d81137..e4c01920e 100644 --- a/common/pom.xml +++ b/common/pom.xml @@ -7,7 +7,7 @@ io.github.a2asdk a2a-java-sdk-parent - 0.2.3.Beta2-SNAPSHOT + 0.2.6.Beta1-SNAPSHOT a2a-java-sdk-common diff --git a/examples/helloworld/client/pom.xml b/examples/helloworld/client/pom.xml index 26f701934..8edb46cb6 100644 --- a/examples/helloworld/client/pom.xml +++ b/examples/helloworld/client/pom.xml @@ -7,7 +7,7 @@ io.github.a2asdk a2a-java-sdk-examples-parent - 0.2.3.Beta2-SNAPSHOT + 0.2.6.Beta1-SNAPSHOT a2a-java-sdk-examples-client diff --git a/examples/helloworld/client/src/main/java/io/a2a/examples/helloworld/HelloWorldRunner.java b/examples/helloworld/client/src/main/java/io/a2a/examples/helloworld/HelloWorldRunner.java index 8583ef27a..144972238 100644 --- a/examples/helloworld/client/src/main/java/io/a2a/examples/helloworld/HelloWorldRunner.java +++ b/examples/helloworld/client/src/main/java/io/a2a/examples/helloworld/HelloWorldRunner.java @@ -1,5 +1,5 @@ ///usr/bin/env jbang "$0" "$@" ; exit $? -//DEPS io.github.a2asdk:a2a-java-sdk-client:0.2.3.Beta2-SNAPSHOT +//DEPS io.github.a2asdk:a2a-java-sdk-client:0.2.6.Beta1-SNAPSHOT //SOURCES HelloWorldClient.java /** @@ -19,4 +19,4 @@ public class HelloWorldRunner { public static void main(String[] args) { io.a2a.examples.helloworld.HelloWorldClient.main(args); } -} \ No newline at end of file +} diff --git a/examples/helloworld/pom.xml b/examples/helloworld/pom.xml index 7c9cf8928..7900d3cef 100644 --- a/examples/helloworld/pom.xml +++ b/examples/helloworld/pom.xml @@ -7,7 +7,7 @@ io.github.a2asdk a2a-java-sdk-parent - 0.2.3.Beta2-SNAPSHOT + 0.2.6.Beta1-SNAPSHOT ../../pom.xml @@ -33,7 +33,7 @@ io.github.a2asdk - a2a-java-sdk-server-quarkus + a2a-java-reference-server ${project.version} diff --git a/examples/helloworld/server/pom.xml b/examples/helloworld/server/pom.xml index 9fce226f5..ee7c5324b 100644 --- a/examples/helloworld/server/pom.xml +++ b/examples/helloworld/server/pom.xml @@ -7,7 +7,7 @@ io.github.a2asdk a2a-java-sdk-examples-parent - 0.2.3.Beta2-SNAPSHOT + 0.2.6.Beta1-SNAPSHOT a2a-java-sdk-examples-server @@ -18,7 +18,7 @@ io.github.a2asdk - a2a-java-sdk-server-quarkus + a2a-java-reference-server io.quarkus diff --git a/examples/helloworld/server/src/main/java/io/a2a/examples/helloworld/AgentCardProducer.java b/examples/helloworld/server/src/main/java/io/a2a/examples/helloworld/AgentCardProducer.java index f07a6e527..06450c935 100644 --- a/examples/helloworld/server/src/main/java/io/a2a/examples/helloworld/AgentCardProducer.java +++ b/examples/helloworld/server/src/main/java/io/a2a/examples/helloworld/AgentCardProducer.java @@ -37,6 +37,7 @@ public AgentCard agentCard() { .tags(Collections.singletonList("hello world")) .examples(List.of("hi", "hello world")) .build())) + .protocolVersion("0.2.5") .build(); } } diff --git a/pom.xml b/pom.xml index 5508c0cba..e34d09477 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.github.a2asdk a2a-java-sdk-parent - 0.2.3.Beta2-SNAPSHOT + 0.2.6.Beta1-SNAPSHOT pom @@ -62,6 +62,9 @@ true + + + false @@ -206,6 +209,7 @@ central-a2asdk-temp + ${release.auto.publish} @@ -271,8 +275,7 @@ common spec client - sdk-jakarta - sdk-quarkus + reference-impl tck examples/helloworld tests/server-common diff --git a/reference-impl/README.md b/reference-impl/README.md new file mode 100644 index 000000000..2a7f0f902 --- /dev/null +++ b/reference-impl/README.md @@ -0,0 +1,7 @@ +# A2A Java SDK Reference Server Integration + +This is a reference server for the A2A SDK for Java, that we use to run tests, as well as to demonstrate examples. + +It is based on [Quarkus](https://quarkus.io), and makes use of Quarkus's [Reactive Routes](https://quarkus.io/guides/reactive-routes). + +It is a great choice if you use Quarkus! \ No newline at end of file diff --git a/sdk-quarkus/pom.xml b/reference-impl/pom.xml similarity index 91% rename from sdk-quarkus/pom.xml rename to reference-impl/pom.xml index daa9418f2..b128aa3f0 100644 --- a/sdk-quarkus/pom.xml +++ b/reference-impl/pom.xml @@ -7,14 +7,14 @@ io.github.a2asdk a2a-java-sdk-parent - 0.2.3.Beta2-SNAPSHOT + 0.2.6.Beta1-SNAPSHOT - a2a-java-sdk-server-quarkus + a2a-java-reference-server jar - Java A2A SDK for Quarkus - Java SDK for the Agent2Agent Protocol (A2A) - SDK - Quarkus + Java A2A Reference Server + Java SDK for the Agent2Agent Protocol (A2A) - A2A Reference Server (based on Quarkus) diff --git a/sdk-quarkus/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java b/reference-impl/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java similarity index 76% rename from sdk-quarkus/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java rename to reference-impl/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java index 7c45c0f9f..988736c91 100644 --- a/sdk-quarkus/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java +++ b/reference-impl/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java @@ -3,6 +3,9 @@ import static io.vertx.core.http.HttpHeaders.CONTENT_TYPE; import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicLong; @@ -16,10 +19,16 @@ import com.fasterxml.jackson.core.JsonParseException; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.io.JsonEOFException; +import com.fasterxml.jackson.databind.JsonNode; import io.a2a.server.ExtendedAgentCard; +import io.a2a.server.ServerCallContext; +import io.a2a.server.auth.UnauthenticatedUser; +import io.a2a.server.auth.User; import io.a2a.server.requesthandlers.JSONRPCHandler; +import io.a2a.server.util.async.Internal; import io.a2a.spec.AgentCard; import io.a2a.spec.CancelTaskRequest; +import io.a2a.spec.DeleteTaskPushNotificationConfigRequest; import io.a2a.spec.GetTaskPushNotificationConfigRequest; import io.a2a.spec.GetTaskRequest; import io.a2a.spec.IdJsonMappingException; @@ -33,6 +42,7 @@ import io.a2a.spec.JSONRPCErrorResponse; import io.a2a.spec.JSONRPCRequest; import io.a2a.spec.JSONRPCResponse; +import io.a2a.spec.ListTaskPushNotificationConfigRequest; import io.a2a.spec.MethodNotFoundError; import io.a2a.spec.MethodNotFoundJsonMappingException; import io.a2a.spec.NonStreamingJSONRPCRequest; @@ -43,7 +53,6 @@ import io.a2a.spec.TaskResubscriptionRequest; import io.a2a.spec.UnsupportedOperationError; import io.a2a.util.Utils; -import io.a2a.server.util.async.Internal; import io.quarkus.vertx.web.Body; import io.quarkus.vertx.web.ReactiveRoutes; import io.quarkus.vertx.web.Route; @@ -68,15 +77,20 @@ public class A2AServerRoutes { Instance extendedAgentCard; // Hook so testing can wait until the MultiSseSupport is subscribed. + // Without this we get intermittent failures private static volatile Runnable streamingMultiSseSupportSubscribedRunnable; @Inject @Internal Executor executor; + @Inject + Instance callContextFactory; + @Route(path = "/", methods = {Route.HttpMethod.POST}, consumes = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING) public void invokeJSONRPCHandler(@Body String body, RoutingContext rc) { boolean streaming = false; + ServerCallContext context = createCallContext(rc); JSONRPCResponse nonStreamingResponse = null; Multi> streamingResponse = null; JSONRPCErrorResponse error = null; @@ -85,10 +99,10 @@ public void invokeJSONRPCHandler(@Body String body, RoutingContext rc) { if (isStreamingRequest(body)) { streaming = true; StreamingJSONRPCRequest request = Utils.OBJECT_MAPPER.readValue(body, StreamingJSONRPCRequest.class); - streamingResponse = processStreamingRequest(request); + streamingResponse = processStreamingRequest(request, context); } else { NonStreamingJSONRPCRequest request = Utils.OBJECT_MAPPER.readValue(body, NonStreamingJSONRPCRequest.class); - nonStreamingResponse = processNonStreamingRequest(request); + nonStreamingResponse = processNonStreamingRequest(request, context); } } catch (JsonProcessingException e) { error = handleError(e); @@ -179,28 +193,34 @@ public void getAuthenticatedExtendedAgentCard(RoutingExchange re) { } } - private JSONRPCResponse processNonStreamingRequest(NonStreamingJSONRPCRequest request) { - if (request instanceof GetTaskRequest) { - return jsonRpcHandler.onGetTask((GetTaskRequest) request); - } else if (request instanceof CancelTaskRequest) { - return jsonRpcHandler.onCancelTask((CancelTaskRequest) request); - } else if (request instanceof SetTaskPushNotificationConfigRequest) { - return jsonRpcHandler.setPushNotification((SetTaskPushNotificationConfigRequest) request); - } else if (request instanceof GetTaskPushNotificationConfigRequest) { - return jsonRpcHandler.getPushNotification((GetTaskPushNotificationConfigRequest) request); - } else if (request instanceof SendMessageRequest) { - return jsonRpcHandler.onMessageSend((SendMessageRequest) request); + private JSONRPCResponse processNonStreamingRequest( + NonStreamingJSONRPCRequest request, ServerCallContext context) { + if (request instanceof GetTaskRequest req) { + return jsonRpcHandler.onGetTask(req, context); + } else if (request instanceof CancelTaskRequest req) { + return jsonRpcHandler.onCancelTask(req, context); + } else if (request instanceof SetTaskPushNotificationConfigRequest req) { + return jsonRpcHandler.setPushNotificationConfig(req, context); + } else if (request instanceof GetTaskPushNotificationConfigRequest req) { + return jsonRpcHandler.getPushNotificationConfig(req, context); + } else if (request instanceof SendMessageRequest req) { + return jsonRpcHandler.onMessageSend(req, context); + } else if (request instanceof ListTaskPushNotificationConfigRequest req) { + return jsonRpcHandler.listPushNotificationConfig(req, context); + } else if (request instanceof DeleteTaskPushNotificationConfigRequest req) { + return jsonRpcHandler.deletePushNotificationConfig(req, context); } else { return generateErrorResponse(request, new UnsupportedOperationError()); } } - private Multi> processStreamingRequest(JSONRPCRequest request) { + private Multi> processStreamingRequest( + JSONRPCRequest request, ServerCallContext context) { Flow.Publisher> publisher; - if (request instanceof SendStreamingMessageRequest) { - publisher = jsonRpcHandler.onMessageSendStream((SendStreamingMessageRequest) request); - } else if (request instanceof TaskResubscriptionRequest) { - publisher = jsonRpcHandler.onResubscribeToTask((TaskResubscriptionRequest) request); + if (request instanceof SendStreamingMessageRequest req) { + publisher = jsonRpcHandler.onMessageSendStream(req, context); + } else if (request instanceof TaskResubscriptionRequest req) { + publisher = jsonRpcHandler.onResubscribeToTask(req, context); } else { return Multi.createFrom().item(generateErrorResponse(request, new UnsupportedOperationError())); } @@ -212,22 +232,56 @@ private JSONRPCResponse generateErrorResponse(JSONRPCRequest request, JSON } private static boolean isStreamingRequest(String requestBody) { - return requestBody.contains(SendStreamingMessageRequest.METHOD) || - requestBody.contains(TaskResubscriptionRequest.METHOD); - } - - private static boolean isNonStreamingRequest(String requestBody) { - return requestBody.contains(GetTaskRequest.METHOD) || - requestBody.contains(CancelTaskRequest.METHOD) || - requestBody.contains(SendMessageRequest.METHOD) || - requestBody.contains(SetTaskPushNotificationConfigRequest.METHOD) || - requestBody.contains(GetTaskPushNotificationConfigRequest.METHOD); + try { + JsonNode node = Utils.OBJECT_MAPPER.readTree(requestBody); + JsonNode method = node != null ? node.get("method") : null; + return method != null && (SendStreamingMessageRequest.METHOD.equals(method.asText()) + || TaskResubscriptionRequest.METHOD.equals(method.asText())); + } catch (Exception e) { + return false; + } } static void setStreamingMultiSseSupportSubscribedRunnable(Runnable runnable) { streamingMultiSseSupportSubscribedRunnable = runnable; } + private ServerCallContext createCallContext(RoutingContext rc) { + + if (callContextFactory.isUnsatisfied()) { + User user; + if (rc.user() == null) { + user = UnauthenticatedUser.INSTANCE; + } else { + user = new User() { + @Override + public boolean isAuthenticated() { + return rc.userContext().authenticated(); + } + + @Override + public String getUsername() { + return rc.user().subject(); + } + }; + } + Map state = new HashMap<>(); + // TODO Python's impl has + // state['auth'] = request.auth + // in jsonrpc_app.py. Figure out what this maps to in what Vert.X gives us + + Map headers = new HashMap<>(); + Set headerNames = rc.request().headers().names(); + headerNames.forEach(name -> headers.put(name, rc.request().getHeader(name))); + state.put("headers", headers); + + return new ServerCallContext(user, state); + } else { + CallContextFactory builder = callContextFactory.get(); + return builder.build(rc); + } + } + // Port of import io.quarkus.vertx.web.runtime.MultiSseSupport, which is considered internal API private static class MultiSseSupport { diff --git a/reference-impl/src/main/java/io/a2a/server/apps/quarkus/CallContextFactory.java b/reference-impl/src/main/java/io/a2a/server/apps/quarkus/CallContextFactory.java new file mode 100644 index 000000000..d40bc65f0 --- /dev/null +++ b/reference-impl/src/main/java/io/a2a/server/apps/quarkus/CallContextFactory.java @@ -0,0 +1,8 @@ +package io.a2a.server.apps.quarkus; + +import io.a2a.server.ServerCallContext; +import io.vertx.ext.web.RoutingContext; + +public interface CallContextFactory { + ServerCallContext build(RoutingContext rc); +} diff --git a/sdk-quarkus/src/main/resources/META-INF/beans.xml b/reference-impl/src/main/resources/META-INF/beans.xml similarity index 100% rename from sdk-quarkus/src/main/resources/META-INF/beans.xml rename to reference-impl/src/main/resources/META-INF/beans.xml diff --git a/reference-impl/src/test/java/io/a2a/server/apps/quarkus/A2ATestRoutes.java b/reference-impl/src/test/java/io/a2a/server/apps/quarkus/A2ATestRoutes.java new file mode 100644 index 000000000..5af126bf5 --- /dev/null +++ b/reference-impl/src/test/java/io/a2a/server/apps/quarkus/A2ATestRoutes.java @@ -0,0 +1,188 @@ +package io.a2a.server.apps.quarkus; + +import static io.vertx.core.http.HttpHeaders.CONTENT_TYPE; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; +import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN; + +import java.util.concurrent.atomic.AtomicInteger; + +import jakarta.annotation.PostConstruct; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import io.a2a.server.apps.common.TestUtilsBean; +import io.a2a.spec.PushNotificationConfig; +import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.a2a.util.Utils; +import io.quarkus.vertx.web.Body; +import io.quarkus.vertx.web.Param; +import io.quarkus.vertx.web.Route; +import io.vertx.ext.web.RoutingContext; + +/** + * Exposes the {@link TestUtilsBean} via REST using Quarkus Reactive Routes + */ +@Singleton +public class A2ATestRoutes { + @Inject + TestUtilsBean testUtilsBean; + + @Inject + A2AServerRoutes a2AServerRoutes; + + AtomicInteger streamingSubscribedCount = new AtomicInteger(0); + + @PostConstruct + public void init() { + A2AServerRoutes.setStreamingMultiSseSupportSubscribedRunnable(() -> streamingSubscribedCount.incrementAndGet()); + } + + + @Route(path = "/test/task", methods = {Route.HttpMethod.POST}, consumes = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING) + public void saveTask(@Body String body, RoutingContext rc) { + try { + Task task = Utils.OBJECT_MAPPER.readValue(body, Task.class); + testUtilsBean.saveTask(task); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/task/:taskId", methods = {Route.HttpMethod.GET}, produces = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING) + public void getTask(@Param String taskId, RoutingContext rc) { + try { + Task task = testUtilsBean.getTask(taskId); + if (task == null) { + rc.response() + .setStatusCode(404) + .end(); + return; + } + rc.response() + .setStatusCode(200) + .putHeader(CONTENT_TYPE, APPLICATION_JSON) + .end(Utils.OBJECT_MAPPER.writeValueAsString(task)); + + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/task/:taskId", methods = {Route.HttpMethod.DELETE}, type = Route.HandlerType.BLOCKING) + public void deleteTask(@Param String taskId, RoutingContext rc) { + try { + Task task = testUtilsBean.getTask(taskId); + if (task == null) { + rc.response() + .setStatusCode(404) + .end(); + return; + } + testUtilsBean.deleteTask(taskId); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/queue/ensure/:taskId", methods = {Route.HttpMethod.POST}) + public void ensureTaskQueue(@Param String taskId, RoutingContext rc) { + try { + testUtilsBean.ensureQueue(taskId); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/queue/enqueueTaskStatusUpdateEvent/:taskId", methods = {Route.HttpMethod.POST}) + public void enqueueTaskStatusUpdateEvent(@Param String taskId, @Body String body, RoutingContext rc) { + + try { + TaskStatusUpdateEvent event = Utils.OBJECT_MAPPER.readValue(body, TaskStatusUpdateEvent.class); + testUtilsBean.enqueueEvent(taskId, event); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/queue/enqueueTaskArtifactUpdateEvent/:taskId", methods = {Route.HttpMethod.POST}) + public void enqueueTaskArtifactUpdateEvent(@Param String taskId, @Body String body, RoutingContext rc) { + + try { + TaskArtifactUpdateEvent event = Utils.OBJECT_MAPPER.readValue(body, TaskArtifactUpdateEvent.class); + testUtilsBean.enqueueEvent(taskId, event); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/streamingSubscribedCount", methods = {Route.HttpMethod.GET}, produces = {TEXT_PLAIN}) + public void getStreamingSubscribedCount(RoutingContext rc) { + rc.response() + .setStatusCode(200) + .end(String.valueOf(streamingSubscribedCount.get())); + } + + @Route(path = "/test/task/:taskId/config/:configId", methods = {Route.HttpMethod.DELETE}, type = Route.HandlerType.BLOCKING) + public void deleteTaskPushNotificationConfig(@Param String taskId, @Param String configId, RoutingContext rc) { + try { + Task task = testUtilsBean.getTask(taskId); + if (task == null) { + rc.response() + .setStatusCode(404) + .end(); + return; + } + testUtilsBean.deleteTaskPushNotificationConfig(taskId, configId); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + @Route(path = "/test/task/:taskId", methods = {Route.HttpMethod.POST}, type = Route.HandlerType.BLOCKING) + public void saveTaskPushNotificationConfig(@Param String taskId, @Body String body, RoutingContext rc) { + try { + PushNotificationConfig notificationConfig = Utils.OBJECT_MAPPER.readValue(body, PushNotificationConfig.class); + if (notificationConfig == null) { + rc.response() + .setStatusCode(404) + .end(); + return; + } + testUtilsBean.saveTaskPushNotificationConfig(taskId, notificationConfig); + rc.response() + .setStatusCode(200) + .end(); + } catch (Throwable t) { + errorResponse(t, rc); + } + } + + private void errorResponse(Throwable t, RoutingContext rc) { + t.printStackTrace(); + rc.response() + .setStatusCode(500) + .putHeader(CONTENT_TYPE, TEXT_PLAIN) + .end(); + } + +} diff --git a/reference-impl/src/test/java/io/a2a/server/apps/quarkus/QuarkusA2AServerTest.java b/reference-impl/src/test/java/io/a2a/server/apps/quarkus/QuarkusA2AServerTest.java new file mode 100644 index 000000000..f9ed48643 --- /dev/null +++ b/reference-impl/src/test/java/io/a2a/server/apps/quarkus/QuarkusA2AServerTest.java @@ -0,0 +1,12 @@ +package io.a2a.server.apps.quarkus; + +import io.a2a.server.apps.common.AbstractA2AServerTest; +import io.quarkus.test.junit.QuarkusTest; + +@QuarkusTest +public class QuarkusA2AServerTest extends AbstractA2AServerTest { + + public QuarkusA2AServerTest() { + super(8081); + } +} diff --git a/sdk-quarkus/src/test/resources/application.properties b/reference-impl/src/test/resources/application.properties similarity index 100% rename from sdk-quarkus/src/test/resources/application.properties rename to reference-impl/src/test/resources/application.properties diff --git a/sdk-jakarta/pom.xml b/sdk-jakarta/pom.xml deleted file mode 100644 index 394956c24..000000000 --- a/sdk-jakarta/pom.xml +++ /dev/null @@ -1,242 +0,0 @@ - - - 4.0.0 - - - io.github.a2asdk - a2a-java-sdk-parent - 0.2.3.Beta2-SNAPSHOT - - a2a-java-sdk-server-jakarta - - jar - - Java A2A SDK for Jakarta - Java SDK for the Agent2Agent Protocol (A2A) - SDK - Jakarta - - 5.1.0.Beta11 - 1.10.0.Final - - 1.2.6 - 10.0.0.Final - 3.3.4 - ${project.build.directory}${file.separator}wildfly - - 8787 - - - - - org.jboss.shrinkwrap - shrinkwrap-bom - ${version.org.jboss.shrinkwrap.shrinkwrap} - pom - import - - - org.jboss.arquillian - arquillian-bom - ${version.org.jboss.arquillian} - pom - import - - - org.wildfly.arquillian - wildfly-arquillian-bom - ${version.org.wildfly.arquillian} - pom - import - - - - - - ${project.groupId} - a2a-java-sdk-spec - ${project.version} - - - ${project.groupId} - a2a-java-sdk-server-common - ${project.version} - - - ${project.groupId} - a2a-java-sdk-tests-server-common - ${project.version} - provided - - - ${project.groupId} - a2a-java-sdk-tests-server-common - test-jar - test - ${project.version} - - - com.fasterxml.jackson.core - jackson-databind - provided - - - jakarta.enterprise - jakarta.enterprise.cdi-api - provided - - - jakarta.inject - jakarta.inject-api - provided - - - jakarta.json - jakarta.json-api - provided - - - jakarta.ws.rs - jakarta.ws.rs-api - provided - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - test - - - io.rest-assured - rest-assured - test - - - org.jboss.arquillian.junit5 - arquillian-junit5-container - test - - - org.wildfly.arquillian - wildfly-arquillian-container-managed - test - - - org.jboss.shrinkwrap - shrinkwrap-api - test - - - org.junit.jupiter - junit-jupiter - test - - - org.jboss.threads - jboss-threads - 3.9.1 - test - - - org.hamcrest - hamcrest - test - - - - - - org.wildfly.glow - wildfly-glow-arquillian-plugin - 1.4.1.Final - - - - org.wildfly - wildfly-galleon-pack - 36.0.1.Final - - - standalone.xml - - - - scan - - scan - - test-compile - - - - - org.wildfly.plugins - wildfly-maven-plugin - 5.1.3.Final - - ${project.build.directory}/glow-scan/provisioning.xml - ${jboss.home} - ${jboss.home} - - - - - - - - - - - test-provisioning - - package - - test-compile - - - - - org.apache.maven.plugins - maven-surefire-plugin - 3.5.3 - - - ${jboss.home} - arquillian.xml - ${arquillian.java.vm.args} - - - - - org.apache.maven.plugins - maven-dependency-plugin - 3.8.1 - - ${project.build.directory}/lib - - - - copy - generate-test-resources - - copy-dependencies - - - test - provided - - - - - - - - - debug.profile - debug - - -agentlib:jdwp=transport=dt_socket,address=*:${server.debug.port},server=y,suspend=y - - - - \ No newline at end of file diff --git a/sdk-jakarta/src/main/java/io/a2a/server/apps/jakarta/A2ARequestFilter.java b/sdk-jakarta/src/main/java/io/a2a/server/apps/jakarta/A2ARequestFilter.java deleted file mode 100644 index 32e47e783..000000000 --- a/sdk-jakarta/src/main/java/io/a2a/server/apps/jakarta/A2ARequestFilter.java +++ /dev/null @@ -1,68 +0,0 @@ -package io.a2a.server.apps.jakarta; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; - -import io.a2a.spec.CancelTaskRequest; -import io.a2a.spec.GetTaskPushNotificationConfigRequest; -import io.a2a.spec.GetTaskRequest; -import io.a2a.spec.SendMessageRequest; -import io.a2a.spec.SendStreamingMessageRequest; -import io.a2a.spec.SetTaskPushNotificationConfigRequest; -import io.a2a.spec.TaskResubscriptionRequest; -import jakarta.ws.rs.container.ContainerRequestContext; -import jakarta.ws.rs.container.ContainerRequestFilter; -import jakarta.ws.rs.container.PreMatching; -import jakarta.ws.rs.core.MediaType; -import jakarta.ws.rs.ext.Provider; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -@Provider -@PreMatching -public class A2ARequestFilter implements ContainerRequestFilter { - private static final Logger LOGGER = LoggerFactory.getLogger(A2ARequestFilter.class); - - @Override - public void filter(ContainerRequestContext requestContext) { - if (requestContext.getMethod().equals("POST") && requestContext.hasEntity()) { - try (InputStream entityInputStream = requestContext.getEntityStream()) { - byte[] requestBodyBytes = entityInputStream.readAllBytes(); - String requestBody = new String(requestBodyBytes); - // ensure the request is treated as a streaming request or a non-streaming request - // based on the method in the request body - if (isStreamingRequest(requestBody)) { - LOGGER.debug("Handling request as streaming: {}", requestBody); - putAcceptHeader(requestContext, MediaType.SERVER_SENT_EVENTS); - } else if (isNonStreamingRequest(requestBody)) { - LOGGER.debug("Handling request as non-streaming: {}", requestBody); - putAcceptHeader(requestContext, MediaType.APPLICATION_JSON); - } - // reset the entity stream - requestContext.setEntityStream(new ByteArrayInputStream(requestBodyBytes)); - } catch(IOException e){ - throw new RuntimeException("Unable to read the request body"); - } - } - } - - private static boolean isStreamingRequest(String requestBody) { - return requestBody.contains(SendStreamingMessageRequest.METHOD) || - requestBody.contains(TaskResubscriptionRequest.METHOD); - } - - private static boolean isNonStreamingRequest(String requestBody) { - return requestBody.contains(GetTaskRequest.METHOD) || - requestBody.contains(CancelTaskRequest.METHOD) || - requestBody.contains(SendMessageRequest.METHOD) || - requestBody.contains(SetTaskPushNotificationConfigRequest.METHOD) || - requestBody.contains(GetTaskPushNotificationConfigRequest.METHOD); - } - - private static void putAcceptHeader(ContainerRequestContext requestContext, String mediaType) { - requestContext.getHeaders().putSingle("Accept", mediaType); - } - -} \ No newline at end of file diff --git a/sdk-jakarta/src/main/java/io/a2a/server/apps/jakarta/A2AServerResource.java b/sdk-jakarta/src/main/java/io/a2a/server/apps/jakarta/A2AServerResource.java deleted file mode 100644 index 1e4db8c39..000000000 --- a/sdk-jakarta/src/main/java/io/a2a/server/apps/jakarta/A2AServerResource.java +++ /dev/null @@ -1,257 +0,0 @@ -package io.a2a.server.apps.jakarta; - -import java.util.concurrent.Executor; -import java.util.concurrent.Flow; - -import jakarta.enterprise.inject.Instance; -import jakarta.inject.Inject; -import jakarta.ws.rs.Consumes; -import jakarta.ws.rs.GET; -import jakarta.ws.rs.POST; -import jakarta.ws.rs.Path; -import jakarta.ws.rs.Produces; -import jakarta.ws.rs.core.Context; -import jakarta.ws.rs.core.MediaType; -import jakarta.ws.rs.core.Response; -import jakarta.ws.rs.ext.ExceptionMapper; -import jakarta.ws.rs.ext.Provider; -import jakarta.ws.rs.sse.Sse; -import jakarta.ws.rs.sse.SseEventSink; - -import com.fasterxml.jackson.core.JsonParseException; -import com.fasterxml.jackson.databind.JsonMappingException; -import io.a2a.server.ExtendedAgentCard; -import io.a2a.server.requesthandlers.JSONRPCHandler; -import io.a2a.server.util.async.Internal; -import io.a2a.spec.AgentCard; -import io.a2a.spec.CancelTaskRequest; -import io.a2a.spec.GetTaskPushNotificationConfigRequest; -import io.a2a.spec.GetTaskRequest; -import io.a2a.spec.IdJsonMappingException; -import io.a2a.spec.InvalidParamsError; -import io.a2a.spec.InvalidParamsJsonMappingException; -import io.a2a.spec.InvalidRequestError; -import io.a2a.spec.JSONErrorResponse; -import io.a2a.spec.JSONParseError; -import io.a2a.spec.JSONRPCError; -import io.a2a.spec.JSONRPCErrorResponse; -import io.a2a.spec.JSONRPCRequest; -import io.a2a.spec.JSONRPCResponse; -import io.a2a.spec.MethodNotFoundError; -import io.a2a.spec.MethodNotFoundJsonMappingException; -import io.a2a.spec.NonStreamingJSONRPCRequest; -import io.a2a.spec.SendMessageRequest; -import io.a2a.spec.SendStreamingMessageRequest; -import io.a2a.spec.SetTaskPushNotificationConfigRequest; -import io.a2a.spec.StreamingJSONRPCRequest; -import io.a2a.spec.TaskResubscriptionRequest; -import io.a2a.spec.UnsupportedOperationError; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -@Path("/") -public class A2AServerResource { - - private static final Logger LOGGER = LoggerFactory.getLogger(A2AServerResource.class); - - @Inject - JSONRPCHandler jsonRpcHandler; - - @Inject - @ExtendedAgentCard - Instance extendedAgentCard; - - // Hook so testing can wait until the async Subscription is subscribed. - private static volatile Runnable streamingIsSubscribedRunnable; - - @Inject - @Internal - Executor executor; - - /** - * Handles incoming POST requests to the main A2A endpoint. Dispatches the - * request to the appropriate JSON-RPC handler method and returns the response. - * - * @param request the JSON-RPC request - * @return the JSON-RPC response which may be an error response - */ - @POST - @Consumes(MediaType.APPLICATION_JSON) - @Produces(MediaType.APPLICATION_JSON) - public JSONRPCResponse handleNonStreamingRequests(NonStreamingJSONRPCRequest request) { - LOGGER.debug("Handling non-streaming request"); - try { - return processNonStreamingRequest(request); - } finally { - LOGGER.debug("Completed non-streaming request"); - } - } - - /** - * Handles incoming POST requests to the main A2A endpoint that involve Server-Sent Events (SSE). - * Dispatches the request to the appropriate JSON-RPC handler method and returns the response. - */ - @POST - @Consumes(MediaType.APPLICATION_JSON) - @Produces(MediaType.SERVER_SENT_EVENTS) - public void handleStreamingRequests(StreamingJSONRPCRequest request, @Context SseEventSink sseEventSink, @Context Sse sse) { - LOGGER.debug("Handling streaming request"); - executor.execute(() -> processStreamingRequest(request, sseEventSink, sse)); - LOGGER.debug("Submitted streaming request for async processing"); - } - - /** - * Handles incoming GET requests to the agent card endpoint. - * Returns the agent card in JSON format. - * - * @return the agent card - */ - @GET - @Path("/.well-known/agent.json") - @Produces(MediaType.APPLICATION_JSON) - public AgentCard getAgentCard() { - return jsonRpcHandler.getAgentCard(); - } - - /** - * Handles incoming GET requests to the authenticated extended agent card endpoint. - * Returns the agent card in JSON format. - * - * @return the authenticated extended agent card - */ - @GET - @Path("/agent/authenticatedExtendedCard") - @Produces(MediaType.APPLICATION_JSON) - public Response getAuthenticatedExtendedAgentCard() { - // TODO need to add authentication for this endpoint - // https://github.com/a2aproject/a2a-java/issues/77 - if (! jsonRpcHandler.getAgentCard().supportsAuthenticatedExtendedCard()) { - JSONErrorResponse errorResponse = new JSONErrorResponse("Extended agent card not supported or not enabled."); - return Response.status(Response.Status.NOT_FOUND) - .entity(errorResponse).build(); - } - if (! extendedAgentCard.isResolvable()) { - JSONErrorResponse errorResponse = new JSONErrorResponse("Authenticated extended agent card is supported but not configured on the server."); - return Response.status(Response.Status.NOT_FOUND) - .entity(errorResponse).build(); - } - return Response.ok(extendedAgentCard.get()) - .type(MediaType.APPLICATION_JSON) - .build(); - } - - private JSONRPCResponse processNonStreamingRequest(NonStreamingJSONRPCRequest request) { - if (request instanceof GetTaskRequest) { - return jsonRpcHandler.onGetTask((GetTaskRequest) request); - } else if (request instanceof CancelTaskRequest) { - return jsonRpcHandler.onCancelTask((CancelTaskRequest) request); - } else if (request instanceof SetTaskPushNotificationConfigRequest) { - return jsonRpcHandler.setPushNotification((SetTaskPushNotificationConfigRequest) request); - } else if (request instanceof GetTaskPushNotificationConfigRequest) { - return jsonRpcHandler.getPushNotification((GetTaskPushNotificationConfigRequest) request); - } else if (request instanceof SendMessageRequest) { - return jsonRpcHandler.onMessageSend((SendMessageRequest) request); - } else { - return generateErrorResponse(request, new UnsupportedOperationError()); - } - } - - private void processStreamingRequest(StreamingJSONRPCRequest request, SseEventSink sseEventSink, Sse sse) { - Flow.Publisher> publisher; - if (request instanceof SendStreamingMessageRequest) { - publisher = jsonRpcHandler.onMessageSendStream((SendStreamingMessageRequest) request); - handleStreamingResponse(publisher, sseEventSink, sse); - } else if (request instanceof TaskResubscriptionRequest) { - publisher = jsonRpcHandler.onResubscribeToTask((TaskResubscriptionRequest) request); - handleStreamingResponse(publisher, sseEventSink, sse); - } - } - - private void handleStreamingResponse(Flow.Publisher> publisher, SseEventSink sseEventSink, Sse sse) { - publisher.subscribe(new Flow.Subscriber>() { - private Flow.Subscription subscription; - - @Override - public void onSubscribe(Flow.Subscription subscription) { - this.subscription = subscription; - subscription.request(Long.MAX_VALUE); - // Notify tests that we are subscribed - Runnable runnable = streamingIsSubscribedRunnable; - if (runnable != null) { - runnable.run(); - } - } - - @Override - public void onNext(JSONRPCResponse item) { - - sseEventSink.send(sse.newEventBuilder() - .mediaType(MediaType.APPLICATION_JSON_TYPE) - .data(item) - .build()); - } - - @Override - public void onError(Throwable throwable) { - // TODO - sseEventSink.close(); - } - - @Override - public void onComplete() { - sseEventSink.close(); - } - }); - } - - private JSONRPCResponse generateErrorResponse(JSONRPCRequest request, JSONRPCError error) { - return new JSONRPCErrorResponse(request.getId(), error); - } - - static void setStreamingIsSubscribedRunnable(Runnable streamingIsSubscribedRunnable) { - A2AServerResource.streamingIsSubscribedRunnable = streamingIsSubscribedRunnable; - } - - @Provider - public static class JsonParseExceptionMapper implements ExceptionMapper { - - public JsonParseExceptionMapper() { - } - - @Override - public Response toResponse(JsonParseException exception) { - // parse error, not possible to determine the request id - return Response.ok(new JSONRPCErrorResponse(new JSONParseError())).type(MediaType.APPLICATION_JSON).build(); - } - - } - - @Provider - public static class JsonMappingExceptionMapper implements ExceptionMapper { - - public JsonMappingExceptionMapper(){ - } - - @Override - public Response toResponse(JsonMappingException exception) { - if (exception.getCause() instanceof JsonParseException) { - return Response.ok(new JSONRPCErrorResponse(new JSONParseError())).type(MediaType.APPLICATION_JSON).build(); - } else if (exception instanceof MethodNotFoundJsonMappingException) { - Object id = ((MethodNotFoundJsonMappingException) exception).getId(); - return Response.ok(new JSONRPCErrorResponse(id, new MethodNotFoundError())) - .type(MediaType.APPLICATION_JSON).build(); - } else if (exception instanceof InvalidParamsJsonMappingException) { - Object id = ((InvalidParamsJsonMappingException) exception).getId(); - return Response.ok(new JSONRPCErrorResponse(id, new InvalidParamsError())) - .type(MediaType.APPLICATION_JSON).build(); - } else if (exception instanceof IdJsonMappingException) { - Object id = ((IdJsonMappingException) exception).getId(); - return Response.ok(new JSONRPCErrorResponse(id, new InvalidRequestError())) - .type(MediaType.APPLICATION_JSON).build(); - } - // not possible to determine the request id - return Response.ok(new JSONRPCErrorResponse(new InvalidRequestError())).type(MediaType.APPLICATION_JSON).build(); - } - - } -} \ No newline at end of file diff --git a/sdk-jakarta/src/main/resources/META-INF/beans.xml b/sdk-jakarta/src/main/resources/META-INF/beans.xml deleted file mode 100644 index 9dfae34df..000000000 --- a/sdk-jakarta/src/main/resources/META-INF/beans.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - \ No newline at end of file diff --git a/sdk-jakarta/src/scripts/configure_logger.cli b/sdk-jakarta/src/scripts/configure_logger.cli deleted file mode 100644 index a45fb245d..000000000 --- a/sdk-jakarta/src/scripts/configure_logger.cli +++ /dev/null @@ -1,2 +0,0 @@ -/subsystem=logging/logger=org.jboss.weld:add(level=DEBUG) -/subsystem=logging/logger=io.a2a:add(level=DEBUG) \ No newline at end of file diff --git a/sdk-jakarta/src/test/java/io/a2a/server/apps/jakarta/JakartaA2AServerTest.java b/sdk-jakarta/src/test/java/io/a2a/server/apps/jakarta/JakartaA2AServerTest.java deleted file mode 100644 index 5bdb5225c..000000000 --- a/sdk-jakarta/src/test/java/io/a2a/server/apps/jakarta/JakartaA2AServerTest.java +++ /dev/null @@ -1,96 +0,0 @@ -package io.a2a.server.apps.jakarta; - - - -import io.a2a.server.apps.common.AbstractA2AServerTest; -import io.a2a.server.apps.common.AgentCardProducer; -import io.a2a.server.apps.common.AgentExecutorProducer; -import io.a2a.server.events.InMemoryQueueManager; -import io.a2a.server.tasks.TaskStore; -import jakarta.enterprise.context.ApplicationScoped; -import jakarta.inject.Inject; -import java.io.File; -import java.io.IOException; -import java.nio.file.DirectoryStream; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import org.jboss.arquillian.container.test.api.Deployment; -import org.jboss.arquillian.junit5.container.annotation.ArquillianTest; -import org.jboss.shrinkwrap.api.ShrinkWrap; -import org.jboss.shrinkwrap.api.spec.WebArchive; - -@ArquillianTest -@ApplicationScoped -public class JakartaA2AServerTest extends AbstractA2AServerTest { - - public JakartaA2AServerTest() { - super(8080); - } - - @Deployment - public static WebArchive createTestArchive() throws IOException { - final List prefixes = List.of( - "a2a-java-sdk-client", - "a2a-java-sdk-common", - "a2a-java-sdk-server-common", - "a2a-java-sdk-spec", - "jackson", - "mutiny", - "slf4j", - "rest-assured", - "groovy", - "http", - "commons", - "xml-path", - "json-path", - "hamcrest" - ); - List libraries = new ArrayList<>(); - try (DirectoryStream stream = Files.newDirectoryStream(Paths.get("target").resolve("lib"))) { - for (Path file : stream) { - String fileName = file.getFileName().toString(); - if (prefixes.stream().anyMatch(fileName::startsWith)) { - libraries.add(file.toFile()); - } - } - } - WebArchive archive = ShrinkWrap.create(WebArchive.class, "ROOT.war") - .addAsLibraries(libraries.toArray(new File[libraries.size()])) - .addClass(AbstractA2AServerTest.class) - .addClass(AgentCardProducer.class) - .addClass(AgentExecutorProducer.class) - .addClass(JakartaA2AServerTest.class) - .addClass(A2ARequestFilter.class) - .addClass(A2AServerResource.class) - .addClass(RestApplication.class) - .addAsManifestResource("META-INF/beans.xml", "beans.xml") - .addAsWebInfResource("META-INF/beans.xml", "beans.xml") - .addAsWebInfResource("WEB-INF/web.xml", "web.xml"); - return archive; - } - - @Inject - TaskStore taskStore; - - @Inject - InMemoryQueueManager queueManager; - - @Override - protected TaskStore getTaskStore() { - return taskStore; - } - - @Override - protected InMemoryQueueManager getQueueManager() { - return queueManager; - } - - @Override - protected void setStreamingSubscribedRunnable(Runnable runnable) { - A2AServerResource.setStreamingIsSubscribedRunnable(runnable); - } -} diff --git a/sdk-jakarta/src/test/java/io/a2a/server/apps/jakarta/RestApplication.java b/sdk-jakarta/src/test/java/io/a2a/server/apps/jakarta/RestApplication.java deleted file mode 100644 index e2fcf6af3..000000000 --- a/sdk-jakarta/src/test/java/io/a2a/server/apps/jakarta/RestApplication.java +++ /dev/null @@ -1,8 +0,0 @@ -package io.a2a.server.apps.jakarta; - -import jakarta.ws.rs.ApplicationPath; -import jakarta.ws.rs.core.Application; - -@ApplicationPath("/") -public class RestApplication extends Application { -} diff --git a/sdk-jakarta/src/test/resources/WEB-INF/web.xml b/sdk-jakarta/src/test/resources/WEB-INF/web.xml deleted file mode 100644 index 2678fbc83..000000000 --- a/sdk-jakarta/src/test/resources/WEB-INF/web.xml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - 30 - - - diff --git a/sdk-jakarta/src/test/resources/arquillian.xml b/sdk-jakarta/src/test/resources/arquillian.xml deleted file mode 100644 index 804afef5d..000000000 --- a/sdk-jakarta/src/test/resources/arquillian.xml +++ /dev/null @@ -1,19 +0,0 @@ - - - - - - - - - - - - ${basedir}/target/wildfly - - ${arquillian.java.vm.args} - true - - - - diff --git a/sdk-jakarta/src/test/resources/logging.properties b/sdk-jakarta/src/test/resources/logging.properties deleted file mode 100644 index 17885c831..000000000 --- a/sdk-jakarta/src/test/resources/logging.properties +++ /dev/null @@ -1,29 +0,0 @@ -# -# Copyright The WildFly Authors -# SPDX-License-Identifier: Apache-2.0 -# - -# Additional logger names to configure (root logger is always configured) -loggers=sun.rmi,org.jboss.shrinkwrap,org.apache.http.wire -logger.org.jboss.shrinkwrap.level=INFO -logger.sun.rmi.level=WARNING -logger.org.apache.http.wire.level=WARN - -# Root logger level -logger.level=WARN - -# Root logger handlers -logger.handlers=FILE - -# File handler configuration -handler.FILE=org.jboss.logmanager.handlers.FileHandler -handler.FILE.properties=autoFlush,append,fileName -handler.FILE.autoFlush=true -handler.FILE.fileName=./target/test.log -handler.FILE.formatter=PATTERN -handler.FILE.append=true - -# Formatter pattern configuration -formatter.PATTERN=org.jboss.logmanager.formatters.PatternFormatter -formatter.PATTERN.properties=pattern -formatter.PATTERN.pattern=%d{HH:mm:ss,SSS} %-5p [%c] (%t) %s%e%n diff --git a/sdk-quarkus/src/test/java/io/a2a/server/apps/quarkus/QuarkusA2AServerTest.java b/sdk-quarkus/src/test/java/io/a2a/server/apps/quarkus/QuarkusA2AServerTest.java deleted file mode 100644 index dab1954ba..000000000 --- a/sdk-quarkus/src/test/java/io/a2a/server/apps/quarkus/QuarkusA2AServerTest.java +++ /dev/null @@ -1,37 +0,0 @@ -package io.a2a.server.apps.quarkus; - -import jakarta.inject.Inject; - -import io.a2a.server.apps.common.AbstractA2AServerTest; -import io.a2a.server.events.InMemoryQueueManager; -import io.a2a.server.tasks.TaskStore; -import io.quarkus.test.junit.QuarkusTest; - -@QuarkusTest -public class QuarkusA2AServerTest extends AbstractA2AServerTest { - - @Inject - TaskStore taskStore; - - @Inject - InMemoryQueueManager queueManager; - - public QuarkusA2AServerTest() { - super(8081); - } - - @Override - protected TaskStore getTaskStore() { - return taskStore; - } - - @Override - protected InMemoryQueueManager getQueueManager() { - return queueManager; - } - - @Override - protected void setStreamingSubscribedRunnable(Runnable runnable) { - A2AServerRoutes.setStreamingMultiSseSupportSubscribedRunnable(runnable); - } -} diff --git a/sdk-server-common/pom.xml b/sdk-server-common/pom.xml index f35b91280..8fc372323 100644 --- a/sdk-server-common/pom.xml +++ b/sdk-server-common/pom.xml @@ -7,7 +7,7 @@ io.github.a2asdk a2a-java-sdk-parent - 0.2.3.Beta2-SNAPSHOT + 0.2.6.Beta1-SNAPSHOT a2a-java-sdk-server-common diff --git a/sdk-server-common/src/main/java/io/a2a/server/ServerCallContext.java b/sdk-server-common/src/main/java/io/a2a/server/ServerCallContext.java index 70fb344df..558f01eda 100644 --- a/sdk-server-common/src/main/java/io/a2a/server/ServerCallContext.java +++ b/sdk-server-common/src/main/java/io/a2a/server/ServerCallContext.java @@ -1,5 +1,26 @@ package io.a2a.server; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import io.a2a.server.auth.User; + public class ServerCallContext { - // TODO port the fields + // TODO Not totally sure yet about these field types + private final Map modelConfig = new ConcurrentHashMap<>(); + private final Map state; + private final User user; + + public ServerCallContext(User user, Map state) { + this.user = user; + this.state = state; + } + + public Map getState() { + return state; + } + + public User getUser() { + return user; + } } diff --git a/sdk-server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java b/sdk-server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java index bac7673a1..585b4fce4 100644 --- a/sdk-server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java +++ b/sdk-server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java @@ -22,13 +22,21 @@ public class RequestContext { private String contextId; private Task task; private List relatedTasks; - - public RequestContext(MessageSendParams params, String taskId, String contextId, Task task, List relatedTasks) throws InvalidParamsError { + private final ServerCallContext callContext; + + public RequestContext( + MessageSendParams params, + String taskId, + String contextId, + Task task, + List relatedTasks, + ServerCallContext callContext) throws InvalidParamsError { this.params = params; this.taskId = taskId; this.contextId = contextId; this.task = task; this.relatedTasks = relatedTasks == null ? new ArrayList<>() : relatedTasks; + this.callContext = callContext; // if the taskId and contextId were specified, they must match the params if (params != null) { @@ -73,6 +81,10 @@ public MessageSendConfiguration getConfiguration() { return params != null ? params.configuration() : null; } + public ServerCallContext getCallContext() { + return callContext; + } + public String getUserInput(String delimiter) { if (params == null) { return ""; @@ -187,7 +199,7 @@ public ServerCallContext getServerCallContext() { } public RequestContext build() { - return new RequestContext(params, taskId, contextId, task, relatedTasks); + return new RequestContext(params, taskId, contextId, task, relatedTasks, serverCallContext); } } diff --git a/sdk-server-common/src/main/java/io/a2a/server/auth/UnauthenticatedUser.java b/sdk-server-common/src/main/java/io/a2a/server/auth/UnauthenticatedUser.java index e4ec7a69c..9988ebbcf 100644 --- a/sdk-server-common/src/main/java/io/a2a/server/auth/UnauthenticatedUser.java +++ b/sdk-server-common/src/main/java/io/a2a/server/auth/UnauthenticatedUser.java @@ -1,6 +1,12 @@ package io.a2a.server.auth; public class UnauthenticatedUser implements User { + + public static UnauthenticatedUser INSTANCE = new UnauthenticatedUser(); + + private UnauthenticatedUser() { + } + @Override public boolean isAuthenticated() { return false; diff --git a/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java index 5d6cbec84..e79a97201 100644 --- a/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java +++ b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java @@ -16,20 +16,28 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; +import io.a2a.server.ServerCallContext; +import io.a2a.server.agentexecution.AgentExecutor; +import io.a2a.server.agentexecution.RequestContext; +import io.a2a.server.agentexecution.SimpleRequestContextBuilder; import io.a2a.server.events.EnhancedRunnable; import io.a2a.server.events.EventConsumer; import io.a2a.server.events.EventQueue; import io.a2a.server.events.QueueManager; import io.a2a.server.events.TaskQueueExistsException; -import io.a2a.server.tasks.PushNotifier; +import io.a2a.server.tasks.PushNotificationConfigStore; +import io.a2a.server.tasks.PushNotificationSender; import io.a2a.server.tasks.ResultAggregator; import io.a2a.server.tasks.TaskManager; import io.a2a.server.tasks.TaskStore; import io.a2a.server.util.async.Internal; +import io.a2a.spec.DeleteTaskPushNotificationConfigParams; import io.a2a.spec.Event; import io.a2a.spec.EventKind; +import io.a2a.spec.GetTaskPushNotificationConfigParams; import io.a2a.spec.InternalError; import io.a2a.spec.JSONRPCError; +import io.a2a.spec.ListTaskPushNotificationConfigParams; import io.a2a.spec.Message; import io.a2a.spec.MessageSendParams; import io.a2a.spec.PushNotificationConfig; @@ -43,9 +51,6 @@ import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; -import io.a2a.server.agentexecution.AgentExecutor; -import io.a2a.server.agentexecution.RequestContext; -import io.a2a.server.agentexecution.SimpleRequestContextBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -57,7 +62,8 @@ public class DefaultRequestHandler implements RequestHandler { private final AgentExecutor agentExecutor; private final TaskStore taskStore; private final QueueManager queueManager; - private final PushNotifier pushNotifier; + private final PushNotificationConfigStore pushConfigStore; + private final PushNotificationSender pushSender; private final Supplier requestContextBuilder; private final ConcurrentMap> runningAgents = new ConcurrentHashMap<>(); @@ -66,11 +72,13 @@ public class DefaultRequestHandler implements RequestHandler { @Inject public DefaultRequestHandler(AgentExecutor agentExecutor, TaskStore taskStore, - QueueManager queueManager, PushNotifier pushNotifier, @Internal Executor executor) { + QueueManager queueManager, PushNotificationConfigStore pushConfigStore, + PushNotificationSender pushSender, @Internal Executor executor) { this.agentExecutor = agentExecutor; this.taskStore = taskStore; this.queueManager = queueManager; - this.pushNotifier = pushNotifier; + this.pushConfigStore = pushConfigStore; + this.pushSender = pushSender; this.executor = executor; // TODO In Python this is also a constructor parameter defaulting to this SimpleRequestContextBuilder // implementation if the parameter is null. Skip that for now, since otherwise I get CDI errors, and @@ -80,7 +88,7 @@ public DefaultRequestHandler(AgentExecutor agentExecutor, TaskStore taskStore, } @Override - public Task onGetTask(TaskQueryParams params) throws JSONRPCError { + public Task onGetTask(TaskQueryParams params, ServerCallContext context) throws JSONRPCError { LOGGER.debug("onGetTask {}", params.id()); Task task = taskStore.get(params.id()); if (task == null) { @@ -107,7 +115,7 @@ public Task onGetTask(TaskQueryParams params) throws JSONRPCError { } @Override - public Task onCancelTask(TaskIdParams params) throws JSONRPCError { + public Task onCancelTask(TaskIdParams params, ServerCallContext context) throws JSONRPCError { Task task = taskStore.get(params.id()); if (task == null) { throw new TaskNotFoundError(); @@ -129,6 +137,7 @@ public Task onCancelTask(TaskIdParams params) throws JSONRPCError { .setTaskId(task.getId()) .setContextId(task.getContextId()) .setTask(task) + .setServerCallContext(context) .build(), queue); @@ -145,9 +154,9 @@ public Task onCancelTask(TaskIdParams params) throws JSONRPCError { } @Override - public EventKind onMessageSend(MessageSendParams params) throws JSONRPCError { + public EventKind onMessageSend(MessageSendParams params, ServerCallContext context) throws JSONRPCError { LOGGER.debug("onMessageSend - task: {}; context {}", params.message().getTaskId(), params.message().getContextId()); - MessageSendSetup mss = initMessageSend(params); + MessageSendSetup mss = initMessageSend(params, context); String taskId = mss.requestContext.getTaskId(); LOGGER.debug("Request context taskId: {}", taskId); @@ -193,9 +202,10 @@ public EventKind onMessageSend(MessageSendParams params) throws JSONRPCError { } @Override - public Flow.Publisher onMessageSendStream(MessageSendParams params) throws JSONRPCError { + public Flow.Publisher onMessageSendStream( + MessageSendParams params, ServerCallContext context) throws JSONRPCError { LOGGER.debug("onMessageSendStream - task: {}; context {}", params.message().getTaskId(), params.message().getContextId()); - MessageSendSetup mss = initMessageSend(params); + MessageSendSetup mss = initMessageSend(params, context); AtomicReference taskId = new AtomicReference<>(mss.requestContext.getTaskId()); EventQueue queue = queueManager.createOrTap(taskId.get()); @@ -226,20 +236,20 @@ public Flow.Publisher onMessageSendStream(MessageSendParams } catch (TaskQueueExistsException e) { // TODO Log } - if (pushNotifier != null && + if (pushConfigStore != null && params.configuration() != null && params.configuration().pushNotification() != null) { - pushNotifier.setInfo( + pushConfigStore.setInfo( createdTask.getId(), params.configuration().pushNotification()); } } - if (pushNotifier != null && taskId.get() != null) { + if (pushSender != null && taskId.get() != null) { EventKind latest = resultAggregator.getCurrentResult(); if (latest instanceof Task latestTask) { - pushNotifier.sendNotification(latestTask); + pushSender.sendNotification(latestTask); } } @@ -253,8 +263,9 @@ public Flow.Publisher onMessageSendStream(MessageSendParams } @Override - public TaskPushNotificationConfig onSetTaskPushNotificationConfig(TaskPushNotificationConfig params) throws JSONRPCError { - if (pushNotifier == null) { + public TaskPushNotificationConfig onSetTaskPushNotificationConfig( + TaskPushNotificationConfig params, ServerCallContext context) throws JSONRPCError { + if (pushConfigStore == null) { throw new UnsupportedOperationError(); } Task task = taskStore.get(params.taskId()); @@ -262,14 +273,15 @@ public TaskPushNotificationConfig onSetTaskPushNotificationConfig(TaskPushNotifi throw new TaskNotFoundError(); } - pushNotifier.setInfo(params.taskId(), params.pushNotificationConfig()); + pushConfigStore.setInfo(params.taskId(), params.pushNotificationConfig()); return params; } @Override - public TaskPushNotificationConfig onGetTaskPushNotificationConfig(TaskIdParams params) throws JSONRPCError { - if (pushNotifier == null) { + public TaskPushNotificationConfig onGetTaskPushNotificationConfig( + GetTaskPushNotificationConfigParams params, ServerCallContext context) throws JSONRPCError { + if (pushConfigStore == null) { throw new UnsupportedOperationError(); } Task task = taskStore.get(params.id()); @@ -277,16 +289,29 @@ public TaskPushNotificationConfig onGetTaskPushNotificationConfig(TaskIdParams p throw new TaskNotFoundError(); } - PushNotificationConfig pushNotificationConfig = pushNotifier.getInfo(params.id()); - if (pushNotificationConfig == null) { + List pushNotificationConfigList = pushConfigStore.getInfo(params.id()); + if (pushNotificationConfigList == null || pushNotificationConfigList.isEmpty()) { throw new InternalError("No push notification config found"); } - return new TaskPushNotificationConfig(params.id(), pushNotificationConfig); + return new TaskPushNotificationConfig(params.id(), getPushNotificationConfig(pushNotificationConfigList, params.pushNotificationConfigId())); + } + + private PushNotificationConfig getPushNotificationConfig(List notificationConfigList, + String configId) { + if (configId != null) { + for (PushNotificationConfig notificationConfig : notificationConfigList) { + if (configId.equals(notificationConfig.id())) { + return notificationConfig; + } + } + } + return notificationConfigList.get(0); } @Override - public Flow.Publisher onResubscribeToTask(TaskIdParams params) throws JSONRPCError { + public Flow.Publisher onResubscribeToTask( + TaskIdParams params, ServerCallContext context) throws JSONRPCError { Task task = taskStore.get(params.id()); if (task == null) { throw new TaskNotFoundError(); @@ -305,8 +330,46 @@ public Flow.Publisher onResubscribeToTask(TaskIdParams param return convertingProcessor(results, e -> (StreamingEventKind) e); } + @Override + public List onListTaskPushNotificationConfig( + ListTaskPushNotificationConfigParams params, ServerCallContext context) throws JSONRPCError { + if (pushConfigStore == null) { + throw new UnsupportedOperationError(); + } + + Task task = taskStore.get(params.id()); + if (task == null) { + throw new TaskNotFoundError(); + } + + List pushNotificationConfigList = pushConfigStore.getInfo(params.id()); + List taskPushNotificationConfigList = new ArrayList<>(); + if (pushNotificationConfigList != null) { + for (PushNotificationConfig pushNotificationConfig : pushNotificationConfigList) { + TaskPushNotificationConfig taskPushNotificationConfig = new TaskPushNotificationConfig(params.id(), pushNotificationConfig); + taskPushNotificationConfigList.add(taskPushNotificationConfig); + } + } + return taskPushNotificationConfigList; + } + + @Override + public void onDeleteTaskPushNotificationConfig( + DeleteTaskPushNotificationConfigParams params, ServerCallContext context) { + if (pushConfigStore == null) { + throw new UnsupportedOperationError(); + } + + Task task = taskStore.get(params.id()); + if (task == null) { + throw new TaskNotFoundError(); + } + + pushConfigStore.deleteInfo(params.id(), params.pushNotificationConfigId()); + } + private boolean shouldAddPushInfo(MessageSendParams params) { - return pushNotifier != null && params.configuration() != null && params.configuration().pushNotification() != null; + return pushConfigStore != null && params.configuration() != null && params.configuration().pushNotification() != null; } private EnhancedRunnable registerAndExecuteAgentAsync(String taskId, RequestContext requestContext, EventQueue queue) { @@ -343,7 +406,7 @@ private void cleanupProducer(String taskId) { }); } - private MessageSendSetup initMessageSend(MessageSendParams params) { + private MessageSendSetup initMessageSend(MessageSendParams params, ServerCallContext context) { TaskManager taskManager = new TaskManager( params.message().getTaskId(), params.message().getContextId(), @@ -357,7 +420,7 @@ private MessageSendSetup initMessageSend(MessageSendParams params) { if (shouldAddPushInfo(params)) { LOGGER.debug("Adding push info"); - pushNotifier.setInfo(task.getId(), params.configuration().pushNotification()); + pushConfigStore.setInfo(task.getId(), params.configuration().pushNotification()); } } @@ -366,6 +429,7 @@ private MessageSendSetup initMessageSend(MessageSendParams params) { .setTaskId(task == null ? null : task.getId()) .setContextId(params.message().getContextId()) .setTask(task) + .setServerCallContext(context) .build(); return new MessageSendSetup(taskManager, task, requestContext); } diff --git a/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/JSONRPCHandler.java b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/JSONRPCHandler.java index fac014b0e..fb120f981 100644 --- a/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/JSONRPCHandler.java +++ b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/JSONRPCHandler.java @@ -1,16 +1,20 @@ package io.a2a.server.requesthandlers; import static io.a2a.server.util.async.AsyncUtils.createTubeConfig; - -import java.util.concurrent.Flow; - import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Instance; import jakarta.inject.Inject; +import java.util.List; +import java.util.concurrent.Flow; + import io.a2a.server.PublicAgentCard; +import io.a2a.server.ServerCallContext; import io.a2a.spec.AgentCard; import io.a2a.spec.CancelTaskRequest; import io.a2a.spec.CancelTaskResponse; +import io.a2a.spec.DeleteTaskPushNotificationConfigRequest; +import io.a2a.spec.DeleteTaskPushNotificationConfigResponse; import io.a2a.spec.EventKind; import io.a2a.spec.GetTaskPushNotificationConfigRequest; import io.a2a.spec.GetTaskPushNotificationConfigResponse; @@ -19,6 +23,9 @@ import io.a2a.spec.InternalError; import io.a2a.spec.InvalidRequestError; import io.a2a.spec.JSONRPCError; +import io.a2a.spec.ListTaskPushNotificationConfigRequest; +import io.a2a.spec.ListTaskPushNotificationConfigResponse; +import io.a2a.spec.PushNotificationNotSupportedError; import io.a2a.spec.SendMessageRequest; import io.a2a.spec.SendMessageResponse; import io.a2a.spec.SendStreamingMessageRequest; @@ -47,9 +54,9 @@ public JSONRPCHandler(@PublicAgentCard AgentCard agentCard, RequestHandler reque this.requestHandler = requestHandler; } - public SendMessageResponse onMessageSend(SendMessageRequest request) { + public SendMessageResponse onMessageSend(SendMessageRequest request, ServerCallContext context) { try { - EventKind taskOrMessage = requestHandler.onMessageSend(request.getParams()); + EventKind taskOrMessage = requestHandler.onMessageSend(request.getParams(), context); return new SendMessageResponse(request.getId(), taskOrMessage); } catch (JSONRPCError e) { return new SendMessageResponse(request.getId(), e); @@ -59,7 +66,8 @@ public SendMessageResponse onMessageSend(SendMessageRequest request) { } - public Flow.Publisher onMessageSendStream(SendStreamingMessageRequest request) { + public Flow.Publisher onMessageSendStream( + SendStreamingMessageRequest request, ServerCallContext context) { if (!agentCard.capabilities().streaming()) { return ZeroPublisher.fromItems( new SendStreamingMessageResponse( @@ -68,7 +76,8 @@ public Flow.Publisher onMessageSendStream(SendStre } try { - Flow.Publisher publisher = requestHandler.onMessageSendStream(request.getParams()); + Flow.Publisher publisher = + requestHandler.onMessageSendStream(request.getParams(), context); // We can't use the convertingProcessor convenience method since that propagates any errors as an error handled // via Subscriber.onError() rather than as part of the SendStreamingResponse payload return convertToSendStreamingMessageResponse(request.getId(), publisher); @@ -79,9 +88,9 @@ public Flow.Publisher onMessageSendStream(SendStre } } - public CancelTaskResponse onCancelTask(CancelTaskRequest request) { + public CancelTaskResponse onCancelTask(CancelTaskRequest request, ServerCallContext context) { try { - Task task = requestHandler.onCancelTask(request.getParams()); + Task task = requestHandler.onCancelTask(request.getParams(), context); if (task != null) { return new CancelTaskResponse(request.getId(), task); } @@ -93,7 +102,8 @@ public CancelTaskResponse onCancelTask(CancelTaskRequest request) { } } - public Flow.Publisher onResubscribeToTask(TaskResubscriptionRequest request) { + public Flow.Publisher onResubscribeToTask( + TaskResubscriptionRequest request, ServerCallContext context) { if (!agentCard.capabilities().streaming()) { return ZeroPublisher.fromItems( new SendStreamingMessageResponse( @@ -102,7 +112,8 @@ public Flow.Publisher onResubscribeToTask(TaskResu } try { - Flow.Publisher publisher = requestHandler.onResubscribeToTask(request.getParams()); + Flow.Publisher publisher = + requestHandler.onResubscribeToTask(request.getParams(), context); // We can't use the convertingProcessor convenience method since that propagates any errors as an error handled // via Subscriber.onError() rather than as part of the SendStreamingResponse payload return convertToSendStreamingMessageResponse(request.getId(), publisher); @@ -113,9 +124,15 @@ public Flow.Publisher onResubscribeToTask(TaskResu } } - public GetTaskPushNotificationConfigResponse getPushNotification(GetTaskPushNotificationConfigRequest request) { + public GetTaskPushNotificationConfigResponse getPushNotificationConfig( + GetTaskPushNotificationConfigRequest request, ServerCallContext context) { + if (!agentCard.capabilities().pushNotifications()) { + return new GetTaskPushNotificationConfigResponse(request.getId(), + new PushNotificationNotSupportedError()); + } try { - TaskPushNotificationConfig config = requestHandler.onGetTaskPushNotificationConfig(request.getParams()); + TaskPushNotificationConfig config = + requestHandler.onGetTaskPushNotificationConfig(request.getParams(), context); return new GetTaskPushNotificationConfigResponse(request.getId(), config); } catch (JSONRPCError e) { return new GetTaskPushNotificationConfigResponse(request.getId().toString(), e); @@ -124,13 +141,15 @@ public GetTaskPushNotificationConfigResponse getPushNotification(GetTaskPushNoti } } - public SetTaskPushNotificationConfigResponse setPushNotification(SetTaskPushNotificationConfigRequest request) { + public SetTaskPushNotificationConfigResponse setPushNotificationConfig( + SetTaskPushNotificationConfigRequest request, ServerCallContext context) { if (!agentCard.capabilities().pushNotifications()) { return new SetTaskPushNotificationConfigResponse(request.getId(), - new InvalidRequestError("Push notifications are not supported by the agent")); + new PushNotificationNotSupportedError()); } try { - TaskPushNotificationConfig config = requestHandler.onSetTaskPushNotificationConfig(request.getParams()); + TaskPushNotificationConfig config = + requestHandler.onSetTaskPushNotificationConfig(request.getParams(), context); return new SetTaskPushNotificationConfigResponse(request.getId().toString(), config); } catch (JSONRPCError e) { return new SetTaskPushNotificationConfigResponse(request.getId(), e); @@ -139,9 +158,9 @@ public SetTaskPushNotificationConfigResponse setPushNotification(SetTaskPushNoti } } - public GetTaskResponse onGetTask(GetTaskRequest request) { + public GetTaskResponse onGetTask(GetTaskRequest request, ServerCallContext context) { try { - Task task = requestHandler.onGetTask(request.getParams()); + Task task = requestHandler.onGetTask(request.getParams(), context); return new GetTaskResponse(request.getId(), task); } catch (JSONRPCError e) { return new GetTaskResponse(request.getId(), e); @@ -150,6 +169,39 @@ public GetTaskResponse onGetTask(GetTaskRequest request) { } } + public ListTaskPushNotificationConfigResponse listPushNotificationConfig( + ListTaskPushNotificationConfigRequest request, ServerCallContext context) { + if ( !agentCard.capabilities().pushNotifications()) { + return new ListTaskPushNotificationConfigResponse(request.getId(), + new PushNotificationNotSupportedError()); + } + try { + List pushNotificationConfigList = + requestHandler.onListTaskPushNotificationConfig(request.getParams(), context); + return new ListTaskPushNotificationConfigResponse(request.getId(), pushNotificationConfigList); + } catch (JSONRPCError e) { + return new ListTaskPushNotificationConfigResponse(request.getId(), e); + } catch (Throwable t) { + return new ListTaskPushNotificationConfigResponse(request.getId(), new InternalError(t.getMessage())); + } + } + + public DeleteTaskPushNotificationConfigResponse deletePushNotificationConfig( + DeleteTaskPushNotificationConfigRequest request, ServerCallContext context) { + if ( !agentCard.capabilities().pushNotifications()) { + return new DeleteTaskPushNotificationConfigResponse(request.getId(), + new PushNotificationNotSupportedError()); + } + try { + requestHandler.onDeleteTaskPushNotificationConfig(request.getParams(), context); + return new DeleteTaskPushNotificationConfigResponse(request.getId()); + } catch (JSONRPCError e) { + return new DeleteTaskPushNotificationConfigResponse(request.getId(), e); + } catch (Throwable t) { + return new DeleteTaskPushNotificationConfigResponse(request.getId(), new InternalError(t.getMessage())); + } + } + public AgentCard getAgentCard() { return agentCard; } diff --git a/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/RequestHandler.java b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/RequestHandler.java index e2902bca0..e45bc3c62 100644 --- a/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/RequestHandler.java +++ b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/RequestHandler.java @@ -1,9 +1,14 @@ package io.a2a.server.requesthandlers; +import java.util.List; import java.util.concurrent.Flow; +import io.a2a.server.ServerCallContext; +import io.a2a.spec.DeleteTaskPushNotificationConfigParams; import io.a2a.spec.EventKind; +import io.a2a.spec.GetTaskPushNotificationConfigParams; import io.a2a.spec.JSONRPCError; +import io.a2a.spec.ListTaskPushNotificationConfigParams; import io.a2a.spec.MessageSendParams; import io.a2a.spec.StreamingEventKind; import io.a2a.spec.Task; @@ -12,17 +17,39 @@ import io.a2a.spec.TaskQueryParams; public interface RequestHandler { - Task onGetTask(TaskQueryParams params) throws JSONRPCError; + Task onGetTask( + TaskQueryParams params, + ServerCallContext context) throws JSONRPCError; - Task onCancelTask(TaskIdParams params) throws JSONRPCError; + Task onCancelTask( + TaskIdParams params, + ServerCallContext context) throws JSONRPCError; - EventKind onMessageSend(MessageSendParams params) throws JSONRPCError; + EventKind onMessageSend( + MessageSendParams params, + ServerCallContext context) throws JSONRPCError; - Flow.Publisher onMessageSendStream(MessageSendParams params) throws JSONRPCError; + Flow.Publisher onMessageSendStream( + MessageSendParams params, + ServerCallContext context) throws JSONRPCError; - TaskPushNotificationConfig onSetTaskPushNotificationConfig(TaskPushNotificationConfig params) throws JSONRPCError; + TaskPushNotificationConfig onSetTaskPushNotificationConfig( + TaskPushNotificationConfig params, + ServerCallContext context) throws JSONRPCError; - TaskPushNotificationConfig onGetTaskPushNotificationConfig(TaskIdParams params) throws JSONRPCError; + TaskPushNotificationConfig onGetTaskPushNotificationConfig( + GetTaskPushNotificationConfigParams params, + ServerCallContext context) throws JSONRPCError; - Flow.Publisher onResubscribeToTask(TaskIdParams params) throws JSONRPCError; + Flow.Publisher onResubscribeToTask( + TaskIdParams params, + ServerCallContext context) throws JSONRPCError; + + List onListTaskPushNotificationConfig( + ListTaskPushNotificationConfigParams params, + ServerCallContext context) throws JSONRPCError; + + void onDeleteTaskPushNotificationConfig( + DeleteTaskPushNotificationConfigParams params, + ServerCallContext context) throws JSONRPCError; } diff --git a/sdk-server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java b/sdk-server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java new file mode 100644 index 000000000..33ac4445c --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java @@ -0,0 +1,96 @@ +package io.a2a.server.tasks; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import com.fasterxml.jackson.core.JsonProcessingException; + +import io.a2a.http.A2AHttpClient; +import io.a2a.http.JdkA2AHttpClient; +import io.a2a.spec.PushNotificationConfig; +import io.a2a.spec.Task; +import io.a2a.util.Utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@ApplicationScoped +public class BasePushNotificationSender implements PushNotificationSender { + + private static final Logger LOGGER = LoggerFactory.getLogger(BasePushNotificationSender.class); + + private final A2AHttpClient httpClient; + private final PushNotificationConfigStore configStore; + + @Inject + public BasePushNotificationSender(PushNotificationConfigStore configStore) { + this.httpClient = new JdkA2AHttpClient(); + this.configStore = configStore; + } + + public BasePushNotificationSender(PushNotificationConfigStore configStore, A2AHttpClient httpClient) { + this.configStore = configStore; + this.httpClient = httpClient; + } + + @Override + public void sendNotification(Task task) { + List pushConfigs = configStore.getInfo(task.getId()); + if (pushConfigs == null || pushConfigs.isEmpty()) { + return; + } + + List> dispatchResults = pushConfigs + .stream() + .map(pushConfig -> dispatch(task, pushConfig)) + .toList(); + CompletableFuture allFutures = CompletableFuture.allOf(dispatchResults.toArray(new CompletableFuture[0])); + CompletableFuture dispatchResult = allFutures.thenApply(v -> dispatchResults.stream() + .allMatch(CompletableFuture::join)); + try { + boolean allSent = dispatchResult.get(); + if (! allSent) { + LOGGER.warn("Some push notifications failed to send for taskId: " + task.getId()); + } + } catch (InterruptedException | ExecutionException e) { + LOGGER.warn("Some push notifications failed to send for taskId " + task.getId() + ": {}", e.getMessage(), e); + } + } + + private CompletableFuture dispatch(Task task, PushNotificationConfig pushInfo) { + return CompletableFuture.supplyAsync(() -> dispatchNotification(task, pushInfo)); + } + + private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo) { + String url = pushInfo.url(); + + // TODO auth + + String body; + try { + body = Utils.OBJECT_MAPPER.writeValueAsString(task); + } catch (JsonProcessingException e) { + LOGGER.debug("Error writing value as string: {}", e.getMessage(), e); + return false; + } catch (Throwable throwable) { + LOGGER.debug("Error writing value as string: {}", throwable.getMessage(), throwable); + return false; + } + + try { + httpClient.createPost() + .url(url) + .body(body) + .post(); + } catch (IOException | InterruptedException e) { + LOGGER.debug("Error pushing data to " + url + ": {}", e.getMessage(), e); + return false; + } + return true; + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStore.java b/sdk-server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStore.java new file mode 100644 index 000000000..e66fc1669 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStore.java @@ -0,0 +1,77 @@ +package io.a2a.server.tasks; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import io.a2a.spec.PushNotificationConfig; + +/** + * In-memory implementation of the PushNotificationConfigStore interface. + * + * Stores push notification configurations in memory + */ +@ApplicationScoped +public class InMemoryPushNotificationConfigStore implements PushNotificationConfigStore { + + private final Map> pushNotificationInfos = Collections.synchronizedMap(new HashMap<>()); + + @Inject + public InMemoryPushNotificationConfigStore() { + } + + @Override + public void setInfo(String taskId, PushNotificationConfig notificationConfig) { + List notificationConfigList = pushNotificationInfos.getOrDefault(taskId, new ArrayList<>()); + PushNotificationConfig.Builder builder = new PushNotificationConfig.Builder(notificationConfig); + if (notificationConfig.id() == null) { + builder.id(taskId); + } + notificationConfig = builder.build(); + + Iterator notificationConfigIterator = notificationConfigList.iterator(); + while (notificationConfigIterator.hasNext()) { + PushNotificationConfig config = notificationConfigIterator.next(); + if (config.id().equals(notificationConfig.id())) { + notificationConfigIterator.remove(); + break; + } + } + notificationConfigList.add(notificationConfig); + pushNotificationInfos.put(taskId, notificationConfigList); + } + + @Override + public List getInfo(String taskId) { + return pushNotificationInfos.get(taskId); + } + + @Override + public void deleteInfo(String taskId, String configId) { + if (configId == null) { + configId = taskId; + } + List notificationConfigList = pushNotificationInfos.get(taskId); + if (notificationConfigList == null || notificationConfigList.isEmpty()) { + return; + } + + Iterator notificationConfigIterator = notificationConfigList.iterator(); + while (notificationConfigIterator.hasNext()) { + PushNotificationConfig config = notificationConfigIterator.next(); + if (configId.equals(config.id())) { + notificationConfigIterator.remove(); + break; + } + } + if (notificationConfigList.isEmpty()) { + pushNotificationInfos.remove(taskId); + } + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotifier.java b/sdk-server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotifier.java deleted file mode 100644 index 6fb1fb39a..000000000 --- a/sdk-server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotifier.java +++ /dev/null @@ -1,78 +0,0 @@ -package io.a2a.server.tasks; - -import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -import jakarta.enterprise.context.ApplicationScoped; -import jakarta.inject.Inject; - -import com.fasterxml.jackson.core.JsonProcessingException; -import io.a2a.http.A2AHttpClient; -import io.a2a.http.JdkA2AHttpClient; -import io.a2a.spec.PushNotificationConfig; -import io.a2a.spec.Task; -import io.a2a.util.Utils; - -@ApplicationScoped -public class InMemoryPushNotifier implements PushNotifier { - private final A2AHttpClient httpClient; - private final Map pushNotificationInfos = Collections.synchronizedMap(new HashMap<>()); - - @Inject - public InMemoryPushNotifier() { - this.httpClient = new JdkA2AHttpClient(); - } - - public InMemoryPushNotifier(A2AHttpClient httpClient) { - this.httpClient = httpClient; - } - - @Override - public void setInfo(String taskId, PushNotificationConfig notificationConfig) { - pushNotificationInfos.put(taskId, notificationConfig); - } - - @Override - public PushNotificationConfig getInfo(String taskId) { - return pushNotificationInfos.get(taskId); - } - - @Override - public void deleteInfo(String taskId) { - pushNotificationInfos.remove(taskId); - } - - @Override - public void sendNotification(Task task) { - PushNotificationConfig pushInfo = pushNotificationInfos.get(task.getId()); - if (pushInfo == null) { - return; - } - String url = pushInfo.url(); - - // TODO auth - - String body; - try { - body = Utils.OBJECT_MAPPER.writeValueAsString(task); - } catch (JsonProcessingException e) { - e.printStackTrace(); - throw new RuntimeException("Error writing value as string: " + e.getMessage(), e); - } catch (Throwable throwable) { - throwable.printStackTrace(); - throw new RuntimeException("Error writing value as string: " + throwable.getMessage(), throwable); - } - - try { - httpClient.createPost() - .url(url) - .body(body) - .post(); - } catch (IOException | InterruptedException e) { - throw new RuntimeException("Error pushing data to " + url + ": " + e.getMessage(), e); - } - - } -} diff --git a/sdk-server-common/src/main/java/io/a2a/server/tasks/PushNotificationConfigStore.java b/sdk-server-common/src/main/java/io/a2a/server/tasks/PushNotificationConfigStore.java new file mode 100644 index 000000000..68f132620 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/tasks/PushNotificationConfigStore.java @@ -0,0 +1,33 @@ +package io.a2a.server.tasks; + +import java.util.List; + +import io.a2a.spec.PushNotificationConfig; + +/** + * Interface for storing and retrieving push notification configurations for tasks. + */ +public interface PushNotificationConfigStore { + + /** + * Sets or updates the push notification configuration for a task. + * @param taskId the task ID + * @param notificationConfig the push notification configuration + */ + void setInfo(String taskId, PushNotificationConfig notificationConfig); + + /** + * Retrieves the push notification configuration for a task. + * @param taskId the task ID + * @return the push notification configurations for a task + */ + List getInfo(String taskId); + + /** + * Deletes the push notification configuration for a task. + * @param taskId the task ID + * @param configId the push notification configuration + */ + void deleteInfo(String taskId, String configId); + +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/tasks/PushNotificationSender.java b/sdk-server-common/src/main/java/io/a2a/server/tasks/PushNotificationSender.java new file mode 100644 index 000000000..81d577f46 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/tasks/PushNotificationSender.java @@ -0,0 +1,15 @@ +package io.a2a.server.tasks; + +import io.a2a.spec.Task; + +/** + * Interface for sending push notifications for tasks. + */ +public interface PushNotificationSender { + + /** + * Sends a push notification containing the latest task state. + * @param task the task + */ + void sendNotification(Task task); +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/tasks/PushNotifier.java b/sdk-server-common/src/main/java/io/a2a/server/tasks/PushNotifier.java deleted file mode 100644 index 2dfc7dff0..000000000 --- a/sdk-server-common/src/main/java/io/a2a/server/tasks/PushNotifier.java +++ /dev/null @@ -1,14 +0,0 @@ -package io.a2a.server.tasks; - -import io.a2a.spec.PushNotificationConfig; -import io.a2a.spec.Task; - -public interface PushNotifier { - void setInfo(String taskId, PushNotificationConfig notificationConfig); - - PushNotificationConfig getInfo(String taskId); - - void deleteInfo(String taskId); - - void sendNotification(Task task); -} diff --git a/sdk-server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java b/sdk-server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java index 5211d1deb..ad9ed9124 100644 --- a/sdk-server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java +++ b/sdk-server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java @@ -4,6 +4,7 @@ import static io.a2a.server.util.async.AsyncUtils.createTubeConfig; import static io.a2a.server.util.async.AsyncUtils.processor; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -100,11 +101,7 @@ public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer) // new request is expected in order for the agent to make progress, // so the agent should exit. - // TODO There is the following line in the Python code I don't totally get - // asyncio.create_task(self._continue_consuming(event_stream)) - // I think it means the continueConsuming() call should be done in another thread - continueConsuming(all); - + CompletableFuture.runAsync(() -> continueConsuming(all)); interrupted.set(true); return false; } diff --git a/sdk-server-common/src/main/java/io/a2a/server/tasks/TaskManager.java b/sdk-server-common/src/main/java/io/a2a/server/tasks/TaskManager.java index 089dc7abe..cebddf85a 100644 --- a/sdk-server-common/src/main/java/io/a2a/server/tasks/TaskManager.java +++ b/sdk-server-common/src/main/java/io/a2a/server/tasks/TaskManager.java @@ -16,8 +16,13 @@ import io.a2a.spec.TaskArtifactUpdateEvent; import io.a2a.spec.TaskStatus; import io.a2a.spec.TaskStatusUpdateEvent; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class TaskManager { + + private static final Logger LOGGER = LoggerFactory.getLogger(TaskManager.class); + private volatile String taskId; private volatile String contextId; private final TaskStore taskStore; @@ -101,16 +106,18 @@ Task saveTaskEvent(TaskArtifactUpdateEvent event) throws A2AServerException { // This represents the first chunk for this artifact index if (existingArtifactIndex >= 0) { // Replace the existing artifact entirely with the new artifact + LOGGER.debug("Replacing artifact at id {} for task {}", artifactId, taskId); artifacts.set(existingArtifactIndex, newArtifact); } else { // Append the new artifact since no artifact with this id/index exists yet + LOGGER.debug("Adding artifact at id {} for task {}", artifactId, taskId); artifacts.add(newArtifact); } } else if (existingArtifact != null) { // Append new parts to the existing artifact's parts list // Do this to a copy - + LOGGER.debug("Appending parts to artifact id {} for task {}", artifactId, taskId); List> parts = new ArrayList<>(existingArtifact.parts()); parts.addAll(newArtifact.parts()); Artifact updated = new Artifact.Builder(existingArtifact) @@ -120,6 +127,9 @@ Task saveTaskEvent(TaskArtifactUpdateEvent event) throws A2AServerException { } else { // We received a chunk to append, but we don't have an existing artifact. // We will ignore this chunk + LOGGER.warn( + "Received append=true for nonexistent artifact index for artifact {} in task {}. Ignoring chunk.", + artifactId, taskId); } task = new Task.Builder(task) diff --git a/sdk-server-common/src/test/java/io/a2a/server/agentexecution/RequestContextTest.java b/sdk-server-common/src/test/java/io/a2a/server/agentexecution/RequestContextTest.java index c9bb79061..081cc873a 100644 --- a/sdk-server-common/src/test/java/io/a2a/server/agentexecution/RequestContextTest.java +++ b/sdk-server-common/src/test/java/io/a2a/server/agentexecution/RequestContextTest.java @@ -25,7 +25,7 @@ public class RequestContextTest { @Test public void testInitWithoutParams() { - RequestContext context = new RequestContext(null, null, null, null, null); + RequestContext context = new RequestContext(null, null, null, null, null, null); assertNull(context.getMessage()); assertNull(context.getTaskId()); assertNull(context.getContextId()); @@ -46,7 +46,7 @@ public void testInitWithParamsNoIds() { .thenReturn(taskId) .thenReturn(contextId); - RequestContext context = new RequestContext(mockParams, null, null, null, null); + RequestContext context = new RequestContext(mockParams, null, null, null, null, null); assertEquals(mockParams.message(), context.getMessage()); assertEquals(taskId.toString(), context.getTaskId()); @@ -62,7 +62,7 @@ public void testInitWithTaskId() { var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId(taskId).build(); var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); - RequestContext context = new RequestContext(mockParams, taskId, null, null, null); + RequestContext context = new RequestContext(mockParams, taskId, null, null, null, null); assertEquals(taskId, context.getTaskId()); assertEquals(taskId, mockParams.message().getTaskId()); @@ -73,7 +73,7 @@ public void testInitWithContextId() { String contextId = "context-456"; var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).contextId(contextId).build(); var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); - RequestContext context = new RequestContext(mockParams, null, contextId, null, null); + RequestContext context = new RequestContext(mockParams, null, contextId, null, null, null); assertEquals(contextId, context.getContextId()); assertEquals(contextId, mockParams.message().getContextId()); @@ -85,7 +85,7 @@ public void testInitWithBothIds() { String contextId = "context-456"; var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId(taskId).contextId(contextId).build(); var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); - RequestContext context = new RequestContext(mockParams, taskId, contextId, null, null); + RequestContext context = new RequestContext(mockParams, taskId, contextId, null, null, null); assertEquals(taskId, context.getTaskId()); assertEquals(taskId, mockParams.message().getTaskId()); @@ -99,14 +99,14 @@ public void testInitWithTask() { var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); - RequestContext context = new RequestContext(mockParams, null, null, mockTask, null); + RequestContext context = new RequestContext(mockParams, null, null, mockTask, null, null); assertEquals(mockTask, context.getTask()); } @Test public void testGetUserInputNoParams() { - RequestContext context = new RequestContext(null, null, null, null, null); + RequestContext context = new RequestContext(null, null, null, null, null, null); assertEquals("", context.getUserInput(null)); } @@ -114,7 +114,7 @@ public void testGetUserInputNoParams() { public void testAttachRelatedTask() { var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); - RequestContext context = new RequestContext(null, null, null, null, null); + RequestContext context = new RequestContext(null, null, null, null, null, null); assertEquals(0, context.getRelatedTasks().size()); context.attachRelatedTask(mockTask); @@ -133,7 +133,7 @@ public void testCheckOrGenerateTaskIdWithExistingTaskId() { var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId(existingId).build(); var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); - RequestContext context = new RequestContext(mockParams, null, null, null, null); + RequestContext context = new RequestContext(mockParams, null, null, null, null, null); assertEquals(existingId, context.getTaskId()); assertEquals(existingId, mockParams.message().getTaskId()); @@ -146,7 +146,7 @@ public void testCheckOrGenerateContextIdWithExistingContextId() { var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).contextId(existingId).build(); var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); - RequestContext context = new RequestContext(mockParams, null, null, null, null); + RequestContext context = new RequestContext(mockParams, null, null, null, null, null); assertEquals(existingId, context.getContextId()); assertEquals(existingId, mockParams.message().getContextId()); @@ -159,7 +159,7 @@ public void testInitRaisesErrorOnTaskIdMismatch() { var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); InvalidParamsError error = assertThrows(InvalidParamsError.class, () -> - new RequestContext(mockParams, "wrong-task-id", null, mockTask, null)); + new RequestContext(mockParams, "wrong-task-id", null, mockTask, null, null)); assertTrue(error.getMessage().contains("bad task id")); } @@ -171,7 +171,7 @@ public void testInitRaisesErrorOnContextIdMismatch() { var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); InvalidParamsError error = assertThrows(InvalidParamsError.class, () -> - new RequestContext(mockParams, mockTask.getId(), "wrong-context-id", mockTask, null)); + new RequestContext(mockParams, mockTask.getId(), "wrong-context-id", mockTask, null, null)); assertTrue(error.getMessage().contains("bad context id")); } @@ -184,7 +184,7 @@ public void testWithRelatedTasksProvided() { relatedTasks.add(mockTask); relatedTasks.add(mock(Task.class)); - RequestContext context = new RequestContext(null, null, null, null, relatedTasks); + RequestContext context = new RequestContext(null, null, null, null, relatedTasks, null); assertEquals(relatedTasks, context.getRelatedTasks()); assertEquals(2, context.getRelatedTasks().size()); @@ -192,7 +192,7 @@ public void testWithRelatedTasksProvided() { @Test public void testMessagePropertyWithoutParams() { - RequestContext context = new RequestContext(null, null, null, null, null); + RequestContext context = new RequestContext(null, null, null, null, null, null); assertNull(context.getMessage()); } @@ -201,7 +201,7 @@ public void testMessagePropertyWithParams() { var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).build(); var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); - RequestContext context = new RequestContext(mockParams, null, null, null, null); + RequestContext context = new RequestContext(mockParams, null, null, null, null, null); assertEquals(mockParams.message(), context.getMessage()); } @@ -214,7 +214,7 @@ public void testInitWithExistingIdsInMessage() { .taskId(existingTaskId).contextId(existingContextId).build(); var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); - RequestContext context = new RequestContext(mockParams, null, null, null, null); + RequestContext context = new RequestContext(mockParams, null, null, null, null, null); assertEquals(existingTaskId, context.getTaskId()); assertEquals(existingContextId, context.getContextId()); @@ -227,7 +227,7 @@ public void testInitWithTaskIdAndExistingTaskIdMatch() { var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); - RequestContext context = new RequestContext(mockParams, mockTask.getId(), null, mockTask, null); + RequestContext context = new RequestContext(mockParams, mockTask.getId(), null, mockTask, null, null); assertEquals(mockTask.getId(), context.getTaskId()); assertEquals(mockTask, context.getTask()); @@ -240,7 +240,7 @@ public void testInitWithContextIdAndExistingContextIdMatch() { var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); - RequestContext context = new RequestContext(mockParams, mockTask.getId(), mockTask.getContextId(), mockTask, null); + RequestContext context = new RequestContext(mockParams, mockTask.getId(), mockTask.getContextId(), mockTask, null, null); assertEquals(mockTask.getContextId(), context.getContextId()); assertEquals(mockTask, context.getTask()); diff --git a/sdk-server-common/src/test/java/io/a2a/server/requesthandlers/JSONRPCHandlerTest.java b/sdk-server-common/src/test/java/io/a2a/server/requesthandlers/JSONRPCHandlerTest.java index 49913c0b6..f67fa87cf 100644 --- a/sdk-server-common/src/test/java/io/a2a/server/requesthandlers/JSONRPCHandlerTest.java +++ b/sdk-server-common/src/test/java/io/a2a/server/requesthandlers/JSONRPCHandlerTest.java @@ -12,6 +12,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; @@ -21,19 +22,22 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; -import io.a2a.spec.InternalError; import jakarta.enterprise.context.Dependent; import io.a2a.http.A2AHttpClient; import io.a2a.http.A2AHttpResponse; +import io.a2a.server.ServerCallContext; import io.a2a.server.agentexecution.AgentExecutor; import io.a2a.server.agentexecution.RequestContext; +import io.a2a.server.auth.UnauthenticatedUser; import io.a2a.server.events.EventConsumer; import io.a2a.server.events.EventQueue; import io.a2a.server.events.InMemoryQueueManager; -import io.a2a.server.tasks.InMemoryPushNotifier; +import io.a2a.server.tasks.BasePushNotificationSender; +import io.a2a.server.tasks.InMemoryPushNotificationConfigStore; import io.a2a.server.tasks.InMemoryTaskStore; -import io.a2a.server.tasks.PushNotifier; +import io.a2a.server.tasks.PushNotificationConfigStore; +import io.a2a.server.tasks.PushNotificationSender; import io.a2a.server.tasks.ResultAggregator; import io.a2a.server.tasks.TaskStore; import io.a2a.server.tasks.TaskUpdater; @@ -42,7 +46,11 @@ import io.a2a.spec.Artifact; import io.a2a.spec.CancelTaskRequest; import io.a2a.spec.CancelTaskResponse; +import io.a2a.spec.DeleteTaskPushNotificationConfigParams; +import io.a2a.spec.DeleteTaskPushNotificationConfigRequest; +import io.a2a.spec.DeleteTaskPushNotificationConfigResponse; import io.a2a.spec.Event; +import io.a2a.spec.GetTaskPushNotificationConfigParams; import io.a2a.spec.GetTaskPushNotificationConfigRequest; import io.a2a.spec.GetTaskPushNotificationConfigResponse; import io.a2a.spec.GetTaskRequest; @@ -50,9 +58,13 @@ import io.a2a.spec.InternalError; import io.a2a.spec.InvalidRequestError; import io.a2a.spec.JSONRPCError; +import io.a2a.spec.ListTaskPushNotificationConfigParams; +import io.a2a.spec.ListTaskPushNotificationConfigRequest; +import io.a2a.spec.ListTaskPushNotificationConfigResponse; import io.a2a.spec.Message; import io.a2a.spec.MessageSendParams; import io.a2a.spec.PushNotificationConfig; +import io.a2a.spec.PushNotificationNotSupportedError; import io.a2a.spec.SendMessageRequest; import io.a2a.spec.SendMessageResponse; import io.a2a.spec.SendStreamingMessageRequest; @@ -108,6 +120,8 @@ public class JSONRPCHandlerTest { private final Executor internalExecutor = Executors.newCachedThreadPool(); + private final ServerCallContext callContext = new ServerCallContext(UnauthenticatedUser.INSTANCE, Map.of("foo", "bar")); + @BeforeEach public void init() { @@ -130,9 +144,10 @@ public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPC taskStore = new InMemoryTaskStore(); queueManager = new InMemoryQueueManager(); httpClient = new TestHttpClient(); - PushNotifier pushNotifier = new InMemoryPushNotifier(httpClient); + PushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore(); + PushNotificationSender pushSender = new BasePushNotificationSender(pushConfigStore, httpClient); - requestHandler = new DefaultRequestHandler(executor, taskStore, queueManager, pushNotifier, internalExecutor); + requestHandler = new DefaultRequestHandler(executor, taskStore, queueManager, pushConfigStore, pushSender, internalExecutor); } @AfterEach @@ -146,7 +161,7 @@ public void testOnGetTaskSuccess() throws Exception { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); taskStore.save(MINIMAL_TASK); GetTaskRequest request = new GetTaskRequest("1", new TaskQueryParams(MINIMAL_TASK.getId())); - GetTaskResponse response = handler.onGetTask(request); + GetTaskResponse response = handler.onGetTask(request, callContext); assertEquals(request.getId(), response.getId()); assertSame(MINIMAL_TASK, response.getResult()); assertNull(response.getError()); @@ -156,7 +171,7 @@ public void testOnGetTaskSuccess() throws Exception { public void testOnGetTaskNotFound() throws Exception { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); GetTaskRequest request = new GetTaskRequest("1", new TaskQueryParams(MINIMAL_TASK.getId())); - GetTaskResponse response = handler.onGetTask(request); + GetTaskResponse response = handler.onGetTask(request, callContext); assertEquals(request.getId(), response.getId()); assertInstanceOf(TaskNotFoundError.class, response.getError()); assertNull(response.getResult()); @@ -177,7 +192,7 @@ public void testOnCancelTaskSuccess() throws Exception { }; CancelTaskRequest request = new CancelTaskRequest("111", new TaskIdParams(MINIMAL_TASK.getId())); - CancelTaskResponse response = handler.onCancelTask(request); + CancelTaskResponse response = handler.onCancelTask(request, callContext); assertNull(response.getError()); assertEquals(request.getId(), response.getId()); @@ -197,7 +212,7 @@ public void testOnCancelTaskNotSupported() { }; CancelTaskRequest request = new CancelTaskRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); - CancelTaskResponse response = handler.onCancelTask(request); + CancelTaskResponse response = handler.onCancelTask(request, callContext); assertEquals(request.getId(), response.getId()); assertNull(response.getResult()); assertInstanceOf(UnsupportedOperationError.class, response.getError()); @@ -207,7 +222,7 @@ public void testOnCancelTaskNotSupported() { public void testOnCancelTaskNotFound() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); CancelTaskRequest request = new CancelTaskRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); - CancelTaskResponse response = handler.onCancelTask(request); + CancelTaskResponse response = handler.onCancelTask(request, callContext); assertEquals(request.getId(), response.getId()); assertNull(response.getResult()); assertInstanceOf(TaskNotFoundError.class, response.getError()); @@ -224,7 +239,7 @@ public void testOnMessageNewMessageSuccess() { .contextId(MINIMAL_TASK.getContextId()) .build(); SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); - SendMessageResponse response = handler.onMessageSend(request); + SendMessageResponse response = handler.onMessageSend(request, callContext); assertNull(response.getError()); // The Python implementation returns a Task here, but then again they are using hardcoded mocks and // bypassing the whole EventQueue. @@ -249,7 +264,7 @@ public void testOnMessageNewMessageSuccessMocks() { try (MockedConstruction mocked = Mockito.mockConstruction( EventConsumer.class, (mock, context) -> {Mockito.doReturn(ZeroPublisher.fromItems(MINIMAL_TASK)).when(mock).consumeAll();})){ - response = handler.onMessageSend(request); + response = handler.onMessageSend(request, callContext); } assertNull(response.getError()); assertSame(MINIMAL_TASK, response.getResult()); @@ -267,7 +282,7 @@ public void testOnMessageNewMessageWithExistingTaskSuccess() { .contextId(MINIMAL_TASK.getContextId()) .build(); SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); - SendMessageResponse response = handler.onMessageSend(request); + SendMessageResponse response = handler.onMessageSend(request, callContext); assertNull(response.getError()); // The Python implementation returns a Task here, but then again they are using hardcoded mocks and // bypassing the whole EventQueue. @@ -293,7 +308,7 @@ public void testOnMessageNewMessageWithExistingTaskSuccessMocks() { EventConsumer.class, (mock, context) -> { Mockito.doReturn(ZeroPublisher.fromItems(MINIMAL_TASK)).when(mock).consumeAll();})){ - response = handler.onMessageSend(request); + response = handler.onMessageSend(request, callContext); } assertNull(response.getError()); assertSame(MINIMAL_TASK, response.getResult()); @@ -314,7 +329,7 @@ public void testOnMessageError() { .build(); SendMessageRequest request = new SendMessageRequest( "1", new MessageSendParams(message, null, null)); - SendMessageResponse response = handler.onMessageSend(request); + SendMessageResponse response = handler.onMessageSend(request, callContext); assertInstanceOf(UnsupportedOperationError.class, response.getError()); assertNull(response.getResult()); } @@ -333,7 +348,7 @@ public void testOnMessageErrorMocks() { EventConsumer.class, (mock, context) -> { Mockito.doReturn(ZeroPublisher.fromItems(new UnsupportedOperationError())).when(mock).consumeAll();})){ - response = handler.onMessageSend(request); + response = handler.onMessageSend(request, callContext); } assertInstanceOf(UnsupportedOperationError.class, response.getError()); @@ -354,7 +369,7 @@ public void testOnMessageStreamNewMessageSuccess() { SendStreamingMessageRequest request = new SendStreamingMessageRequest( "1", new MessageSendParams(message, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); List results = new ArrayList<>(); CountDownLatch latch = new CountDownLatch(1); @@ -428,7 +443,7 @@ public void testOnMessageStreamNewMessageSuccessMocks() { EventConsumer.class, (mock, context) -> { Mockito.doReturn(ZeroPublisher.fromIterable(events)).when(mock).consumeAll();})){ - response = handler.onMessageSendStream(request); + response = handler.onMessageSendStream(request, callContext); } List results = new ArrayList<>(); @@ -482,7 +497,7 @@ public void testOnMessageStreamNewMessageExistingTaskSuccess() throws Exception SendStreamingMessageRequest request = new SendStreamingMessageRequest( "1", new MessageSendParams(message, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); // This Publisher never completes so we subscribe in a new thread. // I _think_ that is as expected, and testOnMessageStreamNewMessageSendPushNotificationSuccess seems @@ -573,7 +588,7 @@ public void testOnMessageStreamNewMessageExistingTaskSuccessMocks() { EventConsumer.class, (mock, context) -> { Mockito.doReturn(ZeroPublisher.fromIterable(events)).when(mock).consumeAll();})){ - response = handler.onMessageSendStream(request); + response = handler.onMessageSendStream(request, callContext); } List results = new ArrayList<>(); @@ -612,7 +627,7 @@ public void onComplete() { @Test - public void testSetPushNotificationSuccess() { + public void testSetPushNotificationConfigSuccess() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); taskStore.save(MINIMAL_TASK); @@ -620,12 +635,12 @@ public void testSetPushNotificationSuccess() { new TaskPushNotificationConfig( MINIMAL_TASK.getId(), new PushNotificationConfig.Builder().url("http://example.com").build()); SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); - SetTaskPushNotificationConfigResponse response = handler.setPushNotification(request); + SetTaskPushNotificationConfigResponse response = handler.setPushNotificationConfig(request, callContext); assertSame(taskPushConfig, response.getResult()); } @Test - public void testGetPushNotificationSuccess() { + public void testGetPushNotificationConfigSuccess() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); taskStore.save(MINIMAL_TASK); agentExecutorExecute = (context, eventQueue) -> { @@ -638,13 +653,15 @@ public void testGetPushNotificationSuccess() { MINIMAL_TASK.getId(), new PushNotificationConfig.Builder().url("http://example.com").build()); SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); - handler.setPushNotification(request); + handler.setPushNotificationConfig(request, callContext); GetTaskPushNotificationConfigRequest getRequest = - new GetTaskPushNotificationConfigRequest("111", new TaskIdParams(MINIMAL_TASK.getId())); - GetTaskPushNotificationConfigResponse getResponse = handler.getPushNotification(getRequest); + new GetTaskPushNotificationConfigRequest("111", new GetTaskPushNotificationConfigParams(MINIMAL_TASK.getId())); + GetTaskPushNotificationConfigResponse getResponse = handler.getPushNotificationConfig(getRequest, callContext); - assertEquals(taskPushConfig, getResponse.getResult()); + TaskPushNotificationConfig expectedConfig = new TaskPushNotificationConfig(MINIMAL_TASK.getId(), + new PushNotificationConfig.Builder().id(MINIMAL_TASK.getId()).url("http://example.com").build()); + assertEquals(expectedConfig, getResponse.getResult()); } @Test @@ -681,14 +698,14 @@ public void testOnMessageStreamNewMessageSendPushNotificationSuccess() throws Ex MINIMAL_TASK.getId(), new PushNotificationConfig.Builder().url("http://example.com").build()); SetTaskPushNotificationConfigRequest stpnRequest = new SetTaskPushNotificationConfigRequest("1", config); - SetTaskPushNotificationConfigResponse stpnResponse = handler.setPushNotification(stpnRequest); + SetTaskPushNotificationConfigResponse stpnResponse = handler.setPushNotificationConfig(stpnRequest, callContext); assertNull(stpnResponse.getError()); Message msg = new Message.Builder(MESSAGE) .taskId(MINIMAL_TASK.getId()) .build(); SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(msg, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); final List results = Collections.synchronizedList(new ArrayList<>()); final AtomicReference subscriptionRef = new AtomicReference<>(); @@ -765,7 +782,7 @@ public void testOnResubscribeExistingTaskSuccess() { }; TaskResubscriptionRequest request = new TaskResubscriptionRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); - Flow.Publisher response = handler.onResubscribeToTask(request); + Flow.Publisher response = handler.onResubscribeToTask(request, callContext); // We need to send some events in order for those to end up in the queue Message message = new Message.Builder() @@ -775,7 +792,9 @@ public void testOnResubscribeExistingTaskSuccess() { .parts(new TextPart("text")) .build(); SendMessageResponse smr = - handler.onMessageSend(new SendMessageRequest("1", new MessageSendParams(message, null, null))); + handler.onMessageSend( + new SendMessageRequest("1", new MessageSendParams(message, null, null)), + callContext); assertNull(smr.getError()); @@ -841,7 +860,7 @@ public void testOnResubscribeExistingTaskSuccessMocks() throws Exception { EventConsumer.class, (mock, context) -> { Mockito.doReturn(ZeroPublisher.fromIterable(events)).when(mock).consumeAll();})){ - response = handler.onResubscribeToTask(request); + response = handler.onResubscribeToTask(request, callContext); } List results = new ArrayList<>(); @@ -886,7 +905,7 @@ public void testOnResubscribeNoExistingTaskError() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); TaskResubscriptionRequest request = new TaskResubscriptionRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); - Flow.Publisher response = handler.onResubscribeToTask(request); + Flow.Publisher response = handler.onResubscribeToTask(request, callContext); List results = new ArrayList<>(); AtomicReference error = new AtomicReference<>(); @@ -934,7 +953,7 @@ public void testStreamingNotSupportedError() { .message(MESSAGE) .build()) .build(); - Flow.Publisher response = handler.onMessageSendStream(request); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); List results = new ArrayList<>(); AtomicReference error = new AtomicReference<>(); @@ -980,7 +999,7 @@ public void testStreamingNotSupportedErrorOnResubscribeToTask() { JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler); TaskResubscriptionRequest request = new TaskResubscriptionRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); - Flow.Publisher response = handler.onResubscribeToTask(request); + Flow.Publisher response = handler.onResubscribeToTask(request, callContext); List results = new ArrayList<>(); AtomicReference error = new AtomicReference<>(); @@ -1036,24 +1055,23 @@ public void testPushNotificationsNotSupportedError() { SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest.Builder() .params(config) .build(); - SetTaskPushNotificationConfigResponse response = handler.setPushNotification(request); - assertInstanceOf(InvalidRequestError.class, response.getError()); - assertEquals("Push notifications are not supported by the agent", response.getError().getMessage()); + SetTaskPushNotificationConfigResponse response = handler.setPushNotificationConfig(request, callContext); + assertInstanceOf(PushNotificationNotSupportedError.class, response.getError()); } @Test - public void testOnGetPushNotificationNoPushNotifier() { + public void testOnGetPushNotificationNoPushNotifierConfig() { // Create request handler without a push notifier DefaultRequestHandler requestHandler = - new DefaultRequestHandler(executor, taskStore, queueManager, null, internalExecutor); + new DefaultRequestHandler(executor, taskStore, queueManager, null, null, internalExecutor); AgentCard card = createAgentCard(false, true, false); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler); taskStore.save(MINIMAL_TASK); GetTaskPushNotificationConfigRequest request = - new GetTaskPushNotificationConfigRequest("id", new TaskIdParams(MINIMAL_TASK.getId())); - GetTaskPushNotificationConfigResponse response = handler.getPushNotification(request); + new GetTaskPushNotificationConfigRequest("id", new GetTaskPushNotificationConfigParams(MINIMAL_TASK.getId())); + GetTaskPushNotificationConfigResponse response = handler.getPushNotificationConfig(request, callContext); assertNotNull(response.getError()); assertInstanceOf(UnsupportedOperationError.class, response.getError()); @@ -1061,10 +1079,10 @@ public void testOnGetPushNotificationNoPushNotifier() { } @Test - public void testOnSetPushNotificationNoPushNotifier() { + public void testOnSetPushNotificationNoPushNotifierConfig() { // Create request handler without a push notifier DefaultRequestHandler requestHandler = - new DefaultRequestHandler(executor, taskStore, queueManager, null, internalExecutor); + new DefaultRequestHandler(executor, taskStore, queueManager, null, null, internalExecutor); AgentCard card = createAgentCard(false, true, false); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler); @@ -1080,7 +1098,7 @@ public void testOnSetPushNotificationNoPushNotifier() { SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest.Builder() .params(config) .build(); - SetTaskPushNotificationConfigResponse response = handler.setPushNotification(request); + SetTaskPushNotificationConfigResponse response = handler.setPushNotificationConfig(request, callContext); assertInstanceOf(UnsupportedOperationError.class, response.getError()); assertEquals("This operation is not supported", response.getError().getMessage()); @@ -1089,12 +1107,13 @@ public void testOnSetPushNotificationNoPushNotifier() { @Test public void testOnMessageSendInternalError() { DefaultRequestHandler mocked = Mockito.mock(DefaultRequestHandler.class); - Mockito.doThrow(new InternalError("Internal Error")).when(mocked).onMessageSend(Mockito.any(MessageSendParams.class)); + Mockito.doThrow(new InternalError("Internal Error")).when(mocked) + .onMessageSend(Mockito.any(MessageSendParams.class), Mockito.any(ServerCallContext.class)); JSONRPCHandler handler = new JSONRPCHandler(CARD, mocked); SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); - SendMessageResponse response = handler.onMessageSend(request); + SendMessageResponse response = handler.onMessageSend(request, callContext); assertInstanceOf(InternalError.class, response.getError()); } @@ -1102,12 +1121,13 @@ public void testOnMessageSendInternalError() { @Test public void testOnMessageStreamInternalError() { DefaultRequestHandler mocked = Mockito.mock(DefaultRequestHandler.class); - Mockito.doThrow(new InternalError("Internal Error")).when(mocked).onMessageSendStream(Mockito.any(MessageSendParams.class)); + Mockito.doThrow(new InternalError("Internal Error")).when(mocked) + .onMessageSendStream(Mockito.any(MessageSendParams.class), Mockito.any(ServerCallContext.class)); JSONRPCHandler handler = new JSONRPCHandler(CARD, mocked); SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); List results = new ArrayList<>(); @@ -1153,7 +1173,7 @@ public void testDefaultRequestHandlerWithCustomComponents() { @Test public void testOnMessageSendErrorHandling() { DefaultRequestHandler requestHandler = - new DefaultRequestHandler(executor, taskStore, queueManager, null, internalExecutor); + new DefaultRequestHandler(executor, taskStore, queueManager, null, null, internalExecutor); AgentCard card = createAgentCard(false, true, false); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler); @@ -1173,7 +1193,7 @@ public void testOnMessageSendErrorHandling() { Mockito.doThrow( new UnsupportedOperationError()) .when(mock).consumeAndBreakOnInterrupt(Mockito.any(EventConsumer.class)))){ - response = handler.onMessageSend(request); + response = handler.onMessageSend(request, callContext); } assertInstanceOf(UnsupportedOperationError.class, response.getError()); @@ -1190,7 +1210,7 @@ public void testOnMessageSendTaskIdMismatch() { }); SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); - SendMessageResponse response = handler.onMessageSend(request); + SendMessageResponse response = handler.onMessageSend(request, callContext); assertInstanceOf(InternalError.class, response.getError()); } @@ -1205,7 +1225,7 @@ public void testOnMessageStreamTaskIdMismatch() { }); SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); List results = new ArrayList<>(); AtomicReference error = new AtomicReference<>(); @@ -1241,6 +1261,181 @@ public void onComplete() { assertInstanceOf(InternalError.class, results.get(0).getError()); } + @Test + public void testListPushNotificationConfig() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + agentExecutorExecute = (context, eventQueue) -> { + eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); + }; + + TaskPushNotificationConfig taskPushConfig = + new TaskPushNotificationConfig( + MINIMAL_TASK.getId(), new PushNotificationConfig.Builder() + .url("http://example.com") + .id(MINIMAL_TASK.getId()) + .build()); + SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); + handler.setPushNotificationConfig(request, callContext); + + ListTaskPushNotificationConfigRequest listRequest = + new ListTaskPushNotificationConfigRequest("111", new ListTaskPushNotificationConfigParams(MINIMAL_TASK.getId())); + ListTaskPushNotificationConfigResponse listResponse = handler.listPushNotificationConfig(listRequest, callContext); + + assertEquals("111", listResponse.getId()); + assertEquals(1, listResponse.getResult().size()); + assertEquals(taskPushConfig, listResponse.getResult().get(0)); + } + + @Test + public void testListPushNotificationConfigNotSupported() { + AgentCard card = createAgentCard(true, false, true); + JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler); + taskStore.save(MINIMAL_TASK); + agentExecutorExecute = (context, eventQueue) -> { + eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); + }; + + TaskPushNotificationConfig taskPushConfig = + new TaskPushNotificationConfig( + MINIMAL_TASK.getId(), new PushNotificationConfig.Builder() + .url("http://example.com") + .id(MINIMAL_TASK.getId()) + .build()); + SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); + handler.setPushNotificationConfig(request, callContext); + + ListTaskPushNotificationConfigRequest listRequest = + new ListTaskPushNotificationConfigRequest("111", new ListTaskPushNotificationConfigParams(MINIMAL_TASK.getId())); + ListTaskPushNotificationConfigResponse listResponse = + handler.listPushNotificationConfig(listRequest, callContext); + + assertEquals("111", listResponse.getId()); + assertNull(listResponse.getResult()); + assertInstanceOf(PushNotificationNotSupportedError.class, listResponse.getError()); + } + + @Test + public void testListPushNotificationConfigNoPushConfigStore() { + DefaultRequestHandler requestHandler = + new DefaultRequestHandler(executor, taskStore, queueManager, null, null, internalExecutor); + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + agentExecutorExecute = (context, eventQueue) -> { + eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); + }; + + ListTaskPushNotificationConfigRequest listRequest = + new ListTaskPushNotificationConfigRequest("111", new ListTaskPushNotificationConfigParams(MINIMAL_TASK.getId())); + ListTaskPushNotificationConfigResponse listResponse = + handler.listPushNotificationConfig(listRequest, callContext); + + assertEquals("111", listResponse.getId()); + assertNull(listResponse.getResult()); + assertInstanceOf(UnsupportedOperationError.class, listResponse.getError()); + } + + @Test + public void testListPushNotificationConfigTaskNotFound() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + agentExecutorExecute = (context, eventQueue) -> { + eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); + }; + + ListTaskPushNotificationConfigRequest listRequest = + new ListTaskPushNotificationConfigRequest("111", new ListTaskPushNotificationConfigParams(MINIMAL_TASK.getId())); + ListTaskPushNotificationConfigResponse listResponse = + handler.listPushNotificationConfig(listRequest, callContext); + + assertEquals("111", listResponse.getId()); + assertNull(listResponse.getResult()); + assertInstanceOf(TaskNotFoundError.class, listResponse.getError()); + } + + @Test + public void testDeletePushNotificationConfig() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + agentExecutorExecute = (context, eventQueue) -> { + eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); + }; + + TaskPushNotificationConfig taskPushConfig = + new TaskPushNotificationConfig( + MINIMAL_TASK.getId(), new PushNotificationConfig.Builder() + .url("http://example.com") + .id(MINIMAL_TASK.getId()) + .build()); + SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); + handler.setPushNotificationConfig(request, callContext); + + DeleteTaskPushNotificationConfigRequest deleteRequest = + new DeleteTaskPushNotificationConfigRequest("111", new DeleteTaskPushNotificationConfigParams(MINIMAL_TASK.getId(), MINIMAL_TASK.getId())); + DeleteTaskPushNotificationConfigResponse deleteResponse = + handler.deletePushNotificationConfig(deleteRequest, callContext); + + assertEquals("111", deleteResponse.getId()); + assertNull(deleteResponse.getError()); + assertNull(deleteResponse.getResult()); + } + + @Test + public void testDeletePushNotificationConfigNotSupported() { + AgentCard card = createAgentCard(true, false, true); + JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler); + taskStore.save(MINIMAL_TASK); + agentExecutorExecute = (context, eventQueue) -> { + eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); + }; + + TaskPushNotificationConfig taskPushConfig = + new TaskPushNotificationConfig( + MINIMAL_TASK.getId(), new PushNotificationConfig.Builder() + .url("http://example.com") + .id(MINIMAL_TASK.getId()) + .build()); + SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); + handler.setPushNotificationConfig(request, callContext); + + DeleteTaskPushNotificationConfigRequest deleteRequest = + new DeleteTaskPushNotificationConfigRequest("111", new DeleteTaskPushNotificationConfigParams(MINIMAL_TASK.getId(), MINIMAL_TASK.getId())); + DeleteTaskPushNotificationConfigResponse deleteResponse = + handler.deletePushNotificationConfig(deleteRequest, callContext); + + assertEquals("111", deleteResponse.getId()); + assertNull(deleteResponse.getResult()); + assertInstanceOf(PushNotificationNotSupportedError.class, deleteResponse.getError()); + } + + @Test + public void testDeletePushNotificationConfigNoPushConfigStore() { + DefaultRequestHandler requestHandler = + new DefaultRequestHandler(executor, taskStore, queueManager, null, null, internalExecutor); + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + agentExecutorExecute = (context, eventQueue) -> { + eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); + }; + + TaskPushNotificationConfig taskPushConfig = + new TaskPushNotificationConfig( + MINIMAL_TASK.getId(), new PushNotificationConfig.Builder() + .url("http://example.com") + .id(MINIMAL_TASK.getId()) + .build()); + SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); + handler.setPushNotificationConfig(request, callContext); + + DeleteTaskPushNotificationConfigRequest deleteRequest = + new DeleteTaskPushNotificationConfigRequest("111", new DeleteTaskPushNotificationConfigParams(MINIMAL_TASK.getId(), MINIMAL_TASK.getId())); + DeleteTaskPushNotificationConfigResponse deleteResponse = + handler.deletePushNotificationConfig(deleteRequest, callContext); + + assertEquals("111", deleteResponse.getId()); + assertNull(deleteResponse.getResult()); + assertInstanceOf(UnsupportedOperationError.class, deleteResponse.getError()); + } + private static AgentCard createAgentCard(boolean streaming, boolean pushNotifications, boolean stateTransitionHistory) { return new AgentCard.Builder() .name("test-card") @@ -1256,6 +1451,7 @@ private static AgentCard createAgentCard(boolean streaming, boolean pushNotifica .defaultInputModes(new ArrayList<>()) .defaultOutputModes(new ArrayList<>()) .skills(new ArrayList<>()) + .protocolVersion("0.2.5") .build(); } diff --git a/spec/pom.xml b/spec/pom.xml index cddec7a1c..ce67def18 100644 --- a/spec/pom.xml +++ b/spec/pom.xml @@ -7,7 +7,7 @@ io.github.a2asdk a2a-java-sdk-parent - 0.2.3.Beta2-SNAPSHOT + 0.2.6.Beta1-SNAPSHOT a2a-java-sdk-spec diff --git a/spec/src/main/java/io/a2a/spec/AgentCard.java b/spec/src/main/java/io/a2a/spec/AgentCard.java index 8429b16f5..e394fb626 100644 --- a/spec/src/main/java/io/a2a/spec/AgentCard.java +++ b/spec/src/main/java/io/a2a/spec/AgentCard.java @@ -17,7 +17,8 @@ public record AgentCard(String name, String description, String url, AgentProvid String version, String documentationUrl, AgentCapabilities capabilities, List defaultInputModes, List defaultOutputModes, List skills, boolean supportsAuthenticatedExtendedCard, Map securitySchemes, - List>> security, String iconUrl) { + List>> security, String iconUrl, List additionalInterfaces, + String preferredTransport, String protocolVersion) { private static final String TEXT_MODE = "text"; @@ -30,6 +31,7 @@ public record AgentCard(String name, String description, String url, AgentProvid Assert.checkNotNullParam("skills", skills); Assert.checkNotNullParam("url", url); Assert.checkNotNullParam("version", version); + Assert.checkNotNullParam("protocolVersion", protocolVersion); } public static class Builder { @@ -47,6 +49,9 @@ public static class Builder { private Map securitySchemes; private List>> security; private String iconUrl; + private List additionalInterfaces; + String preferredTransport; + String protocolVersion; public Builder name(String name) { this.name = name; @@ -118,10 +123,26 @@ public Builder iconUrl(String iconUrl) { return this; } + public Builder additionalInterfaces(List additionalInterfaces) { + this.additionalInterfaces = additionalInterfaces; + return this; + } + + public Builder preferredTransport(String preferredTransport) { + this.preferredTransport = preferredTransport; + return this; + } + + public Builder protocolVersion(String protocolVersion) { + this.protocolVersion = protocolVersion; + return this; + } + public AgentCard build() { return new AgentCard(name, description, url, provider, version, documentationUrl, capabilities, defaultInputModes, defaultOutputModes, skills, - supportsAuthenticatedExtendedCard, securitySchemes, security, iconUrl); + supportsAuthenticatedExtendedCard, securitySchemes, security, iconUrl, + additionalInterfaces, preferredTransport, protocolVersion); } } } diff --git a/spec/src/main/java/io/a2a/spec/AgentInterface.java b/spec/src/main/java/io/a2a/spec/AgentInterface.java new file mode 100644 index 000000000..ab2b7307d --- /dev/null +++ b/spec/src/main/java/io/a2a/spec/AgentInterface.java @@ -0,0 +1,18 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import io.a2a.util.Assert; + +/** + * Provides a declaration of the target url and the supported transport to interact with the agent. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record AgentInterface(String transport, String url) { + + public AgentInterface { + Assert.checkNotNullParam("transport", transport); + Assert.checkNotNullParam("url", url); + } +} diff --git a/spec/src/main/java/io/a2a/spec/CancelTaskResponse.java b/spec/src/main/java/io/a2a/spec/CancelTaskResponse.java index 02bd63461..9ef775118 100644 --- a/spec/src/main/java/io/a2a/spec/CancelTaskResponse.java +++ b/spec/src/main/java/io/a2a/spec/CancelTaskResponse.java @@ -15,7 +15,7 @@ public final class CancelTaskResponse extends JSONRPCResponse { @JsonCreator public CancelTaskResponse(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, @JsonProperty("result") Task result, @JsonProperty("error") JSONRPCError error) { - super(jsonrpc, id, result, error); + super(jsonrpc, id, result, error, Task.class); } public CancelTaskResponse(Object id, JSONRPCError error) { diff --git a/spec/src/main/java/io/a2a/spec/DeleteTaskPushNotificationConfigParams.java b/spec/src/main/java/io/a2a/spec/DeleteTaskPushNotificationConfigParams.java new file mode 100644 index 000000000..a64421a4c --- /dev/null +++ b/spec/src/main/java/io/a2a/spec/DeleteTaskPushNotificationConfigParams.java @@ -0,0 +1,50 @@ +package io.a2a.spec; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; + +import io.a2a.util.Assert; + +/** + * Parameters for removing pushNotificationConfiguration associated with a Task. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record DeleteTaskPushNotificationConfigParams(String id, String pushNotificationConfigId, Map metadata) { + + public DeleteTaskPushNotificationConfigParams { + Assert.checkNotNullParam("id", id); + Assert.checkNotNullParam("pushNotificationConfigId", pushNotificationConfigId); + } + + public DeleteTaskPushNotificationConfigParams(String id, String pushNotificationConfigId) { + this(id, pushNotificationConfigId, null); + } + + public static class Builder { + String id; + String pushNotificationConfigId; + Map metadata; + + public Builder id(String id) { + this.id = id; + return this; + } + + public Builder pushNotificationConfigId(String pushNotificationConfigId) { + this.pushNotificationConfigId = pushNotificationConfigId; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public DeleteTaskPushNotificationConfigParams build() { + return new DeleteTaskPushNotificationConfigParams(id, pushNotificationConfigId, metadata); + } + } +} diff --git a/spec/src/main/java/io/a2a/spec/DeleteTaskPushNotificationConfigRequest.java b/spec/src/main/java/io/a2a/spec/DeleteTaskPushNotificationConfigRequest.java new file mode 100644 index 000000000..99f50ebfd --- /dev/null +++ b/spec/src/main/java/io/a2a/spec/DeleteTaskPushNotificationConfigRequest.java @@ -0,0 +1,77 @@ +package io.a2a.spec; + +import java.util.UUID; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.a2a.util.Assert; +import io.a2a.util.Utils; + +/** + * A delete task push notification config request. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class DeleteTaskPushNotificationConfigRequest extends NonStreamingJSONRPCRequest { + + public static final String METHOD = "tasks/pushNotificationConfig/delete"; + + @JsonCreator + public DeleteTaskPushNotificationConfigRequest(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("method") String method, + @JsonProperty("params") DeleteTaskPushNotificationConfigParams params) { + if (jsonrpc != null && ! jsonrpc.equals(JSONRPC_VERSION)) { + throw new IllegalArgumentException("Invalid JSON-RPC protocol version"); + } + Assert.checkNotNullParam("method", method); + if (! method.equals(METHOD)) { + throw new IllegalArgumentException("Invalid DeleteTaskPushNotificationConfigRequest method"); + } + Assert.isNullOrStringOrInteger(id); + this.jsonrpc = Utils.defaultIfNull(jsonrpc, JSONRPC_VERSION); + this.id = id; + this.method = method; + this.params = params; + } + + public DeleteTaskPushNotificationConfigRequest(String id, DeleteTaskPushNotificationConfigParams params) { + this(null, id, METHOD, params); + } + + public static class Builder { + private String jsonrpc; + private Object id; + private String method; + private DeleteTaskPushNotificationConfigParams params; + + public Builder jsonrpc(String jsonrpc) { + this.jsonrpc = jsonrpc; + return this; + } + + public Builder id(Object id) { + this.id = id; + return this; + } + + public Builder method(String method) { + this.method = method; + return this; + } + + public Builder params(DeleteTaskPushNotificationConfigParams params) { + this.params = params; + return this; + } + + public DeleteTaskPushNotificationConfigRequest build() { + if (id == null) { + id = UUID.randomUUID().toString(); + } + return new DeleteTaskPushNotificationConfigRequest(jsonrpc, id, method, params); + } + } +} diff --git a/spec/src/main/java/io/a2a/spec/DeleteTaskPushNotificationConfigResponse.java b/spec/src/main/java/io/a2a/spec/DeleteTaskPushNotificationConfigResponse.java new file mode 100644 index 000000000..0f65b5ad5 --- /dev/null +++ b/spec/src/main/java/io/a2a/spec/DeleteTaskPushNotificationConfigResponse.java @@ -0,0 +1,32 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; + +/** + * A response for a delete task push notification config request. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonSerialize(using = JSONRPCVoidResponseSerializer.class) +public final class DeleteTaskPushNotificationConfigResponse extends JSONRPCResponse { + + @JsonCreator + public DeleteTaskPushNotificationConfigResponse(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("result") Void result, + @JsonProperty("error") JSONRPCError error) { + super(jsonrpc, id, result, error, Void.class); + } + + public DeleteTaskPushNotificationConfigResponse(Object id, JSONRPCError error) { + this(null, id, null, error); + } + + public DeleteTaskPushNotificationConfigResponse(Object id) { + this(null, id, null, null); + } + +} diff --git a/spec/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigParams.java b/spec/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigParams.java new file mode 100644 index 000000000..d8952f87d --- /dev/null +++ b/spec/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigParams.java @@ -0,0 +1,53 @@ +package io.a2a.spec; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; + +import io.a2a.util.Assert; + +/** + * Parameters for fetching a pushNotificationConfiguration associated with a Task. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record GetTaskPushNotificationConfigParams(String id, String pushNotificationConfigId, Map metadata) { + + public GetTaskPushNotificationConfigParams { + Assert.checkNotNullParam("id", id); + } + + public GetTaskPushNotificationConfigParams(String id) { + this(id, null, null); + } + + public GetTaskPushNotificationConfigParams(String id, String pushNotificationConfigId) { + this(id, pushNotificationConfigId, null); + } + + public static class Builder { + String id; + String pushNotificationConfigId; + Map metadata; + + public Builder id(String id) { + this.id = id; + return this; + } + + public Builder pushNotificationConfigId(String pushNotificationConfigId) { + this.pushNotificationConfigId = pushNotificationConfigId; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public GetTaskPushNotificationConfigParams build() { + return new GetTaskPushNotificationConfigParams(id, pushNotificationConfigId, metadata); + } + } +} diff --git a/spec/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigRequest.java b/spec/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigRequest.java index ab39015bb..b353e0cc8 100644 --- a/spec/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigRequest.java +++ b/spec/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigRequest.java @@ -14,13 +14,13 @@ */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) -public final class GetTaskPushNotificationConfigRequest extends NonStreamingJSONRPCRequest { +public final class GetTaskPushNotificationConfigRequest extends NonStreamingJSONRPCRequest { public static final String METHOD = "tasks/pushNotificationConfig/get"; @JsonCreator public GetTaskPushNotificationConfigRequest(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, - @JsonProperty("method") String method, @JsonProperty("params") TaskIdParams params) { + @JsonProperty("method") String method, @JsonProperty("params") GetTaskPushNotificationConfigParams params) { if (jsonrpc != null && ! jsonrpc.equals(JSONRPC_VERSION)) { throw new IllegalArgumentException("Invalid JSON-RPC protocol version"); } @@ -35,7 +35,7 @@ public GetTaskPushNotificationConfigRequest(@JsonProperty("jsonrpc") String json this.params = params; } - public GetTaskPushNotificationConfigRequest(String id, TaskIdParams params) { + public GetTaskPushNotificationConfigRequest(String id, GetTaskPushNotificationConfigParams params) { this(null, id, METHOD, params); } @@ -43,7 +43,7 @@ public static class Builder { private String jsonrpc; private Object id; private String method; - private TaskIdParams params; + private GetTaskPushNotificationConfigParams params; public GetTaskPushNotificationConfigRequest.Builder jsonrpc(String jsonrpc) { this.jsonrpc = jsonrpc; @@ -60,7 +60,7 @@ public GetTaskPushNotificationConfigRequest.Builder method(String method) { return this; } - public GetTaskPushNotificationConfigRequest.Builder params(TaskIdParams params) { + public GetTaskPushNotificationConfigRequest.Builder params(GetTaskPushNotificationConfigParams params) { this.params = params; return this; } diff --git a/spec/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigResponse.java b/spec/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigResponse.java index c4340188f..116799a9e 100644 --- a/spec/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigResponse.java +++ b/spec/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigResponse.java @@ -16,7 +16,7 @@ public final class GetTaskPushNotificationConfigResponse extends JSONRPCResponse public GetTaskPushNotificationConfigResponse(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, @JsonProperty("result") TaskPushNotificationConfig result, @JsonProperty("error") JSONRPCError error) { - super(jsonrpc, id, result, error); + super(jsonrpc, id, result, error, TaskPushNotificationConfig.class); } public GetTaskPushNotificationConfigResponse(Object id, JSONRPCError error) { diff --git a/spec/src/main/java/io/a2a/spec/GetTaskResponse.java b/spec/src/main/java/io/a2a/spec/GetTaskResponse.java index e51cb66aa..0d27a8e68 100644 --- a/spec/src/main/java/io/a2a/spec/GetTaskResponse.java +++ b/spec/src/main/java/io/a2a/spec/GetTaskResponse.java @@ -15,7 +15,7 @@ public final class GetTaskResponse extends JSONRPCResponse { @JsonCreator public GetTaskResponse(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, @JsonProperty("result") Task result, @JsonProperty("error") JSONRPCError error) { - super(jsonrpc, id, result, error); + super(jsonrpc, id, result, error, Task.class); } public GetTaskResponse(Object id, JSONRPCError error) { diff --git a/spec/src/main/java/io/a2a/spec/JSONRPCErrorResponse.java b/spec/src/main/java/io/a2a/spec/JSONRPCErrorResponse.java index 95ac5d341..ea7846655 100644 --- a/spec/src/main/java/io/a2a/spec/JSONRPCErrorResponse.java +++ b/spec/src/main/java/io/a2a/spec/JSONRPCErrorResponse.java @@ -17,7 +17,7 @@ public final class JSONRPCErrorResponse extends JSONRPCResponse { @JsonCreator public JSONRPCErrorResponse(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, @JsonProperty("result") Void result, @JsonProperty("error") JSONRPCError error) { - super(jsonrpc, id, result, error); + super(jsonrpc, id, result, error, Void.class); Assert.checkNotNullParam("error", error); } diff --git a/spec/src/main/java/io/a2a/spec/JSONRPCRequestDeserializerBase.java b/spec/src/main/java/io/a2a/spec/JSONRPCRequestDeserializerBase.java index fe21ffab8..cf0134efe 100644 --- a/spec/src/main/java/io/a2a/spec/JSONRPCRequestDeserializerBase.java +++ b/spec/src/main/java/io/a2a/spec/JSONRPCRequestDeserializerBase.java @@ -4,13 +4,10 @@ import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonMappingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.deser.std.StdDeserializer; -import java.io.IOException; - public abstract class JSONRPCRequestDeserializerBase extends StdDeserializer> { public JSONRPCRequestDeserializerBase() { @@ -21,35 +18,6 @@ public JSONRPCRequestDeserializerBase(Class vc) { super(vc); } - @Override - public JSONRPCRequest deserialize(JsonParser jsonParser, DeserializationContext context) - throws IOException, JsonProcessingException { - JsonNode treeNode = jsonParser.getCodec().readTree(jsonParser); - String jsonrpc = getAndValidateJsonrpc(treeNode, jsonParser); - String method = getAndValidateMethod(treeNode, jsonParser); - Object id = getAndValidateId(treeNode, jsonParser); - JsonNode paramsNode = treeNode.get("params"); - - switch (method) { - case GetTaskRequest.METHOD: - return new GetTaskRequest(jsonrpc, id, method, getAndValidateParams(paramsNode, jsonParser, treeNode, TaskQueryParams.class)); - case CancelTaskRequest.METHOD: - return new CancelTaskRequest(jsonrpc, id, method, getAndValidateParams(paramsNode, jsonParser, treeNode, TaskIdParams.class)); - case SetTaskPushNotificationConfigRequest.METHOD: - return new SetTaskPushNotificationConfigRequest(jsonrpc, id, method, getAndValidateParams(paramsNode, jsonParser, treeNode, TaskPushNotificationConfig.class)); - case GetTaskPushNotificationConfigRequest.METHOD: - return new GetTaskPushNotificationConfigRequest(jsonrpc, id, method, getAndValidateParams(paramsNode, jsonParser, treeNode, TaskIdParams.class)); - case SendMessageRequest.METHOD: - return new SendMessageRequest(jsonrpc, id, method, getAndValidateParams(paramsNode, jsonParser, treeNode, MessageSendParams.class)); - case TaskResubscriptionRequest.METHOD: - return new TaskResubscriptionRequest(jsonrpc, id, method, getAndValidateParams(paramsNode, jsonParser, treeNode, TaskIdParams.class)); - case SendStreamingMessageRequest.METHOD: - return new SendStreamingMessageRequest(jsonrpc, id, method, getAndValidateParams(paramsNode, jsonParser, treeNode, MessageSendParams.class)); - default: - throw new MethodNotFoundJsonMappingException("Invalid method", getIdIfPossible(treeNode, jsonParser)); - } - } - protected T getAndValidateParams(JsonNode paramsNode, JsonParser jsonParser, JsonNode node, Class paramsType) throws JsonMappingException { if (paramsNode == null) { return null; @@ -112,7 +80,9 @@ protected static boolean isValidMethodName(String methodName) { || methodName.equals(SetTaskPushNotificationConfigRequest.METHOD) || methodName.equals(TaskResubscriptionRequest.METHOD) || methodName.equals(SendMessageRequest.METHOD) - || methodName.equals(SendStreamingMessageRequest.METHOD)); + || methodName.equals(SendStreamingMessageRequest.METHOD) + || methodName.equals(ListTaskPushNotificationConfigRequest.METHOD) + || methodName.equals(DeleteTaskPushNotificationConfigRequest.METHOD)); } } diff --git a/spec/src/main/java/io/a2a/spec/JSONRPCResponse.java b/spec/src/main/java/io/a2a/spec/JSONRPCResponse.java index 1be348043..3a382b1a7 100644 --- a/spec/src/main/java/io/a2a/spec/JSONRPCResponse.java +++ b/spec/src/main/java/io/a2a/spec/JSONRPCResponse.java @@ -2,6 +2,7 @@ import static io.a2a.util.Utils.defaultIfNull; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; @@ -14,7 +15,7 @@ @JsonIgnoreProperties(ignoreUnknown = true) public abstract sealed class JSONRPCResponse implements JSONRPCMessage permits SendStreamingMessageResponse, GetTaskResponse, CancelTaskResponse, SetTaskPushNotificationConfigResponse, GetTaskPushNotificationConfigResponse, - SendMessageResponse, JSONRPCErrorResponse { + SendMessageResponse, DeleteTaskPushNotificationConfigResponse, ListTaskPushNotificationConfigResponse, JSONRPCErrorResponse { protected String jsonrpc; protected Object id; @@ -24,14 +25,14 @@ public abstract sealed class JSONRPCResponse implements JSONRPCMessage permit public JSONRPCResponse() { } - public JSONRPCResponse(String jsonrpc, Object id, T result, JSONRPCError error) { + public JSONRPCResponse(String jsonrpc, Object id, T result, JSONRPCError error, Class resultType) { if (jsonrpc != null && ! jsonrpc.equals(JSONRPC_VERSION)) { throw new IllegalArgumentException("Invalid JSON-RPC protocol version"); } if (error != null && result != null) { throw new IllegalArgumentException("Invalid JSON-RPC error response"); } - if (error == null && result == null) { + if (error == null && result == null && ! Void.class.equals(resultType)) { throw new IllegalArgumentException("Invalid JSON-RPC success response"); } Assert.isNullOrStringOrInteger(id); diff --git a/spec/src/main/java/io/a2a/spec/JSONRPCVoidResponseSerializer.java b/spec/src/main/java/io/a2a/spec/JSONRPCVoidResponseSerializer.java new file mode 100644 index 000000000..200bc4cd4 --- /dev/null +++ b/spec/src/main/java/io/a2a/spec/JSONRPCVoidResponseSerializer.java @@ -0,0 +1,32 @@ +package io.a2a.spec; + +import java.io.IOException; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.ser.std.StdSerializer; +import com.fasterxml.jackson.databind.type.TypeFactory; + +public class JSONRPCVoidResponseSerializer extends StdSerializer> { + + private static final JSONRPCErrorSerializer JSON_RPC_ERROR_SERIALIZER = new JSONRPCErrorSerializer(); + + public JSONRPCVoidResponseSerializer() { + super(TypeFactory.defaultInstance().constructParametricType(JSONRPCResponse.class, + Void.class)); + } + + @Override + public void serialize(JSONRPCResponse value, JsonGenerator gen, SerializerProvider provider) throws IOException { + gen.writeStartObject(); + gen.writeStringField("jsonrpc", value.getJsonrpc()); + gen.writeObjectField("id", value.getId()); + if (value.getError() != null) { + gen.writeFieldName("error"); + JSON_RPC_ERROR_SERIALIZER.serialize(value.getError(), gen, provider); + } else { + gen.writeNullField("result"); + } + gen.writeEndObject(); + } +} diff --git a/spec/src/main/java/io/a2a/spec/ListTaskPushNotificationConfigParams.java b/spec/src/main/java/io/a2a/spec/ListTaskPushNotificationConfigParams.java new file mode 100644 index 000000000..5ebb12f76 --- /dev/null +++ b/spec/src/main/java/io/a2a/spec/ListTaskPushNotificationConfigParams.java @@ -0,0 +1,24 @@ +package io.a2a.spec; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; + +import io.a2a.util.Assert; + +/** + * Parameters for getting list of pushNotificationConfigurations associated with a Task. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record ListTaskPushNotificationConfigParams(String id, Map metadata) { + + public ListTaskPushNotificationConfigParams { + Assert.checkNotNullParam("id", id); + } + + public ListTaskPushNotificationConfigParams(String id) { + this(id, null); + } +} diff --git a/spec/src/main/java/io/a2a/spec/ListTaskPushNotificationConfigRequest.java b/spec/src/main/java/io/a2a/spec/ListTaskPushNotificationConfigRequest.java new file mode 100644 index 000000000..90ba0f1f5 --- /dev/null +++ b/spec/src/main/java/io/a2a/spec/ListTaskPushNotificationConfigRequest.java @@ -0,0 +1,77 @@ +package io.a2a.spec; + +import java.util.UUID; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.a2a.util.Assert; +import io.a2a.util.Utils; + +/** + * A list task push notification config request. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class ListTaskPushNotificationConfigRequest extends NonStreamingJSONRPCRequest { + + public static final String METHOD = "tasks/pushNotificationConfig/list"; + + @JsonCreator + public ListTaskPushNotificationConfigRequest(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("method") String method, + @JsonProperty("params") ListTaskPushNotificationConfigParams params) { + if (jsonrpc != null && ! jsonrpc.equals(JSONRPC_VERSION)) { + throw new IllegalArgumentException("Invalid JSON-RPC protocol version"); + } + Assert.checkNotNullParam("method", method); + if (! method.equals(METHOD)) { + throw new IllegalArgumentException("Invalid ListTaskPushNotificationConfigRequest method"); + } + Assert.isNullOrStringOrInteger(id); + this.jsonrpc = Utils.defaultIfNull(jsonrpc, JSONRPC_VERSION); + this.id = id; + this.method = method; + this.params = params; + } + + public ListTaskPushNotificationConfigRequest(String id, ListTaskPushNotificationConfigParams params) { + this(null, id, METHOD, params); + } + + public static class Builder { + private String jsonrpc; + private Object id; + private String method; + private ListTaskPushNotificationConfigParams params; + + public Builder jsonrpc(String jsonrpc) { + this.jsonrpc = jsonrpc; + return this; + } + + public Builder id(Object id) { + this.id = id; + return this; + } + + public Builder method(String method) { + this.method = method; + return this; + } + + public Builder params(ListTaskPushNotificationConfigParams params) { + this.params = params; + return this; + } + + public ListTaskPushNotificationConfigRequest build() { + if (id == null) { + id = UUID.randomUUID().toString(); + } + return new ListTaskPushNotificationConfigRequest(jsonrpc, id, method, params); + } + } +} diff --git a/spec/src/main/java/io/a2a/spec/ListTaskPushNotificationConfigResponse.java b/spec/src/main/java/io/a2a/spec/ListTaskPushNotificationConfigResponse.java new file mode 100644 index 000000000..cc610416e --- /dev/null +++ b/spec/src/main/java/io/a2a/spec/ListTaskPushNotificationConfigResponse.java @@ -0,0 +1,32 @@ +package io.a2a.spec; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A response for a list task push notification config request. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class ListTaskPushNotificationConfigResponse extends JSONRPCResponse> { + + @JsonCreator + public ListTaskPushNotificationConfigResponse(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("result") List result, + @JsonProperty("error") JSONRPCError error) { + super(jsonrpc, id, result, error, (Class>) (Class) List.class); + } + + public ListTaskPushNotificationConfigResponse(Object id, JSONRPCError error) { + this(null, id, null, error); + } + + public ListTaskPushNotificationConfigResponse(Object id, List result) { + this(null, id, result, null); + } + +} diff --git a/spec/src/main/java/io/a2a/spec/NonStreamingJSONRPCRequest.java b/spec/src/main/java/io/a2a/spec/NonStreamingJSONRPCRequest.java index f79800215..1c0a696e7 100644 --- a/spec/src/main/java/io/a2a/spec/NonStreamingJSONRPCRequest.java +++ b/spec/src/main/java/io/a2a/spec/NonStreamingJSONRPCRequest.java @@ -12,5 +12,5 @@ @JsonDeserialize(using = NonStreamingJSONRPCRequestDeserializer.class) public abstract sealed class NonStreamingJSONRPCRequest extends JSONRPCRequest permits GetTaskRequest, CancelTaskRequest, SetTaskPushNotificationConfigRequest, GetTaskPushNotificationConfigRequest, - SendMessageRequest { + SendMessageRequest, DeleteTaskPushNotificationConfigRequest, ListTaskPushNotificationConfigRequest { } diff --git a/spec/src/main/java/io/a2a/spec/NonStreamingJSONRPCRequestDeserializer.java b/spec/src/main/java/io/a2a/spec/NonStreamingJSONRPCRequestDeserializer.java index f2978b66f..c97c524c5 100644 --- a/spec/src/main/java/io/a2a/spec/NonStreamingJSONRPCRequestDeserializer.java +++ b/spec/src/main/java/io/a2a/spec/NonStreamingJSONRPCRequestDeserializer.java @@ -1,12 +1,12 @@ package io.a2a.spec; +import java.io.IOException; + import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonNode; -import java.io.IOException; - public class NonStreamingJSONRPCRequestDeserializer extends JSONRPCRequestDeserializerBase> { public NonStreamingJSONRPCRequestDeserializer() { @@ -38,10 +38,16 @@ public NonStreamingJSONRPCRequest deserialize(JsonParser jsonParser, Deserial getAndValidateParams(paramsNode, jsonParser, treeNode, TaskPushNotificationConfig.class)); case GetTaskPushNotificationConfigRequest.METHOD: return new GetTaskPushNotificationConfigRequest(jsonrpc, id, method, - getAndValidateParams(paramsNode, jsonParser, treeNode, TaskIdParams.class)); + getAndValidateParams(paramsNode, jsonParser, treeNode, GetTaskPushNotificationConfigParams.class)); case SendMessageRequest.METHOD: return new SendMessageRequest(jsonrpc, id, method, getAndValidateParams(paramsNode, jsonParser, treeNode, MessageSendParams.class)); + case ListTaskPushNotificationConfigRequest.METHOD: + return new ListTaskPushNotificationConfigRequest(jsonrpc, id, method, + getAndValidateParams(paramsNode, jsonParser, treeNode, ListTaskPushNotificationConfigParams.class)); + case DeleteTaskPushNotificationConfigRequest.METHOD: + return new DeleteTaskPushNotificationConfigRequest(jsonrpc, id, method, + getAndValidateParams(paramsNode, jsonParser, treeNode, DeleteTaskPushNotificationConfigParams.class)); default: throw new MethodNotFoundJsonMappingException("Invalid method", getIdIfPossible(treeNode, jsonParser)); } diff --git a/spec/src/main/java/io/a2a/spec/PushNotificationConfig.java b/spec/src/main/java/io/a2a/spec/PushNotificationConfig.java index 19e9d3491..34270637f 100644 --- a/spec/src/main/java/io/a2a/spec/PushNotificationConfig.java +++ b/spec/src/main/java/io/a2a/spec/PushNotificationConfig.java @@ -21,6 +21,16 @@ public static class Builder { private PushNotificationAuthenticationInfo authentication; private String id; + public Builder() { + } + + public Builder(PushNotificationConfig notificationConfig) { + this.url = notificationConfig.url; + this.token = notificationConfig.token; + this.authentication = notificationConfig.authentication; + this.id = notificationConfig.id; + } + public Builder url(String url) { this.url = url; return this; diff --git a/spec/src/main/java/io/a2a/spec/PushNotificationNotSupportedError.java b/spec/src/main/java/io/a2a/spec/PushNotificationNotSupportedError.java index b97094cc0..d639b7bab 100644 --- a/spec/src/main/java/io/a2a/spec/PushNotificationNotSupportedError.java +++ b/spec/src/main/java/io/a2a/spec/PushNotificationNotSupportedError.java @@ -13,6 +13,10 @@ public class PushNotificationNotSupportedError extends JSONRPCError { public final static Integer DEFAULT_CODE = -32003; + public PushNotificationNotSupportedError() { + this(null, null, null); + } + @JsonCreator public PushNotificationNotSupportedError( @JsonProperty("code") Integer code, diff --git a/spec/src/main/java/io/a2a/spec/SendMessageResponse.java b/spec/src/main/java/io/a2a/spec/SendMessageResponse.java index fa95bad36..901beba90 100644 --- a/spec/src/main/java/io/a2a/spec/SendMessageResponse.java +++ b/spec/src/main/java/io/a2a/spec/SendMessageResponse.java @@ -18,11 +18,7 @@ public final class SendMessageResponse extends JSONRPCResponse { @JsonCreator public SendMessageResponse(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, @JsonProperty("result") EventKind result, @JsonProperty("error") JSONRPCError error) { - this.jsonrpc = defaultIfNull(jsonrpc, JSONRPC_VERSION); - Assert.isNullOrStringOrInteger(id); - this.id = id; - this.result = result; - this.error = error; + super(jsonrpc, id, result, error, EventKind.class); } public SendMessageResponse(Object id, EventKind result) { diff --git a/spec/src/main/java/io/a2a/spec/SendStreamingMessageResponse.java b/spec/src/main/java/io/a2a/spec/SendStreamingMessageResponse.java index b3597bfb6..f3bcb9676 100644 --- a/spec/src/main/java/io/a2a/spec/SendStreamingMessageResponse.java +++ b/spec/src/main/java/io/a2a/spec/SendStreamingMessageResponse.java @@ -18,11 +18,7 @@ public final class SendStreamingMessageResponse extends JSONRPCResponse io.github.a2asdk a2a-java-sdk-parent - 0.2.3.Beta2-SNAPSHOT + 0.2.6.Beta1-SNAPSHOT a2a-tck-server @@ -18,7 +18,7 @@ io.github.a2asdk - a2a-java-sdk-server-quarkus + a2a-java-reference-server ${project.version} diff --git a/tck/src/main/java/io/a2a/tck/server/AgentCardProducer.java b/tck/src/main/java/io/a2a/tck/server/AgentCardProducer.java index 7abf29a12..610443ac8 100644 --- a/tck/src/main/java/io/a2a/tck/server/AgentCardProducer.java +++ b/tck/src/main/java/io/a2a/tck/server/AgentCardProducer.java @@ -37,6 +37,7 @@ public AgentCard agentCard() { .tags(Collections.singletonList("hello world")) .examples(List.of("hi", "hello world")) .build())) + .protocolVersion("0.2.5") .build(); } } diff --git a/tck/src/main/java/io/a2a/tck/server/AgentExecutorProducer.java b/tck/src/main/java/io/a2a/tck/server/AgentExecutorProducer.java index 592546591..8d2cee335 100644 --- a/tck/src/main/java/io/a2a/tck/server/AgentExecutorProducer.java +++ b/tck/src/main/java/io/a2a/tck/server/AgentExecutorProducer.java @@ -29,10 +29,6 @@ private static class FireAndForgetAgentExecutor implements AgentExecutor { public void execute(RequestContext context, EventQueue eventQueue) throws JSONRPCError { Task task = context.getTask(); - if (context.getMessage().getTaskId() != null && task == null && context.getMessage().getTaskId().startsWith("non-existent")) { - throw new TaskNotFoundError(); - } - if (task == null) { task = new Task.Builder() .id(context.getTaskId()) @@ -43,12 +39,21 @@ public void execute(RequestContext context, EventQueue eventQueue) throws JSONRP eventQueue.enqueueEvent(task); } + if (context.getMessage().getMessageId().startsWith("test-resubscribe-message-id")) { + int timeoutMs = Integer.parseInt(System.getenv().getOrDefault("RESUBSCRIBE_TIMEOUT_MS", "3000")); + System.out.println("====> task id starts with test-resubscribe-message-id, sleeping for " + timeoutMs + " ms"); + try { + Thread.sleep(timeoutMs); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } TaskUpdater updater = new TaskUpdater(context, eventQueue); // Immediately set to WORKING state updater.startWork(); System.out.println("====> task set to WORKING, starting background execution"); - + // Method returns immediately - task continues in background System.out.println("====> execute() method returning immediately, task running in background"); } diff --git a/tests/server-common/pom.xml b/tests/server-common/pom.xml index 0e205b547..831dcc86c 100644 --- a/tests/server-common/pom.xml +++ b/tests/server-common/pom.xml @@ -7,7 +7,7 @@ io.github.a2asdk a2a-java-sdk-parent - 0.2.3.Beta2-SNAPSHOT + 0.2.6.Beta1-SNAPSHOT ../../pom.xml a2a-java-sdk-tests-server-common diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java index 0992a157f..f9e5cee17 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java @@ -5,282 +5,248 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.fail; import static org.wildfly.common.Assert.assertNotNull; import static org.wildfly.common.Assert.assertTrue; -import java.io.EOFException; +import java.io.IOException; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Stream; -import jakarta.ws.rs.core.MediaType; - -import com.fasterxml.jackson.core.JsonProcessingException; import io.a2a.server.events.InMemoryQueueManager; import io.a2a.server.tasks.TaskStore; +import jakarta.ws.rs.core.MediaType; + +import io.a2a.client.A2AClient; +import io.a2a.spec.A2AServerException; import io.a2a.spec.AgentCard; import io.a2a.spec.Artifact; -import io.a2a.spec.CancelTaskRequest; import io.a2a.spec.CancelTaskResponse; +import io.a2a.spec.DeleteTaskPushNotificationConfigResponse; import io.a2a.spec.Event; -import io.a2a.spec.GetTaskPushNotificationConfigRequest; import io.a2a.spec.GetTaskPushNotificationConfigResponse; -import io.a2a.spec.GetTaskRequest; import io.a2a.spec.GetTaskResponse; import io.a2a.spec.InvalidParamsError; import io.a2a.spec.InvalidRequestError; import io.a2a.spec.JSONParseError; import io.a2a.spec.JSONRPCError; import io.a2a.spec.JSONRPCErrorResponse; +import io.a2a.spec.ListTaskPushNotificationConfigResponse; import io.a2a.spec.Message; import io.a2a.spec.MessageSendParams; import io.a2a.spec.MethodNotFoundError; import io.a2a.spec.Part; import io.a2a.spec.PushNotificationConfig; -import io.a2a.spec.SendMessageRequest; import io.a2a.spec.SendMessageResponse; -import io.a2a.spec.SendStreamingMessageRequest; -import io.a2a.spec.SendStreamingMessageResponse; -import io.a2a.spec.SetTaskPushNotificationConfigRequest; import io.a2a.spec.SetTaskPushNotificationConfigResponse; -import io.a2a.spec.StreamingJSONRPCRequest; import io.a2a.spec.Task; import io.a2a.spec.TaskArtifactUpdateEvent; import io.a2a.spec.TaskIdParams; import io.a2a.spec.TaskNotFoundError; import io.a2a.spec.TaskPushNotificationConfig; import io.a2a.spec.TaskQueryParams; -import io.a2a.spec.TaskResubscriptionRequest; import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; import io.a2a.spec.UnsupportedOperationError; import io.a2a.util.Utils; -import io.restassured.RestAssured; -import io.restassured.specification.RequestSpecification; + import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +/** + * This test requires doing some work on the server to add/get/delete tasks, and enqueue events. This is exposed via REST, + * which delegates to {@link TestUtilsBean}. + */ public abstract class AbstractA2AServerTest { - + private static final Task MINIMAL_TASK = new Task.Builder() .id("task-123") .contextId("session-xyz") .status(new TaskStatus(TaskState.SUBMITTED)) .build(); - + private static final Task CANCEL_TASK = new Task.Builder() .id("cancel-task-123") .contextId("session-xyz") .status(new TaskStatus(TaskState.SUBMITTED)) .build(); - + private static final Task CANCEL_TASK_NOT_SUPPORTED = new Task.Builder() .id("cancel-task-not-supported-123") .contextId("session-xyz") .status(new TaskStatus(TaskState.SUBMITTED)) .build(); - + private static final Task SEND_MESSAGE_NOT_SUPPORTED = new Task.Builder() .id("task-not-supported-123") .contextId("session-xyz") .status(new TaskStatus(TaskState.SUBMITTED)) .build(); - + private static final Message MESSAGE = new Message.Builder() .messageId("111") .role(Message.Role.AGENT) .parts(new TextPart("test message")) .build(); - + public static final String APPLICATION_JSON = "application/json"; + private final int serverPort; - + private A2AClient client; + protected AbstractA2AServerTest(int serverPort) { this.serverPort = serverPort; + this.client = new A2AClient("http://localhost:" + serverPort); } - + + @Test + public void testTaskStoreMethodsSanityTest() throws Exception { + Task task = new Task.Builder(MINIMAL_TASK).id("abcde").build(); + saveTaskInTaskStore(task); + Task saved = getTaskFromTaskStore(task.getId()); + assertEquals(task.getId(), saved.getId()); + assertEquals(task.getContextId(), saved.getContextId()); + assertEquals(task.getStatus().state(), saved.getStatus().state()); + + deleteTaskInTaskStore(task.getId()); + Task saved2 = getTaskFromTaskStore(task.getId()); + assertNull(saved2); + } + @Test - public void testGetTaskSuccess() { + public void testGetTaskSuccess() throws Exception { testGetTask(); } - - private void testGetTask() { + + private void testGetTask() throws Exception { testGetTask(null); } - - private void testGetTask(String mediaType) { + + private void testGetTask(String mediaType) throws Exception { getTaskStore().save(MINIMAL_TASK); try { - GetTaskRequest request = new GetTaskRequest("1", new TaskQueryParams(MINIMAL_TASK.getId())); - RequestSpecification requestSpecification = RestAssured.given() - .contentType(MediaType.APPLICATION_JSON) - .body(request); - if (mediaType != null) { - requestSpecification = requestSpecification.accept(mediaType); - } - GetTaskResponse response = requestSpecification - .when() - .post("/") - .then() - .statusCode(200) - .extract() - .as(GetTaskResponse.class); + GetTaskResponse response = client.getTask("1", new TaskQueryParams(MINIMAL_TASK.getId())); assertEquals("1", response.getId()); assertEquals("task-123", response.getResult().getId()); assertEquals("session-xyz", response.getResult().getContextId()); assertEquals(TaskState.SUBMITTED, response.getResult().getStatus().state()); assertNull(response.getError()); - } catch (Exception e) { + } catch (A2AServerException e) { + fail("Unexpected exception during getTask: " + e.getMessage(), e); } finally { getTaskStore().delete(MINIMAL_TASK.getId()); } } - + @Test - public void testGetTaskNotFound() { + public void testGetTaskNotFound() throws Exception { assertTrue(getTaskStore().get("non-existent-task") == null); - GetTaskRequest request = new GetTaskRequest("1", new TaskQueryParams("non-existent-task")); - GetTaskResponse response = given() - .contentType(MediaType.APPLICATION_JSON) - .body(request) - .when() - .post("/") - .then() - .statusCode(200) - .extract() - .as(GetTaskResponse.class); - assertEquals("1", response.getId()); - // this should be an instance of TaskNotFoundError, see https://github.com/a2aproject/a2a-java/issues/23 - assertInstanceOf(JSONRPCError.class, response.getError()); - assertEquals(new TaskNotFoundError().getCode(), response.getError().getCode()); - assertNull(response.getResult()); + try { + GetTaskResponse response = client.getTask("1", new TaskQueryParams("non-existent-task")); + assertEquals("1", response.getId()); + assertInstanceOf(JSONRPCError.class, response.getError()); + assertEquals(new TaskNotFoundError().getCode(), response.getError().getCode()); + assertNull(response.getResult()); + } catch (A2AServerException e) { + fail("Unexpected exception during getTask for non-existent task: " + e.getMessage(), e); + } } - + @Test - public void testCancelTaskSuccess() { + public void testCancelTaskSuccess() throws Exception { getTaskStore().save(CANCEL_TASK); try { - CancelTaskRequest request = new CancelTaskRequest("1", new TaskIdParams(CANCEL_TASK.getId())); - CancelTaskResponse response = given() - .contentType(MediaType.APPLICATION_JSON) - .body(request) - .when() - .post("/") - .then() - .statusCode(200) - .extract() - .as(CancelTaskResponse.class); + CancelTaskResponse response = client.cancelTask("1", new TaskIdParams(CANCEL_TASK.getId())); + assertEquals("1", response.getId()); assertNull(response.getError()); - assertEquals(request.getId(), response.getId()); Task task = response.getResult(); - assertEquals(CANCEL_TASK.getId(), task.getId()); assertEquals(CANCEL_TASK.getContextId(), task.getContextId()); assertEquals(TaskState.CANCELED, task.getStatus().state()); - } catch (Exception e) { + } catch (A2AServerException e) { + fail("Unexpected exception during cancel task success test: " + e.getMessage(), e); } finally { getTaskStore().delete(CANCEL_TASK.getId()); } } - + @Test - public void testCancelTaskNotSupported() { + public void testCancelTaskNotSupported() throws Exception { getTaskStore().save(CANCEL_TASK_NOT_SUPPORTED); try { - CancelTaskRequest request = new CancelTaskRequest("1", new TaskIdParams(CANCEL_TASK_NOT_SUPPORTED.getId())); - CancelTaskResponse response = given() - .contentType(MediaType.APPLICATION_JSON) - .body(request) - .when() - .post("/") - .then() - .statusCode(200) - .extract() - .as(CancelTaskResponse.class); - assertEquals(request.getId(), response.getId()); + CancelTaskResponse response = client.cancelTask("1", new TaskIdParams(CANCEL_TASK_NOT_SUPPORTED.getId())); assertNull(response.getResult()); - // this should be an instance of UnsupportedOperationError, see https://github.com/a2aproject/a2a-java/issues/23 + assertEquals("1", response.getId()); assertInstanceOf(JSONRPCError.class, response.getError()); assertEquals(new UnsupportedOperationError().getCode(), response.getError().getCode()); - } catch (Exception e) { + } catch (A2AServerException e) { + fail("Unexpected exception during cancel task not supported test: " + e.getMessage(), e); } finally { getTaskStore().delete(CANCEL_TASK_NOT_SUPPORTED.getId()); } } - + @Test public void testCancelTaskNotFound() { - CancelTaskRequest request = new CancelTaskRequest("1", new TaskIdParams("non-existent-task")); - CancelTaskResponse response = given() - .contentType(MediaType.APPLICATION_JSON) - .body(request) - .when() - .post("/") - .then() - .statusCode(200) - .extract() - .as(CancelTaskResponse.class) - ; - assertEquals(request.getId(), response.getId()); - assertNull(response.getResult()); - // this should be an instance of UnsupportedOperationError, see https://github.com/a2aproject/a2a-java/issues/23 - assertInstanceOf(JSONRPCError.class, response.getError()); - assertEquals(new TaskNotFoundError().getCode(), response.getError().getCode()); + try { + CancelTaskResponse response = client.cancelTask("1", new TaskIdParams("non-existent-task")); + assertEquals("1", response.getId()); + assertNull(response.getResult()); + assertInstanceOf(JSONRPCError.class, response.getError()); + assertEquals(new TaskNotFoundError().getCode(), response.getError().getCode()); + } catch (A2AServerException e) { + fail("Unexpected exception during cancel task not found test: " + e.getMessage(), e); + } } - + @Test - public void testSendMessageNewMessageSuccess() { + public void testSendMessageNewMessageSuccess() throws Exception { assertTrue(getTaskStore().get(MINIMAL_TASK.getId()) == null); Message message = new Message.Builder(MESSAGE) .taskId(MINIMAL_TASK.getId()) .contextId(MINIMAL_TASK.getContextId()) .build(); - SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); - SendMessageResponse response = given() - .contentType(MediaType.APPLICATION_JSON) - .body(request) - .when() - .post("/") - .then() - .statusCode(200) - .extract() - .as(SendMessageResponse.class); - assertNull(response.getError()); - Message messageResponse = (Message) response.getResult(); - assertEquals(MESSAGE.getMessageId(), messageResponse.getMessageId()); - assertEquals(MESSAGE.getRole(), messageResponse.getRole()); - Part part = messageResponse.getParts().get(0); - assertEquals(Part.Kind.TEXT, part.getKind()); - assertEquals("test message", ((TextPart) part).getText()); + MessageSendParams messageSendParams = new MessageSendParams(message, null, null); + + try { + SendMessageResponse response = client.sendMessage("1", messageSendParams); + assertEquals("1", response.getId()); + assertNull(response.getError()); + Message messageResponse = (Message) response.getResult(); + assertEquals(MESSAGE.getMessageId(), messageResponse.getMessageId()); + assertEquals(MESSAGE.getRole(), messageResponse.getRole()); + Part part = messageResponse.getParts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("test message", ((TextPart) part).getText()); + } catch (A2AServerException e) { + fail("Unexpected exception during send new message test: " + e.getMessage(), e); + } } - + @Test - public void testSendMessageExistingTaskSuccess() { + public void testSendMessageExistingTaskSuccess() throws Exception { getTaskStore().save(MINIMAL_TASK); try { Message message = new Message.Builder(MESSAGE) .taskId(MINIMAL_TASK.getId()) .contextId(MINIMAL_TASK.getContextId()) .build(); - SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); - SendMessageResponse response = given() - .contentType(MediaType.APPLICATION_JSON) - .body(request) - .when() - .post("/") - .then() - .statusCode(200) - .extract() - .as(SendMessageResponse.class); + MessageSendParams messageSendParams = new MessageSendParams(message, null, null); + + SendMessageResponse response = client.sendMessage("1", messageSendParams); + assertEquals("1", response.getId()); assertNull(response.getError()); Message messageResponse = (Message) response.getResult(); assertEquals(MESSAGE.getMessageId(), messageResponse.getMessageId()); @@ -288,106 +254,80 @@ public void testSendMessageExistingTaskSuccess() { Part part = messageResponse.getParts().get(0); assertEquals(Part.Kind.TEXT, part.getKind()); assertEquals("test message", ((TextPart) part).getText()); - } catch (Exception e) { + } catch (A2AServerException e) { + fail("Unexpected exception during send message to existing task test: " + e.getMessage(), e); } finally { getTaskStore().delete(MINIMAL_TASK.getId()); } } - + @Test - public void testSetPushNotificationSuccess() { + public void testSetPushNotificationSuccess() throws Exception { getTaskStore().save(MINIMAL_TASK); try { - TaskPushNotificationConfig taskPushConfig = - new TaskPushNotificationConfig( - MINIMAL_TASK.getId(), new PushNotificationConfig.Builder().url("http://example.com").build()); - SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); - SetTaskPushNotificationConfigResponse response = given() - .contentType(MediaType.APPLICATION_JSON) - .body(request) - .when() - .post("/") - .then() - .statusCode(200) - .extract() - .as(SetTaskPushNotificationConfigResponse.class); + PushNotificationConfig pushNotificationConfig = new PushNotificationConfig.Builder() + .url("http://example.com") + .build(); + SetTaskPushNotificationConfigResponse response = client.setTaskPushNotificationConfig("1", + MINIMAL_TASK.getId(), pushNotificationConfig); + assertEquals("1", response.getId()); assertNull(response.getError()); - assertEquals(request.getId(), response.getId()); TaskPushNotificationConfig config = response.getResult(); assertEquals(MINIMAL_TASK.getId(), config.taskId()); assertEquals("http://example.com", config.pushNotificationConfig().url()); - } catch (Exception e) { + } catch (A2AServerException e) { + fail("Unexpected exception during set push notification test: " + e.getMessage(), e); } finally { getTaskStore().delete(MINIMAL_TASK.getId()); } } - + @Test - public void testGetPushNotificationSuccess() { + public void testGetPushNotificationSuccess() throws Exception { getTaskStore().save(MINIMAL_TASK); try { - TaskPushNotificationConfig taskPushConfig = - new TaskPushNotificationConfig( - MINIMAL_TASK.getId(), new PushNotificationConfig.Builder().url("http://example.com").build()); - - SetTaskPushNotificationConfigRequest setTaskPushNotificationRequest = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); - SetTaskPushNotificationConfigResponse setTaskPushNotificationResponse = given() - .contentType(MediaType.APPLICATION_JSON) - .body(setTaskPushNotificationRequest) - .when() - .post("/") - .then() - .statusCode(200) - .extract() - .as(SetTaskPushNotificationConfigResponse.class); - assertNotNull(setTaskPushNotificationResponse); - - GetTaskPushNotificationConfigRequest request = - new GetTaskPushNotificationConfigRequest("111", new TaskIdParams(MINIMAL_TASK.getId())); - GetTaskPushNotificationConfigResponse response = given() - .contentType(MediaType.APPLICATION_JSON) - .body(request) - .when() - .post("/") - .then() - .statusCode(200) - .extract() - .as(GetTaskPushNotificationConfigResponse.class); + PushNotificationConfig pushNotificationConfig = new PushNotificationConfig.Builder() + .url("http://example.com") + .build(); + + // First set the push notification config + SetTaskPushNotificationConfigResponse setResponse = client.setTaskPushNotificationConfig("1", + MINIMAL_TASK.getId(), pushNotificationConfig); + assertNotNull(setResponse); + + // Then get the push notification config + GetTaskPushNotificationConfigResponse response = client.getTaskPushNotificationConfig("2", pushNotificationConfig.id()); + assertEquals("2", response.getId()); assertNull(response.getError()); - assertEquals(request.getId(), response.getId()); TaskPushNotificationConfig config = response.getResult(); assertEquals(MINIMAL_TASK.getId(), config.taskId()); assertEquals("http://example.com", config.pushNotificationConfig().url()); - } catch (Exception e) { + } catch (A2AServerException e) { + fail("Unexpected exception during get push notification test: " + e.getMessage(), e); } finally { getTaskStore().delete(MINIMAL_TASK.getId()); } } - + @Test public void testError() { Message message = new Message.Builder(MESSAGE) .taskId(SEND_MESSAGE_NOT_SUPPORTED.getId()) .contextId(SEND_MESSAGE_NOT_SUPPORTED.getContextId()) .build(); - SendMessageRequest request = new SendMessageRequest( - "1", new MessageSendParams(message, null, null)); - SendMessageResponse response = given() - .contentType(MediaType.APPLICATION_JSON) - .body(request) - .when() - .post("/") - .then() - .statusCode(200) - .extract() - .as(SendMessageResponse.class); - assertEquals(request.getId(), response.getId()); - assertNull(response.getResult()); - // this should be an instance of UnsupportedOperationError, see https://github.com/a2aproject/a2a-java/issues/23 - assertInstanceOf(JSONRPCError.class, response.getError()); - assertEquals(new UnsupportedOperationError().getCode(), response.getError().getCode()); + MessageSendParams messageSendParams = new MessageSendParams(message, null, null); + + try { + SendMessageResponse response = client.sendMessage("1", messageSendParams); + assertEquals("1", response.getId()); + assertNull(response.getResult()); + assertInstanceOf(JSONRPCError.class, response.getError()); + assertEquals(new UnsupportedOperationError().getCode(), response.getError().getCode()); + } catch (A2AServerException e) { + fail("Unexpected exception during error handling test: " + e.getMessage(), e); + } } - + @Test public void testGetAgentCard() { AgentCard agentCard = given() @@ -409,7 +349,7 @@ public void testGetAgentCard() { assertTrue(agentCard.capabilities().stateTransitionHistory()); assertTrue(agentCard.skills().isEmpty()); } - + @Test public void testGetExtendAgentCardNotSupported() { given() @@ -420,7 +360,7 @@ public void testGetExtendAgentCardNotSupported() { .statusCode(404) .body("error", equalTo("Extended agent card not supported or not enabled.")); } - + @Test public void testMalformedJSONRPCRequest() { // missing closing bracket @@ -437,20 +377,20 @@ public void testMalformedJSONRPCRequest() { assertNotNull(response.getError()); assertEquals(new JSONParseError().getCode(), response.getError().getCode()); } - + @Test public void testInvalidParamsJSONRPCRequest() { String invalidParamsRequest = """ {"jsonrpc": "2.0", "method": "message/send", "params": "not_a_dict", "id": "1"} """; testInvalidParams(invalidParamsRequest); - + invalidParamsRequest = """ {"jsonrpc": "2.0", "method": "message/send", "params": {"message": {"parts": "invalid"}}, "id": "1"} """; testInvalidParams(invalidParamsRequest); } - + private void testInvalidParams(String invalidParamsRequest) { JSONRPCErrorResponse response = given() .contentType(MediaType.APPLICATION_JSON) @@ -465,7 +405,7 @@ private void testInvalidParams(String invalidParamsRequest) { assertEquals(new InvalidParamsError().getCode(), response.getError().getCode()); assertEquals("1", response.getId()); } - + @Test public void testInvalidJSONRPCRequestMissingJsonrpc() { String invalidRequest = """ @@ -486,7 +426,7 @@ public void testInvalidJSONRPCRequestMissingJsonrpc() { assertNotNull(response.getError()); assertEquals(new InvalidRequestError().getCode(), response.getError().getCode()); } - + @Test public void testInvalidJSONRPCRequestMissingMethod() { String invalidRequest = """ @@ -504,7 +444,7 @@ public void testInvalidJSONRPCRequestMissingMethod() { assertNotNull(response.getError()); assertEquals(new InvalidRequestError().getCode(), response.getError().getCode()); } - + @Test public void testInvalidJSONRPCRequestInvalidId() { String invalidRequest = """ @@ -522,7 +462,7 @@ public void testInvalidJSONRPCRequestInvalidId() { assertNotNull(response.getError()); assertEquals(new InvalidRequestError().getCode(), response.getError().getCode()); } - + @Test public void testInvalidJSONRPCRequestNonExistentMethod() { String invalidRequest = """ @@ -540,129 +480,138 @@ public void testInvalidJSONRPCRequestNonExistentMethod() { assertNotNull(response.getError()); assertEquals(new MethodNotFoundError().getCode(), response.getError().getCode()); } - + @Test - public void testNonStreamingMethodWithAcceptHeader() { + public void testNonStreamingMethodWithAcceptHeader() throws Exception { testGetTask(MediaType.APPLICATION_JSON); } - - + + + @Test - public void testSendMessageStreamExistingTaskSuccess() { + public void testSendMessageStreamExistingTaskSuccess() throws Exception { getTaskStore().save(MINIMAL_TASK); try { Message message = new Message.Builder(MESSAGE) .taskId(MINIMAL_TASK.getId()) .contextId(MINIMAL_TASK.getContextId()) .build(); - SendStreamingMessageRequest request = new SendStreamingMessageRequest( - "1", new MessageSendParams(message, null, null)); - - CompletableFuture>> responseFuture = initialiseStreamingRequest(request, null); - + MessageSendParams messageSendParams = new MessageSendParams(message, null, null); + CountDownLatch latch = new CountDownLatch(1); AtomicReference errorRef = new AtomicReference<>(); - - responseFuture.thenAccept(response -> { - if (response.statusCode() != 200) { - //errorRef.set(new IllegalStateException("Status code was " + response.statusCode())); - throw new IllegalStateException("Status code was " + response.statusCode()); - } - response.body().forEach(line -> { - try { - SendStreamingMessageResponse jsonResponse = extractJsonResponseFromSseLine(line); - if (jsonResponse != null) { - assertNull(jsonResponse.getError()); - Message messageResponse = (Message) jsonResponse.getResult(); - assertEquals(MESSAGE.getMessageId(), messageResponse.getMessageId()); - assertEquals(MESSAGE.getRole(), messageResponse.getRole()); - Part part = messageResponse.getParts().get(0); - assertEquals(Part.Kind.TEXT, part.getKind()); - assertEquals("test message", ((TextPart) part).getText()); + AtomicReference messageResponseRef = new AtomicReference<>(); + + // Replace the native HttpClient with A2AClient's sendStreamingMessage method. + client.sendStreamingMessage( + "1", + messageSendParams, + // eventHandler + (streamingEvent) -> { + try { + if (streamingEvent instanceof Message) { + messageResponseRef.set((Message) streamingEvent); + latch.countDown(); + } + } catch (Exception e) { + errorRef.set(e); latch.countDown(); } - } catch (JsonProcessingException e) { - throw new RuntimeException(e); + }, + // errorHandler + (jsonRpcError) -> { + errorRef.set(new RuntimeException("JSON-RPC Error: " + jsonRpcError.getMessage())); + latch.countDown(); + }, + // failureHandler + () -> { + if (errorRef.get() == null) { + errorRef.set(new RuntimeException("Stream processing failed")); + } + latch.countDown(); } - }); - }).exceptionally(t -> { - if (!isStreamClosedError(t)) { - errorRef.set(t); - } - latch.countDown(); - return null; - }); - + ); + boolean dataRead = latch.await(20, TimeUnit.SECONDS); Assertions.assertTrue(dataRead); Assertions.assertNull(errorRef.get()); + + Message messageResponse = messageResponseRef.get(); + Assertions.assertNotNull(messageResponse); + assertEquals(MESSAGE.getMessageId(), messageResponse.getMessageId()); + assertEquals(MESSAGE.getRole(), messageResponse.getRole()); + Part part = messageResponse.getParts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("test message", ((TextPart) part).getText()); } catch (Exception e) { + fail("Unexpected exception during error handling test: " + e.getMessage(), e); } finally { getTaskStore().delete(MINIMAL_TASK.getId()); } } - + @Test + @Timeout(value = 3, unit = TimeUnit.MINUTES) public void testResubscribeExistingTaskSuccess() throws Exception { ExecutorService executorService = Executors.newSingleThreadExecutor(); getTaskStore().save(MINIMAL_TASK); - + try { // attempting to send a streaming message instead of explicitly calling queueManager#createOrTap // does not work because after the message is sent, the queue becomes null but task resubscription // requires the queue to still be active - getQueueManager().createOrTap(MINIMAL_TASK.getId()); - + ensureQueueForTask(MINIMAL_TASK.getId()); + CountDownLatch taskResubscriptionRequestSent = new CountDownLatch(1); CountDownLatch taskResubscriptionResponseReceived = new CountDownLatch(2); - AtomicReference firstResponse = new AtomicReference<>(); - AtomicReference secondResponse = new AtomicReference<>(); - + AtomicReference firstResponse = new AtomicReference<>(); + AtomicReference secondResponse = new AtomicReference<>(); + AtomicReference errorRef = new AtomicReference<>(); + // resubscribe to the task, requires the task and its queue to still be active - TaskResubscriptionRequest taskResubscriptionRequest = new TaskResubscriptionRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); - + TaskIdParams taskIdParams = new TaskIdParams(MINIMAL_TASK.getId()); + // Count down the latch when the MultiSseSupport on the server has started subscribing - setStreamingSubscribedRunnable(taskResubscriptionRequestSent::countDown); - - CompletableFuture>> responseFuture = initialiseStreamingRequest(taskResubscriptionRequest, null); - - AtomicReference errorRef = new AtomicReference<>(); - - responseFuture.thenAccept(response -> { - - if (response.statusCode() != 200) { - //errorRef.set(new IllegalStateException("Status code was " + response.statusCode())); - throw new IllegalStateException("Status code was " + response.statusCode()); - } - try { - response.body().forEach(line -> { + awaitStreamingSubscription() + .whenComplete((unused, throwable) -> taskResubscriptionRequestSent.countDown()); + + // Use A2AClient-like resubscribeToTask Method + client.resubscribeToTask( + "1", // requestId + taskIdParams, + // eventHandler + (streamingEvent) -> { try { - SendStreamingMessageResponse jsonResponse = extractJsonResponseFromSseLine(line); - if (jsonResponse != null) { - SendStreamingMessageResponse sendStreamingMessageResponse = Utils.OBJECT_MAPPER.readValue(line.substring("data: ".length()).trim(), SendStreamingMessageResponse.class); + if (streamingEvent instanceof TaskArtifactUpdateEvent) { if (taskResubscriptionResponseReceived.getCount() == 2) { - firstResponse.set(sendStreamingMessageResponse); + firstResponse.set((TaskArtifactUpdateEvent) streamingEvent); } else { - secondResponse.set(sendStreamingMessageResponse); + secondResponse.set((TaskStatusUpdateEvent) streamingEvent); } taskResubscriptionResponseReceived.countDown(); - if (taskResubscriptionResponseReceived.getCount() == 0) { - throw new BreakException(); - } } - } catch (JsonProcessingException e) { - throw new RuntimeException(e); + } catch (Exception e) { + errorRef.set(e); + taskResubscriptionResponseReceived.countDown(); + taskResubscriptionResponseReceived.countDown(); // Make sure the counter is zeroed } - }); - } catch (BreakException e) { - } - }).exceptionally(t -> { - if (!isStreamClosedError(t)) { - errorRef.set(t); - } - return null; - }); - + }, + // errorHandler + (jsonRpcError) -> { + errorRef.set(new RuntimeException("JSON-RPC Error: " + jsonRpcError.getMessage())); + taskResubscriptionResponseReceived.countDown(); + taskResubscriptionResponseReceived.countDown(); // Make sure the counter is zeroed + }, + // failureHandler + () -> { + if (errorRef.get() == null) { + errorRef.set(new RuntimeException("Stream processing failed")); + } + taskResubscriptionResponseReceived.countDown(); + taskResubscriptionResponseReceived.countDown(); // Make sure the counter is zeroed + } + ); + try { taskResubscriptionRequestSent.await(); List events = List.of( @@ -680,37 +629,34 @@ public void testResubscribeExistingTaskSuccess() throws Exception { .status(new TaskStatus(TaskState.COMPLETED)) .isFinal(true) .build()); - + for (Event event : events) { - getQueueManager().get(MINIMAL_TASK.getId()).enqueueEvent(event); + enqueueEventOnServer(event); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); } - + // wait for the client to receive the responses taskResubscriptionResponseReceived.await(); - + + Assertions.assertNull(errorRef.get()); + assertNotNull(firstResponse.get()); - SendStreamingMessageResponse sendStreamingMessageResponse = firstResponse.get(); - assertNull(sendStreamingMessageResponse.getError()); - TaskArtifactUpdateEvent taskArtifactUpdateEvent = (TaskArtifactUpdateEvent) sendStreamingMessageResponse.getResult(); + TaskArtifactUpdateEvent taskArtifactUpdateEvent = firstResponse.get(); assertEquals(MINIMAL_TASK.getId(), taskArtifactUpdateEvent.getTaskId()); assertEquals(MINIMAL_TASK.getContextId(), taskArtifactUpdateEvent.getContextId()); Part part = taskArtifactUpdateEvent.getArtifact().parts().get(0); assertEquals(Part.Kind.TEXT, part.getKind()); assertEquals("text", ((TextPart) part).getText()); - + assertNotNull(secondResponse.get()); - sendStreamingMessageResponse = secondResponse.get(); - assertNull(sendStreamingMessageResponse.getError()); - TaskStatusUpdateEvent taskStatusUpdateEvent = (TaskStatusUpdateEvent) sendStreamingMessageResponse.getResult(); + TaskStatusUpdateEvent taskStatusUpdateEvent = secondResponse.get(); assertEquals(MINIMAL_TASK.getId(), taskStatusUpdateEvent.getTaskId()); assertEquals(MINIMAL_TASK.getContextId(), taskStatusUpdateEvent.getContextId()); assertEquals(TaskState.COMPLETED, taskStatusUpdateEvent.getStatus().state()); assertNotNull(taskStatusUpdateEvent.getStatus().timestamp()); } finally { - setStreamingSubscribedRunnable(null); getTaskStore().delete(MINIMAL_TASK.getId()); executorService.shutdown(); if (!executorService.awaitTermination(10, TimeUnit.SECONDS)) { @@ -718,169 +664,496 @@ public void testResubscribeExistingTaskSuccess() throws Exception { } } } - + @Test public void testResubscribeNoExistingTaskError() throws Exception { - TaskResubscriptionRequest request = new TaskResubscriptionRequest("1", new TaskIdParams("non-existent-task")); - - CompletableFuture>> responseFuture = initialiseStreamingRequest(request, null); - + TaskIdParams taskIdParams = new TaskIdParams("non-existent-task"); + CountDownLatch latch = new CountDownLatch(1); AtomicReference errorRef = new AtomicReference<>(); - - responseFuture.thenAccept(response -> { - if (response.statusCode() != 200) { - //errorRef.set(new IllegalStateException("Status code was " + response.statusCode())); - throw new IllegalStateException("Status code was " + response.statusCode()); - } - response.body().forEach(line -> { - try { - SendStreamingMessageResponse jsonResponse = extractJsonResponseFromSseLine(line); - if (jsonResponse != null) { - assertEquals(request.getId(), jsonResponse.getId()); - assertNull(jsonResponse.getResult()); - // this should be an instance of TaskNotFoundError, see https://github.com/a2aproject/a2a-java/issues/23 - assertInstanceOf(JSONRPCError.class, jsonResponse.getError()); - assertEquals(new TaskNotFoundError().getCode(), jsonResponse.getError().getCode()); - latch.countDown(); + AtomicReference jsonRpcErrorRef = new AtomicReference<>(); + + // Use A2AClient-like resubscribeToTask Method + client.resubscribeToTask( + "1", // requestId + taskIdParams, + // eventHandler + (streamingEvent) -> { + // Do not expect to receive any success events, as the task does not exist + errorRef.set(new RuntimeException("Unexpected event received for non-existent task")); + latch.countDown(); + }, + // errorHandler + (jsonRpcError) -> { + jsonRpcErrorRef.set(jsonRpcError); + latch.countDown(); + }, + // failureHandler + () -> { + if (errorRef.get() == null && jsonRpcErrorRef.get() == null) { + errorRef.set(new RuntimeException("Expected error for non-existent task")); } - } catch (JsonProcessingException e) { - throw new RuntimeException(e); + latch.countDown(); } - }); - }).exceptionally(t -> { - if (!isStreamClosedError(t)) { - errorRef.set(t); - } - latch.countDown(); - return null; - }); - + ); + boolean dataRead = latch.await(20, TimeUnit.SECONDS); Assertions.assertTrue(dataRead); Assertions.assertNull(errorRef.get()); + + // Validation returns the expected TaskNotFoundError + JSONRPCError jsonRpcError = jsonRpcErrorRef.get(); + Assertions.assertNotNull(jsonRpcError); + assertEquals(new TaskNotFoundError().getCode(), jsonRpcError.getCode()); } - + @Test public void testStreamingMethodWithAcceptHeader() throws Exception { testSendStreamingMessage(MediaType.SERVER_SENT_EVENTS); } - + @Test public void testSendMessageStreamNewMessageSuccess() throws Exception { testSendStreamingMessage(null); } - + private void testSendStreamingMessage(String mediaType) throws Exception { Message message = new Message.Builder(MESSAGE) .taskId(MINIMAL_TASK.getId()) .contextId(MINIMAL_TASK.getContextId()) .build(); - SendStreamingMessageRequest request = new SendStreamingMessageRequest( - "1", new MessageSendParams(message, null, null)); - - CompletableFuture>> responseFuture = initialiseStreamingRequest(request, mediaType); - + MessageSendParams messageSendParams = new MessageSendParams(message, null, null); + CountDownLatch latch = new CountDownLatch(1); AtomicReference errorRef = new AtomicReference<>(); - - responseFuture.thenAccept(response -> { - if (response.statusCode() != 200) { - //errorRef.set(new IllegalStateException("Status code was " + response.statusCode())); - throw new IllegalStateException("Status code was " + response.statusCode()); - } - response.body().forEach(line -> { - try { - SendStreamingMessageResponse jsonResponse = extractJsonResponseFromSseLine(line); - if (jsonResponse != null) { - assertNull(jsonResponse.getError()); - Message messageResponse = (Message) jsonResponse.getResult(); - assertEquals(MESSAGE.getMessageId(), messageResponse.getMessageId()); - assertEquals(MESSAGE.getRole(), messageResponse.getRole()); - Part part = messageResponse.getParts().get(0); - assertEquals(Part.Kind.TEXT, part.getKind()); - assertEquals("test message", ((TextPart) part).getText()); + AtomicReference messageResponseRef = new AtomicReference<>(); + + // Using A2AClient's sendStreamingMessage method + client.sendStreamingMessage( + "1", // requestId + messageSendParams, + // eventHandler + (streamingEvent) -> { + try { + if (streamingEvent instanceof Message) { + messageResponseRef.set((Message) streamingEvent); + latch.countDown(); + } + } catch (Exception e) { + errorRef.set(e); latch.countDown(); } - } catch (JsonProcessingException e) { - throw new RuntimeException(e); + }, + // errorHandler + (jsonRpcError) -> { + errorRef.set(new RuntimeException("JSON-RPC Error: " + jsonRpcError.getMessage())); + latch.countDown(); + }, + // failureHandler + () -> { + if (errorRef.get() == null) { + errorRef.set(new RuntimeException("Stream processing failed")); + } + latch.countDown(); } - }); - }).exceptionally(t -> { - if (!isStreamClosedError(t)) { - errorRef.set(t); - } - latch.countDown(); - return null; - }); - - + ); + boolean dataRead = latch.await(20, TimeUnit.SECONDS); Assertions.assertTrue(dataRead); Assertions.assertNull(errorRef.get()); - + + Message messageResponse = messageResponseRef.get(); + Assertions.assertNotNull(messageResponse); + assertEquals(MESSAGE.getMessageId(), messageResponse.getMessageId()); + assertEquals(MESSAGE.getRole(), messageResponse.getRole()); + Part part = messageResponse.getParts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("test message", ((TextPart) part).getText()); + } - - private SendStreamingMessageResponse extractJsonResponseFromSseLine(String line) throws JsonProcessingException { - line = extractSseData(line); - if (line != null) { - return Utils.OBJECT_MAPPER.readValue(line, SendStreamingMessageResponse.class); + + @Test + public void testListPushNotificationConfigWithConfigId() throws Exception { + getTaskStore().save(MINIMAL_TASK); + PushNotificationConfig notificationConfig1 = + new PushNotificationConfig.Builder() + .url("http://example.com") + .id("config1") + .build(); + PushNotificationConfig notificationConfig2 = + new PushNotificationConfig.Builder() + .url("http://example.com") + .id("config2") + .build(); + savePushNotificationConfigInStore(MINIMAL_TASK.getId(), notificationConfig1); + savePushNotificationConfigInStore(MINIMAL_TASK.getId(), notificationConfig2); + + try { + ListTaskPushNotificationConfigResponse listResponse = client.listTaskPushNotificationConfig("111", MINIMAL_TASK.getId()); + assertEquals("111", listResponse.getId()); + assertEquals(2, listResponse.getResult().size()); + assertEquals(new TaskPushNotificationConfig(MINIMAL_TASK.getId(), notificationConfig1), listResponse.getResult().get(0)); + assertEquals(new TaskPushNotificationConfig(MINIMAL_TASK.getId(), notificationConfig2), listResponse.getResult().get(1)); + } catch (Exception e) { + fail(); + } finally { + deletePushNotificationConfigInStore(MINIMAL_TASK.getId(), "config1"); + deletePushNotificationConfigInStore(MINIMAL_TASK.getId(), "config2"); + getTaskStore().delete(MINIMAL_TASK.getId()); } - return null; } - - private static String extractSseData(String line) { - if (line.startsWith("data:")) { - line = line.substring(5).trim(); - return line; + + @Test + public void testListPushNotificationConfigWithoutConfigId() throws Exception { + getTaskStore().save(MINIMAL_TASK); + PushNotificationConfig notificationConfig1 = + new PushNotificationConfig.Builder() + .url("http://1.example.com") + .build(); + PushNotificationConfig notificationConfig2 = + new PushNotificationConfig.Builder() + .url("http://2.example.com") + .build(); + savePushNotificationConfigInStore(MINIMAL_TASK.getId(), notificationConfig1); + + // will overwrite the previous one + savePushNotificationConfigInStore(MINIMAL_TASK.getId(), notificationConfig2); + try { + ListTaskPushNotificationConfigResponse listResponse = client.listTaskPushNotificationConfig("111", MINIMAL_TASK.getId()); + assertEquals("111", listResponse.getId()); + assertEquals(1, listResponse.getResult().size()); + + PushNotificationConfig expectedNotificationConfig = new PushNotificationConfig.Builder() + .url("http://2.example.com") + .id(MINIMAL_TASK.getId()) + .build(); + assertEquals(new TaskPushNotificationConfig(MINIMAL_TASK.getId(), expectedNotificationConfig), + listResponse.getResult().get(0)); + } catch (Exception e) { + fail(); + } finally { + deletePushNotificationConfigInStore(MINIMAL_TASK.getId(), MINIMAL_TASK.getId()); + getTaskStore().delete(MINIMAL_TASK.getId()); } - return null; } - - private boolean isStreamClosedError(Throwable throwable) { - // Unwrap the CompletionException - Throwable cause = throwable; - - while (cause != null) { - if (cause instanceof EOFException) { - return true; + + @Test + public void testListPushNotificationConfigTaskNotFound() { + try { + client.listTaskPushNotificationConfig("111", "non-existent-task"); + fail(); + } catch (A2AServerException e) { + assertInstanceOf(TaskNotFoundError.class, e.getCause()); + } + } + + @Test + public void testListPushNotificationConfigEmptyList() throws Exception { + getTaskStore().save(MINIMAL_TASK); + try { + ListTaskPushNotificationConfigResponse listResponse = client.listTaskPushNotificationConfig("111", MINIMAL_TASK.getId()); + assertEquals("111", listResponse.getId()); + assertEquals(0, listResponse.getResult().size()); + } catch (Exception e) { + fail(); + } finally { + getTaskStore().delete(MINIMAL_TASK.getId()); + } + } + + @Test + public void testDeletePushNotificationConfigWithValidConfigId() throws Exception { + getTaskStore().save(MINIMAL_TASK); + saveTaskInTaskStore(new Task.Builder() + .id("task-456") + .contextId("session-xyz") + .status(new TaskStatus(TaskState.SUBMITTED)) + .build()); + + PushNotificationConfig notificationConfig1 = + new PushNotificationConfig.Builder() + .url("http://example.com") + .id("config1") + .build(); + PushNotificationConfig notificationConfig2 = + new PushNotificationConfig.Builder() + .url("http://example.com") + .id("config2") + .build(); + savePushNotificationConfigInStore(MINIMAL_TASK.getId(), notificationConfig1); + savePushNotificationConfigInStore(MINIMAL_TASK.getId(), notificationConfig2); + savePushNotificationConfigInStore("task-456", notificationConfig1); + + try { + // specify the config ID to delete + DeleteTaskPushNotificationConfigResponse deleteResponse = client.deleteTaskPushNotificationConfig(MINIMAL_TASK.getId(), + "config1"); + assertNull(deleteResponse.getError()); + assertNull(deleteResponse.getResult()); + + // should now be 1 left + ListTaskPushNotificationConfigResponse listResponse = client.listTaskPushNotificationConfig(MINIMAL_TASK.getId()); + assertEquals(1, listResponse.getResult().size()); + + // should remain unchanged, this is a different task + listResponse = client.listTaskPushNotificationConfig("task-456"); + assertEquals(1, listResponse.getResult().size()); + } catch (Exception e) { + fail(); + } finally { + deletePushNotificationConfigInStore(MINIMAL_TASK.getId(), "config1"); + deletePushNotificationConfigInStore(MINIMAL_TASK.getId(), "config2"); + deletePushNotificationConfigInStore("task-456", "config1"); + getTaskStore().delete(MINIMAL_TASK.getId()); + deleteTaskInTaskStore("task-456"); + } + } + + @Test + public void testDeletePushNotificationConfigWithNonExistingConfigId() throws Exception { + getTaskStore().save(MINIMAL_TASK); + PushNotificationConfig notificationConfig1 = + new PushNotificationConfig.Builder() + .url("http://example.com") + .id("config1") + .build(); + PushNotificationConfig notificationConfig2 = + new PushNotificationConfig.Builder() + .url("http://example.com") + .id("config2") + .build(); + savePushNotificationConfigInStore(MINIMAL_TASK.getId(), notificationConfig1); + savePushNotificationConfigInStore(MINIMAL_TASK.getId(), notificationConfig2); + + try { + DeleteTaskPushNotificationConfigResponse deleteResponse = client.deleteTaskPushNotificationConfig(MINIMAL_TASK.getId(), + "non-existent-config-id"); + assertNull(deleteResponse.getError()); + assertNull(deleteResponse.getResult()); + + // should remain unchanged + ListTaskPushNotificationConfigResponse listResponse = client.listTaskPushNotificationConfig(MINIMAL_TASK.getId()); + assertEquals(2, listResponse.getResult().size()); + } catch (Exception e) { + fail(); + } finally { + deletePushNotificationConfigInStore(MINIMAL_TASK.getId(), "config1"); + deletePushNotificationConfigInStore(MINIMAL_TASK.getId(), "config2"); + getTaskStore().delete(MINIMAL_TASK.getId()); + } + } + + @Test + public void testDeletePushNotificationConfigTaskNotFound() { + try { + client.deleteTaskPushNotificationConfig("non-existent-task", "non-existent-config-id"); + fail(); + } catch (A2AServerException e) { + assertInstanceOf(TaskNotFoundError.class, e.getCause()); + } + } + + @Test + public void testDeletePushNotificationConfigSetWithoutConfigId() throws Exception { + getTaskStore().save(MINIMAL_TASK); + PushNotificationConfig notificationConfig1 = + new PushNotificationConfig.Builder() + .url("http://1.example.com") + .build(); + PushNotificationConfig notificationConfig2 = + new PushNotificationConfig.Builder() + .url("http://2.example.com") + .build(); + savePushNotificationConfigInStore(MINIMAL_TASK.getId(), notificationConfig1); + + // this one will overwrite the previous one + savePushNotificationConfigInStore(MINIMAL_TASK.getId(), notificationConfig2); + + try { + DeleteTaskPushNotificationConfigResponse deleteResponse = client.deleteTaskPushNotificationConfig(MINIMAL_TASK.getId(), + MINIMAL_TASK.getId()); + assertNull(deleteResponse.getError()); + assertNull(deleteResponse.getResult()); + + // should now be 0 + ListTaskPushNotificationConfigResponse listResponse = client.listTaskPushNotificationConfig(MINIMAL_TASK.getId()); + assertEquals(0, listResponse.getResult().size()); + } catch (Exception e) { + fail(); + } finally { + deletePushNotificationConfigInStore(MINIMAL_TASK.getId(), MINIMAL_TASK.getId()); + getTaskStore().delete(MINIMAL_TASK.getId()); + } + } + + protected void saveTaskInTaskStore(Task task) throws Exception { + HttpClient client = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .build(); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + serverPort + "/test/task")) + .POST(HttpRequest.BodyPublishers.ofString(Utils.OBJECT_MAPPER.writeValueAsString(task))) + .header("Content-Type", APPLICATION_JSON) + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + if (response.statusCode() != 200) { + throw new RuntimeException(String.format("Saving task failed! Status: %d, Body: %s", response.statusCode(), response.body())); + } + } + + protected Task getTaskFromTaskStore(String taskId) throws Exception { + HttpClient client = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .build(); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + serverPort + "/test/task/" + taskId)) + .GET() + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + if (response.statusCode() == 404) { + return null; + } + if (response.statusCode() != 200) { + throw new RuntimeException(String.format("Getting task failed! Status: %d, Body: %s", response.statusCode(), response.body())); + } + return Utils.OBJECT_MAPPER.readValue(response.body(), Task.TYPE_REFERENCE); + } + + protected void deleteTaskInTaskStore(String taskId) throws Exception { + HttpClient client = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .build(); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(("http://localhost:" + serverPort + "/test/task/" + taskId))) + .DELETE() + .build(); + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + if (response.statusCode() != 200) { + throw new RuntimeException(response.statusCode() + ": Deleting task failed!" + response.body()); + } + } + + protected void ensureQueueForTask(String taskId) throws Exception { + HttpClient client = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .build(); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + serverPort + "/test/queue/ensure/" + taskId)) + .POST(HttpRequest.BodyPublishers.noBody()) + .build(); + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + if (response.statusCode() != 200) { + throw new RuntimeException(String.format("Ensuring queue failed! Status: %d, Body: %s", response.statusCode(), response.body())); + } + } + + protected void enqueueEventOnServer(Event event) throws Exception { + String path; + if (event instanceof TaskArtifactUpdateEvent e) { + path = "test/queue/enqueueTaskArtifactUpdateEvent/" + e.getTaskId(); + } else if (event instanceof TaskStatusUpdateEvent e) { + path = "test/queue/enqueueTaskStatusUpdateEvent/" + e.getTaskId(); + } else { + throw new RuntimeException("Unknown event type " + event.getClass() + ". If you need the ability to" + + " handle more types, please add the REST endpoints."); + } + HttpClient client = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .build(); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + serverPort + "/" + path)) + .header("Content-Type", APPLICATION_JSON) + .POST(HttpRequest.BodyPublishers.ofString(Utils.OBJECT_MAPPER.writeValueAsString(event))) + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + if (response.statusCode() != 200) { + throw new RuntimeException(response.statusCode() + ": Queueing event failed!" + response.body()); + } + } + + private CompletableFuture awaitStreamingSubscription() { + int cnt = getStreamingSubscribedCount(); + AtomicInteger initialCount = new AtomicInteger(cnt); + + return CompletableFuture.runAsync(() -> { + try { + boolean done = false; + long end = System.currentTimeMillis() + 15000; + while (System.currentTimeMillis() < end) { + int count = getStreamingSubscribedCount(); + if (count > initialCount.get()) { + done = true; + break; + } + Thread.sleep(500); + } + if (!done) { + throw new RuntimeException("Timed out waiting for subscription"); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted"); } - cause = cause.getCause(); + }); + } + + private int getStreamingSubscribedCount() { + HttpClient client = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .build(); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + serverPort + "/test/streamingSubscribedCount")) + .GET() + .build(); + try { + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + String body = response.body().trim(); + return Integer.parseInt(body); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); } - return false; } - - private CompletableFuture>> initialiseStreamingRequest( - StreamingJSONRPCRequest request, String mediaType) throws Exception { - - // Create the client + + protected void deletePushNotificationConfigInStore(String taskId, String configId) throws Exception { HttpClient client = HttpClient.newBuilder() .version(HttpClient.Version.HTTP_2) .build(); - - // Create the request - HttpRequest.Builder builder = HttpRequest.newBuilder() - .uri(URI.create("http://localhost:" + serverPort + - "/")) - .POST(HttpRequest.BodyPublishers.ofString(Utils.OBJECT_MAPPER.writeValueAsString(request))) - .header("Content-Type", "application/json"); - if (mediaType != null) { - builder.header("Accept", mediaType); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(("http://localhost:" + serverPort + "/test/task/" + taskId + "/config/" + configId))) + .DELETE() + .build(); + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + if (response.statusCode() != 200) { + throw new RuntimeException(response.statusCode() + ": Deleting task failed!" + response.body()); } - HttpRequest httpRequest = builder.build(); - - - // Send request async and return the CompletableFuture - return client.sendAsync(httpRequest, HttpResponse.BodyHandlers.ofLines()); } - + + protected void savePushNotificationConfigInStore(String taskId, PushNotificationConfig notificationConfig) throws Exception { + HttpClient client = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .build(); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + serverPort + "/test/task/" + taskId)) + .POST(HttpRequest.BodyPublishers.ofString(Utils.OBJECT_MAPPER.writeValueAsString(notificationConfig))) + .header("Content-Type", APPLICATION_JSON) + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + if (response.statusCode() != 200) { + throw new RuntimeException(response.statusCode() + ": Creating task push notification config failed! " + response.body()); + } + } + protected abstract TaskStore getTaskStore(); - + protected abstract InMemoryQueueManager getQueueManager(); - + protected abstract void setStreamingSubscribedRunnable(Runnable runnable); - + private static class BreakException extends RuntimeException { - + } -} +} \ No newline at end of file diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentCardProducer.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentCardProducer.java index f68f44967..3354c1522 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentCardProducer.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentCardProducer.java @@ -32,6 +32,7 @@ public AgentCard agentCard() { .defaultInputModes(Collections.singletonList("text")) .defaultOutputModes(Collections.singletonList("text")) .skills(new ArrayList<>()) + .protocolVersion("0.2.5") .build(); } } diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/TestUtilsBean.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestUtilsBean.java new file mode 100644 index 000000000..c65766dd8 --- /dev/null +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestUtilsBean.java @@ -0,0 +1,60 @@ +package io.a2a.server.apps.common; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import io.a2a.server.events.QueueManager; +import io.a2a.server.tasks.PushNotificationConfigStore; +import io.a2a.server.tasks.TaskStore; +import io.a2a.spec.Event; +import io.a2a.spec.PushNotificationConfig; +import io.a2a.spec.Task; + +/** + * Contains utilities to interact with the server side for the tests. + * The intent for this bean is to be exposed via REST. + * + *

There is a Quarkus implementation in {@code A2ATestRoutes} which shows the contract for how to + * expose it via REST. For other REST frameworks, you will need to provide an implementation that works in a similar + * way to {@code A2ATestRoutes}.

+ */ +@ApplicationScoped +public class TestUtilsBean { + + @Inject + TaskStore taskStore; + + @Inject + QueueManager queueManager; + + @Inject + PushNotificationConfigStore pushNotificationConfigStore; + + public void saveTask(Task task) { + taskStore.save(task); + } + + public Task getTask(String taskId) { + return taskStore.get(taskId); + } + + public void deleteTask(String taskId) { + taskStore.delete(taskId); + } + + public void ensureQueue(String taskId) { + queueManager.createOrTap(taskId); + } + + public void enqueueEvent(String taskId, Event event) { + queueManager.get(taskId).enqueueEvent(event); + } + + public void deleteTaskPushNotificationConfig(String taskId, String configId) { + pushNotificationConfigStore.deleteInfo(taskId, configId); + } + + public void saveTaskPushNotificationConfig(String taskId, PushNotificationConfig notificationConfig) { + pushNotificationConfigStore.setInfo(taskId, notificationConfig); + } +}