diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..4eb10e60b --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,136 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, caste, color, religion, or sexual +identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall + community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of + any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, + without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement: + +Farah Juma - fjuma@redhat.com \ +Kabir Khan - kkhan@redhat.com \ +Stefano Maestri - smaestri@redhat.com + +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of +actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or permanent +ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.1, available at +[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder][Mozilla CoC]. + +For answers to common questions about this code of conduct, see the FAQ at +[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at +[https://www.contributor-covenant.org/translations][translations]. + +[homepage]: https://www.contributor-covenant.org +[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html +[Mozilla CoC]: https://github.com/mozilla/diversity +[FAQ]: https://www.contributor-covenant.org/faq +[translations]: https://www.contributor-covenant.org/translations diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..479228199 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,91 @@ +Contributing to a2a-java +================================== + +Welcome to the A2A Java SDK project! We welcome contributions from the community. This guide will walk you through the steps for getting started on our project. + +- [Forking the Project](#forking-the-project) +- [Issues](#issues) + - [Good First Issues](#good-first-issues) +- [Setting up your Developer Environment](#setting-up-your-developer-environment) +- [Contributing Guidelines](#contributing-guidelines) +- [Community](#community) + + +## Forking the Project +To contribute, you will first need to fork the [a2a-java](https://github.com/a2aproject/a2a-java) repository. + +This can be done by looking in the top-right corner of the repository page and clicking "Fork". +![fork](images/fork.jpg) + +The next step is to clone your newly forked repository onto your local workspace. This can be done by going to your newly forked repository, which should be at `https://github.com/USERNAME/a2a-java`. + +Then, there will be a green button that says "Code". Click on that and copy the URL. + +Then, in your terminal, paste the following command: +```bash +git clone [URL] +``` +Be sure to replace [URL] with the URL that you copied. + +Now you have the repository on your computer! + +## Issues +The `a2a-java` project uses GitHub to manage issues. All issues can be found [here](https://github.com/a2aproject/a2a-java/issues). + +To create a new issue, comment on an existing issue, or assign an issue to yourself, you'll need to first [create a GitHub account](https://github.com/). + + +### Good First Issues +Want to contribute to the a2a-java project but aren't quite sure where to start? Check out our issues with the `good-first-issue` label. These are a triaged set of issues that are great for getting started on our project. These can be found [here](https://github.com/a2aproject/a2a-java/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22good%20first%20issue%22). + +Once you have selected an issue you'd like to work on, make sure it's not already assigned to someone else, and assign it to yourself. + +It is recommended that you use a separate branch for every issue you work on. To keep things straightforward and memorable, you can name each branch using the GitHub issue number. This way, you can have multiple PRs open for different issues. For example, if you were working on [issue-20](https://github.com/a2aproject/a2a-java/issues/20), you could use `issue-20` as your branch name. + +## Setting up your Developer Environment +You will need: + +* Java 17+ +* Git +* An IDE (e.g., IntelliJ IDEA, Eclipse, VSCode, etc.) + +To set up your development environment you need to: + +1. First `cd` to the directory where you cloned the project (eg: `cd a2a-java`) + +2. Add a remote ref to upstream, for pulling future updates. For example: + + ``` + git remote add upstream https://github.com/a2aproject/a2a-java + ``` + +3. To build `a2a-java` and run the tests, use the following command: + + ``` + mvn clean install + ``` + +4. To skip the tests: + + ``` + mvn clean install -DskipTests=true + ``` + +## Contributing Guidelines + +When submitting a PR, please keep the following guidelines in mind: + +1. In general, it's good practice to squash all of your commits into a single commit. For larger changes, it's ok to have multiple meaningful commits. If you need help with squashing your commits, feel free to ask us how to do this on your pull request. We're more than happy to help! + +2. Please link the issue you worked on in the description of your pull request and in your commit message. For example, for issue-20, the PR description and commit message could be: ```Add tests to A2AClientTest for sending a task with a FilePart and with a DataPart + Fixes #20``` + +3. Your PR should include tests for the functionality that you are adding. + +4. Your PR should include appropriate [documentation](https://github.com/a2aproject/a2a-java/blob/main/README.md) for the functionality that you are adding. + +## Code Reviews + +All submissions, including submissions by project members, need to be reviewed by at least one `a2a-java` committer before being merged. + +The [GitHub Pull Request Review Process](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/reviewing-changes-in-pull-requests/about-pull-request-reviews) is followed for every pull request. diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..7a4a3ea24 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 000000000..b76694f2c --- /dev/null +++ b/README.md @@ -0,0 +1,358 @@ +# A2A Java SDK + +[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE) + + + + +

+ A2A Logo +

+

A Java library that helps run agentic applications as A2AServers following the Agent2Agent (A2A) Protocol.

+ + +## Installation + +You can build the A2A Java SDK using `mvn`: + +```bash +mvn clean install +``` + +## Examples + +You can find an example of how to use the A2A Java SDK [here](https://github.com/fjuma/a2a-samples/tree/java-sdk-example/samples/multi_language/python_and_java_multiagent/weather_agent). + +More examples will be added soon. + +## A2A Server + +The A2A Java SDK provides a Java server implementation of the [Agent2Agent (A2A) Protocol](https://google-a2a.github.io/A2A). To run your agentic Java application as an A2A server, simply follow the steps below. + +- [Add the A2A Java SDK Core Maven dependency to your project](#1-add-the-a2a-java-sdk-core-maven-dependency-to-your-project) +- [Add a class that creates an A2A Agent Card](#2-add-a-class-that-creates-an-a2a-agent-card) +- [Add a class that creates an A2A Agent Executor](#3-add-a-class-that-creates-an-a2a-agent-executor) +- [Add an A2A Java SDK Server Maven dependency to your project](#4-add-an-a2a-java-sdk-server-maven-dependency-to-your-project) + +### 1. Add the A2A Java SDK Core Maven dependency to your project + +> **Note**: The A2A Java SDK isn't available yet in Maven Central but will be soon. For now, be +> sure to check out the latest tag (you can see the tags [here](https://github.com/a2aproject/a2a-java/tags)), build from the tag, and reference that version below. For example, if the latest tag is `0.2.3`, you can use the following dependency. + +```xml + + io.a2a.sdk + a2a-java-sdk-core + 0.2.3 + +``` + +### 2. Add a class that creates an A2A Agent Card + +```java +import io.a2a.server.PublicAgentCard; +import io.a2a.spec.AgentCapabilities; +import io.a2a.spec.AgentCard; +import io.a2a.spec.AgentSkill; +... + +@ApplicationScoped +public class WeatherAgentCardProducer { + + @Produces + @PublicAgentCard + public AgentCard agentCard() { + return new AgentCard.Builder() + .name("Weather Agent") + .description("Helps with weather") + .url("http://localhost:10001") + .version("1.0.0") + .capabilities(new AgentCapabilities.Builder() + .streaming(true) + .pushNotifications(false) + .stateTransitionHistory(false) + .build()) + .defaultInputModes(Collections.singletonList("text")) + .defaultOutputModes(Collections.singletonList("text")) + .skills(Collections.singletonList(new AgentSkill.Builder() + .id("weather_search") + .name("Search weather") + .description("Helps with weather in city, or states") + .tags(Collections.singletonList("weather")) + .examples(List.of("weather in LA, CA")) + .build())) + .build(); + } +} +``` + +### 3. Add a class that creates an A2A Agent Executor + +```java +import io.a2a.server.agentexecution.AgentExecutor; +import io.a2a.server.agentexecution.RequestContext; +import io.a2a.server.events.EventQueue; +import io.a2a.server.tasks.TaskUpdater; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.Message; +import io.a2a.spec.Part; +import io.a2a.spec.Task; +import io.a2a.spec.TaskNotCancelableError; +import io.a2a.spec.TaskState; +import io.a2a.spec.TextPart; +... + +@ApplicationScoped +public class WeatherAgentExecutorProducer { + + @Inject + WeatherAgent weatherAgent; + + @Produces + public AgentExecutor agentExecutor() { + return new WeatherAgentExecutor(weatherAgent); + } + + private static class WeatherAgentExecutor implements AgentExecutor { + + private final WeatherAgent weatherAgent; + + public WeatherAgentExecutor(WeatherAgent weatherAgent) { + this.weatherAgent = weatherAgent; + } + + @Override + public void execute(RequestContext context, EventQueue eventQueue) throws JSONRPCError { + TaskUpdater updater = new TaskUpdater(context, eventQueue); + + // mark the task as submitted and start working on it + if (context.getTask() == null) { + updater.submit(); + } + updater.startWork(); + + // extract the text from the message + String userMessage = extractTextFromMessage(context.getMessage()); + + // call the weather agent with the user's message + String response = weatherAgent.chat(userMessage); + + // create the response part + TextPart responsePart = new TextPart(response, null); + List> parts = List.of(responsePart); + + // add the response as an artifact and complete the task + updater.addArtifact(parts, null, null, null); + updater.complete(); + } + + @Override + public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPCError { + Task task = context.getTask(); + + if (task.getStatus().state() == TaskState.CANCELED) { + // task already cancelled + throw new TaskNotCancelableError(); + } + + if (task.getStatus().state() == TaskState.COMPLETED) { + // task already completed + throw new TaskNotCancelableError(); + } + + // cancel the task + TaskUpdater updater = new TaskUpdater(context, eventQueue); + updater.cancel(); + } + + private String extractTextFromMessage(Message message) { + StringBuilder textBuilder = new StringBuilder(); + if (message.getParts() != null) { + for (Part part : message.getParts()) { + if (part instanceof TextPart textPart) { + textBuilder.append(textPart.getText()); + } + } + } + return textBuilder.toString(); + } + } +} +``` + +### 4. Add an A2A Java SDK Server Maven dependency to your project + +> **Note**: The A2A Java SDK isn't available yet in Maven Central but will be soon. For now, be +> sure to check out the latest tag (you can see the tags [here](https://github.com/a2aproject/a2a-java/tags)), build from the tag, and reference that version below. For example, if the latest tag is `0.2.3`, you can use the following dependency. + +Adding a dependency on an A2A Java SDK Server will allow you to run your agentic Java application as an A2A server. + +The A2A Java SDK provides two A2A server endpoint implementations, one based on Jakarta REST (`a2a-java-sdk-server-jakarta`) and one based on Quarkus Reactive Routes (`a2a-java-sdk-server-quarkus`). You can choose the one that best fits your application. + +Add **one** of the following dependencies to your project: + +```xml + + io.a2a.sdk + a2a-java-sdk-server-jakarta + ${io.a2a.sdk.version} + +``` + +OR + +```xml + + io.a2a.sdk + a2a-java-sdk-server-quarkus + ${io.a2a.sdk.version} + +``` + +## A2A Client + +The A2A Java SDK provides a Java client implementation of the [Agent2Agent (A2A) Protocol](https://google-a2a.github.io/A2A), allowing communication with A2A servers. + +### Sample Usage + +#### Create an A2A client + +```java +// Create an A2AClient (the URL specified is the server agent's URL, be sure to replace it with the actual URL of the A2A server you want to connect to) +A2AClient client = new A2AClient("http://localhost:1234"); +``` + +#### Send a message to the A2A server agent + +```java +// Send a text message to the A2A server agent +Message message = A2A.toUserMessage("tell me a joke"); // the message ID will be automatically generated for you +MessageSendParams params = new MessageSendParams.Builder() + .message(message) + .build(); +SendMessageResponse response = client.sendMessage(params); +``` + +Note that `A2A#toUserMessage` will automatically generate a message ID for you when creating the `Message` +if you don't specify it. You can also explicitly specify a message ID like this: + +```java +Message message = A2A.toUserMessage("tell me a joke", "message-1234"); // messageId is message-1234 +``` + +#### Get the current state of a task + +```java +// Retrieve the task with id "task-1234" +GetTaskResponse response = client.getTask("task-1234"); + +// You can also specify the maximum number of items of history for the task +// to include in the response +GetTaskResponse response = client.getTask(new TaskQueryParams("task-1234", 10)); +``` + +#### Cancel an ongoing task + +```java +// Cancel the task we previously submitted with id "task-1234" +CancelTaskResponse response = client.cancelTask("task-1234"); + +// You can also specify additional properties using a map +Map metadata = ... +CancelTaskResponse response = client.cancelTask(new TaskIdParams("task-1234", metadata)); +``` + +#### Get the push notification configuration for a task + +```java +// Get task push notification configuration +GetTaskPushNotificationConfigResponse response = client.getTaskPushNotificationConfig("task-1234"); + +// You can also specify additional properties using a map +Map metadata = ... +GetTaskPushNotificationConfigResponse response = client.getTaskPushNotificationConfig(new TaskIdParams("task-1234", metadata)); +``` + +#### Set the push notification configuration for a task + +```java +// Set task push notification configuration +PushNotificationConfig pushNotificationConfig = new PushNotificationConfig.Builder() + .url("https://example.com/callback") + .authenticationInfo(new AuthenticationInfo(Collections.singletonList("jwt"), null)) + .build(); +SetTaskPushNotificationResponse response = client.setTaskPushNotificationConfig("task-1234", pushNotificationConfig); +``` + +#### Send a streaming message + +```java +// Send a text message to the remote agent +Message message = A2A.toUserMessage("tell me some jokes"); // the message ID will be automatically generated for you +MessageSendParams params = new MessageSendParams.Builder() + .message(message) + .build(); + +// Create a handler that will be invoked for Task, Message, TaskStatusUpdateEvent, and TaskArtifactUpdateEvent +Consumer eventHandler = event -> {...}; + +// Create a handler that will be invoked if an error is received +Consumer errorHandler = error -> {...}; + +// Create a handler that will be invoked in the event of a failure +Runnable failureHandler = () -> {...}; + +// Send the streaming message to the remote agent +client.sendStreamingMessage(params, eventHandler, errorHandler, failureHandler); +``` + +#### Resubscribe to a task + +```java +// Create a handler that will be invoked for Task, Message, TaskStatusUpdateEvent, and TaskArtifactUpdateEvent +Consumer eventHandler = event -> {...}; + +// Create a handler that will be invoked if an error is received +Consumer errorHandler = error -> {...}; + +// Create a handler that will be invoked in the event of a failure +Runnable failureHandler = () -> {...}; + +// Resubscribe to an ongoing task with id "task-1234" +TaskIdParams taskIdParams = new TaskIdParams("task-1234"); +client.resubscribeToTask("request-1234", taskIdParams, eventHandler, errorHandler, failureHandler); +``` + +#### Retrieve details about the server agent that this client agent is communicating with +```java +AgentCard serverAgentCard = client.getAgentCard(); +``` + +An agent card can also be retrieved using the `A2A#getAgentCard` method: +```java +// http://localhost:1234 is the base URL for the agent whose card we want to retrieve +AgentCard agentCard = A2A.getAgentCard("http://localhost:1234"); +``` + +## Additional Examples + +### Hello World Example + +A complete example of an A2A client communicating with a Python A2A server is available in the [examples/helloworld](src/main/java/io/a2a/examples/helloworld) directory. This example demonstrates: + +- Setting up and using the A2A Java client +- Sending regular and streaming messages +- Receiving and processing responses + +The example includes detailed instructions on how to run both the Python server and the Java client using JBang. Check out the [example's README](examples/client/src/main/java/io/a2a/examples/helloworld/README.md) for more information. + +## License + +This project is licensed under the terms of the [Apache 2.0 License](LICENSE). + +## Contributing + +See [CONTRIBUTING.md](CONTRIBUTING.md) for contribution guidelines. + + + diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..36c3d8d4b --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,31 @@ +# Reporting of CVEs and Security Issues + +## The A2A Java SDK community takes security bugs very seriously + +We aim to take immediate action to address serious security-related problems that involve our project. + +Note that we will only fix such issues in the most recent minor release of the A2A Java SDK. + +## Reporting of Security Issues + +When reporting a security vulnerability it is important to not accidentally broadcast to the world that +the issue exists, as this makes it easier for people to exploit it. The software industry uses the term +embargo to describe the time a +security issue is known internally until it is public knowledge. + +Our preferred way of reporting security issues is listed below. + +### Email the A2A Java SDK team + +To report a security issue, please email fjuma@redhat.com, +kkhan@redhat.com, and/or smaestri@redhat.com. A member of the team will open the required issues. + +### Other considerations + +If you would like to work with us on a fix for the security vulnerability, please include your GitHub username +in the above email, and we will provide you access to a temporary private fork where we can collaborate on a +fix without it being disclosed publicly, **including in your own publicly visible git repository**. + +Do not open a public issue, send a pull request, or disclose any information about the suspected vulnerability +publicly, **including in your own publicly visible git repository**. If you discover any publicly disclosed +security vulnerabilities, please notify us immediately through the emails listed in the section above. \ No newline at end of file diff --git a/core/pom.xml b/core/pom.xml new file mode 100644 index 000000000..8aed247f4 --- /dev/null +++ b/core/pom.xml @@ -0,0 +1,45 @@ + + + 4.0.0 + + + io.a2a.sdk + a2a-java-sdk-parent + 0.2.4-SNAPSHOT + + a2a-java-sdk-core + + jar + + Java SDK A2A Core + Java SDK for the Agent2Agent Protocol (A2A) - Core + + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + + + org.junit.jupiter + junit-jupiter-api + test + + + org.mockito + mockito-core + test + + + org.mock-server + mockserver-netty + test + + + + \ No newline at end of file diff --git a/core/src/main/java/io/a2a/client/A2ACardResolver.java b/core/src/main/java/io/a2a/client/A2ACardResolver.java new file mode 100644 index 000000000..1266f7219 --- /dev/null +++ b/core/src/main/java/io/a2a/client/A2ACardResolver.java @@ -0,0 +1,100 @@ +package io.a2a.client; + +import static io.a2a.util.Utils.unmarshalFrom; + +import java.io.IOException; +import java.util.Map; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import io.a2a.http.A2AHttpClient; +import io.a2a.http.A2AHttpResponse; +import io.a2a.spec.A2AClientError; +import io.a2a.spec.A2AClientJSONError; +import io.a2a.spec.AgentCard; + +public class A2ACardResolver { + private final A2AHttpClient httpClient; + private final String url; + private final Map authHeaders; + + static String DEFAULT_AGENT_CARD_PATH = "/.well-known/agent.json"; + + static final TypeReference AGENT_CARD_TYPE_REFERENCE = new TypeReference<>() {}; + /** + * @param httpClient the http client to use + * @param baseUrl the base URL for the agent whose agent card we want to retrieve + */ + public A2ACardResolver(A2AHttpClient httpClient, String baseUrl) { + this(httpClient, baseUrl, null, null); + } + + /** + * @param httpClient the http client to use + * @param baseUrl the base URL for the agent whose agent card we want to retrieve + * @param agentCardPath optional path to the agent card endpoint relative to the base + * agent URL, defaults to ".well-known/agent.json" + */ + public A2ACardResolver(A2AHttpClient httpClient, String baseUrl, String agentCardPath) { + this(httpClient, baseUrl, agentCardPath, null); + } + + /** + * @param httpClient the http client to use + * @param baseUrl the base URL for the agent whose agent card we want to retrieve + * @param agentCardPath optional path to the agent card endpoint relative to the base + * agent URL, defaults to ".well-known/agent.json" + * @param authHeaders the HTTP authentication headers to use. May be {@code null} + */ + public A2ACardResolver(A2AHttpClient httpClient, String baseUrl, String agentCardPath, Map authHeaders) { + this.httpClient = httpClient; + if (!baseUrl.endsWith("/")) { + baseUrl += "/"; + } + agentCardPath = agentCardPath == null || agentCardPath.isEmpty() ? DEFAULT_AGENT_CARD_PATH : agentCardPath; + if (agentCardPath.startsWith("/")) { + agentCardPath = agentCardPath.substring(1); + } + this.url = baseUrl + agentCardPath; + this.authHeaders = authHeaders; + } + + /** + * Get the agent card for the configured A2A agent. + * + * @return the agent card + * @throws A2AClientError If an HTTP error occurs fetching the card + * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema + */ + public AgentCard getAgentCard() throws A2AClientError, A2AClientJSONError { + A2AHttpClient.GetBuilder builder = httpClient.createGet() + .url(url) + .addHeader("Content-Type", "application/json"); + + if (authHeaders != null) { + for (Map.Entry entry : authHeaders.entrySet()) { + builder.addHeader(entry.getKey(), entry.getValue()); + } + } + + String body; + try { + A2AHttpResponse response = builder.get(); + if (!response.success()) { + throw new A2AClientError("Failed to obtain agent card: " + response.status()); + } + body = response.body(); + } catch (IOException | InterruptedException e) { + throw new A2AClientError("Failed to obtain agent card", e); + } + + try { + return unmarshalFrom(body, AGENT_CARD_TYPE_REFERENCE); + } catch (JsonProcessingException e) { + throw new A2AClientJSONError("Could not unmarshal agent card response", e); + } + + } + + +} diff --git a/core/src/main/java/io/a2a/client/A2AClient.java b/core/src/main/java/io/a2a/client/A2AClient.java new file mode 100644 index 000000000..3a35f0b67 --- /dev/null +++ b/core/src/main/java/io/a2a/client/A2AClient.java @@ -0,0 +1,518 @@ +package io.a2a.client; + +import static io.a2a.spec.A2A.CANCEL_TASK_METHOD; +import static io.a2a.spec.A2A.GET_TASK_METHOD; +import static io.a2a.spec.A2A.GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD; +import static io.a2a.spec.A2A.JSONRPC_VERSION; +import static io.a2a.spec.A2A.SEND_MESSAGE_METHOD; +import static io.a2a.spec.A2A.SEND_STREAMING_MESSAGE_METHOD; +import static io.a2a.spec.A2A.SEND_TASK_RESUBSCRIPTION_METHOD; +import static io.a2a.spec.A2A.SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD; +import static io.a2a.util.Assert.checkNotNullParam; +import static io.a2a.util.Utils.OBJECT_MAPPER; +import static io.a2a.util.Utils.unmarshalFrom; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import io.a2a.client.sse.SSEEventListener; +import io.a2a.http.A2AHttpClient; +import io.a2a.http.A2AHttpResponse; +import io.a2a.http.JdkA2AHttpClient; +import io.a2a.spec.A2A; +import io.a2a.spec.A2AClientError; +import io.a2a.spec.A2AClientJSONError; +import io.a2a.spec.A2AServerException; +import io.a2a.spec.AgentCard; +import io.a2a.spec.CancelTaskRequest; +import io.a2a.spec.CancelTaskResponse; +import io.a2a.spec.GetTaskPushNotificationConfigRequest; +import io.a2a.spec.GetTaskPushNotificationConfigResponse; +import io.a2a.spec.GetTaskRequest; +import io.a2a.spec.GetTaskResponse; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.JSONRPCResponse; +import io.a2a.spec.MessageSendParams; +import io.a2a.spec.PushNotificationConfig; +import io.a2a.spec.SendMessageRequest; +import io.a2a.spec.SendMessageResponse; +import io.a2a.spec.SendStreamingMessageRequest; +import io.a2a.spec.SetTaskPushNotificationConfigRequest; +import io.a2a.spec.SetTaskPushNotificationConfigResponse; +import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.TaskIdParams; +import io.a2a.spec.TaskPushNotificationConfig; +import io.a2a.spec.TaskQueryParams; +import io.a2a.spec.TaskResubscriptionRequest; + +/** + * An A2A client. + */ +public class A2AClient { + + private static final TypeReference SEND_MESSAGE_RESPONSE_REFERENCE = new TypeReference<>() {}; + private static final TypeReference GET_TASK_RESPONSE_REFERENCE = new TypeReference<>() {}; + private static final TypeReference CANCEL_TASK_RESPONSE_REFERENCE = new TypeReference<>() {}; + private static final TypeReference GET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE = new TypeReference<>() {}; + private static final TypeReference SET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE = new TypeReference<>() {}; + private final A2AHttpClient httpClient; + private final String agentUrl; + private AgentCard agentCard; + + + /** + * Create a new A2AClient. + * + * @param agentCard the agent card for the A2A server this client will be communicating with + */ + public A2AClient(AgentCard agentCard) { + checkNotNullParam("agentCard", agentCard); + this.agentCard = agentCard; + this.agentUrl = agentCard.url(); + this.httpClient = new JdkA2AHttpClient(); + } + + /** + * Create a new A2AClient. + * + * @param agentUrl the URL for the A2A server this client will be communicating with + */ + public A2AClient(String agentUrl) { + checkNotNullParam("agentUrl", agentUrl); + this.agentUrl = agentUrl; + this.httpClient = new JdkA2AHttpClient(); + } + + /** + * Fetches the agent card and initialises an A2A client. + * + * @param httpClient the {@link A2AHttpClient} to use + * @param baseUrl the base URL of the agent's host + * @param agentCardPath the path to the agent card endpoint, relative to the {@code baseUrl}. If {@code null}, the + * value {@link A2ACardResolver#DEFAULT_AGENT_CARD_PATH} will be used + * @return an initialised {@code A2AClient} instance + * @throws A2AClientError If an HTTP error occurs fetching the card + * @throws A2AClientJSONError if the agent card response is invalid + */ + public static A2AClient getClientFromAgentCardUrl(A2AHttpClient httpClient, String baseUrl, + String agentCardPath) throws A2AClientError, A2AClientJSONError { + A2ACardResolver resolver = new A2ACardResolver(httpClient, baseUrl, agentCardPath); + AgentCard card = resolver.getAgentCard(); + return new A2AClient(card); + } + + /** + * Get the agent card for the A2A server this client will be communicating with from + * the default public agent card endpoint. + * + * @return the agent card for the A2A server + * @throws A2AClientError If an HTTP error occurs fetching the card + * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema + */ + public AgentCard getAgentCard() throws A2AClientError, A2AClientJSONError { + if (this.agentCard == null) { + this.agentCard = A2A.getAgentCard(this.httpClient, this.agentUrl); + } + return this.agentCard; + } + + /** + * Get the agent card for the A2A server this client will be communicating with. + * + * @param relativeCardPath the path to the agent card endpoint relative to the base URL of the A2A server + * @param authHeaders the HTTP authentication headers to use + * @return the agent card for the A2A server + * @throws A2AClientError If an HTTP error occurs fetching the card + * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema + */ + public AgentCard getAgentCard(String relativeCardPath, Map authHeaders) throws A2AClientError, A2AClientJSONError { + if (this.agentCard == null) { + this.agentCard = A2A.getAgentCard(this.httpClient, this.agentUrl, relativeCardPath, authHeaders); + } + return this.agentCard; + } + + /** + * Send a message to the remote agent. + * + * @param messageSendParams the parameters for the message to be sent + * @return the response, may contain a message or a task + * @throws A2AServerException if sending the message fails for any reason + */ + public SendMessageResponse sendMessage(MessageSendParams messageSendParams) throws A2AServerException { + return sendMessage(null, messageSendParams); + } + + /** + * Send a message to the remote agent. + * + * @param requestId the request ID to use + * @param messageSendParams the parameters for the message to be sent + * @return the response, may contain a message or a task + * @throws A2AServerException if sending the message fails for any reason + */ + public SendMessageResponse sendMessage(String requestId, MessageSendParams messageSendParams) throws A2AServerException { + SendMessageRequest.Builder sendMessageRequestBuilder = new SendMessageRequest.Builder() + .jsonrpc(JSONRPC_VERSION) + .method(SEND_MESSAGE_METHOD) + .params(messageSendParams); + + if (requestId != null) { + sendMessageRequestBuilder.id(requestId); + } + + SendMessageRequest sendMessageRequest = sendMessageRequestBuilder.build(); + + try { + String httpResponseBody = sendPostRequest(sendMessageRequest); + return unmarshalResponse(httpResponseBody, SEND_MESSAGE_RESPONSE_REFERENCE); + } catch (IOException | InterruptedException e) { + throw new A2AServerException("Failed to send message: " + e); + } + } + + /** + * Retrieve a task from the A2A server. This method can be used to retrieve the generated + * artifacts for a task. + * + * @param id the task ID + * @return the response containing the task + * @throws A2AServerException if retrieving the task fails for any reason + */ + public GetTaskResponse getTask(String id) throws A2AServerException { + return getTask(null, new TaskQueryParams(id)); + } + + /** + * Retrieve a task from the A2A server. This method can be used to retrieve the generated + * artifacts for a task. + * + * @param taskQueryParams the params for the task to be queried + * @return the response containing the task + * @throws A2AServerException if retrieving the task fails for any reason + */ + public GetTaskResponse getTask(TaskQueryParams taskQueryParams) throws A2AServerException { + return getTask(null, taskQueryParams); + } + + /** + * Retrieve the generated artifacts for a task. + * + * @param requestId the request ID to use + * @param taskQueryParams the params for the task to be queried + * @return the response containing the task + * @throws A2AServerException if retrieving the task fails for any reason + */ + public GetTaskResponse getTask(String requestId, TaskQueryParams taskQueryParams) throws A2AServerException { + GetTaskRequest.Builder getTaskRequestBuilder = new GetTaskRequest.Builder() + .jsonrpc(JSONRPC_VERSION) + .method(GET_TASK_METHOD) + .params(taskQueryParams); + + if (requestId != null) { + getTaskRequestBuilder.id(requestId); + } + + GetTaskRequest getTaskRequest = getTaskRequestBuilder.build(); + + try { + String httpResponseBody = sendPostRequest(getTaskRequest); + return unmarshalResponse(httpResponseBody, GET_TASK_RESPONSE_REFERENCE); + } catch (IOException | InterruptedException e) { + throw new A2AServerException("Failed to get task: " + e); + } + } + + /** + * Cancel a task that was previously submitted to the A2A server. + * + * @param id the task ID + * @return the response indicating if the task was cancelled + * @throws A2AServerException if cancelling the task fails for any reason + */ + public CancelTaskResponse cancelTask(String id) throws A2AServerException { + return cancelTask(null, new TaskIdParams(id)); + } + + /** + * Cancel a task that was previously submitted to the A2A server. + * + * @param taskIdParams the params for the task to be cancelled + * @return the response indicating if the task was cancelled + * @throws A2AServerException if cancelling the task fails for any reason + */ + public CancelTaskResponse cancelTask(TaskIdParams taskIdParams) throws A2AServerException { + return cancelTask(null, taskIdParams); + } + + /** + * Cancel a task that was previously submitted to the A2A server. + * + * @param requestId the request ID to use + * @param taskIdParams the params for the task to be cancelled + * @return the response indicating if the task was cancelled + * @throws A2AServerException if retrieving the task fails for any reason + */ + public CancelTaskResponse cancelTask(String requestId, TaskIdParams taskIdParams) throws A2AServerException { + CancelTaskRequest.Builder cancelTaskRequestBuilder = new CancelTaskRequest.Builder() + .jsonrpc(JSONRPC_VERSION) + .method(CANCEL_TASK_METHOD) + .params(taskIdParams); + + if (requestId != null) { + cancelTaskRequestBuilder.id(requestId); + } + + CancelTaskRequest cancelTaskRequest = cancelTaskRequestBuilder.build(); + + try { + String httpResponseBody = sendPostRequest(cancelTaskRequest); + return unmarshalResponse(httpResponseBody, CANCEL_TASK_RESPONSE_REFERENCE); + } catch (IOException | InterruptedException e) { + throw new A2AServerException("Failed to cancel task: " + e); + } + } + + /** + * Get the push notification configuration for a task. + * + * @param id the task ID + * @return the response containing the push notification configuration + * @throws A2AServerException if getting the push notification configuration fails for any reason + */ + public GetTaskPushNotificationConfigResponse getTaskPushNotificationConfig(String id) throws A2AServerException { + return getTaskPushNotificationConfig(null, new TaskIdParams(id)); + } + + /** + * Get the push notification configuration for a task. + * + * @param taskIdParams the params for the task + * @return the response containing the push notification configuration + * @throws A2AServerException if getting the push notification configuration fails for any reason + */ + public GetTaskPushNotificationConfigResponse getTaskPushNotificationConfig(TaskIdParams taskIdParams) throws A2AServerException { + return getTaskPushNotificationConfig(null, taskIdParams); + } + + /** + * Get the push notification configuration for a task. + * + * @param requestId the request ID to use + * @param taskIdParams the params for the task + * @return the response containing the push notification configuration + * @throws A2AServerException if getting the push notification configuration fails for any reason + */ + public GetTaskPushNotificationConfigResponse getTaskPushNotificationConfig(String requestId, TaskIdParams taskIdParams) throws A2AServerException { + GetTaskPushNotificationConfigRequest.Builder getTaskPushNotificationRequestBuilder = new GetTaskPushNotificationConfigRequest.Builder() + .jsonrpc(JSONRPC_VERSION) + .method(GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD) + .params(taskIdParams); + + if (requestId != null) { + getTaskPushNotificationRequestBuilder.id(requestId); + } + + GetTaskPushNotificationConfigRequest getTaskPushNotificationRequest = getTaskPushNotificationRequestBuilder.build(); + + try { + String httpResponseBody = sendPostRequest(getTaskPushNotificationRequest); + return unmarshalResponse(httpResponseBody, GET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE); + } catch (IOException | InterruptedException e) { + throw new A2AServerException("Failed to get task push notification config: " + e); + } + } + + /** + * Set push notification configuration for a task. + * + * @param taskId the task ID + * @param pushNotificationConfig the push notification configuration + * @return the response indicating whether setting the task push notification configuration succeeded + * @throws A2AServerException if setting the push notification configuration fails for any reason + */ + public SetTaskPushNotificationConfigResponse setTaskPushNotificationConfig(String taskId, + PushNotificationConfig pushNotificationConfig) throws A2AServerException { + return setTaskPushNotificationConfig(null, taskId, pushNotificationConfig); + } + + /** + * Set push notification configuration for a task. + * + * @param requestId the request ID to use + * @param taskId the task ID + * @param pushNotificationConfig the push notification configuration + * @return the response indicating whether setting the task push notification configuration succeeded + * @throws A2AServerException if setting the push notification configuration fails for any reason + */ + public SetTaskPushNotificationConfigResponse setTaskPushNotificationConfig(String requestId, String taskId, + PushNotificationConfig pushNotificationConfig) throws A2AServerException { + SetTaskPushNotificationConfigRequest.Builder setTaskPushNotificationRequestBuilder = new SetTaskPushNotificationConfigRequest.Builder() + .jsonrpc(JSONRPC_VERSION) + .method(SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD) + .params(new TaskPushNotificationConfig(taskId, pushNotificationConfig)); + + if (requestId != null) { + setTaskPushNotificationRequestBuilder.id(requestId); + } + + SetTaskPushNotificationConfigRequest setTaskPushNotificationRequest = setTaskPushNotificationRequestBuilder.build(); + + try { + String httpResponseBody = sendPostRequest(setTaskPushNotificationRequest); + return unmarshalResponse(httpResponseBody, SET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE); + } catch (IOException | InterruptedException e) { + throw new A2AServerException("Failed to set task push notification config: " + e); + } + } + + /** + * Send a streaming message to the remote agent. + * + * @param messageSendParams the parameters for the message to be sent + * @param eventHandler a consumer that will be invoked for each event received from the remote agent + * @param errorHandler a consumer that will be invoked if the remote agent returns an error + * @param failureHandler a consumer that will be invoked if a failure occurs when processing events + * @throws A2AServerException if sending the streaming message fails for any reason + */ + public void sendStreamingMessage(MessageSendParams messageSendParams, Consumer eventHandler, + Consumer errorHandler, Runnable failureHandler) throws A2AServerException { + sendStreamingMessage(null, messageSendParams, eventHandler, errorHandler, failureHandler); + } + + /** + * Send a streaming message to the remote agent. + * + * @param requestId the request ID to use + * @param messageSendParams the parameters for the message to be sent + * @param eventHandler a consumer that will be invoked for each event received from the remote agent + * @param errorHandler a consumer that will be invoked if the remote agent returns an error + * @param failureHandler a consumer that will be invoked if a failure occurs when processing events + * @throws A2AServerException if sending the streaming message fails for any reason + */ + public void sendStreamingMessage(String requestId, MessageSendParams messageSendParams, Consumer eventHandler, + Consumer errorHandler, Runnable failureHandler) throws A2AServerException { + checkNotNullParam("messageSendParams", messageSendParams); + checkNotNullParam("eventHandler", eventHandler); + checkNotNullParam("errorHandler", errorHandler); + checkNotNullParam("failureHandler", failureHandler); + + SendStreamingMessageRequest.Builder sendStreamingMessageRequestBuilder = new SendStreamingMessageRequest.Builder() + .jsonrpc(JSONRPC_VERSION) + .method(SEND_STREAMING_MESSAGE_METHOD) + .params(messageSendParams); + + if (requestId != null) { + sendStreamingMessageRequestBuilder.id(requestId); + } + + AtomicReference> ref = new AtomicReference<>(); + SSEEventListener sseEventListener = new SSEEventListener(eventHandler, errorHandler, failureHandler); + SendStreamingMessageRequest sendStreamingMessageRequest = sendStreamingMessageRequestBuilder.build(); + try { + A2AHttpClient.PostBuilder builder = createPostBuilder(sendStreamingMessageRequest); + ref.set(builder.postAsyncSSE( + msg -> sseEventListener.onMessage(msg, ref.get()), + throwable -> sseEventListener.onError(throwable, ref.get()), + () -> { + // We don't need to do anything special on completion + })); + + } catch (IOException e) { + throw new A2AServerException("Failed to send streaming message request: " + e); + } catch (InterruptedException e) { + throw new A2AServerException("Send streaming message request timed out: " + e); + } + } + + /** + * Resubscribe to an ongoing task. + * + * @param taskIdParams the params for the task to resubscribe to + * @param eventHandler a consumer that will be invoked for each event received from the remote agent + * @param errorHandler a consumer that will be invoked if the remote agent returns an error + * @param failureHandler a consumer that will be invoked if a failure occurs when processing events + * @throws A2AServerException if resubscribing to the task fails for any reason + */ + public void resubscribeToTask(TaskIdParams taskIdParams, Consumer eventHandler, + Consumer errorHandler, Runnable failureHandler) throws A2AServerException { + resubscribeToTask(null, taskIdParams, eventHandler, errorHandler, failureHandler); + } + + /** + * Resubscribe to an ongoing task. + * + * @param requestId the request ID to use + * @param taskIdParams the params for the task to resubscribe to + * @param eventHandler a consumer that will be invoked for each event received from the remote agent + * @param errorHandler a consumer that will be invoked if the remote agent returns an error + * @param failureHandler a consumer that will be invoked if a failure occurs when processing events + * @throws A2AServerException if resubscribing to the task fails for any reason + */ + public void resubscribeToTask(String requestId, TaskIdParams taskIdParams, Consumer eventHandler, + Consumer errorHandler, Runnable failureHandler) throws A2AServerException { + checkNotNullParam("taskIdParams", taskIdParams); + checkNotNullParam("eventHandler", eventHandler); + checkNotNullParam("errorHandler", errorHandler); + checkNotNullParam("failureHandler", failureHandler); + + TaskResubscriptionRequest.Builder taskResubscriptionRequestBuilder = new TaskResubscriptionRequest.Builder() + .jsonrpc(JSONRPC_VERSION) + .method(SEND_TASK_RESUBSCRIPTION_METHOD) + .params(taskIdParams); + + if (requestId != null) { + taskResubscriptionRequestBuilder.id(requestId); + } + + AtomicReference> ref = new AtomicReference<>(); + SSEEventListener sseEventListener = new SSEEventListener(eventHandler, errorHandler, failureHandler); + TaskResubscriptionRequest taskResubscriptionRequest = taskResubscriptionRequestBuilder.build(); + try { + A2AHttpClient.PostBuilder builder = createPostBuilder(taskResubscriptionRequest); + ref.set(builder.postAsyncSSE( + msg -> sseEventListener.onMessage(msg, ref.get()), + throwable -> sseEventListener.onError(throwable, ref.get()), + () -> { + // We don't need to do anything special on completion + })); + + } catch (IOException e) { + throw new A2AServerException("Failed to send task resubscription request: " + e); + } catch (InterruptedException e) { + throw new A2AServerException("Task resubscription request timed out: " + e); + } + } + + private String sendPostRequest(Object value) throws IOException, InterruptedException { + A2AHttpClient.PostBuilder builder = createPostBuilder(value); + A2AHttpResponse response = builder.post(); + if (!response.success()) { + throw new IOException("Request failed " + response.status()); + } + return response.body(); + } + + private A2AHttpClient.PostBuilder createPostBuilder(Object value) throws JsonProcessingException { + return httpClient.createPost() + .url(agentUrl) + .addHeader("Content-Type", "application/json") + .body(OBJECT_MAPPER.writeValueAsString(value)); + + } + + private T unmarshalResponse(String response, TypeReference typeReference) + throws A2AServerException, JsonProcessingException { + T value = unmarshalFrom(response, typeReference); + JSONRPCError error = value.getError(); + if (error != null) { + throw new A2AServerException(error.getMessage() + (error.getData() != null ? ": " + error.getData() : "")); + } + return value; + } +} diff --git a/core/src/main/java/io/a2a/client/sse/SSEEventListener.java b/core/src/main/java/io/a2a/client/sse/SSEEventListener.java new file mode 100644 index 000000000..8ed0e9aa3 --- /dev/null +++ b/core/src/main/java/io/a2a/client/sse/SSEEventListener.java @@ -0,0 +1,61 @@ +package io.a2a.client.sse; + +import static io.a2a.util.Utils.OBJECT_MAPPER; + +import java.util.concurrent.Future; +import java.util.function.Consumer; +import java.util.logging.Logger; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.TaskStatusUpdateEvent; + +public class SSEEventListener { + private static final Logger log = Logger.getLogger(SSEEventListener.class.getName()); + private final Consumer eventHandler; + private final Consumer errorHandler; + private final Runnable failureHandler; + + public SSEEventListener(Consumer eventHandler, Consumer errorHandler, Runnable failureHandler) { + this.eventHandler = eventHandler; + this.errorHandler = errorHandler; + this.failureHandler = failureHandler; + } + + public void onMessage(String message, Future completableFuture) { + try { + handleMessage(OBJECT_MAPPER.readTree(message),completableFuture); + } catch (JsonProcessingException e) { + log.warning("Failed to parse JSON message: " + message); + } + } + + public void onError(Throwable throwable, Future future) { + failureHandler.run(); + future.cancel(true); // close SSE channel + } + + private void handleMessage(JsonNode jsonNode, Future future) { + try { + if (jsonNode.has("error")) { + JSONRPCError error = OBJECT_MAPPER.treeToValue(jsonNode.get("error"), JSONRPCError.class); + errorHandler.accept(error); + } else if (jsonNode.has("result")) { + // result can be a Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent + JsonNode result = jsonNode.path("result"); + StreamingEventKind event = OBJECT_MAPPER.treeToValue(result, StreamingEventKind.class); + eventHandler.accept(event); + if (event instanceof TaskStatusUpdateEvent && ((TaskStatusUpdateEvent) event).isFinal()) { + future.cancel(true); // close SSE channel + } + } else { + throw new IllegalArgumentException("Unknown message type"); + } + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + +} diff --git a/core/src/main/java/io/a2a/http/A2AHttpClient.java b/core/src/main/java/io/a2a/http/A2AHttpClient.java new file mode 100644 index 000000000..7a246843a --- /dev/null +++ b/core/src/main/java/io/a2a/http/A2AHttpClient.java @@ -0,0 +1,34 @@ +package io.a2a.http; + +import java.io.IOException; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +public interface A2AHttpClient { + + GetBuilder createGet(); + + PostBuilder createPost(); + + interface Builder> { + T url(String s); + T addHeader(String name, String value); + } + + interface GetBuilder extends Builder { + A2AHttpResponse get() throws IOException, InterruptedException; + CompletableFuture getAsyncSSE( + Consumer messageConsumer, + Consumer errorConsumer, + Runnable completeRunnable) throws IOException, InterruptedException; + } + + interface PostBuilder extends Builder { + PostBuilder body(String body); + A2AHttpResponse post() throws IOException, InterruptedException; + CompletableFuture postAsyncSSE( + Consumer messageConsumer, + Consumer errorConsumer, + Runnable completeRunnable) throws IOException, InterruptedException; + } +} diff --git a/core/src/main/java/io/a2a/http/A2AHttpResponse.java b/core/src/main/java/io/a2a/http/A2AHttpResponse.java new file mode 100644 index 000000000..d6973a5dc --- /dev/null +++ b/core/src/main/java/io/a2a/http/A2AHttpResponse.java @@ -0,0 +1,9 @@ +package io.a2a.http; + +public interface A2AHttpResponse { + int status(); + + boolean success(); + + String body(); +} diff --git a/core/src/main/java/io/a2a/http/JdkA2AHttpClient.java b/core/src/main/java/io/a2a/http/JdkA2AHttpClient.java new file mode 100644 index 000000000..e3b5c0c66 --- /dev/null +++ b/core/src/main/java/io/a2a/http/JdkA2AHttpClient.java @@ -0,0 +1,210 @@ +package io.a2a.http; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandler; +import java.net.http.HttpResponse.BodyHandlers; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Flow; +import java.util.function.Consumer; + +public class JdkA2AHttpClient implements A2AHttpClient { + + private final HttpClient httpClient; + + public JdkA2AHttpClient() { + httpClient = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .followRedirects(HttpClient.Redirect.NORMAL) + .build(); + } + + @Override + public GetBuilder createGet() { + return new JdkGetBuilder(); + } + + @Override + public PostBuilder createPost() { + return new JdkPostBuilder(); + } + + private abstract class JdkBuilder> implements Builder { + private String url; + private Map headers = new HashMap<>(); + + @Override + public T url(String url) { + this.url = url; + return self(); + } + + @Override + public T addHeader(String name, String value) { + headers.put(name, value); + return self(); + } + + @SuppressWarnings("unchecked") + T self() { + return (T) this; + } + + protected HttpRequest.Builder createRequestBuilder() throws IOException { + HttpRequest.Builder builder = HttpRequest.newBuilder() + .uri(URI.create(url)); + for (Map.Entry headerEntry : headers.entrySet()) { + builder.header(headerEntry.getKey(), headerEntry.getValue()); + } + return builder; + } + + protected CompletableFuture asyncRequest( + HttpRequest request, + Consumer messageConsumer, + Consumer errorConsumer, + Runnable completeRunnable + ) { + Flow.Subscriber subscriber = new Flow.Subscriber() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(String item) { + // SSE messages sometimes start with "data:". Strip that off + if (item != null && item.startsWith("data:")) { + item = item.substring(5).trim(); + if (!item.isEmpty()) { + messageConsumer.accept(item); + } + } + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + errorConsumer.accept(throwable); + subscription.cancel(); + } + + @Override + public void onComplete() { + completeRunnable.run(); + subscription.cancel(); + } + }; + + BodyHandler bodyHandler = BodyHandlers.fromLineSubscriber(subscriber); + + // Send the response async, and let the subscriber handle the lines. + return httpClient.sendAsync(request, bodyHandler) + .thenAccept(response -> { + if (!JdkHttpResponse.success(response)) { + subscriber.onError(new IOException("Request failed " + response.statusCode())); + } + }); + } + } + + private class JdkGetBuilder extends JdkBuilder implements A2AHttpClient.GetBuilder { + + private HttpRequest.Builder createRequestBuilder(boolean SSE) throws IOException { + HttpRequest.Builder builder = super.createRequestBuilder().GET(); + if (SSE) { + builder.header("Accept", "text/event-stream"); + } + return builder; + } + + @Override + public A2AHttpResponse get() throws IOException, InterruptedException { + HttpRequest request = createRequestBuilder(false) + .build(); + HttpResponse response = + httpClient.send(request, BodyHandlers.ofString(StandardCharsets.UTF_8)); + return new JdkHttpResponse(response); + } + + @Override + public CompletableFuture getAsyncSSE( + Consumer messageConsumer, + Consumer errorConsumer, + Runnable completeRunnable) throws IOException, InterruptedException { + HttpRequest request = createRequestBuilder(false) + .build(); + return super.asyncRequest(request, messageConsumer, errorConsumer, completeRunnable); + } + } + + private class JdkPostBuilder extends JdkBuilder implements A2AHttpClient.PostBuilder { + String body = ""; + + @Override + public PostBuilder body(String body) { + this.body = body; + return self(); + } + + private HttpRequest.Builder createRequestBuilder(boolean SSE) throws IOException { + HttpRequest.Builder builder = super.createRequestBuilder() + .POST(HttpRequest.BodyPublishers.ofString(body, StandardCharsets.UTF_8)); + if (SSE) { + builder.header("Accept", "text/event-stream"); + } + return builder; + } + + @Override + public A2AHttpResponse post() throws IOException, InterruptedException { + HttpRequest request = createRequestBuilder(false) + .POST(HttpRequest.BodyPublishers.ofString(body, StandardCharsets.UTF_8)) + .build(); + HttpResponse response = + httpClient.send(request, BodyHandlers.ofString(StandardCharsets.UTF_8)); + return new JdkHttpResponse(response); + } + + @Override + public CompletableFuture postAsyncSSE( + Consumer messageConsumer, + Consumer errorConsumer, + Runnable completeRunnable) throws IOException, InterruptedException { + HttpRequest request = createRequestBuilder(false) + .build(); + return super.asyncRequest(request, messageConsumer, errorConsumer, completeRunnable); + } + } + + private record JdkHttpResponse(HttpResponse response) implements A2AHttpResponse { + + @Override + public int status() { + return response.statusCode(); + } + + @Override + public boolean success() {// Send the request and get the response + return success(response); + } + + static boolean success(HttpResponse response) { + return response.statusCode() >= 200 && response.statusCode() < 300; + } + + @Override + public String body() { + return response.body(); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/A2A.java b/core/src/main/java/io/a2a/spec/A2A.java new file mode 100644 index 000000000..4f3d2381c --- /dev/null +++ b/core/src/main/java/io/a2a/spec/A2A.java @@ -0,0 +1,148 @@ +package io.a2a.spec; + +import java.util.Collections; +import java.util.Map; + +import io.a2a.client.A2ACardResolver; +import io.a2a.http.A2AHttpClient; +import io.a2a.http.JdkA2AHttpClient; + + +/** + * Constants and utility methods related to the A2A protocol. + */ +public class A2A { + + public static final String CANCEL_TASK_METHOD = "tasks/cancel"; + public static final String GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD = "tasks/pushNotificationConfig/get"; + public static final String GET_TASK_METHOD = "tasks/get"; + public static final String SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD = "tasks/pushNotificationConfig/set"; + public static final String SEND_TASK_RESUBSCRIPTION_METHOD = "tasks/resubscribe"; + public static final String SEND_STREAMING_MESSAGE_METHOD = "message/stream"; + public static final String SEND_MESSAGE_METHOD = "message/send"; + + public static final String JSONRPC_VERSION = "2.0"; + + + /** + * Convert the given text to a user message. + * + * @param text the message text + * @return the user message + */ + public static Message toUserMessage(String text) { + return toMessage(text, Message.Role.USER, null); + } + + /** + * Convert the given text to a user message. + * + * @param text the message text + * @param messageId the message ID to use + * @return the user message + */ + public static Message toUserMessage(String text, String messageId) { + return toMessage(text, Message.Role.USER, messageId); + } + + /** + * Convert the given text to an agent message. + * + * @param text the message text + * @return the agent message + */ + public static Message toAgentMessage(String text) { + return toMessage(text, Message.Role.AGENT, null); + } + + /** + * Convert the given text to an agent message. + * + * @param text the message text + * @param messageId the message ID to use + * @return the agent message + */ + public static Message toAgentMessage(String text, String messageId) { + return toMessage(text, Message.Role.AGENT, messageId); + } + + + private static Message toMessage(String text, Message.Role role, String messageId) { + Message.Builder messageBuilder = new Message.Builder() + .role(role) + .parts(Collections.singletonList(new TextPart(text))); + if (messageId != null) { + messageBuilder.messageId(messageId); + } + return messageBuilder.build(); + } + + /** + * Get the agent card for an A2A agent. + * + * @param agentUrl the base URL for the agent whose agent card we want to retrieve + * @return the agent card + * @throws A2AClientError If an HTTP error occurs fetching the card + * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema + */ + public static AgentCard getAgentCard(String agentUrl) throws A2AClientError, A2AClientJSONError { + return getAgentCard(new JdkA2AHttpClient(), agentUrl); + } + + /** + * Get the agent card for an A2A agent. + * + * @param httpClient the http client to use + * @param agentUrl the base URL for the agent whose agent card we want to retrieve + * @return the agent card + * @throws A2AClientError If an HTTP error occurs fetching the card + * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema + */ + public static AgentCard getAgentCard(A2AHttpClient httpClient, String agentUrl) throws A2AClientError, A2AClientJSONError { + return getAgentCard(httpClient, agentUrl, null, null); + } + + /** + * Get the agent card for an A2A agent. + * + * @param agentUrl the base URL for the agent whose agent card we want to retrieve + * @param relativeCardPath optional path to the agent card endpoint relative to the base + * agent URL, defaults to ".well-known/agent.json" + * @param authHeaders the HTTP authentication headers to use + * @return the agent card + * @throws A2AClientError If an HTTP error occurs fetching the card + * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema + */ + public static AgentCard getAgentCard(String agentUrl, String relativeCardPath, Map authHeaders) throws A2AClientError, A2AClientJSONError { + return getAgentCard(new JdkA2AHttpClient(), agentUrl, relativeCardPath, authHeaders); + } + + /** + * Get the agent card for an A2A agent. + * + * @param httpClient the http client to use + * @param agentUrl the base URL for the agent whose agent card we want to retrieve + * @param relativeCardPath optional path to the agent card endpoint relative to the base + * agent URL, defaults to ".well-known/agent.json" + * @param authHeaders the HTTP authentication headers to use + * @return the agent card + * @throws A2AClientError If an HTTP error occurs fetching the card + * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema + */ + public static AgentCard getAgentCard(A2AHttpClient httpClient, String agentUrl, String relativeCardPath, Map authHeaders) throws A2AClientError, A2AClientJSONError { + A2ACardResolver resolver = new A2ACardResolver(httpClient, agentUrl, relativeCardPath, authHeaders); + return resolver.getAgentCard(); + } + + protected static boolean isValidMethodName(String methodName) { + return methodName != null && (methodName.equals(CANCEL_TASK_METHOD) + || methodName.equals(GET_TASK_METHOD) + || methodName.equals(GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD) + || methodName.equals(SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD) + || methodName.equals(SEND_TASK_RESUBSCRIPTION_METHOD) + || methodName.equals(SEND_MESSAGE_METHOD) + || methodName.equals(SEND_STREAMING_MESSAGE_METHOD)); + + } + +} diff --git a/core/src/main/java/io/a2a/spec/A2AClientError.java b/core/src/main/java/io/a2a/spec/A2AClientError.java new file mode 100644 index 000000000..2ec8c864e --- /dev/null +++ b/core/src/main/java/io/a2a/spec/A2AClientError.java @@ -0,0 +1,14 @@ +package io.a2a.spec; + +public class A2AClientError extends Exception { + public A2AClientError() { + } + + public A2AClientError(String message) { + super(message); + } + + public A2AClientError(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/core/src/main/java/io/a2a/spec/A2AClientHTTPError.java b/core/src/main/java/io/a2a/spec/A2AClientHTTPError.java new file mode 100644 index 000000000..95b59a764 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/A2AClientHTTPError.java @@ -0,0 +1,33 @@ +package io.a2a.spec; + +import io.a2a.util.Assert; + +public class A2AClientHTTPError extends A2AClientError { + private final int code; + private final String message; + + public A2AClientHTTPError(int code, String message, Object data) { + Assert.checkNotNullParam("code", code); + Assert.checkNotNullParam("message", message); + this.code = code; + this.message = message; + } + + /** + * Gets the error code + * + * @return the error code + */ + public int getCode() { + return code; + } + + /** + * Gets the error message + * + * @return the error message + */ + public String getMessage() { + return message; + } +} diff --git a/core/src/main/java/io/a2a/spec/A2AClientJSONError.java b/core/src/main/java/io/a2a/spec/A2AClientJSONError.java new file mode 100644 index 000000000..75988da1c --- /dev/null +++ b/core/src/main/java/io/a2a/spec/A2AClientJSONError.java @@ -0,0 +1,15 @@ +package io.a2a.spec; + +public class A2AClientJSONError extends A2AClientError { + + public A2AClientJSONError() { + } + + public A2AClientJSONError(String message) { + super(message); + } + + public A2AClientJSONError(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/core/src/main/java/io/a2a/spec/A2AError.java b/core/src/main/java/io/a2a/spec/A2AError.java new file mode 100644 index 000000000..4c9951df2 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/A2AError.java @@ -0,0 +1,4 @@ +package io.a2a.spec; + +public interface A2AError extends Event { +} diff --git a/core/src/main/java/io/a2a/spec/A2AException.java b/core/src/main/java/io/a2a/spec/A2AException.java new file mode 100644 index 000000000..22b8363e2 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/A2AException.java @@ -0,0 +1,46 @@ +package io.a2a.spec; + +import java.io.IOException; + +/** + * Exception to indicate a general failure related to the A2A protocol. + */ +public class A2AException extends IOException { + + /** + * Constructs a new {@code A2AException} instance. The message is left blank ({@code null}), and no + * cause is specified. + */ + public A2AException() { + } + + /** + * Constructs a new {@code A2AException} instance with an initial message. No cause is specified. + * + * @param msg the message + */ + public A2AException(final String msg) { + super(msg); + } + + /** + * Constructs a new {@code A2AException} instance with an initial cause. If a non-{@code null} cause + * is specified, its message is used to initialize the message of this {@code A2AException}; otherwise + * the message is left blank ({@code null}). + * + * @param cause the cause + */ + public A2AException(final Throwable cause) { + super(cause); + } + + /** + * Constructs a new {@code A2AException} instance with an initial message and cause. + * + * @param msg the message + * @param cause the cause + */ + public A2AException(final String msg, final Throwable cause) { + super(msg, cause); + } +} diff --git a/core/src/main/java/io/a2a/spec/A2AServerException.java b/core/src/main/java/io/a2a/spec/A2AServerException.java new file mode 100644 index 000000000..ca2611c2f --- /dev/null +++ b/core/src/main/java/io/a2a/spec/A2AServerException.java @@ -0,0 +1,23 @@ +package io.a2a.spec; + +/** + * Exception to indicate a general failure related to an A2A server. + */ +public class A2AServerException extends A2AException { + + public A2AServerException() { + super(); + } + + public A2AServerException(final String msg) { + super(msg); + } + + public A2AServerException(final Throwable cause) { + super(cause); + } + + public A2AServerException(final String msg, final Throwable cause) { + super(msg, cause); + } +} diff --git a/core/src/main/java/io/a2a/spec/APIKeySecurityScheme.java b/core/src/main/java/io/a2a/spec/APIKeySecurityScheme.java new file mode 100644 index 000000000..acf33feba --- /dev/null +++ b/core/src/main/java/io/a2a/spec/APIKeySecurityScheme.java @@ -0,0 +1,118 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonValue; + +import io.a2a.util.Assert; + +/** + * Represents an API Key security scheme. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class APIKeySecurityScheme implements SecurityScheme { + + public static final String API_KEY = "apiKey"; + private final String in; + private final String name; + private final String type; + private final String description; + + /** + * Represents the location of the API key. + */ + public enum Location { + COOKIE("cookie"), + HEADER("header"), + QUERY("query"); + + private final String location; + + Location(String location) { + this.location = location; + } + + @JsonValue + public String asString() { + return location; + } + + @JsonCreator + public static Location fromString(String location) { + switch (location) { + case "cookie": + return COOKIE; + case "header": + return HEADER; + case "query": + return QUERY; + default: + throw new IllegalArgumentException("Invalid API key location: " + location); + } + } + } + + public APIKeySecurityScheme(String in, String name, String description) { + this(in, name, description, API_KEY); + } + + @JsonCreator + public APIKeySecurityScheme(@JsonProperty("in") String in, @JsonProperty("name") String name, + @JsonProperty("description") String description, @JsonProperty("type") String type) { + Assert.checkNotNullParam("in", in); + Assert.checkNotNullParam("name", name); + if (! type.equals(API_KEY)) { + throw new IllegalArgumentException("Invalid type for APIKeySecurityScheme"); + } + this.in = in; + this.name = name; + this.description = description; + this.type = type; + } + + @Override + public String getDescription() { + return description; + } + + + public String getIn() { + return in; + } + + public String getName() { + return name; + } + + public String getType() { + return type; + } + + public static class Builder { + private String in; + private String name; + private String description; + + public Builder in(String in) { + this.in = in; + return this; + } + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public APIKeySecurityScheme build() { + return new APIKeySecurityScheme(in, name, description); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/AgentCapabilities.java b/core/src/main/java/io/a2a/spec/AgentCapabilities.java new file mode 100644 index 000000000..641f56ccb --- /dev/null +++ b/core/src/main/java/io/a2a/spec/AgentCapabilities.java @@ -0,0 +1,47 @@ +package io.a2a.spec; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; + +/** + * An agent's capabilities. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record AgentCapabilities(boolean streaming, boolean pushNotifications, boolean stateTransitionHistory, + List extensions) { + + public static class Builder { + + private boolean streaming; + private boolean pushNotifications; + private boolean stateTransitionHistory; + private List extensions; + + public Builder streaming(boolean streaming) { + this.streaming = streaming; + return this; + } + + public Builder pushNotifications(boolean pushNotifications) { + this.pushNotifications = pushNotifications; + return this; + } + + public Builder stateTransitionHistory(boolean stateTransitionHistory) { + this.stateTransitionHistory = stateTransitionHistory; + return this; + } + + public Builder extensions(List extensions) { + this.extensions = extensions; + return this; + } + + public AgentCapabilities build() { + return new AgentCapabilities(streaming, pushNotifications, stateTransitionHistory, extensions); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/AgentCard.java b/core/src/main/java/io/a2a/spec/AgentCard.java new file mode 100644 index 000000000..8429b16f5 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/AgentCard.java @@ -0,0 +1,127 @@ +package io.a2a.spec; + +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import io.a2a.util.Assert; + +/** + * A public metadata file that describes an agent's capabilities, skills, endpoint URL, and + * authentication requirements. Clients use this for discovery. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record AgentCard(String name, String description, String url, AgentProvider provider, + String version, String documentationUrl, AgentCapabilities capabilities, + List defaultInputModes, List defaultOutputModes, List skills, + boolean supportsAuthenticatedExtendedCard, Map securitySchemes, + List>> security, String iconUrl) { + + private static final String TEXT_MODE = "text"; + + public AgentCard { + Assert.checkNotNullParam("capabilities", capabilities); + Assert.checkNotNullParam("defaultInputModes", defaultInputModes); + Assert.checkNotNullParam("defaultOutputModes", defaultOutputModes); + Assert.checkNotNullParam("description", description); + Assert.checkNotNullParam("name", name); + Assert.checkNotNullParam("skills", skills); + Assert.checkNotNullParam("url", url); + Assert.checkNotNullParam("version", version); + } + + public static class Builder { + private String name; + private String description; + private String url; + private AgentProvider provider; + private String version; + private String documentationUrl; + private AgentCapabilities capabilities; + private List defaultInputModes; + private List defaultOutputModes; + private List skills; + private boolean supportsAuthenticatedExtendedCard = false; + private Map securitySchemes; + private List>> security; + private String iconUrl; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder url(String url) { + this.url = url; + return this; + } + + public Builder provider(AgentProvider provider) { + this.provider = provider; + return this; + } + + public Builder version(String version) { + this.version = version; + return this; + } + + public Builder documentationUrl(String documentationUrl) { + this.documentationUrl = documentationUrl; + return this; + } + + public Builder capabilities(AgentCapabilities capabilities) { + this.capabilities = capabilities; + return this; + } + + public Builder defaultInputModes(List defaultInputModes) { + this.defaultInputModes = defaultInputModes; + return this; + } + + public Builder defaultOutputModes(List defaultOutputModes) { + this.defaultOutputModes = defaultOutputModes; + return this; + } + + public Builder skills(List skills) { + this.skills = skills; + return this; + } + + public Builder supportsAuthenticatedExtendedCard(boolean supportsAuthenticatedExtendedCard) { + this.supportsAuthenticatedExtendedCard = supportsAuthenticatedExtendedCard; + return this; + } + + public Builder securitySchemes(Map securitySchemes) { + this.securitySchemes = securitySchemes; + return this; + } + + public Builder security(List>> security) { + this.security = security; + return this; + } + + public Builder iconUrl(String iconUrl) { + this.iconUrl = iconUrl; + return this; + } + + public AgentCard build() { + return new AgentCard(name, description, url, provider, version, documentationUrl, + capabilities, defaultInputModes, defaultOutputModes, skills, + supportsAuthenticatedExtendedCard, securitySchemes, security, iconUrl); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/AgentExtension.java b/core/src/main/java/io/a2a/spec/AgentExtension.java new file mode 100644 index 000000000..931bc1c16 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/AgentExtension.java @@ -0,0 +1,44 @@ +package io.a2a.spec; + +import java.util.Map; + +import io.a2a.util.Assert; + +public record AgentExtension (String description, Map params, boolean required, String uri) { + + public AgentExtension { + Assert.checkNotNullParam("uri", uri); + } + + public static class Builder { + String description; + Map params; + boolean required; + String uri; + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder params(Map params) { + this.params = params; + return this; + } + + public Builder required(boolean required) { + this.required = required; + return this; + } + + public Builder uri(String uri) { + this.uri = uri; + return this; + } + + public AgentExtension build() { + return new AgentExtension(description, params, required, uri); + } + } + +} diff --git a/core/src/main/java/io/a2a/spec/AgentProvider.java b/core/src/main/java/io/a2a/spec/AgentProvider.java new file mode 100644 index 000000000..363d42b03 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/AgentProvider.java @@ -0,0 +1,18 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import io.a2a.util.Assert; + +/** + * An agent provider. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record AgentProvider(String organization, String url) { + + public AgentProvider { + Assert.checkNotNullParam("organization", organization); + Assert.checkNotNullParam("url", url); + } +} diff --git a/core/src/main/java/io/a2a/spec/AgentSkill.java b/core/src/main/java/io/a2a/spec/AgentSkill.java new file mode 100644 index 000000000..ce9f4811e --- /dev/null +++ b/core/src/main/java/io/a2a/spec/AgentSkill.java @@ -0,0 +1,73 @@ +package io.a2a.spec; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import io.a2a.util.Assert; + +/** + * An agent skill. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record AgentSkill(String id, String name, String description, List tags, + List examples, List inputModes, List outputModes) { + + public AgentSkill { + Assert.checkNotNullParam("description", description); + Assert.checkNotNullParam("id", id); + Assert.checkNotNullParam("name", name); + Assert.checkNotNullParam("tags", tags); + } + + public static class Builder { + + private String id; + private String name; + private String description; + private List tags; + private List examples; + private List inputModes; + private List outputModes; + + public Builder id(String id) { + this.id = id; + return this; + } + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder tags(List tags) { + this.tags = tags; + return this; + } + + public Builder examples(List examples) { + this.examples = examples; + return this; + } + + public Builder inputModes(List inputModes) { + this.inputModes = inputModes; + return this; + } + + public Builder outputModes(List outputModes) { + this.outputModes = outputModes; + return this; + } + + public AgentSkill build() { + return new AgentSkill(id, name, description, tags, examples, inputModes, outputModes); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/Artifact.java b/core/src/main/java/io/a2a/spec/Artifact.java new file mode 100644 index 000000000..5a5baa063 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/Artifact.java @@ -0,0 +1,79 @@ +package io.a2a.spec; + +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import io.a2a.util.Assert; + +/** + * Represents outputs generated by an agent during a task (e.g., generated files or final structured + * data). Contains parts. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record Artifact(String artifactId, String name, String description, List> parts, Map metadata) { + + public Artifact { + Assert.checkNotNullParam("artifactId", artifactId); + Assert.checkNotNullParam("parts", parts); + if (parts.isEmpty()) { + throw new IllegalArgumentException("Parts cannot be empty"); + } + } + + public static class Builder { + private String artifactId; + private String name; + private String description; + private List> parts; + private Map metadata; + + public Builder(){ + } + + public Builder(Artifact existingArtifact) { + artifactId = existingArtifact.artifactId; + name = existingArtifact.name; + description = existingArtifact.description; + parts = existingArtifact.parts; + metadata = existingArtifact.metadata; + } + + public Builder artifactId(String artifactId) { + this.artifactId = artifactId; + return this; + } + + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder parts(List> parts) { + this.parts = parts; + return this; + } + + public Builder parts(Part... parts) { + this.parts = List.of(parts); + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public Artifact build() { + return new Artifact(artifactId, name, description, parts, metadata); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/AuthenticationInfo.java b/core/src/main/java/io/a2a/spec/AuthenticationInfo.java new file mode 100644 index 000000000..d28a1e173 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/AuthenticationInfo.java @@ -0,0 +1,19 @@ +package io.a2a.spec; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import io.a2a.util.Assert; + +/** + * The authentication info for an agent. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record AuthenticationInfo(List schemes, String credentials) { + + public AuthenticationInfo { + Assert.checkNotNullParam("schemes", schemes); + } +} diff --git a/core/src/main/java/io/a2a/spec/AuthorizationCodeOAuthFlow.java b/core/src/main/java/io/a2a/spec/AuthorizationCodeOAuthFlow.java new file mode 100644 index 000000000..0723a1281 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/AuthorizationCodeOAuthFlow.java @@ -0,0 +1,23 @@ +package io.a2a.spec; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; + +import io.a2a.util.Assert; + +/** + * Configuration for the OAuth Authorization Code flow. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record AuthorizationCodeOAuthFlow(String authorizationUrl, String refreshUrl, Map scopes, + String tokenUrl) { + + public AuthorizationCodeOAuthFlow { + Assert.checkNotNullParam("authorizationUrl", authorizationUrl); + Assert.checkNotNullParam("scopes", scopes); + Assert.checkNotNullParam("tokenUrl", tokenUrl); + } +} diff --git a/core/src/main/java/io/a2a/spec/CancelTaskRequest.java b/core/src/main/java/io/a2a/spec/CancelTaskRequest.java new file mode 100644 index 000000000..3349ceb5c --- /dev/null +++ b/core/src/main/java/io/a2a/spec/CancelTaskRequest.java @@ -0,0 +1,78 @@ +package io.a2a.spec; + +import static io.a2a.spec.A2A.CANCEL_TASK_METHOD; +import static io.a2a.spec.A2A.JSONRPC_VERSION; +import static io.a2a.util.Utils.defaultIfNull; + +import java.util.UUID; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.a2a.util.Assert; + +/** + * A request that can be used to cancel a task. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class CancelTaskRequest extends NonStreamingJSONRPCRequest { + + @JsonCreator + public CancelTaskRequest(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("method") String method, @JsonProperty("params") TaskIdParams params) { + if (jsonrpc != null && ! jsonrpc.equals(JSONRPC_VERSION)) { + throw new IllegalArgumentException("Invalid JSON-RPC protocol version"); + } + Assert.checkNotNullParam("method", method); + if (! method.equals(CANCEL_TASK_METHOD)) { + throw new IllegalArgumentException("Invalid CancelTaskRequest method"); + } + Assert.checkNotNullParam("params", params); + Assert.isNullOrStringOrInteger(id); + this.jsonrpc = defaultIfNull(jsonrpc, JSONRPC_VERSION); + this.id = id; + this.method = method; + this.params = params; + } + + public CancelTaskRequest(Object id, TaskIdParams params) { + this(null, id, CANCEL_TASK_METHOD, params); + } + + public static class Builder { + private String jsonrpc; + private Object id; + private String method = CANCEL_TASK_METHOD; + private TaskIdParams params; + + public CancelTaskRequest.Builder jsonrpc(String jsonrpc) { + this.jsonrpc = jsonrpc; + return this; + } + + public CancelTaskRequest.Builder id(Object id) { + this.id = id; + return this; + } + + public CancelTaskRequest.Builder method(String method) { + this.method = method; + return this; + } + + public CancelTaskRequest.Builder params(TaskIdParams params) { + this.params = params; + return this; + } + + public CancelTaskRequest build() { + if (id == null) { + id = UUID.randomUUID().toString(); + } + return new CancelTaskRequest(jsonrpc, id, method, params); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/CancelTaskResponse.java b/core/src/main/java/io/a2a/spec/CancelTaskResponse.java new file mode 100644 index 000000000..02bd63461 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/CancelTaskResponse.java @@ -0,0 +1,29 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A response to a cancel task request. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class CancelTaskResponse extends JSONRPCResponse { + + @JsonCreator + public CancelTaskResponse(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("result") Task result, @JsonProperty("error") JSONRPCError error) { + super(jsonrpc, id, result, error); + } + + public CancelTaskResponse(Object id, JSONRPCError error) { + this(null, id, null, error); + } + + + public CancelTaskResponse(Object id, Task result) { + this(null, id, result, null); + } +} diff --git a/core/src/main/java/io/a2a/spec/ClientCredentialsOAuthFlow.java b/core/src/main/java/io/a2a/spec/ClientCredentialsOAuthFlow.java new file mode 100644 index 000000000..05322cc0b --- /dev/null +++ b/core/src/main/java/io/a2a/spec/ClientCredentialsOAuthFlow.java @@ -0,0 +1,23 @@ +package io.a2a.spec; + + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; + +import io.a2a.util.Assert; + +/** + * Configuration for the OAuth Client Credentials flow. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record ClientCredentialsOAuthFlow(String refreshUrl, Map scopes, String tokenUrl) { + + public ClientCredentialsOAuthFlow { + Assert.checkNotNullParam("scopes", scopes); + Assert.checkNotNullParam("tokenUrl", tokenUrl); + } + +} diff --git a/core/src/main/java/io/a2a/spec/ContentTypeNotSupportedError.java b/core/src/main/java/io/a2a/spec/ContentTypeNotSupportedError.java new file mode 100644 index 000000000..04b37ba25 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/ContentTypeNotSupportedError.java @@ -0,0 +1,23 @@ +package io.a2a.spec; + +import static io.a2a.util.Utils.defaultIfNull; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public class ContentTypeNotSupportedError extends JSONRPCError { + @JsonCreator + public ContentTypeNotSupportedError( + @JsonProperty("code") Integer code, + @JsonProperty("message") String message, + @JsonProperty("data") Object data) { + super( + defaultIfNull(code, -32005), + defaultIfNull(message, "Incompatible content types"), + data); + } +} diff --git a/core/src/main/java/io/a2a/spec/DataPart.java b/core/src/main/java/io/a2a/spec/DataPart.java new file mode 100644 index 000000000..1449f55ac --- /dev/null +++ b/core/src/main/java/io/a2a/spec/DataPart.java @@ -0,0 +1,49 @@ +package io.a2a.spec; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.a2a.util.Assert; + +/** + * A fundamental data unit within a Message or Artifact. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public class DataPart extends Part> { + + private final Map data; + private final Map metadata; + private final Kind kind; + + public DataPart(Map data) { + this(data, null); + } + + @JsonCreator + public DataPart(@JsonProperty("data") Map data, + @JsonProperty("metadata") Map metadata) { + Assert.checkNotNullParam("data", data); + this.data = data; + this.metadata = metadata; + this.kind = Kind.DATA; + } + + @Override + public Kind getKind() { + return kind; + } + + public Map getData() { + return data; + } + + @Override + public Map getMetadata() { + return metadata; + } + +} diff --git a/core/src/main/java/io/a2a/spec/Event.java b/core/src/main/java/io/a2a/spec/Event.java new file mode 100644 index 000000000..4d0daa531 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/Event.java @@ -0,0 +1,4 @@ +package io.a2a.spec; + +public interface Event { +} diff --git a/core/src/main/java/io/a2a/spec/EventKind.java b/core/src/main/java/io/a2a/spec/EventKind.java new file mode 100644 index 000000000..07ab3fe9f --- /dev/null +++ b/core/src/main/java/io/a2a/spec/EventKind.java @@ -0,0 +1,22 @@ +package io.a2a.spec; + +import static io.a2a.spec.Message.MESSAGE; +import static io.a2a.spec.Task.TASK; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +@JsonTypeInfo( + use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.PROPERTY, + property = "kind", + visible = true +) +@JsonSubTypes({ + @JsonSubTypes.Type(value = Task.class, name = TASK), + @JsonSubTypes.Type(value = Message.class, name = MESSAGE) +}) +public interface EventKind { + + String getKind(); +} diff --git a/core/src/main/java/io/a2a/spec/FileContent.java b/core/src/main/java/io/a2a/spec/FileContent.java new file mode 100644 index 000000000..f9609fb8b --- /dev/null +++ b/core/src/main/java/io/a2a/spec/FileContent.java @@ -0,0 +1,11 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; + +@JsonDeserialize(using = FileContentDeserializer.class) +public sealed interface FileContent permits FileWithBytes, FileWithUri { + + String mimeType(); + + String name(); +} diff --git a/core/src/main/java/io/a2a/spec/FileContentDeserializer.java b/core/src/main/java/io/a2a/spec/FileContentDeserializer.java new file mode 100644 index 000000000..aa763db42 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/FileContentDeserializer.java @@ -0,0 +1,38 @@ +package io.a2a.spec; + +import java.io.IOException; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.deser.std.StdDeserializer; + +public class FileContentDeserializer extends StdDeserializer { + + public FileContentDeserializer() { + this(null); + } + + public FileContentDeserializer(Class vc) { + super(vc); + } + + @Override + public FileContent deserialize(JsonParser jsonParser, DeserializationContext context) + throws IOException, JsonProcessingException { + JsonNode node = jsonParser.getCodec().readTree(jsonParser); + JsonNode mimeType = node.get("mimeType"); + JsonNode name = node.get("name"); + JsonNode bytes = node.get("bytes"); + if (bytes != null) { + return new FileWithBytes(mimeType != null ? mimeType.asText() : null, + name != null ? name.asText() : null, bytes.asText()); + } else if (node.has("uri")) { + return new FileWithUri(mimeType != null ? mimeType.asText() : null, + name != null ? name.asText() : null, node.get("uri").asText()); + } else { + throw new IOException("Invalid file format: missing 'bytes' or 'uri'"); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/FilePart.java b/core/src/main/java/io/a2a/spec/FilePart.java new file mode 100644 index 000000000..f4c1d0f8e --- /dev/null +++ b/core/src/main/java/io/a2a/spec/FilePart.java @@ -0,0 +1,48 @@ +package io.a2a.spec; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.a2a.util.Assert; + +/** + * A fundamental file unit within a Message or Artifact. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public class FilePart extends Part { + + private final FileContent file; + private final Map metadata; + private final Kind kind; + + public FilePart(FileContent file) { + this(file, null); + } + + @JsonCreator + public FilePart(@JsonProperty("file") FileContent file, @JsonProperty("metadata") Map metadata) { + Assert.checkNotNullParam("file", file); + this.file = file; + this.metadata = metadata; + this.kind = Kind.FILE; + } + + @Override + public Kind getKind() { + return kind; + } + + public FileContent getFile() { + return file; + } + + @Override + public Map getMetadata() { + return metadata; + } + +} \ No newline at end of file diff --git a/core/src/main/java/io/a2a/spec/FileWithBytes.java b/core/src/main/java/io/a2a/spec/FileWithBytes.java new file mode 100644 index 000000000..e2259e902 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/FileWithBytes.java @@ -0,0 +1,9 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; + +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record FileWithBytes(String mimeType, String name, String bytes) implements FileContent { +} diff --git a/core/src/main/java/io/a2a/spec/FileWithUri.java b/core/src/main/java/io/a2a/spec/FileWithUri.java new file mode 100644 index 000000000..65db42dc5 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/FileWithUri.java @@ -0,0 +1,10 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; + +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record FileWithUri(String mimeType, String name, String uri) implements FileContent { +} + diff --git a/core/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigRequest.java b/core/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigRequest.java new file mode 100644 index 000000000..ba84b0231 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigRequest.java @@ -0,0 +1,77 @@ +package io.a2a.spec; + +import static io.a2a.spec.A2A.GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD; +import static io.a2a.spec.A2A.JSONRPC_VERSION; +import static io.a2a.util.Utils.defaultIfNull; + +import java.util.UUID; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.a2a.util.Assert; + +/** + * A get task push notification request. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class GetTaskPushNotificationConfigRequest extends NonStreamingJSONRPCRequest { + + @JsonCreator + public GetTaskPushNotificationConfigRequest(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("method") String method, @JsonProperty("params") TaskIdParams params) { + if (jsonrpc != null && ! jsonrpc.equals(JSONRPC_VERSION)) { + throw new IllegalArgumentException("Invalid JSON-RPC protocol version"); + } + Assert.checkNotNullParam("method", method); + if (! method.equals(GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD)) { + throw new IllegalArgumentException("Invalid GetTaskPushNotificationRequest method"); + } + Assert.isNullOrStringOrInteger(id); + this.jsonrpc = defaultIfNull(jsonrpc, JSONRPC_VERSION); + this.id = id; + this.method = method; + this.params = params; + } + + public GetTaskPushNotificationConfigRequest(String id, TaskIdParams params) { + this(null, id, GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD, params); + } + + public static class Builder { + private String jsonrpc; + private Object id; + private String method; + private TaskIdParams params; + + public GetTaskPushNotificationConfigRequest.Builder jsonrpc(String jsonrpc) { + this.jsonrpc = jsonrpc; + return this; + } + + public GetTaskPushNotificationConfigRequest.Builder id(Object id) { + this.id = id; + return this; + } + + public GetTaskPushNotificationConfigRequest.Builder method(String method) { + this.method = method; + return this; + } + + public GetTaskPushNotificationConfigRequest.Builder params(TaskIdParams params) { + this.params = params; + return this; + } + + public GetTaskPushNotificationConfigRequest build() { + if (id == null) { + id = UUID.randomUUID().toString(); + } + return new GetTaskPushNotificationConfigRequest(jsonrpc, id, method, params); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigResponse.java b/core/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigResponse.java new file mode 100644 index 000000000..c4340188f --- /dev/null +++ b/core/src/main/java/io/a2a/spec/GetTaskPushNotificationConfigResponse.java @@ -0,0 +1,30 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A response for a get task push notification request. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class GetTaskPushNotificationConfigResponse extends JSONRPCResponse { + + @JsonCreator + public GetTaskPushNotificationConfigResponse(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("result") TaskPushNotificationConfig result, + @JsonProperty("error") JSONRPCError error) { + super(jsonrpc, id, result, error); + } + + public GetTaskPushNotificationConfigResponse(Object id, JSONRPCError error) { + this(null, id, null, error); + } + + public GetTaskPushNotificationConfigResponse(Object id, TaskPushNotificationConfig result) { + this(null, id, result, null); + } + +} diff --git a/core/src/main/java/io/a2a/spec/GetTaskRequest.java b/core/src/main/java/io/a2a/spec/GetTaskRequest.java new file mode 100644 index 000000000..6c62cc265 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/GetTaskRequest.java @@ -0,0 +1,79 @@ +package io.a2a.spec; + +import static io.a2a.spec.A2A.GET_TASK_METHOD; +import static io.a2a.spec.A2A.JSONRPC_VERSION; +import static io.a2a.util.Utils.defaultIfNull; + +import java.util.UUID; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.a2a.util.Assert; + +/** + * A get task request. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class GetTaskRequest extends NonStreamingJSONRPCRequest { + + @JsonCreator + public GetTaskRequest(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("method") String method, @JsonProperty("params") TaskQueryParams params) { + if (jsonrpc != null && ! jsonrpc.equals(JSONRPC_VERSION)) { + throw new IllegalArgumentException("Invalid JSON-RPC protocol version"); + } + Assert.checkNotNullParam("method", method); + if (! method.equals(GET_TASK_METHOD)) { + throw new IllegalArgumentException("Invalid GetTaskRequest method"); + } + Assert.checkNotNullParam("params", params); + Assert.isNullOrStringOrInteger(id); + this.jsonrpc = defaultIfNull(jsonrpc, JSONRPC_VERSION); + this.id = id; + this.method = method; + this.params = params; + } + + public GetTaskRequest(Object id, TaskQueryParams params) { + this(null, id, GET_TASK_METHOD, params); + } + + + public static class Builder { + private String jsonrpc; + private Object id; + private String method = "tasks/get"; + private TaskQueryParams params; + + public GetTaskRequest.Builder jsonrpc(String jsonrpc) { + this.jsonrpc = jsonrpc; + return this; + } + + public GetTaskRequest.Builder id(Object id) { + this.id = id; + return this; + } + + public GetTaskRequest.Builder method(String method) { + this.method = method; + return this; + } + + public GetTaskRequest.Builder params(TaskQueryParams params) { + this.params = params; + return this; + } + + public GetTaskRequest build() { + if (id == null) { + id = UUID.randomUUID().toString(); + } + return new GetTaskRequest(jsonrpc, id, method, params); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/GetTaskResponse.java b/core/src/main/java/io/a2a/spec/GetTaskResponse.java new file mode 100644 index 000000000..e51cb66aa --- /dev/null +++ b/core/src/main/java/io/a2a/spec/GetTaskResponse.java @@ -0,0 +1,28 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * The response for a get task request. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class GetTaskResponse extends JSONRPCResponse { + + @JsonCreator + public GetTaskResponse(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("result") Task result, @JsonProperty("error") JSONRPCError error) { + super(jsonrpc, id, result, error); + } + + public GetTaskResponse(Object id, JSONRPCError error) { + this(null, id, null, error); + } + + public GetTaskResponse(Object id, Task result) { + this(null, id, result, null); + } +} diff --git a/core/src/main/java/io/a2a/spec/HTTPAuthSecurityScheme.java b/core/src/main/java/io/a2a/spec/HTTPAuthSecurityScheme.java new file mode 100644 index 000000000..029419a19 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/HTTPAuthSecurityScheme.java @@ -0,0 +1,81 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.a2a.util.Assert; + +/** + * Represents an HTTP authentication security scheme. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class HTTPAuthSecurityScheme implements SecurityScheme { + + public static final String HTTP = "http"; + private final String bearerFormat; + private final String scheme; + private final String description; + private final String type; + + public HTTPAuthSecurityScheme(String bearerFormat, String scheme, String description) { + this(bearerFormat, scheme, description, HTTP); + } + + @JsonCreator + public HTTPAuthSecurityScheme(@JsonProperty("bearerFormat") String bearerFormat, @JsonProperty("scheme") String scheme, + @JsonProperty("description") String description, @JsonProperty("type") String type) { + Assert.checkNotNullParam("scheme", scheme); + if (! type.equals(HTTP)) { + throw new IllegalArgumentException("Invalid type for HTTPAuthSecurityScheme"); + } + this.bearerFormat = bearerFormat; + this.scheme = scheme; + this.description = description; + this.type = type; + } + + @Override + public String getDescription() { + return description; + } + + public String getBearerFormat() { + return bearerFormat; + } + + public String getScheme() { + return scheme; + } + + public String getType() { + return type; + } + + public static class Builder { + private String bearerFormat; + private String scheme; + private String description; + + public Builder bearerFormat(String bearerFormat) { + this.bearerFormat = bearerFormat; + return this; + } + + public Builder scheme(String scheme) { + this.scheme = scheme; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public HTTPAuthSecurityScheme build() { + return new HTTPAuthSecurityScheme(bearerFormat, scheme, description); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/IdJsonMappingException.java b/core/src/main/java/io/a2a/spec/IdJsonMappingException.java new file mode 100644 index 000000000..15e0b07b1 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/IdJsonMappingException.java @@ -0,0 +1,22 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.databind.JsonMappingException; + +public class IdJsonMappingException extends JsonMappingException { + + Object id; + + public IdJsonMappingException(String msg, Object id) { + super(null, msg); + this.id = id; + } + + public IdJsonMappingException(String msg, Throwable cause, Object id) { + super(null, msg, cause); + this.id = id; + } + + public Object getId() { + return id; + } +} diff --git a/core/src/main/java/io/a2a/spec/ImplicitOAuthFlow.java b/core/src/main/java/io/a2a/spec/ImplicitOAuthFlow.java new file mode 100644 index 000000000..46e76cc84 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/ImplicitOAuthFlow.java @@ -0,0 +1,21 @@ +package io.a2a.spec; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; + +import io.a2a.util.Assert; + +/** + * Configuration for the OAuth Implicit flow. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record ImplicitOAuthFlow(String authorizationUrl, String refreshUrl, Map scopes) { + + public ImplicitOAuthFlow { + Assert.checkNotNullParam("authorizationUrl", authorizationUrl); + Assert.checkNotNullParam("scopes", scopes); + } +} diff --git a/core/src/main/java/io/a2a/spec/IntegerJsonrpcId.java b/core/src/main/java/io/a2a/spec/IntegerJsonrpcId.java new file mode 100644 index 000000000..2a01c1f50 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/IntegerJsonrpcId.java @@ -0,0 +1,4 @@ +package io.a2a.spec; + +public class IntegerJsonrpcId { +} diff --git a/core/src/main/java/io/a2a/spec/InternalError.java b/core/src/main/java/io/a2a/spec/InternalError.java new file mode 100644 index 000000000..02929c0d9 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/InternalError.java @@ -0,0 +1,27 @@ +package io.a2a.spec; + +import static io.a2a.util.Utils.defaultIfNull; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public class InternalError extends JSONRPCError { + @JsonCreator + public InternalError( + @JsonProperty("code") Integer code, + @JsonProperty("message") String message, + @JsonProperty("data") Object data) { + super( + defaultIfNull(code, -32603), + defaultIfNull(message, "Internal Error"), + data); + } + + public InternalError(String message) { + this(null, message, null); + } +} diff --git a/core/src/main/java/io/a2a/spec/InvalidAgentResponseError.java b/core/src/main/java/io/a2a/spec/InvalidAgentResponseError.java new file mode 100644 index 000000000..fb1af4210 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/InvalidAgentResponseError.java @@ -0,0 +1,26 @@ +package io.a2a.spec; + +import static io.a2a.util.Utils.defaultIfNull; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A2A specific error indicating agent returned invalid response for the current method. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public class InvalidAgentResponseError extends JSONRPCError { + @JsonCreator + public InvalidAgentResponseError( + @JsonProperty("code") Integer code, + @JsonProperty("message") String message, + @JsonProperty("data") Object data) { + super( + defaultIfNull(code, -32006), + defaultIfNull(message, "Invalid agent response"), + data); + } +} diff --git a/core/src/main/java/io/a2a/spec/InvalidParamsError.java b/core/src/main/java/io/a2a/spec/InvalidParamsError.java new file mode 100644 index 000000000..4a532d5ff --- /dev/null +++ b/core/src/main/java/io/a2a/spec/InvalidParamsError.java @@ -0,0 +1,31 @@ +package io.a2a.spec; + +import static io.a2a.util.Utils.defaultIfNull; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public class InvalidParamsError extends JSONRPCError { + @JsonCreator + public InvalidParamsError( + @JsonProperty("code") Integer code, + @JsonProperty("message") String message, + @JsonProperty("data") Object data) { + super( + defaultIfNull(code, -32602), + defaultIfNull(message, "Invalid parameters"), + data); + } + + public InvalidParamsError(String message) { + this(null, message, null); + } + + public InvalidParamsError() { + this(null, null, null); + } +} diff --git a/core/src/main/java/io/a2a/spec/InvalidParamsJsonMappingException.java b/core/src/main/java/io/a2a/spec/InvalidParamsJsonMappingException.java new file mode 100644 index 000000000..41aa9a9bc --- /dev/null +++ b/core/src/main/java/io/a2a/spec/InvalidParamsJsonMappingException.java @@ -0,0 +1,12 @@ +package io.a2a.spec; + +public class InvalidParamsJsonMappingException extends IdJsonMappingException { + + public InvalidParamsJsonMappingException(String msg, Object id) { + super(msg, id); + } + + public InvalidParamsJsonMappingException(String msg, Throwable cause, Object id) { + super(msg, cause, id); + } +} diff --git a/core/src/main/java/io/a2a/spec/InvalidRequestError.java b/core/src/main/java/io/a2a/spec/InvalidRequestError.java new file mode 100644 index 000000000..a22fc15ba --- /dev/null +++ b/core/src/main/java/io/a2a/spec/InvalidRequestError.java @@ -0,0 +1,32 @@ +package io.a2a.spec; + +import static io.a2a.util.Utils.defaultIfNull; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public class InvalidRequestError extends JSONRPCError { + + public InvalidRequestError() { + this(null, null, null); + } + + @JsonCreator + public InvalidRequestError( + @JsonProperty("code") Integer code, + @JsonProperty("message") String message, + @JsonProperty("data") Object data) { + super( + defaultIfNull(code, -32600), + defaultIfNull(message, "Request payload validation error"), + data); + } + + public InvalidRequestError(String message) { + this(null, message, null); + } +} diff --git a/core/src/main/java/io/a2a/spec/JSONErrorResponse.java b/core/src/main/java/io/a2a/spec/JSONErrorResponse.java new file mode 100644 index 000000000..3029f5394 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/JSONErrorResponse.java @@ -0,0 +1,9 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; + +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record JSONErrorResponse(String error) { +} diff --git a/core/src/main/java/io/a2a/spec/JSONParseError.java b/core/src/main/java/io/a2a/spec/JSONParseError.java new file mode 100644 index 000000000..93000b5bf --- /dev/null +++ b/core/src/main/java/io/a2a/spec/JSONParseError.java @@ -0,0 +1,32 @@ +package io.a2a.spec; + +import static io.a2a.util.Utils.defaultIfNull; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public class JSONParseError extends JSONRPCError implements A2AError { + + public JSONParseError() { + this(null, null, null); + } + + public JSONParseError(String message) { + this(null, message, null); + } + + @JsonCreator + public JSONParseError( + @JsonProperty("code") Integer code, + @JsonProperty("message") String message, + @JsonProperty("data") Object data) { + super( + defaultIfNull(code, -32700), + defaultIfNull(message, "Invalid JSON payload"), + data); + } +} diff --git a/core/src/main/java/io/a2a/spec/JSONRPCError.java b/core/src/main/java/io/a2a/spec/JSONRPCError.java new file mode 100644 index 000000000..b682203d4 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/JSONRPCError.java @@ -0,0 +1,53 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; + +import io.a2a.util.Assert; + +/** + * Represents a JSONRPC error. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonDeserialize(using = JSONRPCErrorDeserializer.class) +@JsonSerialize(using = JSONRPCErrorSerializer.class) +@JsonIgnoreProperties(ignoreUnknown = true) +public class JSONRPCError extends Error implements Event, A2AError { + + private final Integer code; + private final Object data; + + @JsonCreator + public JSONRPCError( + @JsonProperty("code") Integer code, + @JsonProperty("message") String message, + @JsonProperty("data") Object data) { + super(message); + Assert.checkNotNullParam("code", code); + Assert.checkNotNullParam("message", message); + this.code = code; + this.data = data; + } + + /** + * Gets the error code + * + * @return the error code + */ + public Integer getCode() { + return code; + } + + /** + * Gets the data associated with the error. + * + * @return the data. May be {@code null} + */ + public Object getData() { + return data; + } +} diff --git a/core/src/main/java/io/a2a/spec/JSONRPCErrorDeserializer.java b/core/src/main/java/io/a2a/spec/JSONRPCErrorDeserializer.java new file mode 100644 index 000000000..33f438469 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/JSONRPCErrorDeserializer.java @@ -0,0 +1,31 @@ +package io.a2a.spec; + +import java.io.IOException; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.deser.std.StdDeserializer; + +public class JSONRPCErrorDeserializer extends StdDeserializer { + + public JSONRPCErrorDeserializer() { + this(null); + } + + public JSONRPCErrorDeserializer(Class vc) { + super(vc); + } + + @Override + public JSONRPCError deserialize(JsonParser jsonParser, DeserializationContext context) + throws IOException, JsonProcessingException { + JsonNode node = jsonParser.getCodec().readTree(jsonParser); + int code = node.get("code").asInt(); + String message = node.get("message").asText(); + JsonNode dataNode = node.get("data"); + Object data = dataNode != null ? jsonParser.getCodec().treeToValue(dataNode, Object.class) : null; + return new JSONRPCError(code, message, data); + } +} diff --git a/core/src/main/java/io/a2a/spec/JSONRPCErrorResponse.java b/core/src/main/java/io/a2a/spec/JSONRPCErrorResponse.java new file mode 100644 index 000000000..95ac5d341 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/JSONRPCErrorResponse.java @@ -0,0 +1,31 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.a2a.util.Assert; + +/** + * A JSON RPC error response. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class JSONRPCErrorResponse extends JSONRPCResponse { + + @JsonCreator + public JSONRPCErrorResponse(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("result") Void result, @JsonProperty("error") JSONRPCError error) { + super(jsonrpc, id, result, error); + Assert.checkNotNullParam("error", error); + } + + public JSONRPCErrorResponse(Object id, JSONRPCError error) { + this(null, id, null, error); + } + + public JSONRPCErrorResponse(JSONRPCError error) { + this(null, null, null, error); + } +} diff --git a/core/src/main/java/io/a2a/spec/JSONRPCErrorSerializer.java b/core/src/main/java/io/a2a/spec/JSONRPCErrorSerializer.java new file mode 100644 index 000000000..87b427548 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/JSONRPCErrorSerializer.java @@ -0,0 +1,29 @@ +package io.a2a.spec; + +import java.io.IOException; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.ser.std.StdSerializer; + +public class JSONRPCErrorSerializer extends StdSerializer { + + public JSONRPCErrorSerializer() { + this(null); + } + + public JSONRPCErrorSerializer(Class vc) { + super(vc); + } + + @Override + public void serialize(JSONRPCError value, JsonGenerator gen, SerializerProvider provider) throws IOException { + gen.writeStartObject(); + gen.writeNumberField("code", value.getCode()); + gen.writeStringField("message", value.getMessage()); + if (value.getData() != null) { + gen.writeObjectField("data", value.getData()); + } + gen.writeEndObject(); + } +} diff --git a/core/src/main/java/io/a2a/spec/JSONRPCMessage.java b/core/src/main/java/io/a2a/spec/JSONRPCMessage.java new file mode 100644 index 000000000..95cadfd94 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/JSONRPCMessage.java @@ -0,0 +1,11 @@ +package io.a2a.spec; + +/** + * Represents a JSONRPC message. + */ +public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCResponse { + + String getJsonrpc(); + Object getId(); + +} diff --git a/core/src/main/java/io/a2a/spec/JSONRPCRequest.java b/core/src/main/java/io/a2a/spec/JSONRPCRequest.java new file mode 100644 index 000000000..f01318974 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/JSONRPCRequest.java @@ -0,0 +1,53 @@ +package io.a2a.spec; + +import static io.a2a.spec.A2A.JSONRPC_VERSION; +import static io.a2a.util.Utils.defaultIfNull; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; + +import io.a2a.util.Assert; + +/** + * Represents a JSONRPC request. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public abstract sealed class JSONRPCRequest implements JSONRPCMessage permits NonStreamingJSONRPCRequest, StreamingJSONRPCRequest { + + protected String jsonrpc; + protected Object id; + protected String method; + protected T params; + + public JSONRPCRequest() { + } + + public JSONRPCRequest(String jsonrpc, Object id, String method, T params) { + Assert.checkNotNullParam("jsonrpc", jsonrpc); + Assert.checkNotNullParam("method", method); + Assert.isNullOrStringOrInteger(id); + this.jsonrpc = defaultIfNull(jsonrpc, JSONRPC_VERSION); + this.id = id; + this.method = method; + this.params = params; + } + + @Override + public String getJsonrpc() { + return this.jsonrpc; + } + + @Override + public Object getId() { + return this.id; + } + + public String getMethod() { + return this.method; + } + + public T getParams() { + return this.params; + } +} diff --git a/core/src/main/java/io/a2a/spec/JSONRPCRequestDeserializerBase.java b/core/src/main/java/io/a2a/spec/JSONRPCRequestDeserializerBase.java new file mode 100644 index 000000000..2cdc61668 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/JSONRPCRequestDeserializerBase.java @@ -0,0 +1,115 @@ +package io.a2a.spec; + +import static io.a2a.spec.A2A.CANCEL_TASK_METHOD; +import static io.a2a.spec.A2A.GET_TASK_METHOD; +import static io.a2a.spec.A2A.GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD; +import static io.a2a.spec.A2A.SEND_MESSAGE_METHOD; +import static io.a2a.spec.A2A.SEND_STREAMING_MESSAGE_METHOD; +import static io.a2a.spec.A2A.SEND_TASK_RESUBSCRIPTION_METHOD; +import static io.a2a.spec.A2A.SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD; +import static io.a2a.spec.A2A.isValidMethodName; +import static io.a2a.util.Utils.OBJECT_MAPPER; + +import java.io.IOException; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.deser.std.StdDeserializer; + +public abstract class JSONRPCRequestDeserializerBase extends StdDeserializer> { + + public JSONRPCRequestDeserializerBase() { + this(null); + } + + public JSONRPCRequestDeserializerBase(Class vc) { + super(vc); + } + + @Override + public JSONRPCRequest deserialize(JsonParser jsonParser, DeserializationContext context) + throws IOException, JsonProcessingException { + JsonNode treeNode = jsonParser.getCodec().readTree(jsonParser); + String jsonrpc = getAndValidateJsonrpc(treeNode, jsonParser); + String method = getAndValidateMethod(treeNode, jsonParser); + Object id = getAndValidateId(treeNode, jsonParser); + JsonNode paramsNode = treeNode.get("params"); + + switch (method) { + case GET_TASK_METHOD: + return new GetTaskRequest(jsonrpc, id, method, getAndValidateParams(paramsNode, jsonParser, treeNode, TaskQueryParams.class)); + case CANCEL_TASK_METHOD: + return new CancelTaskRequest(jsonrpc, id, method, getAndValidateParams(paramsNode, jsonParser, treeNode, TaskIdParams.class)); + case SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD: + return new SetTaskPushNotificationConfigRequest(jsonrpc, id, method, getAndValidateParams(paramsNode, jsonParser, treeNode, TaskPushNotificationConfig.class)); + case GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD: + return new GetTaskPushNotificationConfigRequest(jsonrpc, id, method, getAndValidateParams(paramsNode, jsonParser, treeNode, TaskIdParams.class)); + case SEND_MESSAGE_METHOD: + return new SendMessageRequest(jsonrpc, id, method, getAndValidateParams(paramsNode, jsonParser, treeNode, MessageSendParams.class)); + case SEND_TASK_RESUBSCRIPTION_METHOD: + return new TaskResubscriptionRequest(jsonrpc, id, method, getAndValidateParams(paramsNode, jsonParser, treeNode, TaskIdParams.class)); + case SEND_STREAMING_MESSAGE_METHOD: + return new SendStreamingMessageRequest(jsonrpc, id, method, getAndValidateParams(paramsNode, jsonParser, treeNode, MessageSendParams.class)); + default: + throw new MethodNotFoundJsonMappingException("Invalid method", getIdIfPossible(treeNode, jsonParser)); + } + } + + protected T getAndValidateParams(JsonNode paramsNode, JsonParser jsonParser, JsonNode node, Class paramsType) throws JsonMappingException { + if (paramsNode == null) { + return null; + } + try { + return OBJECT_MAPPER.treeToValue(paramsNode, paramsType); + } catch (JsonProcessingException e) { + throw new InvalidParamsJsonMappingException("Invalid params", e, getIdIfPossible(node, jsonParser)); + } + } + + protected String getAndValidateJsonrpc(JsonNode treeNode, JsonParser jsonParser) throws JsonMappingException { + JsonNode jsonrpcNode = treeNode.get("jsonrpc"); + if (jsonrpcNode == null || ! jsonrpcNode.asText().equals(A2A.JSONRPC_VERSION)) { + throw new IdJsonMappingException("Invalid JSON-RPC protocol version", getIdIfPossible(treeNode, jsonParser)); + } + return jsonrpcNode.asText(); + } + + protected String getAndValidateMethod(JsonNode treeNode, JsonParser jsonParser) throws JsonMappingException { + JsonNode methodNode = treeNode.get("method"); + if (methodNode == null) { + throw new IdJsonMappingException("Missing method", getIdIfPossible(treeNode, jsonParser)); + } + String method = methodNode.asText(); + if (! isValidMethodName(method)) { + throw new MethodNotFoundJsonMappingException("Invalid method", getIdIfPossible(treeNode, jsonParser)); + } + return method; + } + + protected Object getAndValidateId(JsonNode treeNode, JsonParser jsonParser) throws JsonProcessingException { + JsonNode idNode = treeNode.get("id"); + Object id = null; + if (idNode != null) { + if (idNode.isTextual()) { + id = OBJECT_MAPPER.treeToValue(idNode, String.class); + } else if (idNode.isNumber()) { + id = OBJECT_MAPPER.treeToValue(idNode, Integer.class); + } else { + throw new JsonMappingException(jsonParser, "Invalid id"); + } + } + return id; + } + + protected Object getIdIfPossible(JsonNode treeNode, JsonParser jsonParser) { + try { + return getAndValidateId(treeNode, jsonParser); + } catch (JsonProcessingException e) { + // id can't be determined + return null; + } + } +} diff --git a/core/src/main/java/io/a2a/spec/JSONRPCResponse.java b/core/src/main/java/io/a2a/spec/JSONRPCResponse.java new file mode 100644 index 000000000..91b9443ee --- /dev/null +++ b/core/src/main/java/io/a2a/spec/JSONRPCResponse.java @@ -0,0 +1,60 @@ +package io.a2a.spec; + +import static io.a2a.spec.A2A.JSONRPC_VERSION; +import static io.a2a.util.Utils.defaultIfNull; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; + +import io.a2a.util.Assert; + +/** + * Represents a JSONRPC response. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public abstract sealed class JSONRPCResponse implements JSONRPCMessage permits SendStreamingMessageResponse, + GetTaskResponse, CancelTaskResponse, SetTaskPushNotificationConfigResponse, GetTaskPushNotificationConfigResponse, + SendMessageResponse, JSONRPCErrorResponse { + + protected String jsonrpc; + protected Object id; + protected T result; + protected JSONRPCError error; + + public JSONRPCResponse() { + } + + public JSONRPCResponse(String jsonrpc, Object id, T result, JSONRPCError error) { + if (jsonrpc != null && ! jsonrpc.equals(JSONRPC_VERSION)) { + throw new IllegalArgumentException("Invalid JSON-RPC protocol version"); + } + if (error != null && result != null) { + throw new IllegalArgumentException("Invalid JSON-RPC error response"); + } + if (error == null && result == null) { + throw new IllegalArgumentException("Invalid JSON-RPC success response"); + } + Assert.isNullOrStringOrInteger(id); + this.jsonrpc = defaultIfNull(jsonrpc, JSONRPC_VERSION); + this.id = id; + this.result = result; + this.error = error; + } + + public String getJsonrpc() { + return this.jsonrpc; + } + + public Object getId() { + return this.id; + } + + public T getResult() { + return this.result; + } + + public JSONRPCError getError() { + return this.error; + } +} diff --git a/core/src/main/java/io/a2a/spec/JsonrpcId.java b/core/src/main/java/io/a2a/spec/JsonrpcId.java new file mode 100644 index 000000000..e4db4b458 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/JsonrpcId.java @@ -0,0 +1,4 @@ +package io.a2a.spec; + +public interface JsonrpcId { +} diff --git a/core/src/main/java/io/a2a/spec/Message.java b/core/src/main/java/io/a2a/spec/Message.java new file mode 100644 index 000000000..f17646a0f --- /dev/null +++ b/core/src/main/java/io/a2a/spec/Message.java @@ -0,0 +1,188 @@ +package io.a2a.spec; + +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonValue; +import com.fasterxml.jackson.core.type.TypeReference; +import io.a2a.util.Assert; + +/** + * An A2A message. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class Message implements EventKind, StreamingEventKind { + + public static final TypeReference TYPE_REFERENCE = new TypeReference<>() {}; + + public static final String MESSAGE = "message"; + private final Role role; + private final List> parts; + private final String messageId; + private String contextId; + private String taskId; + private final Map metadata; + private final String kind; + private final List referenceTaskIds; + + public Message(Role role, List> parts, String messageId, String contextId, String taskId, + List referenceTaskIds, Map metadata) { + this(role, parts, messageId, contextId, taskId, referenceTaskIds, metadata, MESSAGE); + } + + @JsonCreator + public Message(@JsonProperty("role") Role role, @JsonProperty("parts") List> parts, + @JsonProperty("messageId") String messageId, @JsonProperty("contextId") String contextId, + @JsonProperty("taskId") String taskId, @JsonProperty("referenceTaskIds") List referenceTaskIds, + @JsonProperty("metadata") Map metadata, + @JsonProperty("kind") String kind) { + Assert.checkNotNullParam("kind", kind); + Assert.checkNotNullParam("parts", parts); + if (parts.isEmpty()) { + throw new IllegalArgumentException("Parts cannot be empty"); + } + Assert.checkNotNullParam("role", role); + if (! kind.equals(MESSAGE)) { + throw new IllegalArgumentException("Invalid Message"); + } + this.role = role; + this.parts = parts; + this.messageId = messageId == null ? UUID.randomUUID().toString() : messageId; + this.contextId = contextId; + this.taskId = taskId; + this.referenceTaskIds = referenceTaskIds; + this.metadata = metadata; + this.kind = kind; + } + + public Role getRole() { + return role; + } + + public List> getParts() { + return parts; + } + + public String getMessageId() { + return messageId; + } + + public String getContextId() { + return contextId; + } + + public String getTaskId() { + return taskId; + } + + public Map getMetadata() { + return metadata; + } + + public void setTaskId(String taskId) { + this.taskId = taskId; + } + + public void setContextId(String contextId) { + this.contextId = contextId; + } + + public List getReferenceTaskIds() { + return referenceTaskIds; + } + + @Override + public String getKind() { + return kind; + } + + public enum Role { + USER("user"), + AGENT("agent"); + + private String role; + + Role(String role) { + this.role = role; + } + + @JsonValue + public String asString() { + return this.role; + } + } + + public static class Builder { + + private Role role; + private List> parts; + private String messageId; + private String contextId; + private String taskId; + private List referenceTaskIds; + private Map metadata; + + public Builder() { + } + + public Builder(Message message) { + role = message.role; + parts = message.parts; + messageId = message.messageId; + contextId = message.contextId; + taskId = message.taskId; + referenceTaskIds = message.referenceTaskIds; + metadata = message.metadata; + } + + public Builder role(Role role) { + this.role = role; + return this; + } + + public Builder parts(List> parts) { + this.parts = parts; + return this; + } + + public Builder parts(Part...parts) { + this.parts = List.of(parts); + return this; + } + + public Builder messageId(String messageId) { + this.messageId = messageId; + return this; + } + + public Builder contextId(String contextId) { + this.contextId = contextId; + return this; + } + + public Builder taskId(String taskId) { + this.taskId = taskId; + return this; + } + + public Builder referenceTaskIds(List referenceTaskIds) { + this.referenceTaskIds = referenceTaskIds; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public Message build() { + return new Message(role, parts, messageId, contextId, taskId, referenceTaskIds, metadata); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/MessageSendConfiguration.java b/core/src/main/java/io/a2a/spec/MessageSendConfiguration.java new file mode 100644 index 000000000..f34af9dcf --- /dev/null +++ b/core/src/main/java/io/a2a/spec/MessageSendConfiguration.java @@ -0,0 +1,58 @@ +package io.a2a.spec; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; + +import io.a2a.util.Assert; + +/** + * Represents the configuration of the message to be sent. + * + * If {@code blocking} is true, {@code pushNotification} is ignored. + * Both {@code blocking} and {@code pushNotification} are ignored in streaming interactions. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record MessageSendConfiguration(List acceptedOutputModes, Integer historyLength, + PushNotificationConfig pushNotification, boolean blocking) { + + public MessageSendConfiguration { + Assert.checkNotNullParam("acceptedOutputModes", acceptedOutputModes); + if (historyLength != null && historyLength < 0) { + throw new IllegalArgumentException("Invalid history length"); + } + } + + public static class Builder { + List acceptedOutputModes; + Integer historyLength; + PushNotificationConfig pushNotification; + boolean blocking; + + public Builder acceptedOutputModes(List acceptedOutputModes) { + this.acceptedOutputModes = acceptedOutputModes; + return this; + } + + public Builder pushNotification(PushNotificationConfig pushNotification) { + this.pushNotification = pushNotification; + return this; + } + + public Builder historyLength(Integer historyLength) { + this.historyLength = historyLength; + return this; + } + + public Builder blocking(boolean blocking) { + this.blocking = blocking; + return this; + } + + public MessageSendConfiguration build() { + return new MessageSendConfiguration(acceptedOutputModes, historyLength, pushNotification, blocking); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/MessageSendParams.java b/core/src/main/java/io/a2a/spec/MessageSendParams.java new file mode 100644 index 000000000..a217539b1 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/MessageSendParams.java @@ -0,0 +1,45 @@ +package io.a2a.spec; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import io.a2a.util.Assert; + +/** + * Used to specify parameters when creating a message. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record MessageSendParams(Message message, MessageSendConfiguration configuration, + Map metadata) { + + public MessageSendParams { + Assert.checkNotNullParam("message", message); + } + + public static class Builder { + Message message; + MessageSendConfiguration configuration; + Map metadata; + + public Builder message(Message message) { + this.message = message; + return this; + } + + public Builder configuration(MessageSendConfiguration configuration) { + this.configuration = configuration; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public MessageSendParams build() { + return new MessageSendParams(message, configuration, metadata); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/MethodNotFoundError.java b/core/src/main/java/io/a2a/spec/MethodNotFoundError.java new file mode 100644 index 000000000..967e938ef --- /dev/null +++ b/core/src/main/java/io/a2a/spec/MethodNotFoundError.java @@ -0,0 +1,27 @@ +package io.a2a.spec; + +import static io.a2a.util.Utils.defaultIfNull; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public class MethodNotFoundError extends JSONRPCError { + @JsonCreator + public MethodNotFoundError( + @JsonProperty("code") Integer code, + @JsonProperty("message") String message, + @JsonProperty("data") Object data) { + super( + defaultIfNull(code, -32601), + defaultIfNull(message, "Method not found"), + data); + } + + public MethodNotFoundError() { + this(-32601, null, null); + } +} diff --git a/core/src/main/java/io/a2a/spec/MethodNotFoundJsonMappingException.java b/core/src/main/java/io/a2a/spec/MethodNotFoundJsonMappingException.java new file mode 100644 index 000000000..7bd167ad1 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/MethodNotFoundJsonMappingException.java @@ -0,0 +1,12 @@ +package io.a2a.spec; + +public class MethodNotFoundJsonMappingException extends IdJsonMappingException { + + public MethodNotFoundJsonMappingException(String msg, Object id) { + super(msg, id); + } + + public MethodNotFoundJsonMappingException(String msg, Throwable cause, Object id) { + super(msg, cause, id); + } +} diff --git a/core/src/main/java/io/a2a/spec/NonStreamingJSONRPCRequest.java b/core/src/main/java/io/a2a/spec/NonStreamingJSONRPCRequest.java new file mode 100644 index 000000000..f79800215 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/NonStreamingJSONRPCRequest.java @@ -0,0 +1,16 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; + +/** + * Represents a non-streaming JSON-RPC request. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonDeserialize(using = NonStreamingJSONRPCRequestDeserializer.class) +public abstract sealed class NonStreamingJSONRPCRequest extends JSONRPCRequest permits GetTaskRequest, + CancelTaskRequest, SetTaskPushNotificationConfigRequest, GetTaskPushNotificationConfigRequest, + SendMessageRequest { +} diff --git a/core/src/main/java/io/a2a/spec/NonStreamingJSONRPCRequestDeserializer.java b/core/src/main/java/io/a2a/spec/NonStreamingJSONRPCRequestDeserializer.java new file mode 100644 index 000000000..a1db9f5c8 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/NonStreamingJSONRPCRequestDeserializer.java @@ -0,0 +1,55 @@ +package io.a2a.spec; + +import static io.a2a.spec.A2A.CANCEL_TASK_METHOD; +import static io.a2a.spec.A2A.GET_TASK_METHOD; +import static io.a2a.spec.A2A.GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD; +import static io.a2a.spec.A2A.SEND_MESSAGE_METHOD; +import static io.a2a.spec.A2A.SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD; + +import java.io.IOException; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonNode; + +public class NonStreamingJSONRPCRequestDeserializer extends JSONRPCRequestDeserializerBase> { + + public NonStreamingJSONRPCRequestDeserializer() { + this(null); + } + + public NonStreamingJSONRPCRequestDeserializer(Class vc) { + super(vc); + } + + @Override + public NonStreamingJSONRPCRequest deserialize(JsonParser jsonParser, DeserializationContext context) + throws IOException, JsonProcessingException { + JsonNode treeNode = jsonParser.getCodec().readTree(jsonParser); + String jsonrpc = getAndValidateJsonrpc(treeNode, jsonParser); + String method = getAndValidateMethod(treeNode, jsonParser); + Object id = getAndValidateId(treeNode, jsonParser); + JsonNode paramsNode = treeNode.get("params"); + + switch (method) { + case GET_TASK_METHOD: + return new GetTaskRequest(jsonrpc, id, method, + getAndValidateParams(paramsNode, jsonParser, treeNode, TaskQueryParams.class)); + case CANCEL_TASK_METHOD: + return new CancelTaskRequest(jsonrpc, id, method, + getAndValidateParams(paramsNode, jsonParser, treeNode, TaskIdParams.class)); + case SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD: + return new SetTaskPushNotificationConfigRequest(jsonrpc, id, method, + getAndValidateParams(paramsNode, jsonParser, treeNode, TaskPushNotificationConfig.class)); + case GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD: + return new GetTaskPushNotificationConfigRequest(jsonrpc, id, method, + getAndValidateParams(paramsNode, jsonParser, treeNode, TaskIdParams.class)); + case SEND_MESSAGE_METHOD: + return new SendMessageRequest(jsonrpc, id, method, + getAndValidateParams(paramsNode, jsonParser, treeNode, MessageSendParams.class)); + default: + throw new MethodNotFoundJsonMappingException("Invalid method", getIdIfPossible(treeNode, jsonParser)); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/OAuth2SecurityScheme.java b/core/src/main/java/io/a2a/spec/OAuth2SecurityScheme.java new file mode 100644 index 000000000..10141ffb7 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/OAuth2SecurityScheme.java @@ -0,0 +1,69 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.a2a.util.Assert; + +/** + * Represents an OAuth2 security scheme. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class OAuth2SecurityScheme implements SecurityScheme { + + public static final String OAUTH2 = "oauth2"; + private final OAuthFlows flows; + private final String description; + private final String type; + + public OAuth2SecurityScheme(OAuthFlows flows, String description) { + this(flows, description, OAUTH2); + } + + @JsonCreator + public OAuth2SecurityScheme(@JsonProperty("flows") OAuthFlows flows, @JsonProperty("description") String description, + @JsonProperty("type") String type) { + Assert.checkNotNullParam("flows", flows); + if (!type.equals(OAUTH2)) { + throw new IllegalArgumentException("Invalid type for OAuth2SecurityScheme"); + } + this.flows = flows; + this.description = description; + this.type = type; + } + + @Override + public String getDescription() { + return description; + } + + public OAuthFlows getFlows() { + return flows; + } + + public String getType() { + return type; + } + + public static class Builder { + private OAuthFlows flows; + private String description; + + public Builder flows(OAuthFlows flows) { + this.flows = flows; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public OAuth2SecurityScheme build() { + return new OAuth2SecurityScheme(flows, description); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/OAuthFlows.java b/core/src/main/java/io/a2a/spec/OAuthFlows.java new file mode 100644 index 000000000..fcd89bc3d --- /dev/null +++ b/core/src/main/java/io/a2a/spec/OAuthFlows.java @@ -0,0 +1,44 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; + +/** + * Allows configuration of the supported OAuth Flows. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record OAuthFlows(AuthorizationCodeOAuthFlow authorizationCode, ClientCredentialsOAuthFlow clientCredentials, + ImplicitOAuthFlow implicit, PasswordOAuthFlow password) { + + public static class Builder { + private AuthorizationCodeOAuthFlow authorizationCode; + private ClientCredentialsOAuthFlow clientCredentials; + private ImplicitOAuthFlow implicit; + private PasswordOAuthFlow password; + + public Builder authorizationCode(AuthorizationCodeOAuthFlow authorizationCode) { + this.authorizationCode = authorizationCode; + return this; + } + + public Builder clientCredentials(ClientCredentialsOAuthFlow clientCredentials) { + this.clientCredentials = clientCredentials; + return this; + } + + public Builder implicit(ImplicitOAuthFlow implicit) { + this.implicit = implicit; + return this; + } + + public Builder password(PasswordOAuthFlow password) { + this.password = password; + return this; + } + + public OAuthFlows build() { + return new OAuthFlows(authorizationCode, clientCredentials, implicit, password); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/OpenIdConnectSecurityScheme.java b/core/src/main/java/io/a2a/spec/OpenIdConnectSecurityScheme.java new file mode 100644 index 000000000..f8f059792 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/OpenIdConnectSecurityScheme.java @@ -0,0 +1,67 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Represents an OpenID Connect security scheme. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class OpenIdConnectSecurityScheme implements SecurityScheme { + + public static final String OPENID_CONNECT = "openIdConnect"; + private final String openIdConnectUrl; + private final String description; + private final String type; + + public OpenIdConnectSecurityScheme(String openIdConnectUrl, String description) { + this(openIdConnectUrl, description, OPENID_CONNECT); + } + + @JsonCreator + public OpenIdConnectSecurityScheme(@JsonProperty("openIdConnectUrl") String openIdConnectUrl, + @JsonProperty("description") String description, @JsonProperty("type") String type) { + if (!type.equals(OPENID_CONNECT)) { + throw new IllegalArgumentException("Invalid type for OpenIdConnectSecurityScheme"); + } + this.openIdConnectUrl = openIdConnectUrl; + this.description = description; + this.type = type; + } + + @Override + public String getDescription() { + return description; + } + + public String getOpenIdConnectUrl() { + return openIdConnectUrl; + } + + public String getType() { + return type; + } + + public static class Builder { + private String openIdConnectUrl; + private String description; + + public Builder openIdConnectUrl(String openIdConnectUrl) { + this.openIdConnectUrl = openIdConnectUrl; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public OpenIdConnectSecurityScheme build() { + return new OpenIdConnectSecurityScheme(openIdConnectUrl, description); + } + } + +} diff --git a/core/src/main/java/io/a2a/spec/Part.java b/core/src/main/java/io/a2a/spec/Part.java new file mode 100644 index 000000000..cf7f04a74 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/Part.java @@ -0,0 +1,46 @@ +package io.a2a.spec; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.annotation.JsonValue; + +/** + * A fundamental unit with a Message or Artifact. + * @param the type of unit + */ +@JsonTypeInfo( + use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.PROPERTY, + property = "kind", + visible = true +) +@JsonSubTypes({ + @JsonSubTypes.Type(value = TextPart.class, name = "text"), + @JsonSubTypes.Type(value = FilePart.class, name = "file"), + @JsonSubTypes.Type(value = DataPart.class, name = "data") +}) +public abstract class Part { + public enum Kind { + TEXT("text"), + FILE("file"), + DATA("data"); + + private String kind; + + Kind(String kind) { + this.kind = kind; + } + + @JsonValue + public String asString() { + return this.kind; + } + } + + public abstract Kind getKind(); + + public abstract Map getMetadata(); + +} \ No newline at end of file diff --git a/core/src/main/java/io/a2a/spec/PasswordOAuthFlow.java b/core/src/main/java/io/a2a/spec/PasswordOAuthFlow.java new file mode 100644 index 000000000..a42e74942 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/PasswordOAuthFlow.java @@ -0,0 +1,21 @@ +package io.a2a.spec; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; + +import io.a2a.util.Assert; + +/** + * Configuration for the OAuth Resource Owner Password flow. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record PasswordOAuthFlow(String refreshUrl, Map scopes, String tokenUrl) { + + public PasswordOAuthFlow { + Assert.checkNotNullParam("scopes", scopes); + Assert.checkNotNullParam("tokenUrl", tokenUrl); + } +} diff --git a/core/src/main/java/io/a2a/spec/PushNotificationAuthenticationInfo.java b/core/src/main/java/io/a2a/spec/PushNotificationAuthenticationInfo.java new file mode 100644 index 000000000..b5b2bacd2 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/PushNotificationAuthenticationInfo.java @@ -0,0 +1,19 @@ +package io.a2a.spec; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import io.a2a.util.Assert; + +/** + * Defines authentication details for push notifications. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record PushNotificationAuthenticationInfo(List schemes, String credentials) { + + public PushNotificationAuthenticationInfo { + Assert.checkNotNullParam("schemes", schemes); + } +} diff --git a/core/src/main/java/io/a2a/spec/PushNotificationConfig.java b/core/src/main/java/io/a2a/spec/PushNotificationConfig.java new file mode 100644 index 000000000..19e9d3491 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/PushNotificationConfig.java @@ -0,0 +1,48 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import io.a2a.util.Assert; + +/** + * Represents a push notification. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record PushNotificationConfig(String url, String token, PushNotificationAuthenticationInfo authentication, String id) { + + public PushNotificationConfig { + Assert.checkNotNullParam("url", url); + } + + public static class Builder { + private String url; + private String token; + private PushNotificationAuthenticationInfo authentication; + private String id; + + public Builder url(String url) { + this.url = url; + return this; + } + + public Builder token(String token) { + this.token = token; + return this; + } + + public Builder authenticationInfo(PushNotificationAuthenticationInfo authenticationInfo) { + this.authentication = authenticationInfo; + return this; + } + + public Builder id(String id) { + this.id = id; + return this; + } + + public PushNotificationConfig build() { + return new PushNotificationConfig(url, token, authentication, id); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/PushNotificationNotSupportedError.java b/core/src/main/java/io/a2a/spec/PushNotificationNotSupportedError.java new file mode 100644 index 000000000..5ae5e8aa9 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/PushNotificationNotSupportedError.java @@ -0,0 +1,23 @@ +package io.a2a.spec; + +import static io.a2a.util.Utils.defaultIfNull; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public class PushNotificationNotSupportedError extends JSONRPCError { + @JsonCreator + public PushNotificationNotSupportedError( + @JsonProperty("code") Integer code, + @JsonProperty("message") String message, + @JsonProperty("data") Object data) { + super( + defaultIfNull(code, -32003), + defaultIfNull(message, "Push Notification is not supported"), + data); + } +} diff --git a/core/src/main/java/io/a2a/spec/SecurityScheme.java b/core/src/main/java/io/a2a/spec/SecurityScheme.java new file mode 100644 index 000000000..5003d01aa --- /dev/null +++ b/core/src/main/java/io/a2a/spec/SecurityScheme.java @@ -0,0 +1,23 @@ +package io.a2a.spec; + +import static io.a2a.spec.APIKeySecurityScheme.API_KEY; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +@JsonTypeInfo( + use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.PROPERTY, + property = "type", + visible = true +) +@JsonSubTypes({ + @JsonSubTypes.Type(value = APIKeySecurityScheme.class, name = API_KEY), + @JsonSubTypes.Type(value = HTTPAuthSecurityScheme.class, name = HTTPAuthSecurityScheme.HTTP), + @JsonSubTypes.Type(value = OAuth2SecurityScheme.class, name = OAuth2SecurityScheme.OAUTH2), + @JsonSubTypes.Type(value = OpenIdConnectSecurityScheme.class, name = OpenIdConnectSecurityScheme.OPENID_CONNECT) +}) +public sealed interface SecurityScheme permits APIKeySecurityScheme, HTTPAuthSecurityScheme, OAuth2SecurityScheme, OpenIdConnectSecurityScheme { + + String getDescription(); +} diff --git a/core/src/main/java/io/a2a/spec/SendMessageRequest.java b/core/src/main/java/io/a2a/spec/SendMessageRequest.java new file mode 100644 index 000000000..620335aa7 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/SendMessageRequest.java @@ -0,0 +1,81 @@ +package io.a2a.spec; + +import static io.a2a.spec.A2A.JSONRPC_VERSION; +import static io.a2a.spec.A2A.SEND_MESSAGE_METHOD; +import static io.a2a.util.Utils.defaultIfNull; + +import java.util.UUID; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.a2a.util.Assert; + +/** + * Used to send a message request. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class SendMessageRequest extends NonStreamingJSONRPCRequest { + + @JsonCreator + public SendMessageRequest(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("method") String method, @JsonProperty("params") MessageSendParams params) { + if (jsonrpc == null || jsonrpc.isEmpty()) { + throw new IllegalArgumentException("JSON-RPC protocol version cannot be null or empty"); + } + if (jsonrpc != null && ! jsonrpc.equals(JSONRPC_VERSION)) { + throw new IllegalArgumentException("Invalid JSON-RPC protocol version"); + } + Assert.checkNotNullParam("method", method); + if (! method.equals(SEND_MESSAGE_METHOD)) { + throw new IllegalArgumentException("Invalid SendMessageRequest method"); + } + Assert.checkNotNullParam("params", params); + Assert.isNullOrStringOrInteger(id); + this.jsonrpc = defaultIfNull(jsonrpc, JSONRPC_VERSION); + this.id = id; + this.method = method; + this.params = params; + } + + public SendMessageRequest(Object id, MessageSendParams params) { + this(JSONRPC_VERSION, id, SEND_MESSAGE_METHOD, params); + } + + public static class Builder { + private String jsonrpc; + private Object id; + private String method; + private MessageSendParams params; + + public Builder jsonrpc(String jsonrpc) { + this.jsonrpc = jsonrpc; + return this; + } + + public Builder id(Object id) { + this.id = id; + return this; + } + + public Builder method(String method) { + this.method = method; + return this; + } + + public Builder params(MessageSendParams params) { + this.params = params; + return this; + } + + public SendMessageRequest build() { + if (id == null) { + id = UUID.randomUUID().toString(); + } + return new SendMessageRequest(jsonrpc, id, method, params); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/SendMessageResponse.java b/core/src/main/java/io/a2a/spec/SendMessageResponse.java new file mode 100644 index 000000000..80b97e272 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/SendMessageResponse.java @@ -0,0 +1,37 @@ +package io.a2a.spec; + +import static io.a2a.spec.A2A.JSONRPC_VERSION; +import static io.a2a.util.Utils.defaultIfNull; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.a2a.util.Assert; + +/** + * The response after receiving a send message request. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class SendMessageResponse extends JSONRPCResponse { + + @JsonCreator + public SendMessageResponse(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("result") EventKind result, @JsonProperty("error") JSONRPCError error) { + this.jsonrpc = defaultIfNull(jsonrpc, JSONRPC_VERSION); + Assert.isNullOrStringOrInteger(id); + this.id = id; + this.result = result; + this.error = error; + } + + public SendMessageResponse(Object id, EventKind result) { + this(null, id, result, null); + } + + public SendMessageResponse(Object id, JSONRPCError error) { + this(null, id, null, error); + } +} diff --git a/core/src/main/java/io/a2a/spec/SendStreamingMessageRequest.java b/core/src/main/java/io/a2a/spec/SendStreamingMessageRequest.java new file mode 100644 index 000000000..6503f03e7 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/SendStreamingMessageRequest.java @@ -0,0 +1,78 @@ +package io.a2a.spec; + +import static io.a2a.spec.A2A.JSONRPC_VERSION; +import static io.a2a.spec.A2A.SEND_STREAMING_MESSAGE_METHOD; +import static io.a2a.util.Utils.defaultIfNull; + +import java.util.UUID; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.a2a.util.Assert; + +/** + * Used to initiate a task with streaming. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class SendStreamingMessageRequest extends StreamingJSONRPCRequest { + + @JsonCreator + public SendStreamingMessageRequest(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("method") String method, @JsonProperty("params") MessageSendParams params) { + if (jsonrpc != null && ! jsonrpc.equals(JSONRPC_VERSION)) { + throw new IllegalArgumentException("Invalid JSON-RPC protocol version"); + } + Assert.checkNotNullParam("method", method); + if (! method.equals(SEND_STREAMING_MESSAGE_METHOD)) { + throw new IllegalArgumentException("Invalid SendStreamingMessageRequest method"); + } + Assert.checkNotNullParam("params", params); + Assert.isNullOrStringOrInteger(id); + this.jsonrpc = defaultIfNull(jsonrpc, JSONRPC_VERSION); + this.id = id; + this.method = method; + this.params = params; + } + + public SendStreamingMessageRequest(Object id, MessageSendParams params) { + this(null, id, SEND_STREAMING_MESSAGE_METHOD, params); + } + + public static class Builder { + private String jsonrpc; + private Object id; + private String method = SEND_STREAMING_MESSAGE_METHOD; + private MessageSendParams params; + + public Builder jsonrpc(String jsonrpc) { + this.jsonrpc = jsonrpc; + return this; + } + + public Builder id(Object id) { + this.id = id; + return this; + } + + public Builder method(String method) { + this.method = method; + return this; + } + + public Builder params(MessageSendParams params) { + this.params = params; + return this; + } + + public SendStreamingMessageRequest build() { + if (id == null) { + id = UUID.randomUUID().toString(); + } + return new SendStreamingMessageRequest(jsonrpc, id, method, params); + } + } + } diff --git a/core/src/main/java/io/a2a/spec/SendStreamingMessageResponse.java b/core/src/main/java/io/a2a/spec/SendStreamingMessageResponse.java new file mode 100644 index 000000000..3a1cc3553 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/SendStreamingMessageResponse.java @@ -0,0 +1,37 @@ +package io.a2a.spec; + +import static io.a2a.spec.A2A.JSONRPC_VERSION; +import static io.a2a.util.Utils.defaultIfNull; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.a2a.util.Assert; + +/** + * The response after receiving a request to initiate a task with streaming. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class SendStreamingMessageResponse extends JSONRPCResponse { + + @JsonCreator + public SendStreamingMessageResponse(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("result") StreamingEventKind result, @JsonProperty("error") JSONRPCError error) { + this.jsonrpc = defaultIfNull(jsonrpc, JSONRPC_VERSION); + Assert.isNullOrStringOrInteger(id); + this.id = id; + this.result = result; + this.error = error; + } + + public SendStreamingMessageResponse(Object id, StreamingEventKind result) { + this(null, id, result, null); + } + + public SendStreamingMessageResponse(Object id, JSONRPCError error) { + this(null, id, null, error); + } +} diff --git a/core/src/main/java/io/a2a/spec/SetTaskPushNotificationConfigRequest.java b/core/src/main/java/io/a2a/spec/SetTaskPushNotificationConfigRequest.java new file mode 100644 index 000000000..b042f6218 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/SetTaskPushNotificationConfigRequest.java @@ -0,0 +1,78 @@ +package io.a2a.spec; + +import static io.a2a.spec.A2A.JSONRPC_VERSION; +import static io.a2a.spec.A2A.SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD; +import static io.a2a.util.Utils.defaultIfNull; + +import java.util.UUID; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.a2a.util.Assert; + +/** + * Used to set a task push notification request. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class SetTaskPushNotificationConfigRequest extends NonStreamingJSONRPCRequest { + + @JsonCreator + public SetTaskPushNotificationConfigRequest(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("method") String method, @JsonProperty("params") TaskPushNotificationConfig params) { + if (jsonrpc != null && ! jsonrpc.equals(JSONRPC_VERSION)) { + throw new IllegalArgumentException("Invalid JSON-RPC protocol version"); + } + Assert.checkNotNullParam("method", method); + if (! method.equals(SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD)) { + throw new IllegalArgumentException("Invalid SetTaskPushNotificationRequest method"); + } + Assert.checkNotNullParam("params", params); + Assert.isNullOrStringOrInteger(id); + this.jsonrpc = defaultIfNull(jsonrpc, JSONRPC_VERSION); + this.id = id; + this.method = method; + this.params = params; + } + + public SetTaskPushNotificationConfigRequest(String id, TaskPushNotificationConfig taskPushConfig) { + this(null, id, SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD, taskPushConfig); + } + + public static class Builder { + private String jsonrpc; + private Object id; + private String method = SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD; + private TaskPushNotificationConfig params; + + public SetTaskPushNotificationConfigRequest.Builder jsonrpc(String jsonrpc) { + this.jsonrpc = jsonrpc; + return this; + } + + public SetTaskPushNotificationConfigRequest.Builder id(Object id) { + this.id = id; + return this; + } + + public SetTaskPushNotificationConfigRequest.Builder method(String method) { + this.method = method; + return this; + } + + public SetTaskPushNotificationConfigRequest.Builder params(TaskPushNotificationConfig params) { + this.params = params; + return this; + } + + public SetTaskPushNotificationConfigRequest build() { + if (id == null) { + id = UUID.randomUUID().toString(); + } + return new SetTaskPushNotificationConfigRequest(jsonrpc, id, method, params); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/SetTaskPushNotificationConfigResponse.java b/core/src/main/java/io/a2a/spec/SetTaskPushNotificationConfigResponse.java new file mode 100644 index 000000000..3af1435ab --- /dev/null +++ b/core/src/main/java/io/a2a/spec/SetTaskPushNotificationConfigResponse.java @@ -0,0 +1,29 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * The response after receiving a set task push notification request. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class SetTaskPushNotificationConfigResponse extends JSONRPCResponse { + + @JsonCreator + public SetTaskPushNotificationConfigResponse(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("result") TaskPushNotificationConfig result, + @JsonProperty("error") JSONRPCError error) { + super(jsonrpc, id, result, error); + } + + public SetTaskPushNotificationConfigResponse(Object id, JSONRPCError error) { + super(null, id, null, error); + } + + public SetTaskPushNotificationConfigResponse(Object id, TaskPushNotificationConfig result) { + this(null, id, result, null); + } +} diff --git a/core/src/main/java/io/a2a/spec/StreamingEventKind.java b/core/src/main/java/io/a2a/spec/StreamingEventKind.java new file mode 100644 index 000000000..a7a6b6232 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/StreamingEventKind.java @@ -0,0 +1,26 @@ +package io.a2a.spec; + +import static io.a2a.spec.Message.MESSAGE; +import static io.a2a.spec.Task.TASK; +import static io.a2a.spec.TaskArtifactUpdateEvent.ARTIFACT_UPDATE; +import static io.a2a.spec.TaskStatusUpdateEvent.STATUS_UPDATE; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +@JsonTypeInfo( + use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.PROPERTY, + property = "kind", + visible = true +) +@JsonSubTypes({ + @JsonSubTypes.Type(value = Task.class, name = TASK), + @JsonSubTypes.Type(value = Message.class, name = MESSAGE), + @JsonSubTypes.Type(value = TaskStatusUpdateEvent.class, name = STATUS_UPDATE), + @JsonSubTypes.Type(value = TaskArtifactUpdateEvent.class, name = ARTIFACT_UPDATE) +}) +public sealed interface StreamingEventKind extends Event permits Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent { + + String getKind(); +} diff --git a/core/src/main/java/io/a2a/spec/StreamingJSONRPCRequest.java b/core/src/main/java/io/a2a/spec/StreamingJSONRPCRequest.java new file mode 100644 index 000000000..7642c41d2 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/StreamingJSONRPCRequest.java @@ -0,0 +1,16 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; + +/** + * Represents a streaming JSON-RPC request. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonDeserialize(using = StreamingJSONRPCRequestDeserializer.class) +public abstract sealed class StreamingJSONRPCRequest extends JSONRPCRequest permits TaskResubscriptionRequest, + SendStreamingMessageRequest { + +} diff --git a/core/src/main/java/io/a2a/spec/StreamingJSONRPCRequestDeserializer.java b/core/src/main/java/io/a2a/spec/StreamingJSONRPCRequestDeserializer.java new file mode 100644 index 000000000..4360e52f7 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/StreamingJSONRPCRequestDeserializer.java @@ -0,0 +1,43 @@ +package io.a2a.spec; + +import static io.a2a.spec.A2A.SEND_STREAMING_MESSAGE_METHOD; +import static io.a2a.spec.A2A.SEND_TASK_RESUBSCRIPTION_METHOD; + +import java.io.IOException; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonNode; + +public class StreamingJSONRPCRequestDeserializer extends JSONRPCRequestDeserializerBase> { + + public StreamingJSONRPCRequestDeserializer() { + this(null); + } + + public StreamingJSONRPCRequestDeserializer(Class vc) { + super(vc); + } + + @Override + public StreamingJSONRPCRequest deserialize(JsonParser jsonParser, DeserializationContext context) + throws IOException, JsonProcessingException { + JsonNode treeNode = jsonParser.getCodec().readTree(jsonParser); + String jsonrpc = getAndValidateJsonrpc(treeNode, jsonParser); + String method = getAndValidateMethod(treeNode, jsonParser); + Object id = getAndValidateId(treeNode, jsonParser); + JsonNode paramsNode = treeNode.get("params"); + + switch (method) { + case SEND_TASK_RESUBSCRIPTION_METHOD: + return new TaskResubscriptionRequest(jsonrpc, id, method, + getAndValidateParams(paramsNode, jsonParser, treeNode, TaskIdParams.class)); + case SEND_STREAMING_MESSAGE_METHOD: + return new SendStreamingMessageRequest(jsonrpc, id, method, + getAndValidateParams(paramsNode, jsonParser, treeNode, MessageSendParams.class)); + default: + throw new MethodNotFoundJsonMappingException("Invalid method", getIdIfPossible(treeNode, jsonParser)); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/StringJsonrpcId.java b/core/src/main/java/io/a2a/spec/StringJsonrpcId.java new file mode 100644 index 000000000..74a28272f --- /dev/null +++ b/core/src/main/java/io/a2a/spec/StringJsonrpcId.java @@ -0,0 +1,4 @@ +package io.a2a.spec; + +public class StringJsonrpcId { +} diff --git a/core/src/main/java/io/a2a/spec/Task.java b/core/src/main/java/io/a2a/spec/Task.java new file mode 100644 index 000000000..86c92473e --- /dev/null +++ b/core/src/main/java/io/a2a/spec/Task.java @@ -0,0 +1,145 @@ +package io.a2a.spec; + +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.type.TypeReference; +import io.a2a.util.Assert; + +/** + * A central unit of work. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class Task implements EventKind, StreamingEventKind { + + public static final TypeReference TYPE_REFERENCE = new TypeReference<>() {}; + + public static final String TASK = "task"; + private final String id; + private final String contextId; + private final TaskStatus status; + private final List artifacts; + private final List history; + private final Map metadata; + private final String kind; + + public Task(String id, String contextId, TaskStatus status, List artifacts, + List history, Map metadata) { + this(id, contextId, status, artifacts, history, metadata, TASK); + } + + @JsonCreator + public Task(@JsonProperty("id") String id, @JsonProperty("contextId") String contextId, @JsonProperty("status") TaskStatus status, + @JsonProperty("artifacts") List artifacts, @JsonProperty("history") List history, + @JsonProperty("metadata") Map metadata, @JsonProperty("kind") String kind) { + Assert.checkNotNullParam("id", id); + Assert.checkNotNullParam("contextId", contextId); + Assert.checkNotNullParam("status", status); + Assert.checkNotNullParam("kind", kind); + if (! kind.equals(TASK)) { + throw new IllegalArgumentException("Invalid Task"); + } + this.id = id; + this.contextId = contextId; + this.status = status; + this.artifacts = artifacts; + this.history = history; + this.metadata = metadata; + this.kind = kind; + } + + public String getId() { + return id; + } + + public String getContextId() { + return contextId; + } + + public TaskStatus getStatus() { + return status; + } + + public List getArtifacts() { + return artifacts; + } + + public List getHistory() { + return history; + } + + public Map getMetadata() { + return metadata; + } + + public String getKind() { + return kind; + } + + public static class Builder { + private String id; + private String contextId; + private TaskStatus status; + private List artifacts; + private List history; + private Map metadata; + + public Builder() { + + } + + public Builder(Task task) { + id = task.id; + contextId = task.contextId; + status = task.status; + artifacts = task.artifacts; + history = task.history; + metadata = task.metadata; + + } + + public Builder id(String id) { + this.id = id; + return this; + } + + public Builder contextId(String contextId) { + this.contextId = contextId; + return this; + } + + public Builder status(TaskStatus status) { + this.status = status; + return this; + } + + public Builder artifacts(List artifacts) { + this.artifacts = artifacts; + return this; + } + + public Builder history(List history) { + this.history = history; + return this; + } + + public Builder history(Message... history) { + this.history = List.of(history); + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public Task build() { + return new Task(id, contextId, status, artifacts, history, metadata); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/TaskArtifactUpdateEvent.java b/core/src/main/java/io/a2a/spec/TaskArtifactUpdateEvent.java new file mode 100644 index 000000000..49485577b --- /dev/null +++ b/core/src/main/java/io/a2a/spec/TaskArtifactUpdateEvent.java @@ -0,0 +1,127 @@ +package io.a2a.spec; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.a2a.util.Assert; + +/** + * A task artifact update event. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class TaskArtifactUpdateEvent implements EventKind, StreamingEventKind { + + public static final String ARTIFACT_UPDATE = "artifact-update"; + private final String taskId; + private final Boolean append; + private final Boolean lastChunk; + private final Artifact artifact; + private final String contextId; + private final Map metadata; + private final String kind; + + public TaskArtifactUpdateEvent(String taskId, Artifact artifact, String contextId, Boolean append, Boolean lastChunk, Map metadata) { + this(taskId, artifact, contextId, append, lastChunk, metadata, ARTIFACT_UPDATE); + } + + @JsonCreator + public TaskArtifactUpdateEvent(@JsonProperty("taskId") String taskId, @JsonProperty("artifact") Artifact artifact, + @JsonProperty("contextId") String contextId, + @JsonProperty("append") Boolean append, + @JsonProperty("lastChunk") Boolean lastChunk, + @JsonProperty("metadata") Map metadata, + @JsonProperty("kind") String kind) { + Assert.checkNotNullParam("taskId", taskId); + Assert.checkNotNullParam("artifact", artifact); + Assert.checkNotNullParam("contextId", contextId); + Assert.checkNotNullParam("kind", kind); + if (! kind.equals(ARTIFACT_UPDATE)) { + throw new IllegalArgumentException("Invalid TaskArtifactUpdateEvent"); + } + this.taskId = taskId; + this.artifact = artifact; + this.contextId = contextId; + this.append = append; + this.lastChunk = lastChunk; + this.metadata = metadata; + this.kind = kind; + } + + public String getTaskId() { + return taskId; + } + + public Artifact getArtifact() { + return artifact; + } + + public String getContextId() { + return contextId; + } + + public Boolean isAppend() { + return append; + } + + public Boolean isLastChunk() { + return lastChunk; + } + + public Map getMetadata() { + return metadata; + } + + @Override + public String getKind() { + return kind; + } + + public static class Builder { + + private String taskId; + private Artifact artifact; + private String contextId; + private Boolean append; + private Boolean lastChunk; + private Map metadata; + + public Builder taskId(String taskId) { + this.taskId = taskId; + return this; + } + + public Builder artifact(Artifact artifact) { + this.artifact = artifact; + return this; + } + + public Builder contextId(String contextId) { + this.contextId = contextId; + return this; + } + + public Builder append(Boolean append) { + this.append = append; + return this; + } + + public Builder lastChunk(Boolean lastChunk) { + this.lastChunk = lastChunk; + return this; + } + + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public TaskArtifactUpdateEvent build() { + return new TaskArtifactUpdateEvent(taskId, artifact, contextId, append, lastChunk, metadata); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/TaskIdParams.java b/core/src/main/java/io/a2a/spec/TaskIdParams.java new file mode 100644 index 000000000..816550eb9 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/TaskIdParams.java @@ -0,0 +1,23 @@ +package io.a2a.spec; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import io.a2a.util.Assert; + +/** + * Task id parameters. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record TaskIdParams(String id, Map metadata) { + + public TaskIdParams { + Assert.checkNotNullParam("id", id); + } + + public TaskIdParams(String id) { + this(id, null); + } +} diff --git a/core/src/main/java/io/a2a/spec/TaskNotCancelableError.java b/core/src/main/java/io/a2a/spec/TaskNotCancelableError.java new file mode 100644 index 000000000..3711fe8dd --- /dev/null +++ b/core/src/main/java/io/a2a/spec/TaskNotCancelableError.java @@ -0,0 +1,28 @@ +package io.a2a.spec; + +import static io.a2a.util.Utils.defaultIfNull; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public class TaskNotCancelableError extends JSONRPCError { + public TaskNotCancelableError() { + this(null, null, null); + } + + @JsonCreator + public TaskNotCancelableError( + @JsonProperty("code") Integer code, + @JsonProperty("message") String message, + @JsonProperty("data") Object data) { + super( + defaultIfNull(code, -32002), + defaultIfNull(message, "Task cannot be canceled"), + data); + } + +} diff --git a/core/src/main/java/io/a2a/spec/TaskNotFoundError.java b/core/src/main/java/io/a2a/spec/TaskNotFoundError.java new file mode 100644 index 000000000..ef24bc1de --- /dev/null +++ b/core/src/main/java/io/a2a/spec/TaskNotFoundError.java @@ -0,0 +1,28 @@ +package io.a2a.spec; + +import static io.a2a.util.Utils.defaultIfNull; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public class TaskNotFoundError extends JSONRPCError { + public TaskNotFoundError() { + this(null, null, null); + } + + @JsonCreator + + public TaskNotFoundError( + @JsonProperty("code") Integer code, + @JsonProperty("message") String message, + @JsonProperty("data") Object data) { + super( + defaultIfNull(code, -32001), + defaultIfNull(message, "Task not found"), + data); + } +} diff --git a/core/src/main/java/io/a2a/spec/TaskPushNotificationConfig.java b/core/src/main/java/io/a2a/spec/TaskPushNotificationConfig.java new file mode 100644 index 000000000..0e4163212 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/TaskPushNotificationConfig.java @@ -0,0 +1,18 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import io.a2a.util.Assert; + +/** + * Task push notification configuration. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record TaskPushNotificationConfig(String taskId, PushNotificationConfig pushNotificationConfig) { + + public TaskPushNotificationConfig { + Assert.checkNotNullParam("taskId", taskId); + Assert.checkNotNullParam("pushNotificationConfig", pushNotificationConfig); + } +} diff --git a/core/src/main/java/io/a2a/spec/TaskQueryParams.java b/core/src/main/java/io/a2a/spec/TaskQueryParams.java new file mode 100644 index 000000000..2587e5fb5 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/TaskQueryParams.java @@ -0,0 +1,34 @@ +package io.a2a.spec; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import io.a2a.util.Assert; + +/** + * Task query parameters. + * + * @param id the ID for the task to be queried + * @param historyLength the maximum number of items of history for the task to include in the response + * @param metadata additional properties + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record TaskQueryParams(String id, Integer historyLength, Map metadata) { + + public TaskQueryParams { + Assert.checkNotNullParam("id", id); + if (historyLength != null && historyLength < 0) { + throw new IllegalArgumentException("Invalid history length"); + } + } + + public TaskQueryParams(String id) { + this(id, null, null); + } + + public TaskQueryParams(String id, Integer historyLength) { + this(id, historyLength, null); + } +} diff --git a/core/src/main/java/io/a2a/spec/TaskResubscriptionRequest.java b/core/src/main/java/io/a2a/spec/TaskResubscriptionRequest.java new file mode 100644 index 000000000..7380cddd1 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/TaskResubscriptionRequest.java @@ -0,0 +1,77 @@ +package io.a2a.spec; + +import static io.a2a.spec.A2A.JSONRPC_VERSION; +import static io.a2a.spec.A2A.SEND_TASK_RESUBSCRIPTION_METHOD; +import static io.a2a.util.Utils.defaultIfNull; + +import java.util.UUID; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import io.a2a.util.Assert; + +/** + * Used to resubscribe to a task. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class TaskResubscriptionRequest extends StreamingJSONRPCRequest { + + @JsonCreator + public TaskResubscriptionRequest(@JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, + @JsonProperty("method") String method, @JsonProperty("params") TaskIdParams params) { + if (jsonrpc != null && ! jsonrpc.equals(JSONRPC_VERSION)) { + throw new IllegalArgumentException("Invalid JSON-RPC protocol version"); + } + Assert.checkNotNullParam("method", method); + if (! method.equals(SEND_TASK_RESUBSCRIPTION_METHOD)) { + throw new IllegalArgumentException("Invalid TaskResubscriptionRequest method"); + } + Assert.checkNotNullParam("params", params); + this.jsonrpc = defaultIfNull(jsonrpc, JSONRPC_VERSION); + this.id = id == null ? UUID.randomUUID().toString() : id; + this.method = method; + this.params = params; + } + + public TaskResubscriptionRequest(Object id, TaskIdParams params) { + this(null, id, SEND_TASK_RESUBSCRIPTION_METHOD, params); + } + + public static class Builder { + private String jsonrpc; + private Object id; + private String method = SEND_TASK_RESUBSCRIPTION_METHOD; + private TaskIdParams params; + + public TaskResubscriptionRequest.Builder jsonrpc(String jsonrpc) { + this.jsonrpc = jsonrpc; + return this; + } + + public TaskResubscriptionRequest.Builder id(Object id) { + this.id = id; + return this; + } + + public TaskResubscriptionRequest.Builder method(String method) { + this.method = method; + return this; + } + + public TaskResubscriptionRequest.Builder params(TaskIdParams params) { + this.params = params; + return this; + } + + public TaskResubscriptionRequest build() { + if (id == null) { + id = UUID.randomUUID().toString(); + } + return new TaskResubscriptionRequest(jsonrpc, id, method, params); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/TaskState.java b/core/src/main/java/io/a2a/spec/TaskState.java new file mode 100644 index 000000000..30d3c1a23 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/TaskState.java @@ -0,0 +1,66 @@ +package io.a2a.spec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; + +/** + * Represents the state of a task. + */ +public enum TaskState { + SUBMITTED("submitted"), + WORKING("working"), + INPUT_REQUIRED("input-required"), + AUTH_REQUIRED("auth-required"), + COMPLETED("completed", true), + CANCELED("canceled", true), + FAILED("failed", true), + REJECTED("rejected", true), + UNKNOWN("unknown", true); + + private final String state; + private final boolean isFinal; + + TaskState(String state) { + this(state, false); + } + + TaskState(String state, boolean isFinal) { + this.state = state; + this.isFinal = isFinal; + } + + @JsonValue + public String asString() { + return state; + } + + public boolean isFinal(){ + return isFinal; + } + + @JsonCreator + public static TaskState fromString(String state) { + switch (state) { + case "submitted": + return SUBMITTED; + case "working": + return WORKING; + case "input-required": + return INPUT_REQUIRED; + case "auth-required": + return AUTH_REQUIRED; + case "completed": + return COMPLETED; + case "canceled": + return CANCELED; + case "failed": + return FAILED; + case "rejected": + return REJECTED; + case "unknown": + return UNKNOWN; + default: + throw new IllegalArgumentException("Invalid TaskState: " + state); + } + } +} \ No newline at end of file diff --git a/core/src/main/java/io/a2a/spec/TaskStatus.java b/core/src/main/java/io/a2a/spec/TaskStatus.java new file mode 100644 index 000000000..2befdfcbf --- /dev/null +++ b/core/src/main/java/io/a2a/spec/TaskStatus.java @@ -0,0 +1,27 @@ +package io.a2a.spec; + +import java.time.LocalDateTime; + +import com.fasterxml.jackson.annotation.JsonFormat; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; + +import io.a2a.util.Assert; + +/** + * Represents the status of a task. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public record TaskStatus(TaskState state, Message message, + @JsonFormat(shape = JsonFormat.Shape.STRING, pattern = "yyyy-MM-dd'T'HH:mm:ss.SSSSSS") LocalDateTime timestamp) { + + public TaskStatus { + Assert.checkNotNullParam("state", state); + timestamp = timestamp == null ? LocalDateTime.now() : timestamp; + } + + public TaskStatus(TaskState state) { + this(state, null, null); + } +} diff --git a/core/src/main/java/io/a2a/spec/TaskStatusUpdateEvent.java b/core/src/main/java/io/a2a/spec/TaskStatusUpdateEvent.java new file mode 100644 index 000000000..7a44480da --- /dev/null +++ b/core/src/main/java/io/a2a/spec/TaskStatusUpdateEvent.java @@ -0,0 +1,113 @@ +package io.a2a.spec; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.a2a.util.Assert; + +/** + * A task status update event. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class TaskStatusUpdateEvent implements EventKind, StreamingEventKind { + + public static final String STATUS_UPDATE = "status-update"; + private final String taskId; + private final TaskStatus status; + private final String contextId; + private final boolean isFinal; + private final Map metadata; + private final String kind; + + + public TaskStatusUpdateEvent(String taskId, TaskStatus status, String contextId, boolean isFinal, + Map metadata) { + this(taskId, status, contextId, isFinal, metadata, STATUS_UPDATE); + } + + @JsonCreator + public TaskStatusUpdateEvent(@JsonProperty("taskId") String taskId, @JsonProperty("status") TaskStatus status, + @JsonProperty("contextId") String contextId, @JsonProperty("final") boolean isFinal, + @JsonProperty("metadata") Map metadata, @JsonProperty("kind") String kind) { + Assert.checkNotNullParam("taskId", taskId); + Assert.checkNotNullParam("status", status); + Assert.checkNotNullParam("contextId", contextId); + Assert.checkNotNullParam("kind", kind); + if (! kind.equals(STATUS_UPDATE)) { + throw new IllegalArgumentException("Invalid TaskStatusUpdateEvent"); + } + this.taskId = taskId; + this.status = status; + this.contextId = contextId; + this.isFinal = isFinal; + this.metadata = metadata; + this.kind = kind; + } + + public String getTaskId() { + return taskId; + } + + public TaskStatus getStatus() { + return status; + } + + public String getContextId() { + return contextId; + } + + @JsonProperty("final") + public boolean isFinal() { + return isFinal; + } + + public Map getMetadata() { + return metadata; + } + + @Override + public String getKind() { + return kind; + } + + public static class Builder { + private String taskId; + private TaskStatus status; + private String contextId; + private boolean isFinal; + private Map metadata; + + public Builder taskId(String id) { + this.taskId = id; + return this; + } + + public Builder status(TaskStatus status) { + this.status = status; + return this; + } + + public Builder contextId(String contextId) { + this.contextId = contextId; + return this; + } + + public Builder isFinal(boolean isFinal) { + this.isFinal = isFinal; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public TaskStatusUpdateEvent build() { + return new TaskStatusUpdateEvent(taskId, status, contextId, isFinal, metadata); + } + } +} diff --git a/core/src/main/java/io/a2a/spec/TextPart.java b/core/src/main/java/io/a2a/spec/TextPart.java new file mode 100644 index 000000000..1b62cf747 --- /dev/null +++ b/core/src/main/java/io/a2a/spec/TextPart.java @@ -0,0 +1,46 @@ +package io.a2a.spec; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.a2a.util.Assert; + +/** + * A fundamental text unit of an Artifact or Message. + */ +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public class TextPart extends Part { + private final String text; + private final Map metadata; + private final Kind kind; + + public TextPart(String text) { + this(text, null); + } + + @JsonCreator + public TextPart(@JsonProperty("text") String text, @JsonProperty("metadata") Map metadata) { + Assert.checkNotNullParam("text", text); + this.text = text; + this.metadata = metadata; + this.kind = Kind.TEXT; + } + + @Override + public Kind getKind() { + return kind; + } + + public String getText() { + return text; + } + + @Override + public Map getMetadata() { + return metadata; + } +} \ No newline at end of file diff --git a/core/src/main/java/io/a2a/spec/UnsupportedOperationError.java b/core/src/main/java/io/a2a/spec/UnsupportedOperationError.java new file mode 100644 index 000000000..637df90fe --- /dev/null +++ b/core/src/main/java/io/a2a/spec/UnsupportedOperationError.java @@ -0,0 +1,27 @@ +package io.a2a.spec; + +import static io.a2a.util.Utils.defaultIfNull; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonInclude(JsonInclude.Include.NON_ABSENT) +@JsonIgnoreProperties(ignoreUnknown = true) +public class UnsupportedOperationError extends JSONRPCError { + @JsonCreator + public UnsupportedOperationError( + @JsonProperty("code") Integer code, + @JsonProperty("message") String message, + @JsonProperty("data") Object data) { + super( + defaultIfNull(code, -32004), + defaultIfNull(message, "This operation is not supported"), + data); + } + + public UnsupportedOperationError() { + this(null, null, null); + } +} diff --git a/core/src/main/java/io/a2a/util/Assert.java b/core/src/main/java/io/a2a/util/Assert.java new file mode 100644 index 000000000..b0077cd23 --- /dev/null +++ b/core/src/main/java/io/a2a/util/Assert.java @@ -0,0 +1,31 @@ +package io.a2a.util; + +public final class Assert { + + /** + * Check that the named parameter is not {@code null}. Use a standard exception message if it is. + * + * @param name the parameter name + * @param value the parameter value + * @param the value type + * @return the value that was passed in + * @throws IllegalArgumentException if the value is {@code null} + */ + @NotNull + public static T checkNotNullParam(String name, T value) throws IllegalArgumentException { + checkNotNullParamChecked("name", name); + checkNotNullParamChecked(name, value); + return value; + } + + private static void checkNotNullParamChecked(final String name, final T value) { + if (value == null) throw new IllegalArgumentException("Parameter '" + name + "' may not be null"); + } + + public static void isNullOrStringOrInteger(Object value) { + if (! (value == null || value instanceof String || value instanceof Integer)) { + throw new IllegalArgumentException("Id must be null, a String, or an Integer"); + } + } + +} diff --git a/core/src/main/java/io/a2a/util/NotNull.java b/core/src/main/java/io/a2a/util/NotNull.java new file mode 100644 index 000000000..3146117f3 --- /dev/null +++ b/core/src/main/java/io/a2a/util/NotNull.java @@ -0,0 +1,13 @@ +package io.a2a.util; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.CLASS) +@Target({ ElementType.FIELD, ElementType.LOCAL_VARIABLE, ElementType.METHOD, ElementType.PARAMETER }) +@Documented +public @interface NotNull { +} \ No newline at end of file diff --git a/core/src/main/java/io/a2a/util/Utils.java b/core/src/main/java/io/a2a/util/Utils.java new file mode 100644 index 000000000..aac6af61c --- /dev/null +++ b/core/src/main/java/io/a2a/util/Utils.java @@ -0,0 +1,30 @@ +package io.a2a.util; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; + +public class Utils { + + public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + static { + // needed for date/time types + OBJECT_MAPPER.registerModule(new JavaTimeModule()); + } + + public static T unmarshalFrom(String data, TypeReference typeRef) throws JsonProcessingException { + return OBJECT_MAPPER.readValue(data, typeRef); + } + + public static T defaultIfNull(T value, T defaultValue) { + if (value == null) { + return defaultValue; + } + return value; + } + + public static void rethrow(Throwable t) throws T { + throw (T) t; + } +} diff --git a/core/src/main/resources/META-INF/beans.xml b/core/src/main/resources/META-INF/beans.xml new file mode 100644 index 000000000..e69de29bb diff --git a/core/src/test/java/io/a2a/client/A2ACardResolverTest.java b/core/src/test/java/io/a2a/client/A2ACardResolverTest.java new file mode 100644 index 000000000..8265b9514 --- /dev/null +++ b/core/src/test/java/io/a2a/client/A2ACardResolverTest.java @@ -0,0 +1,164 @@ +package io.a2a.client; + +import static io.a2a.client.A2ACardResolver.AGENT_CARD_TYPE_REFERENCE; +import static io.a2a.util.Utils.OBJECT_MAPPER; +import static io.a2a.util.Utils.unmarshalFrom; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +import io.a2a.http.A2AHttpClient; +import io.a2a.http.A2AHttpResponse; +import io.a2a.spec.A2AClientError; +import io.a2a.spec.A2AClientJSONError; +import io.a2a.spec.AgentCard; +import org.junit.jupiter.api.Test; + +public class A2ACardResolverTest { + @Test + public void testConstructorStripsSlashes() throws Exception { + TestHttpClient client = new TestHttpClient(); + client.body = JsonMessages.AGENT_CARD; + + A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); + AgentCard card = resolver.getAgentCard(); + + assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + + + resolver = new A2ACardResolver(client, "http://example.com"); + card = resolver.getAgentCard(); + + assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + + resolver = new A2ACardResolver(client, "http://example.com/", A2ACardResolver.DEFAULT_AGENT_CARD_PATH); + card = resolver.getAgentCard(); + + assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + + resolver = new A2ACardResolver(client, "http://example.com", A2ACardResolver.DEFAULT_AGENT_CARD_PATH); + card = resolver.getAgentCard(); + + assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + + resolver = new A2ACardResolver(client, "http://example.com/", A2ACardResolver.DEFAULT_AGENT_CARD_PATH.substring(0)); + card = resolver.getAgentCard(); + + assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + + resolver = new A2ACardResolver(client, "http://example.com", A2ACardResolver.DEFAULT_AGENT_CARD_PATH.substring(0)); + card = resolver.getAgentCard(); + + assertEquals("http://example.com" + A2ACardResolver.DEFAULT_AGENT_CARD_PATH, client.url); + } + + + @Test + public void testGetAgentCardSuccess() throws Exception { + TestHttpClient client = new TestHttpClient(); + client.body = JsonMessages.AGENT_CARD; + + A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); + AgentCard card = resolver.getAgentCard(); + + AgentCard expectedCard = unmarshalFrom(JsonMessages.AGENT_CARD, AGENT_CARD_TYPE_REFERENCE); + String expected = OBJECT_MAPPER.writeValueAsString(expectedCard); + + String requestCardString = OBJECT_MAPPER.writeValueAsString(card); + assertEquals(expected, requestCardString); + } + + @Test + public void testGetAgentCardJsonDecodeError() throws Exception { + TestHttpClient client = new TestHttpClient(); + client.body = "X" + JsonMessages.AGENT_CARD; + + A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); + + boolean success = false; + try { + AgentCard card = resolver.getAgentCard(); + success = true; + } catch (A2AClientJSONError expected) { + } + assertFalse(success); + } + + + @Test + public void testGetAgentCardRequestError() throws Exception { + TestHttpClient client = new TestHttpClient(); + client.status = 503; + + A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); + + String msg = null; + try { + AgentCard card = resolver.getAgentCard(); + } catch (A2AClientError expected) { + msg = expected.getMessage(); + } + assertTrue(msg.contains("503")); + } + + private static class TestHttpClient implements A2AHttpClient { + int status = 200; + String body; + String url; + + @Override + public GetBuilder createGet() { + return new TestGetBuilder(); + } + + @Override + public PostBuilder createPost() { + return null; + } + + class TestGetBuilder implements A2AHttpClient.GetBuilder { + + @Override + public A2AHttpResponse get() throws IOException, InterruptedException { + return new A2AHttpResponse() { + @Override + public int status() { + return status; + } + + @Override + public boolean success() { + return status == 200; + } + + @Override + public String body() { + return body; + } + }; + } + + @Override + public CompletableFuture getAsyncSSE(Consumer messageConsumer, Consumer errorConsumer, Runnable completeRunnable) throws IOException, InterruptedException { + return null; + } + + @Override + public GetBuilder url(String s) { + url = s; + return this; + } + + @Override + public GetBuilder addHeader(String name, String value) { + + return this; + } + } + } + +} diff --git a/core/src/test/java/io/a2a/client/A2AClientStreamingTest.java b/core/src/test/java/io/a2a/client/A2AClientStreamingTest.java new file mode 100644 index 000000000..78b0c0945 --- /dev/null +++ b/core/src/test/java/io/a2a/client/A2AClientStreamingTest.java @@ -0,0 +1,177 @@ +package io.a2a.client; + +import static io.a2a.client.JsonStreamingMessages.SEND_MESSAGE_STREAMING_TEST_REQUEST; +import static io.a2a.client.JsonStreamingMessages.SEND_MESSAGE_STREAMING_TEST_RESPONSE; +import static io.a2a.client.JsonStreamingMessages.TASK_RESUBSCRIPTION_REQUEST_TEST_RESPONSE; +import static io.a2a.client.JsonStreamingMessages.TASK_RESUBSCRIPTION_TEST_REQUEST; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockserver.model.HttpRequest.request; +import static org.mockserver.model.HttpResponse.response; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import io.a2a.spec.Artifact; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.Message; +import io.a2a.spec.MessageSendConfiguration; +import io.a2a.spec.MessageSendParams; +import io.a2a.spec.Part; +import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.Task; +import io.a2a.spec.TaskIdParams; +import io.a2a.spec.TaskState; +import io.a2a.spec.TextPart; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockserver.integration.ClientAndServer; +import org.mockserver.matchers.MatchType; +import org.mockserver.model.JsonBody; + +public class A2AClientStreamingTest { + + private ClientAndServer server; + + @BeforeEach + public void setUp() { + server = new ClientAndServer(4001); + } + + @AfterEach + public void tearDown() { + server.stop(); + } + + @Test + public void testSendStreamingMessageParams() { + // The goal here is just to verify the correct parameters are being used + // This is a unit test of the parameter construction, not the streaming itself + Message message = new Message.Builder() + .role(Message.Role.USER) + .parts(Collections.singletonList(new TextPart("test message"))) + .contextId("context-test") + .messageId("message-test") + .build(); + + MessageSendConfiguration configuration = new MessageSendConfiguration.Builder() + .acceptedOutputModes(List.of("text")) + .blocking(false) + .build(); + + MessageSendParams params = new MessageSendParams.Builder() + .message(message) + .configuration(configuration) + .build(); + + assertNotNull(params); + assertEquals(message, params.message()); + assertEquals(configuration, params.configuration()); + assertEquals(Message.Role.USER, params.message().getRole()); + assertEquals("test message", ((TextPart) params.message().getParts().get(0)).getText()); + } + + @Test + public void testA2AClientSendStreamingMessage() throws Exception { + this.server.when( + request() + .withMethod("POST") + .withPath("/") + .withBody(JsonBody.json(SEND_MESSAGE_STREAMING_TEST_REQUEST, MatchType.STRICT)) + + ) + .respond( + response() + .withStatusCode(200) + .withHeader("Content-Type", "text/event-stream") + .withBody(SEND_MESSAGE_STREAMING_TEST_RESPONSE) + ); + + A2AClient client = new A2AClient("http://localhost:4001"); + Message message = new Message.Builder() + .role(Message.Role.USER) + .parts(Collections.singletonList(new TextPart("tell me some jokes"))) + .contextId("context-1234") + .messageId("message-1234") + .build(); + MessageSendConfiguration configuration = new MessageSendConfiguration.Builder() + .acceptedOutputModes(List.of("text")) + .blocking(false) + .build(); + MessageSendParams params = new MessageSendParams.Builder() + .message(message) + .configuration(configuration) + .build(); + + AtomicReference receivedEvent = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + Consumer eventHandler = event -> { + receivedEvent.set(event); + latch.countDown(); + }; + Consumer errorHandler = error -> {}; + Runnable failureHandler = () -> {}; + client.sendStreamingMessage("request-1234", params, eventHandler, errorHandler, failureHandler); + + boolean eventReceived = latch.await(10, TimeUnit.SECONDS); + assertTrue(eventReceived); + assertNotNull(receivedEvent.get()); + } + + @Test + public void testA2AClientResubscribeToTask() throws Exception { + this.server.when( + request() + .withMethod("POST") + .withPath("/") + .withBody(JsonBody.json(TASK_RESUBSCRIPTION_TEST_REQUEST, MatchType.STRICT)) + + ) + .respond( + response() + .withStatusCode(200) + .withHeader("Content-Type", "text/event-stream") + .withBody(TASK_RESUBSCRIPTION_REQUEST_TEST_RESPONSE) + ); + + A2AClient client = new A2AClient("http://localhost:4001"); + TaskIdParams taskIdParams = new TaskIdParams("task-1234"); + + AtomicReference receivedEvent = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + Consumer eventHandler = event -> { + receivedEvent.set(event); + latch.countDown(); + }; + Consumer errorHandler = error -> {}; + Runnable failureHandler = () -> {}; + client.resubscribeToTask("request-1234", taskIdParams, eventHandler, errorHandler, failureHandler); + + boolean eventReceived = latch.await(10, TimeUnit.SECONDS); + assertTrue(eventReceived); + + StreamingEventKind eventKind = receivedEvent.get();; + assertNotNull(eventKind); + assertInstanceOf(Task.class, eventKind); + Task task = (Task) eventKind; + assertEquals("2", task.getId()); + assertEquals("context-1234", task.getContextId()); + assertEquals(TaskState.COMPLETED, task.getStatus().state()); + List artifacts = task.getArtifacts(); + assertEquals(1, artifacts.size()); + Artifact artifact = artifacts.get(0); + assertEquals("artifact-1", artifact.artifactId()); + assertEquals("joke", artifact.name()); + Part part = artifact.parts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("Why did the chicken cross the road? To get to the other side!", ((TextPart) part).getText()); + } +} \ No newline at end of file diff --git a/core/src/test/java/io/a2a/client/A2AClientTest.java b/core/src/test/java/io/a2a/client/A2AClientTest.java new file mode 100644 index 000000000..e99734ff1 --- /dev/null +++ b/core/src/test/java/io/a2a/client/A2AClientTest.java @@ -0,0 +1,697 @@ +package io.a2a.client; + +import static io.a2a.client.JsonMessages.AGENT_CARD; +import static io.a2a.client.JsonMessages.AUTHENTICATION_EXTENDED_AGENT_CARD; +import static io.a2a.client.JsonMessages.CANCEL_TASK_TEST_REQUEST; +import static io.a2a.client.JsonMessages.CANCEL_TASK_TEST_RESPONSE; +import static io.a2a.client.JsonMessages.GET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_REQUEST; +import static io.a2a.client.JsonMessages.GET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE; +import static io.a2a.client.JsonMessages.GET_TASK_TEST_REQUEST; +import static io.a2a.client.JsonMessages.GET_TASK_TEST_RESPONSE; +import static io.a2a.client.JsonMessages.SEND_MESSAGE_ERROR_TEST_RESPONSE; +import static io.a2a.client.JsonMessages.SEND_MESSAGE_TEST_REQUEST; +import static io.a2a.client.JsonMessages.SEND_MESSAGE_TEST_REQUEST_WITH_MESSAGE_RESPONSE; +import static io.a2a.client.JsonMessages.SEND_MESSAGE_TEST_RESPONSE; +import static io.a2a.client.JsonMessages.SEND_MESSAGE_TEST_RESPONSE_WITH_MESSAGE_RESPONSE; +import static io.a2a.client.JsonMessages.SEND_MESSAGE_WITH_ERROR_TEST_REQUEST; +import static io.a2a.client.JsonMessages.SEND_MESSAGE_WITH_FILE_PART_TEST_REQUEST; +import static io.a2a.client.JsonMessages.SEND_MESSAGE_WITH_FILE_PART_TEST_RESPONSE; +import static io.a2a.client.JsonMessages.SEND_MESSAGE_WITH_DATA_PART_TEST_REQUEST; +import static io.a2a.client.JsonMessages.SEND_MESSAGE_WITH_DATA_PART_TEST_RESPONSE; +import static io.a2a.client.JsonMessages.SEND_MESSAGE_WITH_MIXED_PARTS_TEST_REQUEST; +import static io.a2a.client.JsonMessages.SEND_MESSAGE_WITH_MIXED_PARTS_TEST_RESPONSE; +import static io.a2a.client.JsonMessages.SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_REQUEST; +import static io.a2a.client.JsonMessages.SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockserver.model.HttpRequest.request; +import static org.mockserver.model.HttpResponse.response; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import io.a2a.spec.A2AServerException; +import io.a2a.spec.AgentCard; +import io.a2a.spec.AgentSkill; +import io.a2a.spec.Artifact; +import io.a2a.spec.CancelTaskResponse; +import io.a2a.spec.DataPart; +import io.a2a.spec.FileContent; +import io.a2a.spec.FilePart; +import io.a2a.spec.FileWithBytes; +import io.a2a.spec.FileWithUri; +import io.a2a.spec.GetTaskPushNotificationConfigResponse; +import io.a2a.spec.GetTaskResponse; +import io.a2a.spec.Message; +import io.a2a.spec.MessageSendConfiguration; +import io.a2a.spec.MessageSendParams; +import io.a2a.spec.OpenIdConnectSecurityScheme; +import io.a2a.spec.Part; +import io.a2a.spec.PushNotificationAuthenticationInfo; +import io.a2a.spec.PushNotificationConfig; +import io.a2a.spec.SecurityScheme; +import io.a2a.spec.SendMessageResponse; +import io.a2a.spec.SetTaskPushNotificationConfigResponse; +import io.a2a.spec.Task; +import io.a2a.spec.TaskIdParams; +import io.a2a.spec.TaskPushNotificationConfig; +import io.a2a.spec.TaskQueryParams; +import io.a2a.spec.TaskState; +import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockserver.integration.ClientAndServer; +import org.mockserver.matchers.MatchType; +import org.mockserver.model.JsonBody; + +public class A2AClientTest { + + private ClientAndServer server; + + @BeforeEach + public void setUp() { + server = new ClientAndServer(4001); + } + + @AfterEach + public void tearDown() { + server.stop(); + } + + @Test + public void testA2AClientSendMessage() throws Exception { + this.server.when( + request() + .withMethod("POST") + .withPath("/") + .withBody(JsonBody.json(SEND_MESSAGE_TEST_REQUEST, MatchType.STRICT)) + + ) + .respond( + response() + .withStatusCode(200) + .withBody(SEND_MESSAGE_TEST_RESPONSE) + ); + + A2AClient client = new A2AClient("http://localhost:4001"); + Message message = new Message.Builder() + .role(Message.Role.USER) + .parts(Collections.singletonList(new TextPart("tell me a joke"))) + .contextId("context-1234") + .messageId("message-1234") + .build(); + MessageSendConfiguration configuration = new MessageSendConfiguration.Builder() + .acceptedOutputModes(List.of("text")) + .blocking(true) + .build(); + MessageSendParams params = new MessageSendParams.Builder() + .message(message) + .configuration(configuration) + .build(); + + SendMessageResponse response = client.sendMessage("request-1234", params); + + assertEquals("2.0", response.getJsonrpc()); + assertNotNull(response.getId()); + Object result = response.getResult(); + assertInstanceOf(Task.class, result); + Task task = (Task) result; + assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); + assertNotNull(task.getContextId()); + assertEquals(TaskState.COMPLETED,task.getStatus().state()); + assertEquals(1, task.getArtifacts().size()); + Artifact artifact = task.getArtifacts().get(0); + assertEquals("artifact-1", artifact.artifactId()); + assertEquals("joke", artifact.name()); + assertEquals(1, artifact.parts().size()); + Part part = artifact.parts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("Why did the chicken cross the road? To get to the other side!", ((TextPart) part).getText()); + assertTrue(task.getMetadata().isEmpty()); + } + + @Test + public void testA2AClientSendMessageWithMessageResponse() throws Exception { + this.server.when( + request() + .withMethod("POST") + .withPath("/") + .withBody(JsonBody.json(SEND_MESSAGE_TEST_REQUEST_WITH_MESSAGE_RESPONSE, MatchType.STRICT)) + + ) + .respond( + response() + .withStatusCode(200) + .withBody(SEND_MESSAGE_TEST_RESPONSE_WITH_MESSAGE_RESPONSE) + ); + + A2AClient client = new A2AClient("http://localhost:4001"); + Message message = new Message.Builder() + .role(Message.Role.USER) + .parts(Collections.singletonList(new TextPart("tell me a joke"))) + .contextId("context-1234") + .messageId("message-1234") + .build(); + MessageSendConfiguration configuration = new MessageSendConfiguration.Builder() + .acceptedOutputModes(List.of("text")) + .blocking(true) + .build(); + MessageSendParams params = new MessageSendParams.Builder() + .message(message) + .configuration(configuration) + .build(); + + SendMessageResponse response = client.sendMessage("request-1234-with-message-response", params); + + assertEquals("2.0", response.getJsonrpc()); + assertNotNull(response.getId()); + Object result = response.getResult(); + assertInstanceOf(Message.class, result); + Message agentMessage = (Message) result; + assertEquals(Message.Role.AGENT, agentMessage.getRole()); + Part part = agentMessage.getParts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("Why did the chicken cross the road? To get to the other side!", ((TextPart) part).getText()); + assertEquals("msg-456", agentMessage.getMessageId()); + } + + + @Test + public void testA2AClientSendMessageWithError() throws Exception { + this.server.when( + request() + .withMethod("POST") + .withPath("/") + .withBody(JsonBody.json(SEND_MESSAGE_WITH_ERROR_TEST_REQUEST, MatchType.STRICT)) + + ) + .respond( + response() + .withStatusCode(200) + .withBody(SEND_MESSAGE_ERROR_TEST_RESPONSE) + ); + + A2AClient client = new A2AClient("http://localhost:4001"); + Message message = new Message.Builder() + .role(Message.Role.USER) + .parts(Collections.singletonList(new TextPart("tell me a joke"))) + .contextId("context-1234") + .messageId("message-1234") + .build(); + MessageSendConfiguration configuration = new MessageSendConfiguration.Builder() + .acceptedOutputModes(List.of("text")) + .blocking(true) + .build(); + MessageSendParams params = new MessageSendParams.Builder() + .message(message) + .configuration(configuration) + .build(); + + try { + client.sendMessage("request-1234-with-error", params); + fail(); // should not reach here + } catch (A2AServerException e) { + assertTrue(e.getMessage().contains("Invalid parameters: Hello world")); + } + } + + @Test + public void testA2AClientGetTask() throws Exception { + this.server.when( + request() + .withMethod("POST") + .withPath("/") + .withBody(JsonBody.json(GET_TASK_TEST_REQUEST, MatchType.STRICT)) + + ) + .respond( + response() + .withStatusCode(200) + .withBody(GET_TASK_TEST_RESPONSE) + ); + + A2AClient client = new A2AClient("http://localhost:4001"); + GetTaskResponse response = client.getTask("request-1234", + new TaskQueryParams("de38c76d-d54c-436c-8b9f-4c2703648d64", 10)); + + assertEquals("2.0", response.getJsonrpc()); + assertEquals(1, response.getId()); + Object result = response.getResult(); + assertInstanceOf(Task.class, result); + Task task = (Task) result; + assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); + assertEquals("c295ea44-7543-4f78-b524-7a38915ad6e4", task.getContextId()); + assertEquals(TaskState.COMPLETED, task.getStatus().state()); + assertEquals(1, task.getArtifacts().size()); + Artifact artifact = task.getArtifacts().get(0); + assertEquals(1, artifact.parts().size()); + assertEquals("artifact-1", artifact.artifactId()); + Part part = artifact.parts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("Why did the chicken cross the road? To get to the other side!", ((TextPart) part).getText()); + assertTrue(task.getMetadata().isEmpty()); + List history = task.getHistory(); + assertNotNull(history); + assertEquals(1, history.size()); + Message message = history.get(0); + assertEquals(Message.Role.USER, message.getRole()); + List> parts = message.getParts(); + assertNotNull(parts); + assertEquals(3, parts.size()); + part = parts.get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("tell me a joke", ((TextPart)part).getText()); + part = parts.get(1); + assertEquals(Part.Kind.FILE, part.getKind()); + FileContent filePart = ((FilePart) part).getFile(); + assertEquals("file:///path/to/file.txt", ((FileWithUri) filePart).uri()); + assertEquals("text/plain", filePart.mimeType()); + part = parts.get(2); + assertEquals(Part.Kind.FILE, part.getKind()); + filePart = ((FilePart) part).getFile(); + assertEquals("aGVsbG8=", ((FileWithBytes) filePart).bytes()); + assertEquals("hello.txt", filePart.name()); + assertTrue(task.getMetadata().isEmpty()); + } + + @Test + public void testA2AClientCancelTask() throws Exception { + this.server.when( + request() + .withMethod("POST") + .withPath("/") + .withBody(JsonBody.json(CANCEL_TASK_TEST_REQUEST, MatchType.STRICT)) + + ) + .respond( + response() + .withStatusCode(200) + .withBody(CANCEL_TASK_TEST_RESPONSE) + ); + + A2AClient client = new A2AClient("http://localhost:4001"); + CancelTaskResponse response = client.cancelTask("request-1234", + new TaskIdParams("de38c76d-d54c-436c-8b9f-4c2703648d64", new HashMap<>())); + + assertEquals("2.0", response.getJsonrpc()); + assertEquals(1, response.getId()); + Object result = response.getResult(); + assertInstanceOf(Task.class, result); + Task task = (Task) result; + assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); + assertEquals("c295ea44-7543-4f78-b524-7a38915ad6e4", task.getContextId()); + assertEquals(TaskState.CANCELED, task.getStatus().state()); + assertTrue(task.getMetadata().isEmpty()); + } + + @Test + public void testA2AClientGetTaskPushNotificationConfig() throws Exception { + this.server.when( + request() + .withMethod("POST") + .withPath("/") + .withBody(JsonBody.json(GET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_REQUEST, MatchType.STRICT)) + + ) + .respond( + response() + .withStatusCode(200) + .withBody(GET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE) + ); + + A2AClient client = new A2AClient("http://localhost:4001"); + GetTaskPushNotificationConfigResponse response = client.getTaskPushNotificationConfig("1", + new TaskIdParams("de38c76d-d54c-436c-8b9f-4c2703648d64", new HashMap<>())); + assertEquals("2.0", response.getJsonrpc()); + assertEquals(1, response.getId()); + assertInstanceOf(TaskPushNotificationConfig.class, response.getResult()); + TaskPushNotificationConfig taskPushNotificationConfig = (TaskPushNotificationConfig) response.getResult(); + PushNotificationConfig pushNotificationConfig = taskPushNotificationConfig.pushNotificationConfig(); + assertNotNull(pushNotificationConfig); + assertEquals("https://example.com/callback", pushNotificationConfig.url()); + PushNotificationAuthenticationInfo authenticationInfo = pushNotificationConfig.authentication(); + assertTrue(authenticationInfo.schemes().size() == 1); + assertEquals("jwt", authenticationInfo.schemes().get(0)); + } + + @Test + public void testA2AClientSetTaskPushNotificationConfig() throws Exception { + this.server.when( + request() + .withMethod("POST") + .withPath("/") + .withBody(JsonBody.json(SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_REQUEST, MatchType.STRICT)) + + ) + .respond( + response() + .withStatusCode(200) + .withBody(SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE) + ); + + A2AClient client = new A2AClient("http://localhost:4001"); + SetTaskPushNotificationConfigResponse response = client.setTaskPushNotificationConfig("1", + "de38c76d-d54c-436c-8b9f-4c2703648d64", + new PushNotificationConfig.Builder() + .url("https://example.com/callback") + .authenticationInfo(new PushNotificationAuthenticationInfo(Collections.singletonList("jwt"), null)) + .build()); + assertEquals("2.0", response.getJsonrpc()); + assertEquals(1, response.getId()); + assertInstanceOf(TaskPushNotificationConfig.class, response.getResult()); + TaskPushNotificationConfig taskPushNotificationConfig = (TaskPushNotificationConfig) response.getResult(); + PushNotificationConfig pushNotificationConfig = taskPushNotificationConfig.pushNotificationConfig(); + assertNotNull(pushNotificationConfig); + assertEquals("https://example.com/callback", pushNotificationConfig.url()); + PushNotificationAuthenticationInfo authenticationInfo = pushNotificationConfig.authentication(); + assertTrue(authenticationInfo.schemes().size() == 1); + assertEquals("jwt", authenticationInfo.schemes().get(0)); + } + + + @Test + public void testA2AClientGetAgentCard() throws Exception { + this.server.when( + request() + .withMethod("GET") + .withPath("/.well-known/agent.json") + ) + .respond( + response() + .withStatusCode(200) + .withBody(AGENT_CARD) + ); + + A2AClient client = new A2AClient("http://localhost:4001"); + AgentCard agentCard = client.getAgentCard(); + assertEquals("GeoSpatial Route Planner Agent", agentCard.name()); + assertEquals("Provides advanced route planning, traffic analysis, and custom map generation services. This agent can calculate optimal routes, estimate travel times considering real-time traffic, and create personalized maps with points of interest.", agentCard.description()); + assertEquals("https://georoute-agent.example.com/a2a/v1", agentCard.url()); + assertEquals("Example Geo Services Inc.", agentCard.provider().organization()); + assertEquals("https://www.examplegeoservices.com", agentCard.provider().url()); + assertEquals("1.2.0", agentCard.version()); + assertEquals("https://docs.examplegeoservices.com/georoute-agent/api", agentCard.documentationUrl()); + assertTrue(agentCard.capabilities().streaming()); + assertTrue(agentCard.capabilities().pushNotifications()); + assertFalse(agentCard.capabilities().stateTransitionHistory()); + Map securitySchemes = agentCard.securitySchemes(); + assertNotNull(securitySchemes); + OpenIdConnectSecurityScheme google = (OpenIdConnectSecurityScheme) securitySchemes.get("google"); + assertEquals("openIdConnect", google.getType()); + assertEquals("https://accounts.google.com/.well-known/openid-configuration", google.getOpenIdConnectUrl()); + List>> security = agentCard.security(); + assertEquals(1, security.size()); + Map> securityMap = security.get(0); + List scopes = securityMap.get("google"); + List expectedScopes = List.of("openid", "profile", "email"); + assertEquals(expectedScopes, scopes); + List defaultInputModes = List.of("application/json", "text/plain"); + assertEquals(defaultInputModes, agentCard.defaultInputModes()); + List defaultOutputModes = List.of("application/json", "image/png"); + assertEquals(defaultOutputModes, agentCard.defaultOutputModes()); + List skills = agentCard.skills(); + assertEquals("route-optimizer-traffic", skills.get(0).id()); + assertEquals("Traffic-Aware Route Optimizer", skills.get(0).name()); + assertEquals("Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", skills.get(0).description()); + List tags = List.of("maps", "routing", "navigation", "directions", "traffic"); + assertEquals(tags, skills.get(0).tags()); + List examples = List.of("Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + "{\"origin\": {\"lat\": 37.422, \"lng\": -122.084}, \"destination\": {\"lat\": 37.7749, \"lng\": -122.4194}, \"preferences\": [\"avoid_ferries\"]}"); + assertEquals(examples, skills.get(0).examples()); + assertEquals(defaultInputModes, skills.get(0).inputModes()); + List outputModes = List.of("application/json", "application/vnd.geo+json", "text/html"); + assertEquals(outputModes, skills.get(0).outputModes()); + assertEquals("custom-map-generator", skills.get(1).id()); + assertEquals("Personalized Map Generator", skills.get(1).name()); + assertEquals("Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", skills.get(1).description()); + tags = List.of("maps", "customization", "visualization", "cartography"); + assertEquals(tags, skills.get(1).tags()); + examples = List.of("Generate a map of my upcoming road trip with all planned stops highlighted.", + "Show me a map visualizing all coffee shops within a 1-mile radius of my current location."); + assertEquals(examples, skills.get(1).examples()); + List inputModes = List.of("application/json"); + assertEquals(inputModes, skills.get(1).inputModes()); + outputModes = List.of("image/png", "image/jpeg", "application/json", "text/html"); + assertEquals(outputModes, skills.get(1).outputModes()); + assertTrue(agentCard.supportsAuthenticatedExtendedCard()); + assertEquals("https://georoute-agent.example.com/icon.png", agentCard.iconUrl()); + } + + @Test + public void testA2AClientGetAuthenticatedExtendedAgentCard() throws Exception { + this.server.when( + request() + .withMethod("GET") + .withPath("/agent/authenticatedExtendedCard") + ) + .respond( + response() + .withStatusCode(200) + .withBody(AUTHENTICATION_EXTENDED_AGENT_CARD) + ); + + A2AClient client = new A2AClient("http://localhost:4001"); + AgentCard agentCard = client.getAgentCard("/agent/authenticatedExtendedCard", null); + assertEquals("GeoSpatial Route Planner Agent Extended", agentCard.name()); + assertEquals("Extended description", agentCard.description()); + assertEquals("https://georoute-agent.example.com/a2a/v1", agentCard.url()); + assertEquals("Example Geo Services Inc.", agentCard.provider().organization()); + assertEquals("https://www.examplegeoservices.com", agentCard.provider().url()); + assertEquals("1.2.0", agentCard.version()); + assertEquals("https://docs.examplegeoservices.com/georoute-agent/api", agentCard.documentationUrl()); + assertTrue(agentCard.capabilities().streaming()); + assertTrue(agentCard.capabilities().pushNotifications()); + assertFalse(agentCard.capabilities().stateTransitionHistory()); + Map securitySchemes = agentCard.securitySchemes(); + assertNotNull(securitySchemes); + OpenIdConnectSecurityScheme google = (OpenIdConnectSecurityScheme) securitySchemes.get("google"); + assertEquals("openIdConnect", google.getType()); + assertEquals("https://accounts.google.com/.well-known/openid-configuration", google.getOpenIdConnectUrl()); + List>> security = agentCard.security(); + assertEquals(1, security.size()); + Map> securityMap = security.get(0); + List scopes = securityMap.get("google"); + List expectedScopes = List.of("openid", "profile", "email"); + assertEquals(expectedScopes, scopes); + List defaultInputModes = List.of("application/json", "text/plain"); + assertEquals(defaultInputModes, agentCard.defaultInputModes()); + List defaultOutputModes = List.of("application/json", "image/png"); + assertEquals(defaultOutputModes, agentCard.defaultOutputModes()); + List skills = agentCard.skills(); + assertEquals("route-optimizer-traffic", skills.get(0).id()); + assertEquals("Traffic-Aware Route Optimizer", skills.get(0).name()); + assertEquals("Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", skills.get(0).description()); + List tags = List.of("maps", "routing", "navigation", "directions", "traffic"); + assertEquals(tags, skills.get(0).tags()); + List examples = List.of("Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + "{\"origin\": {\"lat\": 37.422, \"lng\": -122.084}, \"destination\": {\"lat\": 37.7749, \"lng\": -122.4194}, \"preferences\": [\"avoid_ferries\"]}"); + assertEquals(examples, skills.get(0).examples()); + assertEquals(defaultInputModes, skills.get(0).inputModes()); + List outputModes = List.of("application/json", "application/vnd.geo+json", "text/html"); + assertEquals(outputModes, skills.get(0).outputModes()); + assertEquals("custom-map-generator", skills.get(1).id()); + assertEquals("Personalized Map Generator", skills.get(1).name()); + assertEquals("Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", skills.get(1).description()); + tags = List.of("maps", "customization", "visualization", "cartography"); + assertEquals(tags, skills.get(1).tags()); + examples = List.of("Generate a map of my upcoming road trip with all planned stops highlighted.", + "Show me a map visualizing all coffee shops within a 1-mile radius of my current location."); + assertEquals(examples, skills.get(1).examples()); + List inputModes = List.of("application/json"); + assertEquals(inputModes, skills.get(1).inputModes()); + outputModes = List.of("image/png", "image/jpeg", "application/json", "text/html"); + assertEquals(outputModes, skills.get(1).outputModes()); + assertEquals("skill-extended", skills.get(2).id()); + assertEquals("Extended Skill", skills.get(2).name()); + assertEquals("This is an extended skill.", skills.get(2).description()); + assertEquals(List.of("extended"), skills.get(2).tags()); + assertTrue(agentCard.supportsAuthenticatedExtendedCard()); + assertEquals("https://georoute-agent.example.com/icon.png", agentCard.iconUrl()); + } + + @Test + public void testA2AClientSendMessageWithFilePart() throws Exception { + this.server.when( + request() + .withMethod("POST") + .withPath("/") + .withBody(JsonBody.json(SEND_MESSAGE_WITH_FILE_PART_TEST_REQUEST, MatchType.STRICT)) + + ) + .respond( + response() + .withStatusCode(200) + .withBody(SEND_MESSAGE_WITH_FILE_PART_TEST_RESPONSE) + ); + + A2AClient client = new A2AClient("http://localhost:4001"); + Message message = new Message.Builder() + .role(Message.Role.USER) + .parts(List.of( + new TextPart("analyze this image"), + new FilePart(new FileWithUri("image/jpeg", null, "file:///path/to/image.jpg")) + )) + .contextId("context-1234") + .messageId("message-1234-with-file") + .build(); + MessageSendConfiguration configuration = new MessageSendConfiguration.Builder() + .acceptedOutputModes(List.of("text")) + .blocking(true) + .build(); + MessageSendParams params = new MessageSendParams.Builder() + .message(message) + .configuration(configuration) + .build(); + + SendMessageResponse response = client.sendMessage("request-1234-with-file", params); + + assertEquals("2.0", response.getJsonrpc()); + assertNotNull(response.getId()); + Object result = response.getResult(); + assertInstanceOf(Task.class, result); + Task task = (Task) result; + assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); + assertNotNull(task.getContextId()); + assertEquals(TaskState.COMPLETED, task.getStatus().state()); + assertEquals(1, task.getArtifacts().size()); + Artifact artifact = task.getArtifacts().get(0); + assertEquals("artifact-1", artifact.artifactId()); + assertEquals("image-analysis", artifact.name()); + assertEquals(1, artifact.parts().size()); + Part part = artifact.parts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("This is an image of a cat sitting on a windowsill.", ((TextPart) part).getText()); + assertTrue(task.getMetadata().isEmpty()); + } + + @Test + public void testA2AClientSendMessageWithDataPart() throws Exception { + this.server.when( + request() + .withMethod("POST") + .withPath("/") + .withBody(JsonBody.json(SEND_MESSAGE_WITH_DATA_PART_TEST_REQUEST, MatchType.STRICT)) + + ) + .respond( + response() + .withStatusCode(200) + .withBody(SEND_MESSAGE_WITH_DATA_PART_TEST_RESPONSE) + ); + + A2AClient client = new A2AClient("http://localhost:4001"); + + Map data = new HashMap<>(); + data.put("temperature", 25.5); + data.put("humidity", 60.2); + data.put("location", "San Francisco"); + data.put("timestamp", "2024-01-15T10:30:00Z"); + + Message message = new Message.Builder() + .role(Message.Role.USER) + .parts(List.of( + new TextPart("process this data"), + new DataPart(data) + )) + .contextId("context-1234") + .messageId("message-1234-with-data") + .build(); + MessageSendConfiguration configuration = new MessageSendConfiguration.Builder() + .acceptedOutputModes(List.of("text")) + .blocking(true) + .build(); + MessageSendParams params = new MessageSendParams.Builder() + .message(message) + .configuration(configuration) + .build(); + + SendMessageResponse response = client.sendMessage("request-1234-with-data", params); + + assertEquals("2.0", response.getJsonrpc()); + assertNotNull(response.getId()); + Object result = response.getResult(); + assertInstanceOf(Task.class, result); + Task task = (Task) result; + assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); + assertNotNull(task.getContextId()); + assertEquals(TaskState.COMPLETED, task.getStatus().state()); + assertEquals(1, task.getArtifacts().size()); + Artifact artifact = task.getArtifacts().get(0); + assertEquals("artifact-1", artifact.artifactId()); + assertEquals("data-analysis", artifact.name()); + assertEquals(1, artifact.parts().size()); + Part part = artifact.parts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("Processed weather data: Temperature is 25.5°C, humidity is 60.2% in San Francisco.", ((TextPart) part).getText()); + assertTrue(task.getMetadata().isEmpty()); + } + + @Test + public void testA2AClientSendMessageWithMixedParts() throws Exception { + this.server.when( + request() + .withMethod("POST") + .withPath("/") + .withBody(JsonBody.json(SEND_MESSAGE_WITH_MIXED_PARTS_TEST_REQUEST, MatchType.STRICT)) + + ) + .respond( + response() + .withStatusCode(200) + .withBody(SEND_MESSAGE_WITH_MIXED_PARTS_TEST_RESPONSE) + ); + + A2AClient client = new A2AClient("http://localhost:4001"); + + Map data = new HashMap<>(); + data.put("chartType", "bar"); + data.put("dataPoints", List.of(10, 20, 30, 40)); + data.put("labels", List.of("Q1", "Q2", "Q3", "Q4")); + + Message message = new Message.Builder() + .role(Message.Role.USER) + .parts(List.of( + new TextPart("analyze this data and image"), + new FilePart(new FileWithBytes("image/png", "chart.png", "aGVsbG8=")), + new DataPart(data) + )) + .contextId("context-1234") + .messageId("message-1234-with-mixed") + .build(); + MessageSendConfiguration configuration = new MessageSendConfiguration.Builder() + .acceptedOutputModes(List.of("text")) + .blocking(true) + .build(); + MessageSendParams params = new MessageSendParams.Builder() + .message(message) + .configuration(configuration) + .build(); + + SendMessageResponse response = client.sendMessage("request-1234-with-mixed", params); + + assertEquals("2.0", response.getJsonrpc()); + assertNotNull(response.getId()); + Object result = response.getResult(); + assertInstanceOf(Task.class, result); + Task task = (Task) result; + assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); + assertNotNull(task.getContextId()); + assertEquals(TaskState.COMPLETED, task.getStatus().state()); + assertEquals(1, task.getArtifacts().size()); + Artifact artifact = task.getArtifacts().get(0); + assertEquals("artifact-1", artifact.artifactId()); + assertEquals("mixed-analysis", artifact.name()); + assertEquals(1, artifact.parts().size()); + Part part = artifact.parts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("Analyzed chart image and data: Bar chart showing quarterly data with values [10, 20, 30, 40].", ((TextPart) part).getText()); + assertTrue(task.getMetadata().isEmpty()); + } +} \ No newline at end of file diff --git a/core/src/test/java/io/a2a/client/JsonMessages.java b/core/src/test/java/io/a2a/client/JsonMessages.java new file mode 100644 index 000000000..c7ebd7780 --- /dev/null +++ b/core/src/test/java/io/a2a/client/JsonMessages.java @@ -0,0 +1,616 @@ +package io.a2a.client; + +/** + * Request and response messages used by the tests. These have been created following examples from + * the A2A sample messages. + */ +public class JsonMessages { + + static final String AGENT_CARD = """ + { + "name": "GeoSpatial Route Planner Agent", + "description": "Provides advanced route planning, traffic analysis, and custom map generation services. This agent can calculate optimal routes, estimate travel times considering real-time traffic, and create personalized maps with points of interest.", + "url": "https://georoute-agent.example.com/a2a/v1", + "provider": { + "organization": "Example Geo Services Inc.", + "url": "https://www.examplegeoservices.com" + }, + "iconUrl": "https://georoute-agent.example.com/icon.png", + "version": "1.2.0", + "documentationUrl": "https://docs.examplegeoservices.com/georoute-agent/api", + "capabilities": { + "streaming": true, + "pushNotifications": true, + "stateTransitionHistory": false + }, + "securitySchemes": { + "google": { + "type": "openIdConnect", + "openIdConnectUrl": "https://accounts.google.com/.well-known/openid-configuration" + } + }, + "security": [{ "google": ["openid", "profile", "email"] }], + "defaultInputModes": ["application/json", "text/plain"], + "defaultOutputModes": ["application/json", "image/png"], + "skills": [ + { + "id": "route-optimizer-traffic", + "name": "Traffic-Aware Route Optimizer", + "description": "Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", + "tags": ["maps", "routing", "navigation", "directions", "traffic"], + "examples": [ + "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + "{\\"origin\\": {\\"lat\\": 37.422, \\"lng\\": -122.084}, \\"destination\\": {\\"lat\\": 37.7749, \\"lng\\": -122.4194}, \\"preferences\\": [\\"avoid_ferries\\"]}" + ], + "inputModes": ["application/json", "text/plain"], + "outputModes": [ + "application/json", + "application/vnd.geo+json", + "text/html" + ] + }, + { + "id": "custom-map-generator", + "name": "Personalized Map Generator", + "description": "Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", + "tags": ["maps", "customization", "visualization", "cartography"], + "examples": [ + "Generate a map of my upcoming road trip with all planned stops highlighted.", + "Show me a map visualizing all coffee shops within a 1-mile radius of my current location." + ], + "inputModes": ["application/json"], + "outputModes": [ + "image/png", + "image/jpeg", + "application/json", + "text/html" + ] + } + ], + "supportsAuthenticatedExtendedCard": true + }"""; + + static final String AUTHENTICATION_EXTENDED_AGENT_CARD = """ + { + "name": "GeoSpatial Route Planner Agent Extended", + "description": "Extended description", + "url": "https://georoute-agent.example.com/a2a/v1", + "provider": { + "organization": "Example Geo Services Inc.", + "url": "https://www.examplegeoservices.com" + }, + "iconUrl": "https://georoute-agent.example.com/icon.png", + "version": "1.2.0", + "documentationUrl": "https://docs.examplegeoservices.com/georoute-agent/api", + "capabilities": { + "streaming": true, + "pushNotifications": true, + "stateTransitionHistory": false + }, + "securitySchemes": { + "google": { + "type": "openIdConnect", + "openIdConnectUrl": "https://accounts.google.com/.well-known/openid-configuration" + } + }, + "security": [{ "google": ["openid", "profile", "email"] }], + "defaultInputModes": ["application/json", "text/plain"], + "defaultOutputModes": ["application/json", "image/png"], + "skills": [ + { + "id": "route-optimizer-traffic", + "name": "Traffic-Aware Route Optimizer", + "description": "Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", + "tags": ["maps", "routing", "navigation", "directions", "traffic"], + "examples": [ + "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + "{\\"origin\\": {\\"lat\\": 37.422, \\"lng\\": -122.084}, \\"destination\\": {\\"lat\\": 37.7749, \\"lng\\": -122.4194}, \\"preferences\\": [\\"avoid_ferries\\"]}" + ], + "inputModes": ["application/json", "text/plain"], + "outputModes": [ + "application/json", + "application/vnd.geo+json", + "text/html" + ] + }, + { + "id": "custom-map-generator", + "name": "Personalized Map Generator", + "description": "Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", + "tags": ["maps", "customization", "visualization", "cartography"], + "examples": [ + "Generate a map of my upcoming road trip with all planned stops highlighted.", + "Show me a map visualizing all coffee shops within a 1-mile radius of my current location." + ], + "inputModes": ["application/json"], + "outputModes": [ + "image/png", + "image/jpeg", + "application/json", + "text/html" + ] + }, + { + "id": "skill-extended", + "name": "Extended Skill", + "description": "This is an extended skill.", + "tags": ["extended"] + } + ], + "supportsAuthenticatedExtendedCard": true + }"""; + + + static final String SEND_MESSAGE_TEST_REQUEST = """ + { + "jsonrpc": "2.0", + "id": "request-1234", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [ + { + "kind": "text", + "text": "tell me a joke" + } + ], + "messageId": "message-1234", + "contextId": "context-1234", + "kind": "message" + }, + "configuration": { + "acceptedOutputModes": ["text"], + "blocking": true + }, + } + }"""; + + static final String SEND_MESSAGE_TEST_RESPONSE = """ + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "id": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "contextId": "c295ea44-7543-4f78-b524-7a38915ad6e4", + "status": { + "state": "completed" + }, + "artifacts": [ + { + "artifactId": "artifact-1", + "name": "joke", + "parts": [ + { + "kind": "text", + "text": "Why did the chicken cross the road? To get to the other side!" + } + ] + } + ], + "metadata": {}, + "kind": "task" + } + }"""; + + static final String SEND_MESSAGE_TEST_REQUEST_WITH_MESSAGE_RESPONSE = """ + { + "jsonrpc": "2.0", + "id": "request-1234-with-message-response", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [ + { + "kind": "text", + "text": "tell me a joke" + } + ], + "messageId": "message-1234", + "contextId": "context-1234", + "kind": "message" + }, + "configuration": { + "acceptedOutputModes": ["text"], + "blocking": true + }, + } + }"""; + + + static final String SEND_MESSAGE_TEST_RESPONSE_WITH_MESSAGE_RESPONSE = """ + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "role": "agent", + "parts": [ + { + "kind": "text", + "text": "Why did the chicken cross the road? To get to the other side!" + } + ], + "messageId": "msg-456", + "kind": "message" + } + }"""; + + static final String SEND_MESSAGE_WITH_ERROR_TEST_REQUEST = """ + { + "jsonrpc": "2.0", + "id": "request-1234-with-error", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [ + { + "kind": "text", + "text": "tell me a joke" + } + ], + "messageId": "message-1234", + "contextId": "context-1234", + "kind": "message" + }, + "configuration": { + "acceptedOutputModes": ["text"], + "blocking": true + }, + } + }"""; + + static final String SEND_MESSAGE_ERROR_TEST_RESPONSE = """ + { + "jsonrpc": "2.0", + "error": { + "code": -32702, + "message": "Invalid parameters", + "data": "Hello world" + } + }"""; + + static final String GET_TASK_TEST_REQUEST = """ + { + "jsonrpc": "2.0", + "id": "request-1234", + "method": "tasks/get", + "params": { + "id": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "historyLength": 10 + } + } + """; + + static final String GET_TASK_TEST_RESPONSE = """ + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "id": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "contextId": "c295ea44-7543-4f78-b524-7a38915ad6e4", + "status": { + "state": "completed" + }, + "artifacts": [ + { + "artifactId": "artifact-1", + "parts": [ + { + "kind": "text", + "text": "Why did the chicken cross the road? To get to the other side!" + } + ] + } + ], + "history": [ + { + "role": "user", + "parts": [ + { + "kind": "text", + "text": "tell me a joke" + }, + { + "kind": "file", + "file": { + "uri": "file:///path/to/file.txt", + "mimeType": "text/plain" + } + }, + { + "kind": "file", + "file": { + "bytes": "aGVsbG8=", + "name": "hello.txt" + } + } + ], + "messageId": "message-123", + "kind": "message" + } + ], + "metadata": {}, + "kind": "task" + } + } + """; + + static final String CANCEL_TASK_TEST_REQUEST = """ + { + "jsonrpc": "2.0", + "id": "request-1234", + "method": "tasks/cancel", + "params": { + "id": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "metadata": {} + } + } + """; + + static final String CANCEL_TASK_TEST_RESPONSE = """ + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "id": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "contextId": "c295ea44-7543-4f78-b524-7a38915ad6e4", + "status": { + "state": "canceled" + }, + "metadata": {}, + "kind" : "task" + } + } + """; + + static final String GET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_REQUEST = """ + { + "jsonrpc": "2.0", + "id": "1", + "method": "tasks/pushNotificationConfig/get", + "params": { + "id": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "metadata": {}, + } + } + """; + + static final String GET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE = """ + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "taskId": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "pushNotificationConfig": { + "url": "https://example.com/callback", + "authentication": { + "schemes": ["jwt"] + } + } + } + } + """; + + static final String SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_REQUEST = """ + { + "jsonrpc": "2.0", + "id": "1", + "method": "tasks/pushNotificationConfig/set", + "params": { + "taskId": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "pushNotificationConfig": { + "url": "https://example.com/callback", + "authentication": { + "schemes": ["jwt"] + } + } + } + }"""; + + static final String SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE = """ + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "taskId": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "pushNotificationConfig": { + "url": "https://example.com/callback", + "authentication": { + "schemes": ["jwt"] + } + } + } + } + """; + + static final String SEND_MESSAGE_WITH_FILE_PART_TEST_REQUEST = """ + { + "jsonrpc": "2.0", + "id": "request-1234-with-file", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [ + { + "kind": "text", + "text": "analyze this image" + }, + { + "kind": "file", + "file": { + "uri": "file:///path/to/image.jpg", + "mimeType": "image/jpeg" + } + } + ], + "messageId": "message-1234-with-file", + "contextId": "context-1234", + "kind": "message" + }, + "configuration": { + "acceptedOutputModes": ["text"], + "blocking": true + } + } + }"""; + + static final String SEND_MESSAGE_WITH_FILE_PART_TEST_RESPONSE = """ + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "id": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "contextId": "c295ea44-7543-4f78-b524-7a38915ad6e4", + "status": { + "state": "completed" + }, + "artifacts": [ + { + "artifactId": "artifact-1", + "name": "image-analysis", + "parts": [ + { + "kind": "text", + "text": "This is an image of a cat sitting on a windowsill." + } + ] + } + ], + "metadata": {}, + "kind": "task" + } + }"""; + + static final String SEND_MESSAGE_WITH_DATA_PART_TEST_REQUEST = """ + { + "jsonrpc": "2.0", + "id": "request-1234-with-data", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [ + { + "kind": "text", + "text": "process this data" + }, + { + "kind": "data", + "data": { + "temperature": 25.5, + "humidity": 60.2, + "location": "San Francisco", + "timestamp": "2024-01-15T10:30:00Z" + } + } + ], + "messageId": "message-1234-with-data", + "contextId": "context-1234", + "kind": "message" + }, + "configuration": { + "acceptedOutputModes": ["text"], + "blocking": true + } + } + }"""; + + static final String SEND_MESSAGE_WITH_DATA_PART_TEST_RESPONSE = """ + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "id": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "contextId": "c295ea44-7543-4f78-b524-7a38915ad6e4", + "status": { + "state": "completed" + }, + "artifacts": [ + { + "artifactId": "artifact-1", + "name": "data-analysis", + "parts": [ + { + "kind": "text", + "text": "Processed weather data: Temperature is 25.5°C, humidity is 60.2% in San Francisco." + } + ] + } + ], + "metadata": {}, + "kind": "task" + } + }"""; + + static final String SEND_MESSAGE_WITH_MIXED_PARTS_TEST_REQUEST = """ + { + "jsonrpc": "2.0", + "id": "request-1234-with-mixed", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [ + { + "kind": "text", + "text": "analyze this data and image" + }, + { + "kind": "file", + "file": { + "bytes": "aGVsbG8=", + "name": "chart.png", + "mimeType": "image/png" + } + }, + { + "kind": "data", + "data": { + "chartType": "bar", + "dataPoints": [10, 20, 30, 40], + "labels": ["Q1", "Q2", "Q3", "Q4"] + } + } + ], + "messageId": "message-1234-with-mixed", + "contextId": "context-1234", + "kind": "message" + }, + "configuration": { + "acceptedOutputModes": ["text"], + "blocking": true + } + } + }"""; + + static final String SEND_MESSAGE_WITH_MIXED_PARTS_TEST_RESPONSE = """ + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "id": "de38c76d-d54c-436c-8b9f-4c2703648d64", + "contextId": "c295ea44-7543-4f78-b524-7a38915ad6e4", + "status": { + "state": "completed" + }, + "artifacts": [ + { + "artifactId": "artifact-1", + "name": "mixed-analysis", + "parts": [ + { + "kind": "text", + "text": "Analyzed chart image and data: Bar chart showing quarterly data with values [10, 20, 30, 40]." + } + ] + } + ], + "metadata": {}, + "kind": "task" + } + }"""; + +} diff --git a/core/src/test/java/io/a2a/client/JsonStreamingMessages.java b/core/src/test/java/io/a2a/client/JsonStreamingMessages.java new file mode 100644 index 000000000..cf80de7b8 --- /dev/null +++ b/core/src/test/java/io/a2a/client/JsonStreamingMessages.java @@ -0,0 +1,148 @@ +package io.a2a.client; + +/** + * Contains JSON strings for testing SSE streaming. + */ +public class JsonStreamingMessages { + + public static final String STREAMING_TASK_EVENT = """ + data: { + "jsonrpc": "2.0", + "id": "1234", + "result": { + "kind": "task", + "id": "task-123", + "contextId": "context-456", + "status": { + "state": "working" + } + } + } + """; + + + public static final String STREAMING_MESSAGE_EVENT = """ + data: { + "jsonrpc": "2.0", + "id": "1234", + "result": { + "kind": "message", + "role": "agent", + "messageId": "msg-123", + "contextId": "context-456", + "parts": [ + { + "kind": "text", + "text": "Hello, world!" + } + ] + } + }"""; + + public static final String STREAMING_STATUS_UPDATE_EVENT = """ + data: { + "jsonrpc": "2.0", + "id": "1234", + "result": { + "taskId": "1", + "contextId": "2", + "status": { + "state": "submitted" + }, + "final": false, + "kind": "status-update" + } + }"""; + + public static final String STREAMING_STATUS_UPDATE_EVENT_FINAL = """ + data: { + "jsonrpc": "2.0", + "id": "1234", + "result": { + "taskId": "1", + "contextId": "2", + "status": { + "state": "completed" + }, + "final": true, + "kind": "status-update" + } + }"""; + + public static final String STREAMING_ARTIFACT_UPDATE_EVENT = """ + data: { + "jsonrpc": "2.0", + "id": "1234", + "result": { + "kind": "artifact-update", + "taskId": "1", + "contextId": "2", + "append": false, + "lastChunk": true, + "artifact": { + "artifactId": "artifact-1", + "parts": [ + { + "kind": "text", + "text": "Why did the chicken cross the road? To get to the other side!" + } + ] + } + } + } + }"""; + + public static final String STREAMING_ERROR_EVENT = """ + data: { + "jsonrpc": "2.0", + "id": "1234", + "error": { + "code": -32602, + "message": "Invalid parameters", + "data": "Missing required field" + } + }"""; + + public static final String SEND_MESSAGE_STREAMING_TEST_REQUEST = """ + { + "jsonrpc": "2.0", + "id": "request-1234", + "method": "message/stream", + "params": { + "message": { + "role": "user", + "parts": [ + { + "kind": "text", + "text": "tell me some jokes" + } + ], + "messageId": "message-1234", + "contextId": "context-1234", + "kind": "message" + }, + "configuration": { + "acceptedOutputModes": ["text"], + "blocking": false + }, + } + }"""; + + static final String SEND_MESSAGE_STREAMING_TEST_RESPONSE = + "event: message\n" + + "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"id\":\"2\",\"contextId\":\"context-1234\",\"status\":{\"state\":\"completed\"},\"artifacts\":[{\"artifactId\":\"artifact-1\",\"name\":\"joke\",\"parts\":[{\"kind\":\"text\",\"text\":\"Why did the chicken cross the road? To get to the other side!\"}]}],\"metadata\":{},\"kind\":\"task\"}}\n\n"; + + static final String TASK_RESUBSCRIPTION_REQUEST_TEST_RESPONSE = + "event: message\n" + + "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"id\":\"2\",\"contextId\":\"context-1234\",\"status\":{\"state\":\"completed\"},\"artifacts\":[{\"artifactId\":\"artifact-1\",\"name\":\"joke\",\"parts\":[{\"kind\":\"text\",\"text\":\"Why did the chicken cross the road? To get to the other side!\"}]}],\"metadata\":{},\"kind\":\"task\"}}\n\n"; + + public static final String TASK_RESUBSCRIPTION_TEST_REQUEST = """ + { + "jsonrpc": "2.0", + "id": "request-1234", + "method": "tasks/resubscribe", + "params": { + "id": "task-1234" + } + }"""; +} \ No newline at end of file diff --git a/core/src/test/java/io/a2a/client/sse/SSEEventListenerTest.java b/core/src/test/java/io/a2a/client/sse/SSEEventListenerTest.java new file mode 100644 index 000000000..1fca0ff9c --- /dev/null +++ b/core/src/test/java/io/a2a/client/sse/SSEEventListenerTest.java @@ -0,0 +1,264 @@ +package io.a2a.client.sse; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import io.a2a.client.JsonStreamingMessages; +import io.a2a.spec.Artifact; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.Message; +import io.a2a.spec.Part; +import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.a2a.spec.TextPart; +import org.junit.jupiter.api.Test; + +public class SSEEventListenerTest { + + @Test + public void testOnEventWithTaskResult() throws Exception { + // Set up event handler + AtomicReference receivedEvent = new AtomicReference<>(); + SSEEventListener listener = new SSEEventListener( + event -> receivedEvent.set(event), + error -> {}, + () -> {}); + + // Parse the task event JSON + String eventData = JsonStreamingMessages.STREAMING_TASK_EVENT.substring( + JsonStreamingMessages.STREAMING_TASK_EVENT.indexOf("{")); + + // Call the onEvent method directly + listener.onMessage(eventData, null); + + // Verify the event was processed correctly + assertNotNull(receivedEvent.get()); + assertTrue(receivedEvent.get() instanceof Task); + Task task = (Task) receivedEvent.get(); + assertEquals("task-123", task.getId()); + assertEquals("context-456", task.getContextId()); + assertEquals(TaskState.WORKING, task.getStatus().state()); + } + + @Test + public void testOnEventWithMessageResult() throws Exception { + // Set up event handler + AtomicReference receivedEvent = new AtomicReference<>(); + SSEEventListener listener = new SSEEventListener( + event -> receivedEvent.set(event), + error -> {}, + () -> {}); + + // Parse the message event JSON + String eventData = JsonStreamingMessages.STREAMING_MESSAGE_EVENT.substring( + JsonStreamingMessages.STREAMING_MESSAGE_EVENT.indexOf("{")); + + // Call onEvent method + listener.onMessage(eventData, null); + + // Verify the event was processed correctly + assertNotNull(receivedEvent.get()); + assertTrue(receivedEvent.get() instanceof Message); + Message message = (Message) receivedEvent.get(); + assertEquals(Message.Role.AGENT, message.getRole()); + assertEquals("msg-123", message.getMessageId()); + assertEquals("context-456", message.getContextId()); + assertEquals(1, message.getParts().size()); + assertTrue(message.getParts().get(0) instanceof TextPart); + assertEquals("Hello, world!", ((TextPart) message.getParts().get(0)).getText()); + } + + @Test + public void testOnEventWithTaskStatusUpdateEventEvent() throws Exception { + // Set up event handler + AtomicReference receivedEvent = new AtomicReference<>(); + SSEEventListener listener = new SSEEventListener( + event -> receivedEvent.set(event), + error -> {}, + () -> {}); + + // Parse the message event JSON + String eventData = JsonStreamingMessages.STREAMING_STATUS_UPDATE_EVENT.substring( + JsonStreamingMessages.STREAMING_STATUS_UPDATE_EVENT.indexOf("{")); + + // Call onEvent method + listener.onMessage(eventData, null); + + // Verify the event was processed correctly + assertNotNull(receivedEvent.get()); + assertTrue(receivedEvent.get() instanceof TaskStatusUpdateEvent); + TaskStatusUpdateEvent taskStatusUpdateEvent = (TaskStatusUpdateEvent) receivedEvent.get(); + assertEquals("1", taskStatusUpdateEvent.getTaskId()); + assertEquals("2", taskStatusUpdateEvent.getContextId()); + assertFalse(taskStatusUpdateEvent.isFinal()); + assertEquals(TaskState.SUBMITTED, taskStatusUpdateEvent.getStatus().state()); + } + + @Test + public void testOnEventWithTaskArtifactUpdateEventEvent() throws Exception { + // Set up event handler + AtomicReference receivedEvent = new AtomicReference<>(); + SSEEventListener listener = new SSEEventListener( + event -> receivedEvent.set(event), + error -> {}, + () -> {}); + + // Parse the message event JSON + String eventData = JsonStreamingMessages.STREAMING_ARTIFACT_UPDATE_EVENT.substring( + JsonStreamingMessages.STREAMING_ARTIFACT_UPDATE_EVENT.indexOf("{")); + + // Call onEvent method + listener.onMessage(eventData, null); + + // Verify the event was processed correctly + assertNotNull(receivedEvent.get()); + assertTrue(receivedEvent.get() instanceof TaskArtifactUpdateEvent); + + TaskArtifactUpdateEvent taskArtifactUpdateEvent = (TaskArtifactUpdateEvent) receivedEvent.get(); + assertEquals("1", taskArtifactUpdateEvent.getTaskId()); + assertEquals("2", taskArtifactUpdateEvent.getContextId()); + assertFalse(taskArtifactUpdateEvent.isAppend()); + assertTrue(taskArtifactUpdateEvent.isLastChunk()); + Artifact artifact = taskArtifactUpdateEvent.getArtifact(); + assertEquals("artifact-1", artifact.artifactId()); + assertEquals(1, artifact.parts().size()); + assertEquals(Part.Kind.TEXT, artifact.parts().get(0).getKind()); + assertEquals("Why did the chicken cross the road? To get to the other side!", ((TextPart) artifact.parts().get(0)).getText()); + } + + @Test + public void testOnEventWithError() throws Exception { + // Set up event handler + AtomicReference receivedError = new AtomicReference<>(); + SSEEventListener listener = new SSEEventListener( + event -> {}, + error -> receivedError.set(error), + () -> {}); + + // Parse the error event JSON + String eventData = JsonStreamingMessages.STREAMING_ERROR_EVENT.substring( + JsonStreamingMessages.STREAMING_ERROR_EVENT.indexOf("{")); + + // Call onEvent method + listener.onMessage(eventData, null); + + // Verify the error was processed correctly + assertNotNull(receivedError.get()); + assertEquals(-32602, receivedError.get().getCode()); + assertEquals("Invalid parameters", receivedError.get().getMessage()); + assertEquals("Missing required field", receivedError.get().getData()); + } + + @Test + public void testOnFailure() { + AtomicBoolean failureHandlerCalled = new AtomicBoolean(false); + SSEEventListener listener = new SSEEventListener( + event -> {}, + error -> {}, + () -> failureHandlerCalled.set(true)); + + // Simulate a failure + CancelCapturingFuture future = new CancelCapturingFuture(); + listener.onError(new RuntimeException("Test exception"), future); + + // Verify the failure handler was called + assertTrue(failureHandlerCalled.get()); + // Verify it got cancelled + assertTrue(future.cancelHandlerCalled); + } + + @Test + public void testFinalTaskStatusUpdateEventCancels() { + TaskStatusUpdateEvent tsue = new TaskStatusUpdateEvent.Builder() + .taskId("1234") + .contextId("xyz") + .status(new TaskStatus(TaskState.COMPLETED)) + .isFinal(true) + .build(); + + // Set up event handler + AtomicReference receivedEvent = new AtomicReference<>(); + SSEEventListener listener = new SSEEventListener( + event -> receivedEvent.set(event), + error -> {}, + () -> {}); + + + } + + @Test + public void testOnEventWithFinalTaskStatusUpdateEventEventCancels() throws Exception { + // Set up event handler + AtomicReference receivedEvent = new AtomicReference<>(); + SSEEventListener listener = new SSEEventListener( + event -> receivedEvent.set(event), + error -> {}, + () -> {}); + + // Parse the message event JSON + String eventData = JsonStreamingMessages.STREAMING_STATUS_UPDATE_EVENT_FINAL.substring( + JsonStreamingMessages.STREAMING_STATUS_UPDATE_EVENT_FINAL.indexOf("{")); + + // Call onEvent method + CancelCapturingFuture future = new CancelCapturingFuture(); + listener.onMessage(eventData, future); + + // Verify the event was processed correctly + assertNotNull(receivedEvent.get()); + assertTrue(receivedEvent.get() instanceof TaskStatusUpdateEvent); + TaskStatusUpdateEvent taskStatusUpdateEvent = (TaskStatusUpdateEvent) receivedEvent.get(); + assertEquals("1", taskStatusUpdateEvent.getTaskId()); + assertEquals("2", taskStatusUpdateEvent.getContextId()); + assertTrue(taskStatusUpdateEvent.isFinal()); + assertEquals(TaskState.COMPLETED, taskStatusUpdateEvent.getStatus().state()); + + assertTrue(future.cancelHandlerCalled); + } + + + private static class CancelCapturingFuture implements Future { + private boolean cancelHandlerCalled; + + public CancelCapturingFuture() { + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + cancelHandlerCalled = true; + return true; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return false; + } + + @Override + public Void get() throws InterruptedException, ExecutionException { + return null; + } + + @Override + public Void get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + return null; + } + } +} \ No newline at end of file diff --git a/examples/client/pom.xml b/examples/client/pom.xml new file mode 100644 index 000000000..4626af575 --- /dev/null +++ b/examples/client/pom.xml @@ -0,0 +1,41 @@ + + + 4.0.0 + + + io.a2a.sdk + a2a-java-sdk-examples-parent + 0.2.4-SNAPSHOT + + + a2a-java-sdk-examples-client + + Java SDK A2A Examples + Examples for the Java SDK for the Agent2Agent Protocol (A2A) + + + + io.a2a.sdk + a2a-java-sdk-core + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + ${maven-compiler-plugin.version} + + + + + org.apache.maven.plugins + maven-surefire-plugin + ${maven-surefire-plugin.version} + + + + \ No newline at end of file diff --git a/examples/client/src/main/java/io/a2a/examples/helloworld/HelloWorldClient.java b/examples/client/src/main/java/io/a2a/examples/helloworld/HelloWorldClient.java new file mode 100644 index 000000000..cea945492 --- /dev/null +++ b/examples/client/src/main/java/io/a2a/examples/helloworld/HelloWorldClient.java @@ -0,0 +1,60 @@ +package io.a2a.examples.helloworld; + +import java.util.HashMap; +import java.util.Map; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.a2a.client.A2AClient; +import io.a2a.spec.A2A; +import io.a2a.spec.AgentCard; +import io.a2a.spec.Message; +import io.a2a.spec.MessageSendParams; +import io.a2a.spec.SendMessageResponse; + +/** + * A simple example of using the A2A Java SDK to communicate with an A2A server. + * This example is equivalent to the Python example provided in the A2A Python SDK. + */ +public class HelloWorldClient { + + private static final String SERVER_URL = "http://localhost:9999"; + private static final String MESSAGE_TEXT = "how much is 10 USD in INR?"; + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + public static void main(String[] args) { + try { + AgentCard finalAgentCard = null; + AgentCard publicAgentCard = A2A.getAgentCard("http://localhost:9999"); + System.out.println("Successfully fetched public agent card:"); + System.out.println(OBJECT_MAPPER.writeValueAsString(publicAgentCard)); + System.out.println("Using public agent card for client initialization (default)."); + finalAgentCard = publicAgentCard; + + if (publicAgentCard.supportsAuthenticatedExtendedCard()) { + System.out.println("Public card supports authenticated extended card. Attempting to fetch from: " + SERVER_URL + "/agent/authenticatedExtendedCard"); + Map authHeaders = new HashMap<>(); + authHeaders.put("Authorization", "Bearer dummy-token-for-extended-card"); + AgentCard extendedAgentCard = A2A.getAgentCard(SERVER_URL, "/agent/authenticatedExtendedCard", authHeaders); + System.out.println("Successfully fetched authenticated extended agent card:"); + System.out.println(OBJECT_MAPPER.writeValueAsString(extendedAgentCard)); + System.out.println("Using AUTHENTICATED EXTENDED agent card for client initialization."); + finalAgentCard = extendedAgentCard; + } else { + System.out.println("Public card does not indicate support for an extended card. Using public card."); + } + + A2AClient client = new A2AClient(finalAgentCard); + Message message = A2A.toUserMessage(MESSAGE_TEXT); // the message ID will be automatically generated for you + MessageSendParams params = new MessageSendParams.Builder() + .message(message) + .build(); + SendMessageResponse response = client.sendMessage(params); + System.out.println("Message sent with ID: " + response.getId()); + System.out.println("Response: " + response.toString()); + } catch (Exception e) { + System.err.println("An error occurred: " + e.getMessage()); + e.printStackTrace(); + } + } + +} \ No newline at end of file diff --git a/examples/client/src/main/java/io/a2a/examples/helloworld/HelloWorldRunner.java b/examples/client/src/main/java/io/a2a/examples/helloworld/HelloWorldRunner.java new file mode 100644 index 000000000..b87b14ce3 --- /dev/null +++ b/examples/client/src/main/java/io/a2a/examples/helloworld/HelloWorldRunner.java @@ -0,0 +1,22 @@ +///usr/bin/env jbang "$0" "$@" ; exit $? +//DEPS io.a2a.sdk:a2a-java-sdk-core:0.2.4-SNAPSHOT +//SOURCES HelloWorldClient.java + +/** + * JBang script to run the A2A HelloWorldClient example. + * This script automatically handles the dependencies and runs the client. + * + * Prerequisites: + * - JBang installed (see https://www.jbang.dev/documentation/guide/latest/installation.html) + * - A running A2A server (see README.md for instructions on setting up the Python server) + * + * Usage: + * $ jbang HelloWorldRunner.java + * + * The script will communicate with the A2A server at http://localhost:9999 + */ +public class HelloWorldRunner { + public static void main(String[] args) { + io.a2a.examples.helloworld.HelloWorldClient.main(args); + } +} \ No newline at end of file diff --git a/examples/client/src/main/java/io/a2a/examples/helloworld/INSTALL_JBANG.md b/examples/client/src/main/java/io/a2a/examples/helloworld/INSTALL_JBANG.md new file mode 100644 index 000000000..7bfa392f6 --- /dev/null +++ b/examples/client/src/main/java/io/a2a/examples/helloworld/INSTALL_JBANG.md @@ -0,0 +1,56 @@ +# Installing JBang + +[JBang](https://www.jbang.dev/) is a tool that makes it easy to run Java code with zero installation. This guide provides quick installation instructions for different platforms. + +## Linux and macOS + +You can install JBang using `curl` or `wget`: + +```bash +# Using curl +curl -Ls https://sh.jbang.dev | bash -s - app setup + +# OR using wget +wget -q https://sh.jbang.dev -O - | bash -s - app setup +``` + +After installation, you may need to restart your terminal or source your shell configuration file: + +```bash +source ~/.bashrc # For Bash +source ~/.zshrc # For Zsh +``` + +## Windows + +### Using PowerShell + +```powershell +iex "& { $(iwr https://ps.jbang.dev) } app setup" +``` + +### Using Chocolatey + +```powershell +choco install jbang +``` + +### Using Scoop + +```powershell +scoop install jbang +``` + +## Verifying Installation + +To verify that JBang is installed correctly, run: + +```bash +jbang --version +``` + +You should see the JBang version number displayed. + +## Further Information + +For more detailed installation instructions and options, visit the [JBang installation documentation](https://www.jbang.dev/documentation/guide/latest/installation.html). \ No newline at end of file diff --git a/examples/client/src/main/java/io/a2a/examples/helloworld/README.md b/examples/client/src/main/java/io/a2a/examples/helloworld/README.md new file mode 100644 index 000000000..d3d33c0bd --- /dev/null +++ b/examples/client/src/main/java/io/a2a/examples/helloworld/README.md @@ -0,0 +1,91 @@ +# A2A Hello World Example + +This example demonstrates how to use the A2A Java SDK to communicate with an A2A server. The example includes a Java client that sends both regular and streaming messages to a Python A2A server. + +## Prerequisites + +- Java 11 or higher +- [JBang](https://www.jbang.dev/documentation/guide/latest/installation.html) (see [INSTALL_JBANG.md](INSTALL_JBANG.md) for quick installation instructions) +- Python 3.8 or higher +- [uv](https://github.com/astral-sh/uv) (recommended) or pip +- Git + +## Setup and Run the Python A2A Server + +The Python A2A server is part of the [a2a-samples](https://github.com/google-a2a/a2a-samples) project. To set it up and run it: + +1. Clone the a2a-samples repository: + ```bash + git clone https://github.com/google-a2a/a2a-samples.git + cd a2a-samples/samples/python/agents/helloworld + ``` + +2. **Recommended method**: Install dependencies using uv (much faster Python package installer): + ```bash + # Install uv if you don't have it already + # On macOS and Linux + curl -LsSf https://astral.sh/uv/install.sh | sh + # On Windows + powershell -c "irm https://astral.sh/uv/install.ps1 | iex" + + # Install the package using uv + uv venv + source .venv/bin/activate # On Windows: .venv\Scripts\activate + uv pip install -e . + ``` + +4. Run the server with uv (recommended): + ```bash + uv run . + ``` + +The server will start running on `http://localhost:9999`. + +## Run the Java A2A Client with JBang + +The Java client can be run using JBang, which allows you to run Java source files directly without any manual compilation. + +### Build the A2A Java SDK + +First, ensure you have built the `a2a-java` project: + +```bash +cd /path/to/a2a-java +mvn clean install +``` + +### Using the JBang script + +A JBang script is provided in the example directory to make running the client easy: + +1. Make sure you have JBang installed. If not, follow the [JBang installation guide](https://www.jbang.dev/documentation/guide/latest/installation.html). + +2. Navigate to the example directory: + ```bash + cd examples/client/src/main/java/io/a2a/examples/helloworld + ``` + +3. Run the client using the JBang script: + ```bash + jbang HelloWorldRunner.java + ``` + +This script automatically handles the dependencies and sources for you. + +## What the Example Does + +The Java client (`HelloWorldClient.java`) performs the following actions: + +1. Fetches the server's public agent card +2. Fetches the server's extended agent card +3. Creates an A2A client using the extended agent card that connects to the Python server at `http://localhost:9999`. +4. Sends a regular message asking "how much is 10 USD in INR?". +5. Prints the server's response. +6. Sends the same message as a streaming request. +7. Prints each chunk of the server's streaming response as it arrives. + +## Notes + +- Make sure the Python server is running before starting the Java client. +- The client will wait for 10 seconds to collect streaming responses before exiting. +- You can modify the message text or server URL in the `HelloWorldClient.java` file if needed. \ No newline at end of file diff --git a/examples/pom.xml b/examples/pom.xml new file mode 100644 index 000000000..c9cd09015 --- /dev/null +++ b/examples/pom.xml @@ -0,0 +1,64 @@ + + + 4.0.0 + + + io.a2a.sdk + a2a-java-sdk-parent + 0.2.4-SNAPSHOT + + + a2a-java-sdk-examples-parent + pom + + Java SDK A2A Examples + Examples for the Java SDK for the Agent2Agent Protocol (A2A) + + + + + io.quarkus + quarkus-bom + ${quarkus.platform.version} + pom + import + + + io.a2a.sdk + a2a-java-sdk-core + ${project.version} + + + io.a2a.sdk + a2a-java-sdk-server-quarkus + ${project.version} + + + + + + + + io.quarkus + quarkus-maven-plugin + true + + + + build + generate-code + generate-code-tests + + + + + + + + + client + server + + \ No newline at end of file diff --git a/examples/server/pom.xml b/examples/server/pom.xml new file mode 100644 index 000000000..5e0b6c58e --- /dev/null +++ b/examples/server/pom.xml @@ -0,0 +1,57 @@ + + + 4.0.0 + + + io.a2a.sdk + a2a-java-sdk-examples-parent + 0.2.4-SNAPSHOT + + + a2a-java-sdk-examples-server + + Java SDK A2A Examples + Examples for the Java SDK for the Agent2Agent Protocol (A2A) + + + + io.a2a.sdk + a2a-java-sdk-server-quarkus + + + io.quarkus + quarkus-resteasy-jackson + provided + + + jakarta.enterprise + jakarta.enterprise.cdi-api + provided + + + jakarta.ws.rs + jakarta.ws.rs-api + + + + + + + io.quarkus + quarkus-maven-plugin + true + + + + build + generate-code + generate-code-tests + + + + + + + \ No newline at end of file diff --git a/examples/server/src/main/java/io/a2a/examples/helloworld/AgentCardProducer.java b/examples/server/src/main/java/io/a2a/examples/helloworld/AgentCardProducer.java new file mode 100644 index 000000000..f07a6e527 --- /dev/null +++ b/examples/server/src/main/java/io/a2a/examples/helloworld/AgentCardProducer.java @@ -0,0 +1,43 @@ +package io.a2a.examples.helloworld; + +import java.util.Collections; +import java.util.List; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Produces; + +import io.a2a.server.PublicAgentCard; +import io.a2a.spec.AgentCapabilities; +import io.a2a.spec.AgentCard; +import io.a2a.spec.AgentSkill; + +@ApplicationScoped +public class AgentCardProducer { + + @Produces + @PublicAgentCard + public AgentCard agentCard() { + return new AgentCard.Builder() + .name("Hello World Agent") + .description("Just a hello world agent") + .url("http://localhost:9999") + .version("1.0.0") + .documentationUrl("http://example.com/docs") + .capabilities(new AgentCapabilities.Builder() + .streaming(true) + .pushNotifications(true) + .stateTransitionHistory(true) + .build()) + .defaultInputModes(Collections.singletonList("text")) + .defaultOutputModes(Collections.singletonList("text")) + .skills(Collections.singletonList(new AgentSkill.Builder() + .id("hello_world") + .name("Returns hello world") + .description("just returns hello world") + .tags(Collections.singletonList("hello world")) + .examples(List.of("hi", "hello world")) + .build())) + .build(); + } +} + diff --git a/examples/server/src/main/java/io/a2a/examples/helloworld/AgentExecutorProducer.java b/examples/server/src/main/java/io/a2a/examples/helloworld/AgentExecutorProducer.java new file mode 100644 index 000000000..e5392ca45 --- /dev/null +++ b/examples/server/src/main/java/io/a2a/examples/helloworld/AgentExecutorProducer.java @@ -0,0 +1,30 @@ +package io.a2a.examples.helloworld; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Produces; + +import io.a2a.server.agentexecution.AgentExecutor; +import io.a2a.server.agentexecution.RequestContext; +import io.a2a.server.events.EventQueue; +import io.a2a.spec.A2A; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.UnsupportedOperationError; + +@ApplicationScoped +public class AgentExecutorProducer { + + @Produces + public AgentExecutor agentExecutor() { + return new AgentExecutor() { + @Override + public void execute(RequestContext context, EventQueue eventQueue) throws JSONRPCError { + eventQueue.enqueueEvent(A2A.toAgentMessage("Hello World")); + } + + @Override + public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPCError { + throw new UnsupportedOperationError(); + } + }; + } +} diff --git a/examples/server/src/main/java/io/a2a/examples/helloworld/README.md b/examples/server/src/main/java/io/a2a/examples/helloworld/README.md new file mode 100644 index 000000000..821471901 --- /dev/null +++ b/examples/server/src/main/java/io/a2a/examples/helloworld/README.md @@ -0,0 +1,70 @@ +# A2A Hello World Example + +This example demonstrates how to use the A2A Java SDK to communicate with an A2A client. The example includes a Java server that receives both regular and streaming messages from a Python A2A client. + +## Prerequisites + +- Java 11 or higher +- Python 3.8 or higher +- [uv](https://github.com/astral-sh/uv) +- Git + +## Run the Java A2A Server + +The Java server can be started using `mvn` as follows: + +```bash +cd examples/server +mvn quarkus:dev +``` + +## Setup and Run the Python A2A Client + +The Python A2A client is part of the [a2a-samples](https://github.com/google-a2a/a2a-samples) project. To set it up and run it: + +1. Clone the a2a-samples repository: + ```bash + git clone https://github.com/google-a2a/a2a-samples.git + cd a2a-samples/samples/python/agents/helloworld + ``` + +2. **Recommended method**: Install dependencies using uv (much faster Python package installer): + ```bash + # Install uv if you don't have it already + # On macOS and Linux + curl -LsSf https://astral.sh/uv/install.sh | sh + # On Windows + powershell -c "irm https://astral.sh/uv/install.ps1 | iex" + + # Install the package using uv + uv venv + source .venv/bin/activate # On Windows: .venv\Scripts\activate + uv pip install -e . + ``` + +4. Run the client with uv (recommended): + ```bash + uv run test_client.py + ``` + +The client will connect to the Java server running on `http://localhost:9999`. + +## What the Example Does + +The Python A2A client (`test_client.py`) performs the following actions: + +1. Fetches the server's public agent card +2. Fetches the server's extended agent card if supported by the server (see https://github.com/a2aproject/a2a-java/issues/81) +3. Creates an A2A client using the extended agent card that connects to the Python server at `http://localhost:9999`. +4. Sends a regular message asking "how much is 10 USD in INR?". +5. Prints the server's response. +6. Sends the same message as a streaming request. +7. Prints each chunk of the server's streaming response as it arrives. + +## Notes + +- Make sure the Java server is running before starting the Python client. +- The client will wait for 10 seconds to collect streaming responses before exiting. +- You can modify the server's response in `AgentExecutorProducer.java` if needed. +- You can modify the server's agent card in `AgentCardProducer.java` if needed. +- You can modify the server's URL in `application.properties` and `AgentCardProducer.java` if needed. \ No newline at end of file diff --git a/examples/server/src/main/resources/application.properties b/examples/server/src/main/resources/application.properties new file mode 100644 index 000000000..a2452b339 --- /dev/null +++ b/examples/server/src/main/resources/application.properties @@ -0,0 +1 @@ +%dev.quarkus.http.port=9999 \ No newline at end of file diff --git a/images/fork.jpg b/images/fork.jpg new file mode 100644 index 000000000..b3888091f Binary files /dev/null and b/images/fork.jpg differ diff --git a/pom.xml b/pom.xml new file mode 100644 index 000000000..80a6863b5 --- /dev/null +++ b/pom.xml @@ -0,0 +1,189 @@ + + + 4.0.0 + + io.a2a.sdk + a2a-java-sdk-parent + 0.2.4-SNAPSHOT + + pom + + Java SDK A2A Parent + Java SDK for the Agent2Agent Protocol (A2A) + + + UTF-8 + 3.11.0 + 3.1.2 + 2.17.0 + 4.1.0 + 2.0.1 + 2.1.3 + 3.1.0 + 5.12.2 + 5.17.0 + 5.15.0 + 1.1.1 + 3.22.3 + 5.5.1 + 2.0.17 + + + true + + + + + + io.quarkus + quarkus-bom + ${quarkus.platform.version} + pom + import + + + org.slf4j + slf4j-bom + ${slf4j.version} + pom + import + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + ${jackson.version} + + + io.smallrye.reactive + mutiny-zero + ${mutiny-zero.version} + + + jakarta.enterprise + jakarta.enterprise.cdi-api + ${jakarta.enterprise.cdi-api.version} + + + jakarta.inject + jakarta.inject-api + ${jakarta.inject.jakarta.inject-api.version} + + + jakarta.json + jakarta.json-api + ${jakarta.json-api.version} + provided + + + jakarta.ws.rs + jakarta.ws.rs-api + ${jakarta.ws.rs-api.version} + provided + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + io.rest-assured + rest-assured + ${rest-assured.version} + test + + + org.mockito + mockito-core + ${mockito-core.version} + test + + + org.mock-server + mockserver-netty + ${mockserver.version} + test + + + + + + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + ${compiler-plugin.version} + + 17 + 17 + + -parameters + + + + + maven-surefire-plugin + ${surefire-plugin.version} + + + ${maven.home} + + + + + io.quarkus + quarkus-maven-plugin + true + ${quarkus.platform.version} + + + + build + generate-code + generate-code-tests + + + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + ${maven-compiler-plugin.version} + + + + + org.apache.maven.plugins + maven-surefire-plugin + ${maven-surefire-plugin.version} + + + + + + + core + sdk-server-common + sdk-jakarta + sdk-quarkus + tck + examples + tests/server-common + + \ No newline at end of file diff --git a/sdk-jakarta/pom.xml b/sdk-jakarta/pom.xml new file mode 100644 index 000000000..d7f85e83f --- /dev/null +++ b/sdk-jakarta/pom.xml @@ -0,0 +1,102 @@ + + + 4.0.0 + + + io.a2a.sdk + a2a-java-sdk-parent + 0.2.4-SNAPSHOT + + a2a-java-sdk-server-jakarta + + jar + + Java A2A SDK for Jakarta + Java SDK for the Agent2Agent Protocol (A2A) - SDK - Jakarta + + + + ${project.groupId} + a2a-java-sdk-core + ${project.version} + + + ${project.groupId} + a2a-java-sdk-server-common + ${project.version} + + + ${project.groupId} + a2a-java-sdk-tests-server-common + ${project.version} + provided + + + ${project.groupId} + a2a-java-sdk-tests-server-common + test-jar + test + ${project.version} + + + com.fasterxml.jackson.core + jackson-databind + provided + + + jakarta.enterprise + jakarta.enterprise.cdi-api + + + jakarta.inject + jakarta.inject-api + + + jakarta.json + jakarta.json-api + provided + + + jakarta.ws.rs + jakarta.ws.rs-api + provided + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + test + + + io.quarkus + quarkus-junit5 + test + + + io.quarkus + quarkus-resteasy-jackson + test + + + io.quarkus + quarkus-resteasy-client-jackson + test + + + org.jboss.resteasy + resteasy-client + test + + + org.junit.jupiter + junit-jupiter-api + test + + + io.rest-assured + rest-assured + test + + + \ No newline at end of file diff --git a/sdk-jakarta/src/main/java/io/a2a/server/apps/jakarta/A2ARequestFilter.java b/sdk-jakarta/src/main/java/io/a2a/server/apps/jakarta/A2ARequestFilter.java new file mode 100644 index 000000000..68fbb5585 --- /dev/null +++ b/sdk-jakarta/src/main/java/io/a2a/server/apps/jakarta/A2ARequestFilter.java @@ -0,0 +1,63 @@ +package io.a2a.server.apps.jakarta; + +import static io.a2a.spec.A2A.CANCEL_TASK_METHOD; +import static io.a2a.spec.A2A.GET_TASK_METHOD; +import static io.a2a.spec.A2A.GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD; +import static io.a2a.spec.A2A.SEND_MESSAGE_METHOD; +import static io.a2a.spec.A2A.SEND_STREAMING_MESSAGE_METHOD; +import static io.a2a.spec.A2A.SEND_TASK_RESUBSCRIPTION_METHOD; +import static io.a2a.spec.A2A.SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; + +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.container.ContainerRequestFilter; +import jakarta.ws.rs.container.PreMatching; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.ext.Provider; + +@Provider +@PreMatching +public class A2ARequestFilter implements ContainerRequestFilter { + + @Override + public void filter(ContainerRequestContext requestContext) { + if (requestContext.getMethod().equals("POST") && requestContext.hasEntity()) { + try (InputStream entityInputStream = requestContext.getEntityStream()) { + byte[] requestBodyBytes = entityInputStream.readAllBytes(); + String requestBody = new String(requestBodyBytes); + // ensure the request is treated as a streaming request or a non-streaming request + // based on the method in the request body + if (isStreamingRequest(requestBody)) { + putAcceptHeader(requestContext, MediaType.SERVER_SENT_EVENTS); + } else if (isNonStreamingRequest(requestBody)) { + putAcceptHeader(requestContext, MediaType.APPLICATION_JSON); + } + // reset the entity stream + requestContext.setEntityStream(new ByteArrayInputStream(requestBodyBytes)); + } catch(IOException e){ + throw new RuntimeException("Unable to read the request body"); + } + } + } + + private static boolean isStreamingRequest(String requestBody) { + return requestBody.contains(SEND_STREAMING_MESSAGE_METHOD) || + requestBody.contains(SEND_TASK_RESUBSCRIPTION_METHOD); + } + + private static boolean isNonStreamingRequest(String requestBody) { + return requestBody.contains(GET_TASK_METHOD) || + requestBody.contains(CANCEL_TASK_METHOD) || + requestBody.contains(SEND_MESSAGE_METHOD) || + requestBody.contains(SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD) || + requestBody.contains(GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD); + } + + private static void putAcceptHeader(ContainerRequestContext requestContext, String mediaType) { + requestContext.getHeaders().putSingle("Accept", mediaType); + } + +} \ No newline at end of file diff --git a/sdk-jakarta/src/main/java/io/a2a/server/apps/jakarta/A2AServerResource.java b/sdk-jakarta/src/main/java/io/a2a/server/apps/jakarta/A2AServerResource.java new file mode 100644 index 000000000..ba587285d --- /dev/null +++ b/sdk-jakarta/src/main/java/io/a2a/server/apps/jakarta/A2AServerResource.java @@ -0,0 +1,244 @@ +package io.a2a.server.apps.jakarta; + +import java.util.concurrent.Executor; +import java.util.concurrent.Flow; + +import jakarta.enterprise.inject.Instance; +import jakarta.inject.Inject; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.ext.ExceptionMapper; +import jakarta.ws.rs.ext.Provider; +import jakarta.ws.rs.sse.Sse; +import jakarta.ws.rs.sse.SseEventSink; + +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.databind.JsonMappingException; + +import io.a2a.server.ExtendedAgentCard; +import io.a2a.server.requesthandlers.JSONRPCHandler; +import io.a2a.spec.AgentCard; +import io.a2a.spec.CancelTaskRequest; +import io.a2a.spec.GetTaskPushNotificationConfigRequest; +import io.a2a.spec.GetTaskRequest; +import io.a2a.spec.IdJsonMappingException; +import io.a2a.spec.InvalidParamsError; +import io.a2a.spec.InvalidParamsJsonMappingException; +import io.a2a.spec.InvalidRequestError; +import io.a2a.spec.JSONErrorResponse; +import io.a2a.spec.JSONParseError; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.JSONRPCErrorResponse; +import io.a2a.spec.JSONRPCRequest; +import io.a2a.spec.JSONRPCResponse; +import io.a2a.spec.MethodNotFoundError; +import io.a2a.spec.MethodNotFoundJsonMappingException; +import io.a2a.spec.NonStreamingJSONRPCRequest; +import io.a2a.spec.SendMessageRequest; +import io.a2a.spec.SendStreamingMessageRequest; +import io.a2a.spec.SetTaskPushNotificationConfigRequest; +import io.a2a.spec.StreamingJSONRPCRequest; +import io.a2a.spec.TaskResubscriptionRequest; +import io.a2a.spec.UnsupportedOperationError; +import io.a2a.server.util.async.Internal; + +@Path("/") +public class A2AServerResource { + + @Inject + JSONRPCHandler jsonRpcHandler; + + @Inject + @ExtendedAgentCard + Instance extendedAgentCard; + + // Hook so testing can wait until the async Subscription is subscribed. + private static volatile Runnable streamingIsSubscribedRunnable; + + @Inject + @Internal + Executor executor; + + /** + * Handles incoming POST requests to the main A2A endpoint. Dispatches the + * request to the appropriate JSON-RPC handler method and returns the response. + * + * @param request the JSON-RPC request + * @return the JSON-RPC response which may be an error response + */ + @POST + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + public JSONRPCResponse handleNonStreamingRequests(NonStreamingJSONRPCRequest request) { + return processNonStreamingRequest(request); + } + + /** + * Handles incoming POST requests to the main A2A endpoint that involve Server-Sent Events (SSE). + * Dispatches the request to the appropriate JSON-RPC handler method and returns the response. + */ + @POST + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.SERVER_SENT_EVENTS) + public void handleStreamingRequests(StreamingJSONRPCRequest request, @Context SseEventSink sseEventSink, @Context Sse sse) { + System.out.println("=====> Streaming"); + executor.execute(() -> processStreamingRequest(request, sseEventSink, sse)); + System.out.println("=====> Streaming - done"); + } + + /** + * Handles incoming GET requests to the agent card endpoint. + * Returns the agent card in JSON format. + * + * @return the agent card + */ + @GET + @Path("/.well-known/agent.json") + @Produces(MediaType.APPLICATION_JSON) + public AgentCard getAgentCard() { + return jsonRpcHandler.getAgentCard(); + } + + /** + * Handles incoming GET requests to the authenticated extended agent card endpoint. + * Returns the agent card in JSON format. + * + * @return the authenticated extended agent card + */ + @GET + @Path("/agent/authenticatedExtendedCard") + @Produces(MediaType.APPLICATION_JSON) + public Response getAuthenticatedExtendedAgentCard() { + // TODO need to add authentication for this endpoint + // https://github.com/a2aproject/a2a-java/issues/77 + if (! jsonRpcHandler.getAgentCard().supportsAuthenticatedExtendedCard()) { + JSONErrorResponse errorResponse = new JSONErrorResponse("Extended agent card not supported or not enabled."); + return Response.status(Response.Status.NOT_FOUND) + .entity(errorResponse).build(); + } + if (! extendedAgentCard.isResolvable()) { + JSONErrorResponse errorResponse = new JSONErrorResponse("Authenticated extended agent card is supported but not configured on the server."); + return Response.status(Response.Status.NOT_FOUND) + .entity(errorResponse).build(); + } + return Response.ok(extendedAgentCard.get()) + .type(MediaType.APPLICATION_JSON) + .build(); + } + + private JSONRPCResponse processNonStreamingRequest(NonStreamingJSONRPCRequest request) { + if (request instanceof GetTaskRequest) { + return jsonRpcHandler.onGetTask((GetTaskRequest) request); + } else if (request instanceof CancelTaskRequest) { + return jsonRpcHandler.onCancelTask((CancelTaskRequest) request); + } else if (request instanceof SetTaskPushNotificationConfigRequest) { + return jsonRpcHandler.setPushNotification((SetTaskPushNotificationConfigRequest) request); + } else if (request instanceof GetTaskPushNotificationConfigRequest) { + return jsonRpcHandler.getPushNotification((GetTaskPushNotificationConfigRequest) request); + } else if (request instanceof SendMessageRequest) { + return jsonRpcHandler.onMessageSend((SendMessageRequest) request); + } else { + return generateErrorResponse(request, new UnsupportedOperationError()); + } + } + + private void processStreamingRequest(StreamingJSONRPCRequest request, SseEventSink sseEventSink, Sse sse) { + Flow.Publisher> publisher; + if (request instanceof SendStreamingMessageRequest) { + publisher = jsonRpcHandler.onMessageSendStream((SendStreamingMessageRequest) request); + handleStreamingResponse(publisher, sseEventSink, sse); + } else if (request instanceof TaskResubscriptionRequest) { + publisher = jsonRpcHandler.onResubscribeToTask((TaskResubscriptionRequest) request); + handleStreamingResponse(publisher, sseEventSink, sse); + } + } + + private void handleStreamingResponse(Flow.Publisher> publisher, SseEventSink sseEventSink, Sse sse) { + publisher.subscribe(new Flow.Subscriber>() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(Long.MAX_VALUE); + System.out.println("SUBSCRIBING!"); + // Notify tests that we are subscribed + Runnable runnable = streamingIsSubscribedRunnable; + if (runnable != null) { + runnable.run(); + } + } + + @Override + public void onNext(JSONRPCResponse item) { + + sseEventSink.send(sse.newEventBuilder() + .mediaType(MediaType.APPLICATION_JSON_TYPE) + .data(item) + .build()); + } + + @Override + public void onError(Throwable throwable) { + // TODO + sseEventSink.close(); + } + + @Override + public void onComplete() { + sseEventSink.close(); + } + }); + } + + private JSONRPCResponse generateErrorResponse(JSONRPCRequest request, JSONRPCError error) { + return new JSONRPCErrorResponse(request.getId(), error); + } + + static void setStreamingIsSubscribedRunnable(Runnable streamingIsSubscribedRunnable) { + A2AServerResource.streamingIsSubscribedRunnable = streamingIsSubscribedRunnable; + } + + @Provider + public class JsonParseExceptionMapper implements ExceptionMapper { + + @Override + public Response toResponse(JsonParseException exception) { + // parse error, not possible to determine the request id + return Response.ok(new JSONRPCErrorResponse(new JSONParseError())).type(MediaType.APPLICATION_JSON).build(); + } + + } + + @Provider + public static class JsonMappingExceptionMapper implements ExceptionMapper { + + @Override + public Response toResponse(JsonMappingException exception) { + if (exception.getCause() instanceof JsonParseException) { + return Response.ok(new JSONRPCErrorResponse(new JSONParseError())).type(MediaType.APPLICATION_JSON).build(); + } else if (exception instanceof MethodNotFoundJsonMappingException) { + Object id = ((MethodNotFoundJsonMappingException) exception).getId(); + return Response.ok(new JSONRPCErrorResponse(id, new MethodNotFoundError())) + .type(MediaType.APPLICATION_JSON).build(); + } else if (exception instanceof InvalidParamsJsonMappingException) { + Object id = ((InvalidParamsJsonMappingException) exception).getId(); + return Response.ok(new JSONRPCErrorResponse(id, new InvalidParamsError())) + .type(MediaType.APPLICATION_JSON).build(); + } else if (exception instanceof IdJsonMappingException) { + Object id = ((IdJsonMappingException) exception).getId(); + return Response.ok(new JSONRPCErrorResponse(id, new InvalidRequestError())) + .type(MediaType.APPLICATION_JSON).build(); + } + // not possible to determine the request id + return Response.ok(new JSONRPCErrorResponse(new InvalidRequestError())).type(MediaType.APPLICATION_JSON).build(); + } + + } +} \ No newline at end of file diff --git a/sdk-jakarta/src/main/resources/META-INF/beans.xml b/sdk-jakarta/src/main/resources/META-INF/beans.xml new file mode 100644 index 000000000..e69de29bb diff --git a/sdk-jakarta/src/test/java/io/a2a/server/apps/jakarta/JakartaA2AServerTest.java b/sdk-jakarta/src/test/java/io/a2a/server/apps/jakarta/JakartaA2AServerTest.java new file mode 100644 index 000000000..06ad61c9a --- /dev/null +++ b/sdk-jakarta/src/test/java/io/a2a/server/apps/jakarta/JakartaA2AServerTest.java @@ -0,0 +1,32 @@ +package io.a2a.server.apps.jakarta; + +import jakarta.inject.Inject; + +import io.a2a.server.apps.common.AbstractA2AServerTest; +import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.tasks.TaskStore; +import io.quarkus.test.junit.QuarkusTest; + +@QuarkusTest +public class JakartaA2AServerTest extends AbstractA2AServerTest { + @Inject + TaskStore taskStore; + + @Inject + InMemoryQueueManager queueManager; + + @Override + protected TaskStore getTaskStore() { + return taskStore; + } + + @Override + protected InMemoryQueueManager getQueueManager() { + return queueManager; + } + + @Override + protected void setStreamingSubscribedRunnable(Runnable runnable) { + A2AServerResource.setStreamingIsSubscribedRunnable(runnable); + } +} diff --git a/sdk-quarkus/pom.xml b/sdk-quarkus/pom.xml new file mode 100644 index 000000000..bacccc073 --- /dev/null +++ b/sdk-quarkus/pom.xml @@ -0,0 +1,80 @@ + + + 4.0.0 + + + io.a2a.sdk + a2a-java-sdk-parent + 0.2.4-SNAPSHOT + + a2a-java-sdk-server-quarkus + + jar + + Java A2A SDK for Quarkus + Java SDK for the Agent2Agent Protocol (A2A) - SDK - Quarkus + + + + ${project.groupId} + a2a-java-sdk-core + ${project.version} + + + ${project.groupId} + a2a-java-sdk-server-common + ${project.version} + + + ${project.groupId} + a2a-java-sdk-tests-server-common + ${project.version} + provided + + + ${project.groupId} + a2a-java-sdk-tests-server-common + test-jar + test + ${project.version} + + + io.quarkus + quarkus-reactive-routes + + + jakarta.enterprise + jakarta.enterprise.cdi-api + + + jakarta.inject + jakarta.inject-api + + + org.slf4j + slf4j-api + + + io.quarkus + quarkus-junit5 + test + + + io.quarkus + quarkus-rest-client-jackson + test + + + org.junit.jupiter + junit-jupiter-api + test + + + io.rest-assured + rest-assured + test + + + \ No newline at end of file diff --git a/sdk-quarkus/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java b/sdk-quarkus/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java new file mode 100644 index 000000000..4246e965e --- /dev/null +++ b/sdk-quarkus/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java @@ -0,0 +1,327 @@ +package io.a2a.server.apps.quarkus; + +import static io.vertx.core.http.HttpHeaders.CONTENT_TYPE; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; + +import java.util.concurrent.Executor; +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; + +import jakarta.enterprise.inject.Instance; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; +import jakarta.ws.rs.core.Response; + +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.io.JsonEOFException; +import io.a2a.server.ExtendedAgentCard; +import io.a2a.server.requesthandlers.JSONRPCHandler; +import io.a2a.spec.A2A; +import io.a2a.spec.AgentCard; +import io.a2a.spec.CancelTaskRequest; +import io.a2a.spec.GetTaskPushNotificationConfigRequest; +import io.a2a.spec.GetTaskRequest; +import io.a2a.spec.IdJsonMappingException; +import io.a2a.spec.InternalError; +import io.a2a.spec.InvalidParamsError; +import io.a2a.spec.InvalidParamsJsonMappingException; +import io.a2a.spec.InvalidRequestError; +import io.a2a.spec.JSONErrorResponse; +import io.a2a.spec.JSONParseError; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.JSONRPCErrorResponse; +import io.a2a.spec.JSONRPCRequest; +import io.a2a.spec.JSONRPCResponse; +import io.a2a.spec.MethodNotFoundError; +import io.a2a.spec.MethodNotFoundJsonMappingException; +import io.a2a.spec.NonStreamingJSONRPCRequest; +import io.a2a.spec.SendMessageRequest; +import io.a2a.spec.SendStreamingMessageRequest; +import io.a2a.spec.SetTaskPushNotificationConfigRequest; +import io.a2a.spec.StreamingJSONRPCRequest; +import io.a2a.spec.TaskResubscriptionRequest; +import io.a2a.spec.UnsupportedOperationError; +import io.a2a.util.Utils; +import io.a2a.server.util.async.Internal; +import io.quarkus.vertx.web.Body; +import io.quarkus.vertx.web.ReactiveRoutes; +import io.quarkus.vertx.web.Route; +import io.quarkus.vertx.web.RoutingExchange; +import io.smallrye.mutiny.Multi; +import io.vertx.core.AsyncResult; +import io.vertx.core.Handler; +import io.vertx.core.MultiMap; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.HttpServerResponse; +import io.vertx.core.json.Json; +import io.vertx.ext.web.RoutingContext; + +@Singleton +public class A2AServerRoutes { + + @Inject + JSONRPCHandler jsonRpcHandler; + + @Inject + @ExtendedAgentCard + Instance extendedAgentCard; + + // Hook so testing can wait until the MultiSseSupport is subscribed. + private static volatile Runnable streamingMultiSseSupportSubscribedRunnable; + + @Inject + @Internal + Executor executor; + + @Route(path = "/", methods = {Route.HttpMethod.POST}, consumes = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING) + public void invokeJSONRPCHandler(@Body String body, RoutingContext rc) { + boolean streaming = false; + JSONRPCResponse nonStreamingResponse = null; + Multi> streamingResponse = null; + JSONRPCErrorResponse error = null; + + try { + if (isStreamingRequest(body)) { + streaming = true; + StreamingJSONRPCRequest request = Utils.OBJECT_MAPPER.readValue(body, StreamingJSONRPCRequest.class); + streamingResponse = processStreamingRequest(request); + } else { + NonStreamingJSONRPCRequest request = Utils.OBJECT_MAPPER.readValue(body, NonStreamingJSONRPCRequest.class); + nonStreamingResponse = processNonStreamingRequest(request); + } + } catch (JsonProcessingException e) { + error = handleError(e); + } catch (Throwable t) { + error = new JSONRPCErrorResponse(new InternalError(t.getMessage())); + } finally { + if (error != null) { + rc.response() + .setStatusCode(200) + .putHeader(CONTENT_TYPE, APPLICATION_JSON) + .end(Json.encodeToBuffer(error)); + } else if (streaming) { + final Multi> finalStreamingResponse = streamingResponse; + executor.execute(() -> { + MultiSseSupport.subscribeObject( + finalStreamingResponse.map(i -> (Object)i), rc); + }); + + } else { + rc.response() + .setStatusCode(200) + .putHeader(CONTENT_TYPE, APPLICATION_JSON) + .end(Json.encodeToBuffer(nonStreamingResponse)); + } + } + } + + private JSONRPCErrorResponse handleError(JsonProcessingException exception) { + Object id = null; + JSONRPCError jsonRpcError = null; + if (exception.getCause() instanceof JsonParseException) { + jsonRpcError = new JSONParseError(); + } else if (exception instanceof JsonEOFException) { + jsonRpcError = new JSONParseError(exception.getMessage()); + } else if (exception instanceof MethodNotFoundJsonMappingException err) { + id = err.getId(); + jsonRpcError = new MethodNotFoundError(); + } else if (exception instanceof InvalidParamsJsonMappingException err) { + id = err.getId(); + jsonRpcError = new InvalidParamsError(); + } else if (exception instanceof IdJsonMappingException err) { + id = err.getId(); + jsonRpcError = new InvalidRequestError(); + } else { + jsonRpcError = new InvalidRequestError(); + } + return new JSONRPCErrorResponse(id, jsonRpcError); + } + + /** + /** + * Handles incoming GET requests to the agent card endpoint. + * Returns the agent card in JSON format. + * + * @return the agent card + */ + @Route(path = "/.well-known/agent.json", methods = Route.HttpMethod.GET, produces = APPLICATION_JSON) + public AgentCard getAgentCard() { + return jsonRpcHandler.getAgentCard(); + } + + /** + * Handles incoming GET requests to the authenticated extended agent card endpoint. + * Returns the agent card in JSON format. + * + * @return the authenticated extended agent card + */ + @Route(path = "/agent/authenticatedExtendedCard", methods = Route.HttpMethod.GET, produces = APPLICATION_JSON) + public void getAuthenticatedExtendedAgentCard(RoutingExchange re) { + // TODO need to add authentication for this endpoint + // https://github.com/a2aproject/a2a-java/issues/77 + try { + if (! jsonRpcHandler.getAgentCard().supportsAuthenticatedExtendedCard()) { + JSONErrorResponse errorResponse = new JSONErrorResponse("Extended agent card not supported or not enabled."); + re.response().setStatusCode(Response.Status.NOT_FOUND.getStatusCode()) + .end(Utils.OBJECT_MAPPER.writeValueAsString(errorResponse)); + return; + } + if (! extendedAgentCard.isResolvable()) { + JSONErrorResponse errorResponse = new JSONErrorResponse("Authenticated extended agent card is supported but not configured on the server."); + re.response().setStatusCode(Response.Status.NOT_FOUND.getStatusCode()) + .end(Utils.OBJECT_MAPPER.writeValueAsString(errorResponse)); + return; + } + + re.response().end(Utils.OBJECT_MAPPER.writeValueAsString(extendedAgentCard.get())); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private JSONRPCResponse processNonStreamingRequest(NonStreamingJSONRPCRequest request) { + if (request instanceof GetTaskRequest) { + return jsonRpcHandler.onGetTask((GetTaskRequest) request); + } else if (request instanceof CancelTaskRequest) { + return jsonRpcHandler.onCancelTask((CancelTaskRequest) request); + } else if (request instanceof SetTaskPushNotificationConfigRequest) { + return jsonRpcHandler.setPushNotification((SetTaskPushNotificationConfigRequest) request); + } else if (request instanceof GetTaskPushNotificationConfigRequest) { + return jsonRpcHandler.getPushNotification((GetTaskPushNotificationConfigRequest) request); + } else if (request instanceof SendMessageRequest) { + return jsonRpcHandler.onMessageSend((SendMessageRequest) request); + } else { + return generateErrorResponse(request, new UnsupportedOperationError()); + } + } + + private Multi> processStreamingRequest(JSONRPCRequest request) { + Flow.Publisher> publisher; + if (request instanceof SendStreamingMessageRequest) { + publisher = jsonRpcHandler.onMessageSendStream((SendStreamingMessageRequest) request); + } else if (request instanceof TaskResubscriptionRequest) { + publisher = jsonRpcHandler.onResubscribeToTask((TaskResubscriptionRequest) request); + } else { + return Multi.createFrom().item(generateErrorResponse(request, new UnsupportedOperationError())); + } + return Multi.createFrom().publisher(publisher); + } + + private JSONRPCResponse generateErrorResponse(JSONRPCRequest request, JSONRPCError error) { + return new JSONRPCErrorResponse(request.getId(), error); + } + + private static boolean isStreamingRequest(String requestBody) { + return requestBody.contains(A2A.SEND_STREAMING_MESSAGE_METHOD) || + requestBody.contains(A2A.SEND_TASK_RESUBSCRIPTION_METHOD); + } + + private static boolean isNonStreamingRequest(String requestBody) { + return requestBody.contains(A2A.GET_TASK_METHOD) || + requestBody.contains(A2A.CANCEL_TASK_METHOD) || + requestBody.contains(A2A.SEND_MESSAGE_METHOD) || + requestBody.contains(A2A.SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD) || + requestBody.contains(A2A.GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD); + } + + static void setStreamingMultiSseSupportSubscribedRunnable(Runnable runnable) { + streamingMultiSseSupportSubscribedRunnable = runnable; + } + + // Port of import io.quarkus.vertx.web.runtime.MultiSseSupport, which is considered internal API + private static class MultiSseSupport { + + private MultiSseSupport() { + // Avoid direct instantiation. + } + + private static void initialize(HttpServerResponse response) { + if (response.bytesWritten() == 0) { + MultiMap headers = response.headers(); + if (headers.get("content-type") == null) { + headers.set("content-type", "text/event-stream"); + } + response.setChunked(true); + } + } + + private static void onWriteDone(Flow.Subscription subscription, AsyncResult ar, RoutingContext rc) { + if (ar.failed()) { + rc.fail(ar.cause()); + } else { + subscription.request(1); + } + } + + public static void write(Multi multi, RoutingContext rc) { + HttpServerResponse response = rc.response(); + multi.subscribe().withSubscriber(new Flow.Subscriber() { + Flow.Subscription upstream; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.upstream = subscription; + this.upstream.request(1); + + // Notify tests that we are subscribed + Runnable runnable = streamingMultiSseSupportSubscribedRunnable; + if (runnable != null) { + runnable.run(); + } + } + + @Override + public void onNext(Buffer item) { + initialize(response); + response.write(item, new Handler>() { + @Override + public void handle(AsyncResult ar) { + onWriteDone(upstream, ar, rc); + } + }); + } + + @Override + public void onError(Throwable throwable) { + rc.fail(throwable); + } + + @Override + public void onComplete() { + endOfStream(response); + } + }); + } + + public static void subscribeObject(Multi multi, RoutingContext rc) { + AtomicLong count = new AtomicLong(); + write(multi.map(new Function() { + @Override + public Buffer apply(Object o) { + if (o instanceof ReactiveRoutes.ServerSentEvent) { + ReactiveRoutes.ServerSentEvent ev = (ReactiveRoutes.ServerSentEvent) o; + long id = ev.id() != -1 ? ev.id() : count.getAndIncrement(); + String e = ev.event() == null ? "" : "event: " + ev.event() + "\n"; + return Buffer.buffer(e + "data: " + Json.encodeToBuffer(ev.data()) + "\nid: " + id + "\n\n"); + } else { + return Buffer.buffer("data: " + Json.encodeToBuffer(o) + "\nid: " + count.getAndIncrement() + "\n\n"); + } + } + }), rc); + } + + private static void endOfStream(HttpServerResponse response) { + if (response.bytesWritten() == 0) { // No item + MultiMap headers = response.headers(); + if (headers.get("content-type") == null) { + headers.set("content-type", "text/event-stream"); + } + } + response.end(); + } + } + +} + diff --git a/sdk-quarkus/src/main/resources/META-INF/beans.xml b/sdk-quarkus/src/main/resources/META-INF/beans.xml new file mode 100644 index 000000000..e69de29bb diff --git a/sdk-quarkus/src/test/java/io/a2a/server/apps/quarkus/QuarkusA2AServerTest.java b/sdk-quarkus/src/test/java/io/a2a/server/apps/quarkus/QuarkusA2AServerTest.java new file mode 100644 index 000000000..4dce72cee --- /dev/null +++ b/sdk-quarkus/src/test/java/io/a2a/server/apps/quarkus/QuarkusA2AServerTest.java @@ -0,0 +1,33 @@ +package io.a2a.server.apps.quarkus; + +import jakarta.inject.Inject; + +import io.a2a.server.apps.common.AbstractA2AServerTest; +import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.tasks.TaskStore; +import io.quarkus.test.junit.QuarkusTest; + +@QuarkusTest +public class QuarkusA2AServerTest extends AbstractA2AServerTest { + + @Inject + TaskStore taskStore; + + @Inject + InMemoryQueueManager queueManager; + + @Override + protected TaskStore getTaskStore() { + return taskStore; + } + + @Override + protected InMemoryQueueManager getQueueManager() { + return queueManager; + } + + @Override + protected void setStreamingSubscribedRunnable(Runnable runnable) { + A2AServerRoutes.setStreamingMultiSseSupportSubscribedRunnable(runnable); + } +} diff --git a/sdk-quarkus/src/test/resources/application.properties b/sdk-quarkus/src/test/resources/application.properties new file mode 100644 index 000000000..d3366bece --- /dev/null +++ b/sdk-quarkus/src/test/resources/application.properties @@ -0,0 +1 @@ +quarkus.arc.selected-alternatives=io.a2a.server.apps.common.TestHttpClient \ No newline at end of file diff --git a/sdk-server-common/pom.xml b/sdk-server-common/pom.xml new file mode 100644 index 000000000..44e2e298a --- /dev/null +++ b/sdk-server-common/pom.xml @@ -0,0 +1,71 @@ + + + 4.0.0 + + + io.a2a.sdk + a2a-java-sdk-parent + 0.2.4-SNAPSHOT + + a2a-java-sdk-server-common + + jar + + Java SDK A2A Core + Java SDK for the Agent2Agent Protocol (A2A) - Server Common + + + + ${project.groupId} + a2a-java-sdk-core + ${project.version} + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + + + io.smallrye.reactive + mutiny-zero + + + jakarta.enterprise + jakarta.enterprise.cdi-api + + + jakarta.inject + jakarta.inject-api + + + org.slf4j + slf4j-api + + + io.quarkus + quarkus-arc + test + + + org.junit.jupiter + junit-jupiter-api + test + + + org.mockito + mockito-core + test + + + org.mock-server + mockserver-netty + test + + + + \ No newline at end of file diff --git a/sdk-server-common/src/main/java/io/a2a/server/ExtendedAgentCard.java b/sdk-server-common/src/main/java/io/a2a/server/ExtendedAgentCard.java new file mode 100644 index 000000000..c9cc7eaf0 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/ExtendedAgentCard.java @@ -0,0 +1,18 @@ +package io.a2a.server; + +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.ElementType.TYPE; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import jakarta.inject.Qualifier; + +@Qualifier +@Retention(RUNTIME) +@Target({FIELD, TYPE, METHOD, PARAMETER}) +public @interface ExtendedAgentCard { +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/JSONRPCException.java b/sdk-server-common/src/main/java/io/a2a/server/JSONRPCException.java new file mode 100644 index 000000000..d7d7d2c8a --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/JSONRPCException.java @@ -0,0 +1,15 @@ +package io.a2a.server; + +import io.a2a.spec.JSONRPCError; + +public class JSONRPCException extends Exception{ + private final JSONRPCError error; + + public JSONRPCException(JSONRPCError error) { + this.error = error; + } + + public JSONRPCError getError() { + return error; + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/PublicAgentCard.java b/sdk-server-common/src/main/java/io/a2a/server/PublicAgentCard.java new file mode 100644 index 000000000..68c670bbe --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/PublicAgentCard.java @@ -0,0 +1,18 @@ +package io.a2a.server; + +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.ElementType.TYPE; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import jakarta.inject.Qualifier; + +@Qualifier +@Retention(RUNTIME) +@Target({FIELD, TYPE, METHOD, PARAMETER}) +public @interface PublicAgentCard { +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/ServerCallContext.java b/sdk-server-common/src/main/java/io/a2a/server/ServerCallContext.java new file mode 100644 index 000000000..70fb344df --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/ServerCallContext.java @@ -0,0 +1,5 @@ +package io.a2a.server; + +public class ServerCallContext { + // TODO port the fields +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/agentexecution/AgentExecutor.java b/sdk-server-common/src/main/java/io/a2a/server/agentexecution/AgentExecutor.java new file mode 100644 index 000000000..f753fed54 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/agentexecution/AgentExecutor.java @@ -0,0 +1,10 @@ +package io.a2a.server.agentexecution; + +import io.a2a.server.events.EventQueue; +import io.a2a.spec.JSONRPCError; + +public interface AgentExecutor { + void execute(RequestContext context, EventQueue eventQueue) throws JSONRPCError; + + void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPCError; +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java b/sdk-server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java new file mode 100644 index 000000000..e62a87d3c --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java @@ -0,0 +1,193 @@ +package io.a2a.server.agentexecution; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; + +import io.a2a.server.ServerCallContext; +import io.a2a.spec.InvalidParamsError; +import io.a2a.spec.Message; +import io.a2a.spec.MessageSendConfiguration; +import io.a2a.spec.MessageSendParams; +import io.a2a.spec.Part; +import io.a2a.spec.Task; +import io.a2a.spec.TextPart; + +public class RequestContext { + + private MessageSendParams params; + private String taskId; + private String contextId; + private Task task; + private List relatedTasks; + + public RequestContext(MessageSendParams params, String taskId, String contextId, Task task, List relatedTasks) throws InvalidParamsError { + this.params = params; + this.taskId = taskId; + this.contextId = contextId; + this.task = task; + this.relatedTasks = relatedTasks == null ? new ArrayList<>() : relatedTasks; + + // if the taskId and contextId were specified, they must match the params + if (params != null) { + if (taskId != null && ! params.message().getTaskId().equals(taskId)) { + throw new InvalidParamsError("bad task id"); + } else { + checkOrGenerateTaskId(); + } + if (contextId != null && ! params.message().getContextId().equals(contextId)) { + throw new InvalidParamsError("bad context id"); + } else { + checkOrGenerateContextId(); + } + } + } + + public MessageSendParams getParams() { + return params; + } + + public String getTaskId() { + return taskId; + } + + public String getContextId() { + return contextId; + } + + public Task getTask() { + return task; + } + + public List getRelatedTasks() { + return relatedTasks; + } + + public Message getMessage() { + return params != null ? params.message() : null; + } + + public MessageSendConfiguration getConfiguration() { + return params != null ? params.configuration() : null; + } + + public String getUserInput(String delimiter) { + if (params == null) { + return ""; + } + if (delimiter == null) { + delimiter = "\n"; + } + return getMessageText(params.message(), delimiter); + } + + public void attachRelatedTask(Task task) { + relatedTasks.add(task); + } + + private void checkOrGenerateTaskId() { + if (params == null) { + return; + } + if (taskId == null && params.message().getTaskId() == null) { + params.message().setTaskId(UUID.randomUUID().toString()); + } + if (params.message().getTaskId() != null) { + this.taskId = params.message().getTaskId(); + } + } + + private void checkOrGenerateContextId() { + if (params == null) { + return; + } + if (contextId == null && params.message().getContextId() == null) { + params.message().setContextId(UUID.randomUUID().toString()); + } + if (params.message().getContextId() != null) { + this.contextId = params.message().getContextId(); + } + } + + private String getMessageText(Message message, String delimiter) { + List textParts = getTextParts(message.getParts()); + return String.join(delimiter, textParts); + } + + private List getTextParts(List> parts) { + return parts.stream() + .filter(part -> part.getKind() == Part.Kind.TEXT) + .map(part -> (TextPart) part) + .map(TextPart::getText) + .collect(Collectors.toList()); + } + + public static class Builder { + private MessageSendParams params; + private String taskId; + private String contextId; + private Task task; + private List relatedTasks; + private ServerCallContext serverCallContext; + + public Builder setParams(MessageSendParams params) { + this.params = params; + return this; + } + + public Builder setTaskId(String taskId) { + this.taskId = taskId; + return this; + } + + public Builder setContextId(String contextId) { + this.contextId = contextId; + return this; + } + + public Builder setTask(Task task) { + this.task = task; + return this; + } + + public Builder setRelatedTasks(List relatedTasks) { + this.relatedTasks = relatedTasks; + return this; + } + + public Builder setServerCallContext(ServerCallContext serverCallContext) { + this.serverCallContext = serverCallContext; + return this; + } + + public MessageSendParams getParams() { + return params; + } + + public String getTaskId() { + return taskId; + } + + public String getContextId() { + return contextId; + } + + public Task getTask() { + return task; + } + + public List getRelatedTasks() { + return relatedTasks; + } + + public ServerCallContext getServerCallContext() { + return serverCallContext; + } + + public RequestContext build() { + return new RequestContext(params, taskId, contextId, task, relatedTasks); + } + } + +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/agentexecution/SimpleRequestContextBuilder.java b/sdk-server-common/src/main/java/io/a2a/server/agentexecution/SimpleRequestContextBuilder.java new file mode 100644 index 000000000..933258a45 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/agentexecution/SimpleRequestContextBuilder.java @@ -0,0 +1,35 @@ +package io.a2a.server.agentexecution; + +import java.util.ArrayList; +import java.util.List; + +import io.a2a.server.tasks.TaskStore; +import io.a2a.spec.Task; + +public class SimpleRequestContextBuilder extends RequestContext.Builder { + private final TaskStore taskStore; + private final boolean shouldPopulateReferredTasks; + + public SimpleRequestContextBuilder(TaskStore taskStore, boolean shouldPopulateReferredTasks) { + this.taskStore = taskStore; + this.shouldPopulateReferredTasks = shouldPopulateReferredTasks; + } + + @Override + public RequestContext build() { + List relatedTasks = null; + if (taskStore != null && shouldPopulateReferredTasks && getParams() != null + && getParams().message().getReferenceTaskIds() != null) { + relatedTasks = new ArrayList<>(); + for (String taskId : getParams().message().getReferenceTaskIds()) { + Task task = taskStore.get(taskId); + if (task != null) { + relatedTasks.add(task); + } + } + } + + super.setRelatedTasks(relatedTasks); + return super.build(); + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/auth/UnauthenticatedUser.java b/sdk-server-common/src/main/java/io/a2a/server/auth/UnauthenticatedUser.java new file mode 100644 index 000000000..e4ec7a69c --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/auth/UnauthenticatedUser.java @@ -0,0 +1,13 @@ +package io.a2a.server.auth; + +public class UnauthenticatedUser implements User { + @Override + public boolean isAuthenticated() { + return false; + } + + @Override + public String getUsername() { + return ""; + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/auth/User.java b/sdk-server-common/src/main/java/io/a2a/server/auth/User.java new file mode 100644 index 000000000..f41e98444 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/auth/User.java @@ -0,0 +1,6 @@ +package io.a2a.server.auth; + +public interface User { + boolean isAuthenticated(); + String getUsername(); +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/events/EnhancedRunnable.java b/sdk-server-common/src/main/java/io/a2a/server/events/EnhancedRunnable.java new file mode 100644 index 000000000..d5316d999 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/events/EnhancedRunnable.java @@ -0,0 +1,35 @@ +package io.a2a.server.events; + +import java.util.ArrayList; +import java.util.List; + +public abstract class EnhancedRunnable implements Runnable { + private volatile Throwable error; + private final List doneCallbacks = new ArrayList<>(); + + public Throwable getError() { + return error; + } + + public void setError(Throwable error) { + this.error = error; + } + + public void addDoneCallback(DoneCallback doneCallback) { + synchronized (doneCallbacks) { + doneCallbacks.add(doneCallback); + } + } + + public void invokeDoneCallbacks() { + synchronized (doneCallbacks) { + for (DoneCallback doneCallback : doneCallbacks) { + doneCallback.done(this); + } + } + } + + public interface DoneCallback { + void done(EnhancedRunnable agentRunnable); + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/events/EventConsumer.java b/sdk-server-common/src/main/java/io/a2a/server/events/EventConsumer.java new file mode 100644 index 000000000..ca38f43b7 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/events/EventConsumer.java @@ -0,0 +1,111 @@ +package io.a2a.server.events; + + +import java.util.concurrent.Flow; + +import io.a2a.spec.A2AServerException; +import io.a2a.spec.Event; +import io.a2a.spec.Message; +import io.a2a.spec.Task; +import io.a2a.spec.TaskStatusUpdateEvent; +import mutiny.zero.BackpressureStrategy; +import mutiny.zero.TubeConfiguration; +import mutiny.zero.ZeroPublisher; + +public class EventConsumer { + private final EventQueue queue; + private Throwable error; + + private static final String ERROR_MSG = "Agent did not return any response"; + private static final int NO_WAIT = -1; + private static final int QUEUE_WAIT_MILLISECONDS = 500; + + public EventConsumer(EventQueue queue) { + this.queue = queue; + } + + public Event consumeOne() throws A2AServerException, EventQueueClosedException { + Event event = queue.dequeueEvent(NO_WAIT); + if (event == null) { + throw new A2AServerException(ERROR_MSG, new InternalError(ERROR_MSG)); + } + return event; + } + + public Flow.Publisher consumeAll() { + TubeConfiguration conf = new TubeConfiguration() + .withBackpressureStrategy(BackpressureStrategy.BUFFER) + .withBufferSize(256); + return ZeroPublisher.create(conf, tube -> { + boolean completed = false; + try { + while (true) { + if (error != null) { + completed = true; + tube.fail(error); + return; + } + // We use a timeout when waiting for an event from the queue. + // This is required because it allows the loop to check if + // `self._exception` has been set by the `agent_task_callback`. + // Without the timeout, loop might hang indefinitely if no events are + // enqueued by the agent and the agent simply threw an exception + + // TODO the callback mentioned above seems unused in the Python 0.2.1 tag + Event event; + try { + event = queue.dequeueEvent(QUEUE_WAIT_MILLISECONDS); + if (event == null) { + continue; + } + if (event instanceof Throwable thr) { + tube.fail(thr); + return; + } + tube.send(event); + } catch (EventQueueClosedException e) { + completed = true; + tube.complete(); + return; + } catch (Exception e) { + // Continue polling until there is a final event + continue; + } + + boolean isFinalEvent = false; + if (event instanceof TaskStatusUpdateEvent tue && tue.isFinal()) { + isFinalEvent = true; + } else if (event instanceof Message) { + isFinalEvent = true; + } else if (event instanceof Task task) { + switch (task.getStatus().state()) { + case COMPLETED: + case CANCELED: + case FAILED: + case REJECTED: + case UNKNOWN: + isFinalEvent = true; + } + } + + if (isFinalEvent) { + queue.close(); + break; + } + } + } finally { + if (!completed) { + tube.complete(); + } + } + }); + } + + public EnhancedRunnable.DoneCallback createAgentRunnableDoneCallback() { + return agentRunnable -> { + if (agentRunnable.getError() != null) { + error = agentRunnable.getError(); + } + }; + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/events/EventQueue.java b/sdk-server-common/src/main/java/io/a2a/server/events/EventQueue.java new file mode 100644 index 000000000..585210c3e --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/events/EventQueue.java @@ -0,0 +1,185 @@ +package io.a2a.server.events; + +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import io.a2a.server.util.TempLoggerWrapper; +import io.a2a.spec.Event; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class EventQueue { + + private static final Logger log = new TempLoggerWrapper(LoggerFactory.getLogger(EventQueue.class)); + + private final EventQueue parent; + // TODO decide on a capacity (or more appropriate queue data structures) + private final BlockingQueue queue = new ArrayBlockingQueue(1000); + private volatile boolean closed = false; + + + + protected EventQueue() { + this(null); + } + + protected EventQueue(EventQueue parent) { + log.trace("Creating {}, parent: {}", this, parent); + this.parent = parent; + } + + public static EventQueue create() { + return new MainQueue(); + } + + public abstract void awaitQueuePollerStart() throws InterruptedException ; + + abstract void signalQueuePollerStarted(); + + public void enqueueEvent(Event event) { + if (closed) { + log.warn("Queue is closed. Event will not be enqueued. {} {}", this, event); + return; + } + // Call toString() since for errors we don't really want the full stacktrace + queue.add(event); + log.debug("Enqueued event {} {}", event instanceof Throwable ? event.toString() : event, this); + } + + abstract EventQueue tap(); + + public Event dequeueEvent(int waitMilliSeconds) throws EventQueueClosedException { + if (closed && queue.isEmpty()) { + log.debug("Queue is closed, and empty. Sending termination message. {}", this); + throw new EventQueueClosedException(); + } + try { + if (waitMilliSeconds <= 0) { + Event event = queue.poll(); + if (event != null) { + // Call toString() since for errors we don't really want the full stacktrace + log.debug("Dequeued event (no wait) {} {}", this, event instanceof Throwable ? event.toString() : event); + } + return event; + } + try { + Event event = queue.poll(waitMilliSeconds, TimeUnit.MILLISECONDS); + if (event != null) { + // Call toString() since for errors we don't really want the full stacktrace + log.debug("Dequeued event (waiting) {} {}", this, event instanceof Throwable ? event.toString() : event); + } + return event; + } catch (InterruptedException e) { + log.debug("Interrupted dequeue (waiting) {}", this); + Thread.currentThread().interrupt(); + return null; + } + } finally { + signalQueuePollerStarted(); + } + } + + public void taskDone() { + // TODO Not sure if needed yet. BlockingQueue.poll()/.take() remove the events. + } + + public abstract void close(); + + public void doClose() { + synchronized (this) { + if (closed) { + return; + } + log.debug("Closing {}", this); + closed = true; + } + // Although the Python implementation drains the queue on closing, + // here it makes events go missing + // TODO do we actually need to drain it? If we do, we need some mechanism to determine that noone is + // polling any longer and drain it asynchronously once it is all done. That could perhaps be done + // via an EnhancedRunnable.DoneCallback. + //queue.drainTo(new ArrayList<>()); + } + + static class MainQueue extends EventQueue { + private final List children = new CopyOnWriteArrayList<>(); + private final CountDownLatch pollingStartedLatch = new CountDownLatch(1); + private final AtomicBoolean pollingStarted = new AtomicBoolean(false); + + EventQueue tap() { + ChildQueue child = new ChildQueue(this); + children.add(child); + return child; + } + + public void enqueueEvent(Event event) { + super.enqueueEvent(event); + children.forEach(eq -> eq.internalEnqueueEvent(event)); + } + + @Override + public void awaitQueuePollerStart() throws InterruptedException { + log.debug("Waiting for queue poller to start on {}", this); + pollingStartedLatch.await(10, TimeUnit.SECONDS); + log.debug("Queue poller started on {}", this); + } + + @Override + void signalQueuePollerStarted() { + if (pollingStarted.get()) { + return; + } + log.debug("Signalling that queue polling started {}", this); + pollingStartedLatch.countDown(); + pollingStarted.set(true); + } + + @Override + public void close() { + doClose(); + children.forEach(EventQueue::doClose); + } + } + + static class ChildQueue extends EventQueue { + private final MainQueue parent; + + public ChildQueue(MainQueue parent) { + this.parent = parent; + } + + @Override + public void enqueueEvent(Event event) { + parent.enqueueEvent(event); + } + + private void internalEnqueueEvent(Event event) { + super.enqueueEvent(event); + } + + @Override + EventQueue tap() { + throw new IllegalStateException("Can only tap the main queue"); + } + + @Override + public void awaitQueuePollerStart() throws InterruptedException { + parent.awaitQueuePollerStart(); + } + + @Override + void signalQueuePollerStarted() { + parent.signalQueuePollerStarted(); + } + + @Override + public void close() { + parent.close(); + } + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/events/EventQueueClosedException.java b/sdk-server-common/src/main/java/io/a2a/server/events/EventQueueClosedException.java new file mode 100644 index 000000000..f8feaa824 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/events/EventQueueClosedException.java @@ -0,0 +1,4 @@ +package io.a2a.server.events; + +public class EventQueueClosedException extends Exception { +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java b/sdk-server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java new file mode 100644 index 000000000..72bb81eb3 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java @@ -0,0 +1,66 @@ +package io.a2a.server.events; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import jakarta.enterprise.context.ApplicationScoped; + +@ApplicationScoped +public class InMemoryQueueManager implements QueueManager { + private final Map queues = Collections.synchronizedMap(new HashMap<>()); + + @Override + public void add(String taskId, EventQueue queue) { + synchronized (queues) { + if (queues.containsKey(taskId)) { + throw new TaskQueueExistsException(); + } + queues.put(taskId, queue); + } + } + + @Override + public EventQueue get(String taskId) { + return queues.get(taskId); + } + + @Override + public EventQueue tap(String taskId) { + synchronized (taskId) { + EventQueue queue = queues.get(taskId); + if (queue == null) { + return queue; + } + return queue.tap(); + } + } + + @Override + public void close(String taskId) { + synchronized (queues) { + EventQueue existing = queues.remove(taskId); + if (existing == null) { + throw new NoTaskQueueException(); + } + } + } + + @Override + public EventQueue createOrTap(String taskId) { + synchronized (queues) { + EventQueue queue = queues.get(taskId); + if (queue != null) { + return queue.tap(); + } + queue = EventQueue.create(); + queues.put(taskId, queue); + return queue; + } + } + + @Override + public void awaitQueuePollerStart(EventQueue eventQueue) throws InterruptedException { + eventQueue.awaitQueuePollerStart(); + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/events/NoTaskQueueException.java b/sdk-server-common/src/main/java/io/a2a/server/events/NoTaskQueueException.java new file mode 100644 index 000000000..d97c3a261 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/events/NoTaskQueueException.java @@ -0,0 +1,22 @@ +package io.a2a.server.events; + +public class NoTaskQueueException extends RuntimeException { + public NoTaskQueueException() { + } + + public NoTaskQueueException(String message) { + super(message); + } + + public NoTaskQueueException(String message, Throwable cause) { + super(message, cause); + } + + public NoTaskQueueException(Throwable cause) { + super(cause); + } + + public NoTaskQueueException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) { + super(message, cause, enableSuppression, writableStackTrace); + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/events/QueueManager.java b/sdk-server-common/src/main/java/io/a2a/server/events/QueueManager.java new file mode 100644 index 000000000..ddba13cb8 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/events/QueueManager.java @@ -0,0 +1,15 @@ +package io.a2a.server.events; + +public interface QueueManager { + void add(String taskId, EventQueue queue); + + EventQueue get(String taskId); + + EventQueue tap(String taskId); + + void close(String taskId); + + EventQueue createOrTap(String taskId); + + void awaitQueuePollerStart(EventQueue eventQueue) throws InterruptedException; +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/events/TaskQueueExistsException.java b/sdk-server-common/src/main/java/io/a2a/server/events/TaskQueueExistsException.java new file mode 100644 index 000000000..bfd429277 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/events/TaskQueueExistsException.java @@ -0,0 +1,22 @@ +package io.a2a.server.events; + +public class TaskQueueExistsException extends RuntimeException { + public TaskQueueExistsException() { + } + + public TaskQueueExistsException(String message) { + super(message); + } + + public TaskQueueExistsException(String message, Throwable cause) { + super(message, cause); + } + + public TaskQueueExistsException(Throwable cause) { + super(cause); + } + + public TaskQueueExistsException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) { + super(message, cause, enableSuppression, writableStackTrace); + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java new file mode 100644 index 000000000..501e71315 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java @@ -0,0 +1,389 @@ +package io.a2a.server.requesthandlers; + +import static io.a2a.server.util.async.AsyncUtils.convertingProcessor; +import static io.a2a.server.util.async.AsyncUtils.createTubeConfig; +import static io.a2a.server.util.async.AsyncUtils.processor; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import io.a2a.server.agentexecution.AgentExecutor; +import io.a2a.server.agentexecution.RequestContext; +import io.a2a.server.agentexecution.SimpleRequestContextBuilder; +import io.a2a.server.events.EnhancedRunnable; +import io.a2a.server.events.EventConsumer; +import io.a2a.server.events.EventQueue; +import io.a2a.server.events.QueueManager; +import io.a2a.server.events.TaskQueueExistsException; +import io.a2a.server.tasks.PushNotifier; +import io.a2a.server.tasks.ResultAggregator; +import io.a2a.server.tasks.TaskManager; +import io.a2a.server.tasks.TaskStore; +import io.a2a.server.util.TempLoggerWrapper; +import io.a2a.server.util.async.Internal; +import io.a2a.spec.Event; +import io.a2a.spec.EventKind; +import io.a2a.spec.InternalError; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.Message; +import io.a2a.spec.MessageSendParams; +import io.a2a.spec.PushNotificationConfig; +import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.Task; +import io.a2a.spec.TaskIdParams; +import io.a2a.spec.TaskNotFoundError; +import io.a2a.spec.TaskPushNotificationConfig; +import io.a2a.spec.TaskQueryParams; +import io.a2a.spec.UnsupportedOperationError; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@ApplicationScoped +public class DefaultRequestHandler implements RequestHandler { + + private static final Logger log = new TempLoggerWrapper(LoggerFactory.getLogger(DefaultRequestHandler.class)); + + private final AgentExecutor agentExecutor; + private final TaskStore taskStore; + private final QueueManager queueManager; + private final PushNotifier pushNotifier; + private final Supplier requestContextBuilder; + + private final Map> runningAgents = Collections.synchronizedMap(new HashMap<>()); + + private final Executor executor; + + @Inject + public DefaultRequestHandler(AgentExecutor agentExecutor, TaskStore taskStore, + QueueManager queueManager, PushNotifier pushNotifier, @Internal Executor executor) { + this.agentExecutor = agentExecutor; + this.taskStore = taskStore; + this.queueManager = queueManager; + this.pushNotifier = pushNotifier; + this.executor = executor; + // TODO In Python this is also a constructor parameter defaulting to this SimpleRequestContextBuilder + // implementation if the parameter is null. Skip that for now, since otherwise I get CDI errors, and + // I am unsure about the correct scope. + // Also reworked to make a Supplier since otherwise the builder gets polluted with wrong tasks + this.requestContextBuilder = () -> new SimpleRequestContextBuilder(taskStore, false); + } + + @Override + public Task onGetTask(TaskQueryParams params) throws JSONRPCError { + log.debug("onGetTask {}", params.id()); + Task task = taskStore.get(params.id()); + if (task == null) { + log.debug("No task found for {}. Throwing TaskNotFoundError", params.id()); + throw new TaskNotFoundError(); + } + if (params.historyLength() != null && task.getHistory() != null && params.historyLength() < task.getHistory().size()) { + List history; + if (params.historyLength() <= 0) { + history = new ArrayList<>(); + } else { + history = task.getHistory().subList( + task.getHistory().size() - params.historyLength(), + task.getHistory().size() - 1); + } + + task = new Task.Builder(task) + .history(history) + .build(); + } + + log.debug("Task found {}", task); + return task; + } + + @Override + public Task onCancelTask(TaskIdParams params) throws JSONRPCError { + Task task = taskStore.get(params.id()); + if (task == null) { + throw new TaskNotFoundError(); + } + TaskManager taskManager = new TaskManager( + task.getId(), + task.getContextId(), + taskStore, + null); + + ResultAggregator resultAggregator = new ResultAggregator(taskManager, null); + + EventQueue queue = queueManager.tap(task.getId()); + if (queue == null) { + queue = EventQueue.create(); + } + agentExecutor.cancel( + requestContextBuilder.get() + .setTaskId(task.getId()) + .setContextId(task.getContextId()) + .setTask(task) + .build(), + queue); + + CompletableFuture cf = runningAgents.get(task.getId()); + if (cf != null) { + cf.cancel(true); + } + + EventConsumer consumer = new EventConsumer(queue); + EventKind type = resultAggregator.consumeAll(consumer); + if (type instanceof Task tempTask) { + return tempTask; + } + + throw new InternalError("Agent did not return a valid response"); + } + + @Override + public EventKind onMessageSend(MessageSendParams params) throws JSONRPCError { + log.debug("onMessageSend - task: {}; context {}", params.message().getTaskId(), params.message().getContextId()); + TaskManager taskManager = new TaskManager( + params.message().getTaskId(), + params.message().getContextId(), + taskStore, + params.message()); + + Task task = taskManager.getTask(); + if (task != null) { + log.debug("Found task updating with message {}", params.message()); + task = taskManager.updateWithMessage(params.message(), task); + + if (shouldAddPushInfo(params)) { + log.debug("Adding push info"); + pushNotifier.setInfo(task.getId(), params.configuration().pushNotification()); + } + } + + RequestContext requestContext = requestContextBuilder.get() + .setParams(params) + .setTaskId(task == null ? null : task.getId()) + .setContextId(params.message().getContextId()) + .setTask(task) + .build(); + + String taskId = requestContext.getTaskId(); + log.debug("Request context taskId: {}", taskId); + + EventQueue queue = queueManager.createOrTap(taskId); + ResultAggregator resultAggregator = new ResultAggregator(taskManager, null); + + EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(taskId, requestContext, queue); + + EventConsumer consumer = new EventConsumer(queue); + + producerRunnable.addDoneCallback(consumer.createAgentRunnableDoneCallback()); + + boolean interrupted = false; + ResultAggregator.EventTypeAndInterrupt etai = resultAggregator.consumeAndBreakOnInterrupt(consumer); + + try { + if (etai == null) { + log.debug("No result, throwing InternalError"); + throw new InternalError("No result"); + } + interrupted = etai.interrupted(); + log.debug("Was interrupted: {}", interrupted); + + EventKind kind = etai.eventType(); + if (kind instanceof Task taskResult && !taskId.equals(taskResult.getId())) { + throw new InternalError("Task ID mismatch in agent response"); + } + + } finally { + if (interrupted) { + // TODO Make this async + cleanupProducer(taskId); + } else { + cleanupProducer(taskId); + } + } + + log.debug("Returning: {}", etai.eventType()); + return etai.eventType(); + } + + @Override + public Flow.Publisher onMessageSendStream(MessageSendParams params) throws JSONRPCError { + TaskManager taskManager = new TaskManager( + params.message().getTaskId(), + params.message().getContextId(), + taskStore, + params.message()); + + Task task = taskManager.getTask(); + if (task != null) { + task = taskManager.updateWithMessage(params.message(), task); + + if (shouldAddPushInfo(params)) { + pushNotifier.setInfo(task.getId(), params.configuration().pushNotification()); + } + } + + RequestContext requestContext = requestContextBuilder.get() + .setParams(params) + .setTaskId(task == null ? null : task.getId()) + .setContextId(params.message().getContextId()) + .setTask(task) + .build(); + + AtomicReference taskId = new AtomicReference<>(requestContext.getTaskId()); + EventQueue queue = queueManager.createOrTap(taskId.get()); + ResultAggregator resultAggregator = new ResultAggregator(taskManager, null); + + EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(taskId.get(), requestContext, queue); + + EventConsumer consumer = new EventConsumer(queue); + + producerRunnable.addDoneCallback(consumer.createAgentRunnableDoneCallback()); + + try { + Flow.Publisher results = resultAggregator.consumeAndEmit(consumer); + + Flow.Publisher eventPublisher = + processor(createTubeConfig(), results, ((errorConsumer, event) -> { + if (event instanceof Task createdTask) { + if (!Objects.equals(taskId.get(), createdTask.getId())) { + errorConsumer.accept(new InternalError("Task ID mismatch in agent response")); + } + + // TODO the Python implementation no longer has the following block but removing it causes + // failures here + try { + queueManager.add(createdTask.getId(), queue); + taskId.set(createdTask.getId()); + } catch (TaskQueueExistsException e) { + // TODO Log + } + if (pushNotifier != null && + params.configuration() != null && + params.configuration().pushNotification() != null) { + + pushNotifier.setInfo( + createdTask.getId(), + params.configuration().pushNotification()); + } + + } + if (pushNotifier != null && taskId.get() != null) { + EventKind latest = resultAggregator.getCurrentResult(); + if (latest instanceof Task latestTask) { + pushNotifier.sendNotification(latestTask); + } + } + + return true; + })); + + return convertingProcessor(eventPublisher, event -> (StreamingEventKind) event); + } finally { + cleanupProducer(taskId.get()); + } + } + + @Override + public TaskPushNotificationConfig onSetTaskPushNotificationConfig(TaskPushNotificationConfig params) throws JSONRPCError { + if (pushNotifier == null) { + throw new UnsupportedOperationError(); + } + Task task = taskStore.get(params.taskId()); + if (task == null) { + throw new TaskNotFoundError(); + } + + pushNotifier.setInfo(params.taskId(), params.pushNotificationConfig()); + + return params; + } + + @Override + public TaskPushNotificationConfig onGetTaskPushNotificationConfig(TaskIdParams params) throws JSONRPCError { + if (pushNotifier == null) { + throw new UnsupportedOperationError(); + } + Task task = taskStore.get(params.id()); + if (task == null) { + throw new TaskNotFoundError(); + } + + PushNotificationConfig pushNotificationConfig = pushNotifier.getInfo(params.id()); + if (pushNotificationConfig == null) { + throw new InternalError("No push notification config found"); + } + + return new TaskPushNotificationConfig(params.id(), pushNotificationConfig); + } + + @Override + public Flow.Publisher onResubscribeToTask(TaskIdParams params) throws JSONRPCError { + Task task = taskStore.get(params.id()); + if (task == null) { + throw new TaskNotFoundError(); + } + + TaskManager taskManager = new TaskManager(task.getId(), task.getContextId(), taskStore, null); + ResultAggregator resultAggregator = new ResultAggregator(taskManager, null); + EventQueue queue = queueManager.tap(task.getId()); + + if (queue == null) { + throw new TaskNotFoundError(); + } + + EventConsumer consumer = new EventConsumer(queue); + Flow.Publisher results = resultAggregator.consumeAndEmit(consumer); + return convertingProcessor(results, e -> (StreamingEventKind) e); + } + + private boolean shouldAddPushInfo(MessageSendParams params) { + return pushNotifier != null && params.configuration() != null && params.configuration().pushNotification() != null; + } + + private EnhancedRunnable registerAndExecuteAgentAsync(String taskId, RequestContext requestContext, EventQueue queue) { + EnhancedRunnable runnable = new EnhancedRunnable() { + @Override + public void run() { + agentExecutor.execute(requestContext, queue); + try { + queueManager.awaitQueuePollerStart(queue); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + }; + + CompletableFuture cf = CompletableFuture.runAsync(runnable, executor) + .whenComplete((v, err) -> { + if (err != null) { + runnable.setError(err); + } + queue.close(); + runnable.invokeDoneCallbacks(); + }); + runningAgents.put(taskId, cf); + return runnable; + } + + private void cleanupProducer(String taskId) { + // TODO the Python implementation waits for the producerRunnable + CompletableFuture cf = runningAgents.get(taskId); + if (cf != null) { + cf.whenComplete((v, t) -> { + queueManager.close(taskId); + runningAgents.remove(taskId); + }); + } + } + +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/JSONRPCHandler.java b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/JSONRPCHandler.java new file mode 100644 index 000000000..131b07fc8 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/JSONRPCHandler.java @@ -0,0 +1,194 @@ +package io.a2a.server.requesthandlers; + +import static io.a2a.server.util.async.AsyncUtils.createTubeConfig; + +import java.util.concurrent.Flow; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import io.a2a.server.PublicAgentCard; +import io.a2a.spec.AgentCard; +import io.a2a.spec.CancelTaskRequest; +import io.a2a.spec.CancelTaskResponse; +import io.a2a.spec.EventKind; +import io.a2a.spec.GetTaskPushNotificationConfigRequest; +import io.a2a.spec.GetTaskPushNotificationConfigResponse; +import io.a2a.spec.GetTaskRequest; +import io.a2a.spec.GetTaskResponse; +import io.a2a.spec.InternalError; +import io.a2a.spec.InvalidRequestError; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.SendMessageRequest; +import io.a2a.spec.SendMessageResponse; +import io.a2a.spec.SendStreamingMessageRequest; +import io.a2a.spec.SendStreamingMessageResponse; +import io.a2a.spec.SetTaskPushNotificationConfigRequest; +import io.a2a.spec.SetTaskPushNotificationConfigResponse; +import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.Task; +import io.a2a.spec.TaskNotFoundError; +import io.a2a.spec.TaskPushNotificationConfig; +import io.a2a.spec.TaskResubscriptionRequest; +import mutiny.zero.ZeroPublisher; + +@ApplicationScoped +public class JSONRPCHandler { + + private AgentCard agentCard; + private RequestHandler requestHandler; + + @Inject + public JSONRPCHandler(@PublicAgentCard AgentCard agentCard, RequestHandler requestHandler) { + this.agentCard = agentCard; + this.requestHandler = requestHandler; + } + + public SendMessageResponse onMessageSend(SendMessageRequest request) { + try { + EventKind taskOrMessage = requestHandler.onMessageSend(request.getParams()); + return new SendMessageResponse(request.getId(), taskOrMessage); + } catch (JSONRPCError e) { + return new SendMessageResponse(request.getId(), e); + } catch (Throwable t) { + return new SendMessageResponse(request.getId(), new InternalError(t.getMessage())); + } + } + + + public Flow.Publisher onMessageSendStream(SendStreamingMessageRequest request) { + if (!agentCard.capabilities().streaming()) { + return ZeroPublisher.fromItems( + new SendStreamingMessageResponse( + request.getId(), + new InvalidRequestError("Streaming is not supported by the agent"))); + } + + try { + Flow.Publisher publisher = requestHandler.onMessageSendStream(request.getParams()); + // We can't use the convertingProcessor convenience method since that propagates any errors as an error handled + // via Subscriber.onError() rather than as part of the SendStreamingResponse payload + return convertToSendStreamingMessageResponse(request.getId(), publisher); + } catch (JSONRPCError e) { + return ZeroPublisher.fromItems(new SendStreamingMessageResponse(request.getId(), e)); + } catch (Throwable throwable) { + return ZeroPublisher.fromItems(new SendStreamingMessageResponse(request.getId(), new InternalError(throwable.getMessage()))); + } + } + + public CancelTaskResponse onCancelTask(CancelTaskRequest request) { + try { + Task task = requestHandler.onCancelTask(request.getParams()); + if (task != null) { + return new CancelTaskResponse(request.getId(), task); + } + return new CancelTaskResponse(request.getId(), new TaskNotFoundError()); + } catch (JSONRPCError e) { + return new CancelTaskResponse(request.getId(), e); + } catch (Throwable t) { + return new CancelTaskResponse(request.getId(), new InternalError(t.getMessage())); + } + } + + public Flow.Publisher onResubscribeToTask(TaskResubscriptionRequest request) { + if (!agentCard.capabilities().streaming()) { + return ZeroPublisher.fromItems( + new SendStreamingMessageResponse( + request.getId(), + new InvalidRequestError("Streaming is not supported by the agent"))); + } + + try { + Flow.Publisher publisher = requestHandler.onResubscribeToTask(request.getParams()); + // We can't use the convertingProcessor convenience method since that propagates any errors as an error handled + // via Subscriber.onError() rather than as part of the SendStreamingResponse payload + return convertToSendStreamingMessageResponse(request.getId(), publisher); + } catch (JSONRPCError e) { + return ZeroPublisher.fromItems(new SendStreamingMessageResponse(request.getId(), e)); + } catch (Throwable throwable) { + return ZeroPublisher.fromItems(new SendStreamingMessageResponse(request.getId(), new InternalError(throwable.getMessage()))); + } + } + + public GetTaskPushNotificationConfigResponse getPushNotification(GetTaskPushNotificationConfigRequest request) { + try { + TaskPushNotificationConfig config = requestHandler.onGetTaskPushNotificationConfig(request.getParams()); + return new GetTaskPushNotificationConfigResponse(request.getId(), config); + } catch (JSONRPCError e) { + return new GetTaskPushNotificationConfigResponse(request.getId().toString(), e); + } catch (Throwable t) { + return new GetTaskPushNotificationConfigResponse(request.getId(), new InternalError(t.getMessage())); + } + } + + public SetTaskPushNotificationConfigResponse setPushNotification(SetTaskPushNotificationConfigRequest request) { + if (!agentCard.capabilities().pushNotifications()) { + return new SetTaskPushNotificationConfigResponse(request.getId(), + new InvalidRequestError("Push notifications are not supported by the agent")); + } + try { + TaskPushNotificationConfig config = requestHandler.onSetTaskPushNotificationConfig(request.getParams()); + return new SetTaskPushNotificationConfigResponse(request.getId().toString(), config); + } catch (JSONRPCError e) { + return new SetTaskPushNotificationConfigResponse(request.getId(), e); + } catch (Throwable t) { + return new SetTaskPushNotificationConfigResponse(request.getId(), new InternalError(t.getMessage())); + } + } + + public GetTaskResponse onGetTask(GetTaskRequest request) { + try { + Task task = requestHandler.onGetTask(request.getParams()); + return new GetTaskResponse(request.getId(), task); + } catch (JSONRPCError e) { + return new GetTaskResponse(request.getId(), e); + } catch (Throwable t) { + return new GetTaskResponse(request.getId(), new InternalError(t.getMessage())); + } + } + + public AgentCard getAgentCard() { + return agentCard; + } + + private Flow.Publisher convertToSendStreamingMessageResponse( + Object requestId, + Flow.Publisher publisher) { + // We can't use the normal convertingProcessor since that propagates any errors as an error handled + // via Subscriber.onError() rather than as part of the SendStreamingResponse payload + return ZeroPublisher.create(createTubeConfig(), tube -> { + publisher.subscribe(new Flow.Subscriber() { + Flow.Subscription subscription; + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(StreamingEventKind item) { + tube.send(new SendStreamingMessageResponse(requestId, item)); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + if (throwable instanceof JSONRPCError jsonrpcError) { + tube.send(new SendStreamingMessageResponse(requestId, jsonrpcError)); + } else { + tube.send( + new SendStreamingMessageResponse( + requestId, new + InternalError(throwable.getMessage()))); + } + onComplete(); + } + + @Override + public void onComplete() { + tube.complete(); + } + }); + }); + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/RequestHandler.java b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/RequestHandler.java new file mode 100644 index 000000000..e2902bca0 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/RequestHandler.java @@ -0,0 +1,28 @@ +package io.a2a.server.requesthandlers; + +import java.util.concurrent.Flow; + +import io.a2a.spec.EventKind; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.MessageSendParams; +import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.Task; +import io.a2a.spec.TaskIdParams; +import io.a2a.spec.TaskPushNotificationConfig; +import io.a2a.spec.TaskQueryParams; + +public interface RequestHandler { + Task onGetTask(TaskQueryParams params) throws JSONRPCError; + + Task onCancelTask(TaskIdParams params) throws JSONRPCError; + + EventKind onMessageSend(MessageSendParams params) throws JSONRPCError; + + Flow.Publisher onMessageSendStream(MessageSendParams params) throws JSONRPCError; + + TaskPushNotificationConfig onSetTaskPushNotificationConfig(TaskPushNotificationConfig params) throws JSONRPCError; + + TaskPushNotificationConfig onGetTaskPushNotificationConfig(TaskIdParams params) throws JSONRPCError; + + Flow.Publisher onResubscribeToTask(TaskIdParams params) throws JSONRPCError; +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotifier.java b/sdk-server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotifier.java new file mode 100644 index 000000000..6fb1fb39a --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotifier.java @@ -0,0 +1,78 @@ +package io.a2a.server.tasks; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import com.fasterxml.jackson.core.JsonProcessingException; +import io.a2a.http.A2AHttpClient; +import io.a2a.http.JdkA2AHttpClient; +import io.a2a.spec.PushNotificationConfig; +import io.a2a.spec.Task; +import io.a2a.util.Utils; + +@ApplicationScoped +public class InMemoryPushNotifier implements PushNotifier { + private final A2AHttpClient httpClient; + private final Map pushNotificationInfos = Collections.synchronizedMap(new HashMap<>()); + + @Inject + public InMemoryPushNotifier() { + this.httpClient = new JdkA2AHttpClient(); + } + + public InMemoryPushNotifier(A2AHttpClient httpClient) { + this.httpClient = httpClient; + } + + @Override + public void setInfo(String taskId, PushNotificationConfig notificationConfig) { + pushNotificationInfos.put(taskId, notificationConfig); + } + + @Override + public PushNotificationConfig getInfo(String taskId) { + return pushNotificationInfos.get(taskId); + } + + @Override + public void deleteInfo(String taskId) { + pushNotificationInfos.remove(taskId); + } + + @Override + public void sendNotification(Task task) { + PushNotificationConfig pushInfo = pushNotificationInfos.get(task.getId()); + if (pushInfo == null) { + return; + } + String url = pushInfo.url(); + + // TODO auth + + String body; + try { + body = Utils.OBJECT_MAPPER.writeValueAsString(task); + } catch (JsonProcessingException e) { + e.printStackTrace(); + throw new RuntimeException("Error writing value as string: " + e.getMessage(), e); + } catch (Throwable throwable) { + throwable.printStackTrace(); + throw new RuntimeException("Error writing value as string: " + throwable.getMessage(), throwable); + } + + try { + httpClient.createPost() + .url(url) + .body(body) + .post(); + } catch (IOException | InterruptedException e) { + throw new RuntimeException("Error pushing data to " + url + ": " + e.getMessage(), e); + } + + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/tasks/InMemoryTaskStore.java b/sdk-server-common/src/main/java/io/a2a/server/tasks/InMemoryTaskStore.java new file mode 100644 index 000000000..4edfe21e4 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/tasks/InMemoryTaskStore.java @@ -0,0 +1,30 @@ +package io.a2a.server.tasks; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import jakarta.enterprise.context.ApplicationScoped; + +import io.a2a.spec.Task; + +@ApplicationScoped +public class InMemoryTaskStore implements TaskStore { + + private final Map tasks = Collections.synchronizedMap(new HashMap<>()); + + @Override + public void save(Task task) { + tasks.put(task.getId(), task); + } + + @Override + public Task get(String taskId) { + return tasks.get(taskId); + } + + @Override + public void delete(String taskId) { + tasks.remove(taskId); + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/tasks/PushNotifier.java b/sdk-server-common/src/main/java/io/a2a/server/tasks/PushNotifier.java new file mode 100644 index 000000000..2dfc7dff0 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/tasks/PushNotifier.java @@ -0,0 +1,14 @@ +package io.a2a.server.tasks; + +import io.a2a.spec.PushNotificationConfig; +import io.a2a.spec.Task; + +public interface PushNotifier { + void setInfo(String taskId, PushNotificationConfig notificationConfig); + + PushNotificationConfig getInfo(String taskId); + + void deleteInfo(String taskId); + + void sendNotification(Task task); +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java b/sdk-server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java new file mode 100644 index 000000000..b9aef07ff --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java @@ -0,0 +1,151 @@ +package io.a2a.server.tasks; + +import static io.a2a.server.util.async.AsyncUtils.consumer; +import static io.a2a.server.util.async.AsyncUtils.createTubeConfig; +import static io.a2a.server.util.async.AsyncUtils.processor; + +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import io.a2a.server.events.EventConsumer; +import io.a2a.server.util.TempLoggerWrapper; +import io.a2a.spec.A2AServerException; +import io.a2a.spec.Event; +import io.a2a.spec.EventKind; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.Message; +import io.a2a.spec.Task; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.a2a.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ResultAggregator { + private static final Logger log = new TempLoggerWrapper(LoggerFactory.getLogger(ResultAggregator.class)); + + private final TaskManager taskManager; + private volatile Message message; + + public ResultAggregator(TaskManager taskManager, Message message) { + this.taskManager = taskManager; + this.message = message; + } + + public EventKind getCurrentResult() { + if (message != null) { + return message; + } + return taskManager.getTask(); + } + + public Flow.Publisher consumeAndEmit(EventConsumer consumer) { + Flow.Publisher all = consumer.consumeAll(); + + return processor(createTubeConfig(), all, ((errorConsumer, event) -> { + callTaskManagerProcess(event); + return true; + })); + } + + public EventKind consumeAll(EventConsumer consumer) { + AtomicReference returnedEvent = new AtomicReference<>(); + Flow.Publisher all = consumer.consumeAll(); + AtomicReference error = new AtomicReference<>(); + consumer( + createTubeConfig(), + all, + (event) -> { + if (event instanceof Message msg) { + message = msg; + if (returnedEvent.get() == null) { + returnedEvent.set(msg); + return false; + } + } + callTaskManagerProcess(event); + return true; + }, + error::set); + + if (returnedEvent.get() != null) { + return returnedEvent.get(); + } + return taskManager.getTask(); + } + + public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer) throws JSONRPCError { + Flow.Publisher all = consumer.consumeAll(); + AtomicReference message = new AtomicReference<>(); + AtomicBoolean interrupted = new AtomicBoolean(false); + AtomicReference errorRef = new AtomicReference<>(); + consumer( + createTubeConfig(), + all, + (event -> { + if (event instanceof Throwable t) { + errorRef.set(t); + return false; + } + if (event instanceof Message msg) { + this.message = msg; + message.set(msg); + return false; + } + + callTaskManagerProcess(event); + + if ((event instanceof Task task && task.getStatus().state() == TaskState.AUTH_REQUIRED) + || (event instanceof TaskStatusUpdateEvent tsue && tsue.getStatus().state() == TaskState.AUTH_REQUIRED)) { + // auth-required is a special state: the message should be + // escalated back to the caller, but the agent is expected to + // continue producing events once the authorization is received + // out-of-band. This is in contrast to input-required, where a + // new request is expected in order for the agent to make progress, + // so the agent should exit. + + // TODO There is the following line in the Python code I don't totally get + // asyncio.create_task(self._continue_consuming(event_stream)) + // I think it means the continueConsuming() call should be done in another thread + continueConsuming(all); + + interrupted.set(true); + return false; + } + return true; + }), + errorRef::set); + + Throwable error = errorRef.get(); + if (error != null) { + Utils.rethrow(error); + } + + return new EventTypeAndInterrupt( + message.get() != null ? message.get() : taskManager.getTask(), interrupted.get()); + } + + private void continueConsuming(Flow.Publisher all) { + consumer(createTubeConfig(), + all, + event -> { + callTaskManagerProcess(event); + return true; + }, + t -> {}); + } + + private void callTaskManagerProcess(Event event) { + try { + taskManager.process(event); + } catch (A2AServerException e) { + // TODO Decide what to do in case of failure + e.printStackTrace(); + } + } + + public record EventTypeAndInterrupt(EventKind eventType, boolean interrupted) { + + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/tasks/TaskManager.java b/sdk-server-common/src/main/java/io/a2a/server/tasks/TaskManager.java new file mode 100644 index 000000000..a0b5a31bd --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/tasks/TaskManager.java @@ -0,0 +1,202 @@ +package io.a2a.server.tasks; + +import static io.a2a.spec.TaskState.SUBMITTED; +import static io.a2a.util.Assert.checkNotNullParam; + +import java.util.ArrayList; +import java.util.List; + +import io.a2a.spec.Event; +import io.a2a.spec.A2AServerException; +import io.a2a.spec.Artifact; +import io.a2a.spec.InvalidParamsError; +import io.a2a.spec.Message; +import io.a2a.spec.Part; +import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; + +public class TaskManager { + private volatile String taskId; + private volatile String contextId; + private final TaskStore taskStore; + private final Message initialMessage; + private volatile Task currentTask; + + public TaskManager(String taskId, String contextId, TaskStore taskStore, Message initialMessage) { + checkNotNullParam("taskStore", taskStore); + this.taskId = taskId; + this.contextId = contextId; + this.taskStore = taskStore; + this.initialMessage = initialMessage; + } + + String getTaskId() { + return taskId; + } + + String getContextId() { + return contextId; + } + + public Task getTask() { + if (taskId == null) { + return null; + } + if (currentTask != null) { + return currentTask; + } + currentTask = taskStore.get(taskId); + return currentTask; + } + + Task saveTaskEvent(Task task) throws A2AServerException { + checkIdsAndUpdateIfNecessary(task.getId(), task.getContextId()); + return saveTask(task); + } + + Task saveTaskEvent(TaskStatusUpdateEvent event) throws A2AServerException { + checkIdsAndUpdateIfNecessary(event.getTaskId(), event.getContextId()); + Task task = ensureTask(event.getTaskId(), event.getContextId()); + + + Task.Builder builder = new Task.Builder(task) + .status(event.getStatus()); + + if (task.getStatus().message() != null) { + List newHistory = task.getHistory() == null ? new ArrayList<>() : new ArrayList<>(task.getHistory()); + newHistory.add(task.getStatus().message()); + builder.history(newHistory); + } + + task = builder.build(); + return saveTask(task); + } + + Task saveTaskEvent(TaskArtifactUpdateEvent event) throws A2AServerException { + checkIdsAndUpdateIfNecessary(event.getTaskId(), event.getContextId()); + Task task = ensureTask(event.getTaskId(), event.getContextId()); + + // Append artifacts + List artifacts = task.getArtifacts() == null ? new ArrayList<>() : new ArrayList<>(task.getArtifacts()); + + Artifact newArtifact = event.getArtifact(); + String artifactId = newArtifact.artifactId(); + boolean appendParts = event.isAppend() != null && event.isAppend(); + + Artifact existingArtifact = null; + int existingArtifactIndex = -1; + + for (int i = 0; i < artifacts.size(); i++) { + Artifact curr = artifacts.get(i); + if (curr.artifactId() != null && curr.artifactId().equals(artifactId)) { + existingArtifact = curr; + existingArtifactIndex = i; + break; + } + } + + if (!appendParts) { + // This represents the first chunk for this artifact index + if (existingArtifactIndex >= 0) { + // Replace the existing artifact entirely with the new artifact + artifacts.set(existingArtifactIndex, newArtifact); + } else { + // Append the new artifact since no artifact with this id/index exists yet + artifacts.add(newArtifact); + } + + } else if (existingArtifact != null) { + // Append new parts to the existing artifact's parts list + // Do this to a copy + + List> parts = new ArrayList<>(existingArtifact.parts()); + parts.addAll(newArtifact.parts()); + Artifact updated = new Artifact.Builder(existingArtifact) + .parts(parts) + .build(); + artifacts.set(existingArtifactIndex, updated); + } else { + // We received a chunk to append, but we don't have an existing artifact. + // We will ignore this chunk + } + + task = new Task.Builder(task) + .artifacts(artifacts) + .build(); + + return saveTask(task); + } + + public Event process(Event event) throws A2AServerException { + if (event instanceof Task task) { + saveTask(task); + } else if (event instanceof TaskStatusUpdateEvent taskStatusUpdateEvent) { + saveTaskEvent(taskStatusUpdateEvent); + } else if (event instanceof TaskArtifactUpdateEvent taskArtifactUpdateEvent) { + saveTaskEvent(taskArtifactUpdateEvent); + } + return event; + } + + public Task updateWithMessage(Message message, Task task) { + List history = task.getHistory() == null ? new ArrayList<>() : new ArrayList<>(task.getHistory()); + if (task.getStatus().message() != null) { + history.add(task.getStatus().message()); + } + history.add(message); + task = new Task.Builder(task) + .history(history) + .build(); + saveTask(task); + return task; + } + + private void checkIdsAndUpdateIfNecessary(String eventTaskId, String eventContextId) throws A2AServerException { + if (taskId != null && !eventTaskId.equals(taskId)) { + throw new A2AServerException( + "Invalid task id", + new InvalidParamsError(String.format("Task in event doesn't match TaskManager "))); + } + if (taskId == null) { + taskId = eventTaskId; + } + if (contextId == null) { + contextId = eventContextId; + } + } + + private Task ensureTask(String eventTaskId, String eventContextId) { + Task task = currentTask; + if (task != null) { + return task; + } + task = taskStore.get(taskId); + if (task == null) { + task = createTask(eventTaskId, eventContextId); + saveTask(task); + } + return task; + } + + private Task createTask(String taskId, String contextId) { + List history = initialMessage != null ? List.of(initialMessage) : null; + return new Task.Builder() + .id(taskId) + .contextId(contextId) + .status(new TaskStatus(SUBMITTED)) + .history(history) + .build(); + } + + private Task saveTask(Task task) { + taskStore.save(task); + if (taskId == null) { + taskId = task.getId(); + contextId = task.getContextId(); + } + currentTask = task; + return currentTask; + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/tasks/TaskStore.java b/sdk-server-common/src/main/java/io/a2a/server/tasks/TaskStore.java new file mode 100644 index 000000000..73ac8f38f --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/tasks/TaskStore.java @@ -0,0 +1,11 @@ +package io.a2a.server.tasks; + +import io.a2a.spec.Task; + +public interface TaskStore { + void save(Task task); + + Task get(String taskId); + + void delete(String taskId); +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/tasks/TaskUpdater.java b/sdk-server-common/src/main/java/io/a2a/server/tasks/TaskUpdater.java new file mode 100644 index 000000000..2e7a7cafb --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/tasks/TaskUpdater.java @@ -0,0 +1,112 @@ +package io.a2a.server.tasks; + +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import io.a2a.server.agentexecution.RequestContext; +import io.a2a.server.events.EventQueue; +import io.a2a.spec.Artifact; +import io.a2a.spec.Message; +import io.a2a.spec.Part; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; + +public class TaskUpdater { + private final EventQueue eventQueue; + private final String taskId; + private final String contextId; + + public TaskUpdater(RequestContext context, EventQueue eventQueue) { + this.eventQueue = eventQueue; + this.taskId = context.getTaskId(); + this.contextId = context.getContextId(); + } + + private void updateStatus(TaskState taskState) { + updateStatus(taskState, null); + } + + private void updateStatus(TaskState state, Message message) { + TaskStatusUpdateEvent event = new TaskStatusUpdateEvent.Builder() + .taskId(taskId) + .contextId(contextId) + .isFinal(state.isFinal()) + .status(new TaskStatus(state, message, null)) + .build(); + eventQueue.enqueueEvent(event); + } + + public void addArtifact(List> parts, String artifactId, String name, Map metadata) { + if (artifactId == null) { + artifactId = UUID.randomUUID().toString(); + } + TaskArtifactUpdateEvent event = new TaskArtifactUpdateEvent.Builder() + .taskId(taskId) + .contextId(contextId) + .artifact( + new Artifact.Builder() + .artifactId(artifactId) + .name(name) + .parts(parts) + .metadata(metadata) + .build() + ) + .build(); + eventQueue.enqueueEvent(event); + } + + public void complete() { + complete(null); + } + + public void complete(Message message) { + updateStatus(TaskState.COMPLETED, message); + } + + public void fail() { + fail(null); + } + + public void fail(Message message) { + updateStatus(TaskState.FAILED, message); + } + + public void submit() { + submit(null); + } + + public void submit(Message message) { + updateStatus(TaskState.SUBMITTED, message); + } + + public void startWork() { + startWork(null); + } + + public void startWork(Message message) { + updateStatus(TaskState.WORKING, message); + } + + public void cancel() { + cancel(null); + } + + public void cancel(Message message) { + updateStatus(TaskState.CANCELED, message); + } + + public Message newAgentMessage(List> parts, Map metadata) { + return new Message.Builder() + .role(Message.Role.AGENT) + .taskId(taskId) + .contextId(contextId) + .messageId(UUID.randomUUID().toString()) + .metadata(metadata) + .parts(parts) + .build(); + } + +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/util/TempLoggerWrapper.java b/sdk-server-common/src/main/java/io/a2a/server/util/TempLoggerWrapper.java new file mode 100644 index 000000000..cecc6797d --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/util/TempLoggerWrapper.java @@ -0,0 +1,370 @@ +package io.a2a.server.util; + +import org.slf4j.Logger; +import org.slf4j.Marker; +import org.slf4j.event.Level; +import org.slf4j.helpers.CheckReturnValue; +import org.slf4j.spi.LoggingEventBuilder; + +/** + * Temporary wrapper class for Logger to log debug and trace messages at info level since I am not able to figure + * out how to enable debug logging via configuration. + */ +public class TempLoggerWrapper implements Logger { + private final Logger delegate; + + public TempLoggerWrapper(Logger delegate) { + this.delegate = delegate; + } + + @Override + public String getName() { + return delegate.getName(); + } + + @Override + public LoggingEventBuilder makeLoggingEventBuilder(Level level) { + return delegate.makeLoggingEventBuilder(level); + } + + @CheckReturnValue + @Override + public LoggingEventBuilder atLevel(Level level) { + return delegate.atLevel(level); + } + + @Override + public boolean isEnabledForLevel(Level level) { + return delegate.isEnabledForLevel(level); + } + + @Override + public boolean isTraceEnabled() { + return delegate.isTraceEnabled(); + } + + @Override + public void trace(String msg) { + delegate.info(msg); + } + + @Override + public void trace(String format, Object arg) { + delegate.info(format, arg); + } + + @Override + public void trace(String format, Object arg1, Object arg2) { + delegate.info(format, arg1, arg2); + } + + @Override + public void trace(String format, Object... arguments) { + delegate.info(format, arguments); + } + + @Override + public void trace(String msg, Throwable t) { + delegate.info(msg, t); + } + + @Override + public boolean isTraceEnabled(Marker marker) { + return delegate.isInfoEnabled(marker); + } + + @CheckReturnValue + @Override + public LoggingEventBuilder atTrace() { + return delegate.atTrace(); + } + + @Override + public void trace(Marker marker, String msg) { + delegate.info(marker, msg); + } + + @Override + public void trace(Marker marker, String format, Object arg) { + delegate.info(marker, format, arg); + } + + @Override + public void trace(Marker marker, String format, Object arg1, Object arg2) { + delegate.info(marker, format, arg1, arg2); + } + + @Override + public void trace(Marker marker, String format, Object... argArray) { + delegate.info(marker, format, argArray); + } + + @Override + public void trace(Marker marker, String msg, Throwable t) { + delegate.info(marker, msg, t); + } + + @Override + public boolean isDebugEnabled() { + return delegate.isInfoEnabled(); + } + + @Override + public void debug(String msg) { + delegate.info(msg); + } + + @Override + public void debug(String format, Object arg) { + delegate.info(format, arg); + } + + @Override + public void debug(String format, Object arg1, Object arg2) { + delegate.info(format, arg1, arg2); + } + + @Override + public void debug(String format, Object... arguments) { + delegate.info(format, arguments); + } + + @Override + public void debug(String msg, Throwable t) { + delegate.info(msg, t); + } + + @Override + public boolean isDebugEnabled(Marker marker) { + return delegate.isDebugEnabled(marker); + } + + @Override + public void debug(Marker marker, String msg) { + delegate.info(marker, msg); + } + + @Override + public void debug(Marker marker, String format, Object arg) { + delegate.info(marker, format, arg); + } + + @Override + public void debug(Marker marker, String format, Object arg1, Object arg2) { + delegate.info(marker, format, arg1, arg2); + } + + @Override + public void debug(Marker marker, String format, Object... arguments) { + delegate.info(marker, format, arguments); + } + + @Override + public void debug(Marker marker, String msg, Throwable t) { + delegate.info(marker, msg, t); + } + + @CheckReturnValue + @Override + public LoggingEventBuilder atDebug() { + return delegate.atDebug(); + } + + @Override + public boolean isInfoEnabled() { + return delegate.isInfoEnabled(); + } + + @Override + public void info(String msg) { + delegate.info(msg); + } + + @Override + public void info(String format, Object arg) { + delegate.info(format, arg); + } + + @Override + public void info(String format, Object arg1, Object arg2) { + delegate.info(format, arg1, arg2); + } + + @Override + public void info(String format, Object... arguments) { + delegate.info(format, arguments); + } + + @Override + public void info(String msg, Throwable t) { + delegate.info(msg, t); + } + + @Override + public boolean isInfoEnabled(Marker marker) { + return delegate.isInfoEnabled(marker); + } + + @Override + public void info(Marker marker, String msg) { + delegate.info(marker, msg); + } + + @Override + public void info(Marker marker, String format, Object arg) { + delegate.info(marker, format, arg); + } + + @Override + public void info(Marker marker, String format, Object arg1, Object arg2) { + delegate.info(marker, format, arg1, arg2); + } + + @Override + public void info(Marker marker, String format, Object... arguments) { + delegate.info(marker, format, arguments); + } + + @Override + public void info(Marker marker, String msg, Throwable t) { + delegate.info(marker, msg, t); + } + + @CheckReturnValue + @Override + public LoggingEventBuilder atInfo() { + return delegate.atInfo(); + } + + @Override + public boolean isWarnEnabled() { + return delegate.isWarnEnabled(); + } + + @Override + public void warn(String msg) { + delegate.warn(msg); + } + + @Override + public void warn(String format, Object arg) { + delegate.warn(format, arg); + } + + @Override + public void warn(String format, Object... arguments) { + delegate.warn(format, arguments); + } + + @Override + public void warn(String format, Object arg1, Object arg2) { + delegate.warn(format, arg1, arg2); + } + + @Override + public void warn(String msg, Throwable t) { + delegate.warn(msg, t); + } + + @Override + public boolean isWarnEnabled(Marker marker) { + return delegate.isWarnEnabled(marker); + } + + @Override + public void warn(Marker marker, String msg) { + delegate.warn(marker, msg); + } + + @Override + public void warn(Marker marker, String format, Object arg) { + delegate.warn(marker, format, arg); + } + + @Override + public void warn(Marker marker, String format, Object arg1, Object arg2) { + delegate.warn(marker, format, arg1, arg2); + } + + @Override + public void warn(Marker marker, String format, Object... arguments) { + delegate.warn(marker, format, arguments); + } + + @Override + public void warn(Marker marker, String msg, Throwable t) { + delegate.warn(marker, msg, t); + } + + @CheckReturnValue + @Override + public LoggingEventBuilder atWarn() { + return delegate.atWarn(); + } + + @Override + public boolean isErrorEnabled() { + return delegate.isErrorEnabled(); + } + + @Override + public void error(String msg) { + delegate.error(msg); + } + + @Override + public void error(String format, Object arg) { + delegate.error(format, arg); + } + + @Override + public void error(String format, Object arg1, Object arg2) { + delegate.error(format, arg1, arg2); + } + + @Override + public void error(String format, Object... arguments) { + delegate.error(format, arguments); + } + + @Override + public void error(String msg, Throwable t) { + delegate.error(msg, t); + } + + @Override + public boolean isErrorEnabled(Marker marker) { + return delegate.isErrorEnabled(marker); + } + + @Override + public void error(Marker marker, String msg) { + delegate.error(marker, msg); + } + + @Override + public void error(Marker marker, String format, Object arg) { + delegate.error(marker, format, arg); + } + + @Override + public void error(Marker marker, String format, Object arg1, Object arg2) { + delegate.error(marker, format, arg1, arg2); + } + + @Override + public void error(Marker marker, String format, Object... arguments) { + delegate.error(marker, format, arguments); + } + + @Override + public void error(Marker marker, String msg, Throwable t) { + delegate.error(marker, msg, t); + } + + @CheckReturnValue + @Override + public LoggingEventBuilder atError() { + return delegate.atError(); + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java b/sdk-server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java new file mode 100644 index 000000000..1b3066ca6 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java @@ -0,0 +1,33 @@ +package io.a2a.server.util.async; + +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.enterprise.inject.Produces; +import jakarta.inject.Singleton; + +@Singleton +public class AsyncExecutorProducer { + + private ExecutorService executor; + + @PostConstruct + public void init() { + executor = Executors.newCachedThreadPool(); + } + + @PreDestroy + public void close() { + executor.shutdown(); + } + + @Produces + @Internal + public Executor produce() { + return executor; + } + +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/util/async/AsyncUtils.java b/sdk-server-common/src/main/java/io/a2a/server/util/async/AsyncUtils.java new file mode 100644 index 000000000..25444cde9 --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/util/async/AsyncUtils.java @@ -0,0 +1,240 @@ +package io.a2a.server.util.async; + +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; + +import io.a2a.util.Assert; +import mutiny.zero.BackpressureStrategy; +import mutiny.zero.Tube; +import mutiny.zero.TubeConfiguration; +import mutiny.zero.ZeroPublisher; +import mutiny.zero.operators.Transform; + +public class AsyncUtils { + + private static final int DEFAULT_TUBE_BUFFER_SIZE = 256; + + public static TubeConfiguration createTubeConfig() { + return createTubeConfig(DEFAULT_TUBE_BUFFER_SIZE); + } + + public static TubeConfiguration createTubeConfig(int bufferSize) { + return new TubeConfiguration() + .withBackpressureStrategy(BackpressureStrategy.BUFFER) + .withBufferSize(256); + } + + public static void consumer( + TubeConfiguration config, + Flow.Publisher source, + Function nextFunction, + Consumer errorConsumer) { + + BiFunction, T, Boolean> nextBiFunction = new BiFunction, T, Boolean>() { + @Override + public Boolean apply(Consumer throwableConsumer, T t) { + return nextFunction.apply(t); + } + }; + + ZeroPublisher.create(config, tube -> { + source.subscribe(new ConsumingSubscriber<>(nextBiFunction, errorConsumer)); + }) + .subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(Object item) { + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + } + + public static Flow.Publisher processor( + TubeConfiguration config, + Flow.Publisher source, + BiFunction, T, Boolean> nextFunction) { + + return ZeroPublisher.create(config, tube -> { + source.subscribe(new ProcessingSubscriber<>(tube, nextFunction)); + }); + } + + public static Flow.Publisher convertingProcessor(Flow.Publisher source, Function converterFunction) { + return new Transform<>(source, converterFunction); + } + + + private static abstract class AbstractSubscriber implements Flow.Subscriber { + private Flow.Subscription subscription; + private final BiFunction, T, Boolean> nextFunction; + private final Consumer publishNextConsumer; + private final Consumer failureOrCompleteConsumer; + + protected AbstractSubscriber( + BiFunction, T, Boolean> nextFunction, + Consumer publishNextConsumer, + Consumer failureOrCompleteConsumer) { + Assert.checkNotNullParam("nextFunction", nextFunction); + this.nextFunction = nextFunction; + this.publishNextConsumer = publishNextConsumer != null ? publishNextConsumer : t -> {}; + this.failureOrCompleteConsumer = failureOrCompleteConsumer != null ? failureOrCompleteConsumer : t -> {}; + } + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + this.subscription.request(1); + } + + @Override + public void onNext(T item) { + AtomicReference errorRaised = new AtomicReference<>(); + + Consumer errorConsumer = t -> { + errorRaised.set(t); + onError(t); + }; + boolean continueProcessing = false; + if (errorRaised.get() == null) { + try { + continueProcessing = nextFunction.apply(errorConsumer, item); + } catch (Throwable t) { + errorConsumer.accept(t); + } + } + if (!continueProcessing || errorRaised.get() != null) { + subscription.cancel(); + } else { + if (publishNextConsumer != null) { + publishNextConsumer.accept(item); + } + subscription.request(1); + } + } + + + @Override + public void onError(Throwable throwable) { + subscription.cancel(); + if (failureOrCompleteConsumer != null) { + failureOrCompleteConsumer.accept(throwable); + } + } + + @Override + public void onComplete() { + subscription.cancel(); + if (failureOrCompleteConsumer != null) { + failureOrCompleteConsumer.accept(null); + } + } + } + + private static class ConsumingSubscriber extends AbstractSubscriber { + public ConsumingSubscriber(BiFunction, T, Boolean> nextFunction, + Consumer failureOrCompleteConsumer) { + super(nextFunction, null, failureOrCompleteConsumer); + } + } + + private static class ProcessingSubscriber extends AbstractSubscriber { + private Flow.Subscription subscription; + private final Tube tube; + + public ProcessingSubscriber(Tube tube, BiFunction, T, Boolean> nextFunction) { + super( + nextFunction, + tube::send, + t -> { + if (t == null) { + tube.complete(); + } else { + tube.fail(t); + } + } + ); + Assert.checkNotNullParam("tube", tube); + this.tube = tube; + } + } + + private static class ConvertingProcessingSubscriber implements Flow.Subscriber { + private Flow.Subscription subscription; + private Tube tube; + private final BiFunction, T, N> converterBiFunction; + + public ConvertingProcessingSubscriber(Tube tube, Function converterFunction) { + Assert.checkNotNullParam("tube", tube); + Assert.checkNotNullParam("converterFunction", converterFunction); + this.tube = tube; + this.converterBiFunction = (throwableConsumer, t) -> converterFunction.apply(t); + } + + public ConvertingProcessingSubscriber(Tube tube, BiFunction, T, N> converterBiFunction) { + Assert.checkNotNullParam("tube", tube); + Assert.checkNotNullParam("converterBiFunction", converterBiFunction); + this.tube = tube; + this.converterBiFunction = converterBiFunction; + } + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + this.subscription.request(1); + } + + @Override + public void onNext(T item) { + AtomicBoolean errorRaised = new AtomicBoolean(false); + Consumer errorConsumer = t -> { + errorRaised.set(true); + onError(t); + }; + + N converted = null; + try { + converted = converterBiFunction.apply(errorConsumer, item); + } catch (Throwable t) { + errorConsumer.accept(t); + return; + } + if (!errorRaised.get()) { + tube.send(converted); + subscription.request(1); + } + } + + @Override + public void onError(Throwable throwable) { + subscription.cancel(); + tube.fail(throwable); + } + + @Override + public void onComplete() { + subscription.cancel(); + tube.complete(); + } + } +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/util/async/Internal.java b/sdk-server-common/src/main/java/io/a2a/server/util/async/Internal.java new file mode 100644 index 000000000..2b8cd100e --- /dev/null +++ b/sdk-server-common/src/main/java/io/a2a/server/util/async/Internal.java @@ -0,0 +1,11 @@ +package io.a2a.server.util.async; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +import jakarta.inject.Qualifier; + +@Qualifier +@Retention(RetentionPolicy.RUNTIME) +public @interface Internal { +} diff --git a/sdk-server-common/src/main/resources/META-INF/beans.xml b/sdk-server-common/src/main/resources/META-INF/beans.xml new file mode 100644 index 000000000..e69de29bb diff --git a/sdk-server-common/src/test/java/io/a2a/server/agentexecution/RequestContextTest.java b/sdk-server-common/src/test/java/io/a2a/server/agentexecution/RequestContextTest.java new file mode 100644 index 000000000..8e3f04af6 --- /dev/null +++ b/sdk-server-common/src/test/java/io/a2a/server/agentexecution/RequestContextTest.java @@ -0,0 +1,248 @@ +package io.a2a.server.agentexecution; + +import io.a2a.spec.InvalidParamsError; +import io.a2a.spec.Message; +import io.a2a.spec.MessageSendParams; +import io.a2a.spec.Task; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskState; +import io.a2a.spec.TextPart; +import org.junit.jupiter.api.Test; +import org.mockito.MockedStatic; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; + +public class RequestContextTest { + + @Test + public void testInitWithoutParams() { + RequestContext context = new RequestContext(null, null, null, null, null); + assertNull(context.getMessage()); + assertNull(context.getTaskId()); + assertNull(context.getContextId()); + assertNull(context.getTask()); + assertTrue(context.getRelatedTasks().isEmpty()); + } + + @Test + public void testInitWithParamsNoIds() { + var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).build(); + var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); + + UUID taskId = UUID.fromString("00000000-0000-0000-0000-000000000001"); + UUID contextId = UUID.fromString("00000000-0000-0000-0000-000000000002"); + + try (MockedStatic mockedUUID = mockStatic(UUID.class)) { + mockedUUID.when(UUID::randomUUID) + .thenReturn(taskId) + .thenReturn(contextId); + + RequestContext context = new RequestContext(mockParams, null, null, null, null); + + assertEquals(mockParams.message(), context.getMessage()); + assertEquals(taskId.toString(), context.getTaskId()); + assertEquals(mockParams.message().getTaskId(), taskId.toString()); + assertEquals(contextId.toString(), context.getContextId()); + assertEquals(mockParams.message().getContextId(), contextId.toString()); + } + } + + @Test + public void testInitWithTaskId() { + String taskId = "task-123"; + var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId(taskId).build(); + var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); + + RequestContext context = new RequestContext(mockParams, taskId, null, null, null); + + assertEquals(taskId, context.getTaskId()); + assertEquals(taskId, mockParams.message().getTaskId()); + } + + @Test + public void testInitWithContextId() { + String contextId = "context-456"; + var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).contextId(contextId).build(); + var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); + RequestContext context = new RequestContext(mockParams, null, contextId, null, null); + + assertEquals(contextId, context.getContextId()); + assertEquals(contextId, mockParams.message().getContextId()); + } + + @Test + public void testInitWithBothIds() { + String taskId = "task-123"; + String contextId = "context-456"; + var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId(taskId).contextId(contextId).build(); + var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); + RequestContext context = new RequestContext(mockParams, taskId, contextId, null, null); + + assertEquals(taskId, context.getTaskId()); + assertEquals(taskId, mockParams.message().getTaskId()); + assertEquals(contextId, context.getContextId()); + assertEquals(contextId, mockParams.message().getContextId()); + } + + @Test + public void testInitWithTask() { + var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).build(); + var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); + var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); + + RequestContext context = new RequestContext(mockParams, null, null, mockTask, null); + + assertEquals(mockTask, context.getTask()); + } + + @Test + public void testGetUserInputNoParams() { + RequestContext context = new RequestContext(null, null, null, null, null); + assertEquals("", context.getUserInput(null)); + } + + @Test + public void testAttachRelatedTask() { + var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); + + RequestContext context = new RequestContext(null, null, null, null, null); + assertEquals(0, context.getRelatedTasks().size()); + + context.attachRelatedTask(mockTask); + assertEquals(1, context.getRelatedTasks().size()); + assertEquals(mockTask, context.getRelatedTasks().get(0)); + + Task anotherTask = mock(Task.class); + context.attachRelatedTask(anotherTask); + assertEquals(2, context.getRelatedTasks().size()); + assertEquals(anotherTask, context.getRelatedTasks().get(1)); + } + + @Test + public void testCheckOrGenerateTaskIdWithExistingTaskId() { + String existingId = "existing-task-id"; + var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId(existingId).build(); + var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); + + RequestContext context = new RequestContext(mockParams, null, null, null, null); + + assertEquals(existingId, context.getTaskId()); + assertEquals(existingId, mockParams.message().getTaskId()); + } + + @Test + public void testCheckOrGenerateContextIdWithExistingContextId() { + String existingId = "existing-context-id"; + + var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).contextId(existingId).build(); + var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); + + RequestContext context = new RequestContext(mockParams, null, null, null, null); + + assertEquals(existingId, context.getContextId()); + assertEquals(existingId, mockParams.message().getContextId()); + } + + @Test + public void testInitRaisesErrorOnTaskIdMismatch() { + var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId("task-123").build(); + var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); + var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); + + InvalidParamsError error = assertThrows(InvalidParamsError.class, () -> + new RequestContext(mockParams, "wrong-task-id", null, mockTask, null)); + + assertTrue(error.getMessage().contains("bad task id")); + } + + @Test + public void testInitRaisesErrorOnContextIdMismatch() { + var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId("task-123").contextId("context-456").build(); + var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); + var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); + + InvalidParamsError error = assertThrows(InvalidParamsError.class, () -> + new RequestContext(mockParams, mockTask.getId(), "wrong-context-id", mockTask, null)); + + assertTrue(error.getMessage().contains("bad context id")); + } + + @Test + public void testWithRelatedTasksProvided() { + var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); + + List relatedTasks = new ArrayList<>(); + relatedTasks.add(mockTask); + relatedTasks.add(mock(Task.class)); + + RequestContext context = new RequestContext(null, null, null, null, relatedTasks); + + assertEquals(relatedTasks, context.getRelatedTasks()); + assertEquals(2, context.getRelatedTasks().size()); + } + + @Test + public void testMessagePropertyWithoutParams() { + RequestContext context = new RequestContext(null, null, null, null, null); + assertNull(context.getMessage()); + } + + @Test + public void testMessagePropertyWithParams() { + var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).build(); + var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); + + RequestContext context = new RequestContext(mockParams, null, null, null, null); + assertEquals(mockParams.message(), context.getMessage()); + } + + @Test + public void testInitWithExistingIdsInMessage() { + String existingTaskId = "existing-task-id"; + String existingContextId = "existing-context-id"; + + var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))) + .taskId(existingTaskId).contextId(existingContextId).build(); + var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); + + RequestContext context = new RequestContext(mockParams, null, null, null, null); + + assertEquals(existingTaskId, context.getTaskId()); + assertEquals(existingContextId, context.getContextId()); + } + + @Test + public void testInitWithTaskIdAndExistingTaskIdMatch() { + var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId("task-123").contextId("context-456").build(); + var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); + var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); + + + RequestContext context = new RequestContext(mockParams, mockTask.getId(), null, mockTask, null); + + assertEquals(mockTask.getId(), context.getTaskId()); + assertEquals(mockTask, context.getTask()); + } + + @Test + public void testInitWithContextIdAndExistingContextIdMatch() { + var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId("task-123").contextId("context-456").build(); + var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); + var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); + + + RequestContext context = new RequestContext(mockParams, mockTask.getId(), mockTask.getContextId(), mockTask, null); + + assertEquals(mockTask.getContextId(), context.getContextId()); + assertEquals(mockTask, context.getTask()); + } +} diff --git a/sdk-server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java b/sdk-server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java new file mode 100644 index 000000000..7d8d6cbad --- /dev/null +++ b/sdk-server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java @@ -0,0 +1,264 @@ +package io.a2a.server.events; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.core.JsonProcessingException; +import io.a2a.spec.A2AError; +import io.a2a.spec.A2AServerException; +import io.a2a.spec.Artifact; +import io.a2a.spec.Event; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.Message; +import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.a2a.spec.TextPart; +import io.a2a.util.Utils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class EventConsumerTest { + + private EventQueue eventQueue; + private EventConsumer eventConsumer; + + + private static final String MINIMAL_TASK = """ + { + "id": "123", + "contextId": "session-xyz", + "status": {"state": "submitted"}, + "kind": "task" + } + """; + + private static final String MESSAGE_PAYLOAD = """ + { + "role": "agent", + "parts": [{"kind": "text", "text": "test message"}], + "messageId": "111", + "kind": "message" + } + """; + + @BeforeEach + public void init() { + eventQueue = EventQueue.create(); + eventConsumer = new EventConsumer(eventQueue); + } + + @Test + public void testConsumeOneTaskEvent() throws Exception { + Task event = Utils.unmarshalFrom(MINIMAL_TASK, Task.TYPE_REFERENCE); + enqueueAndConsumeOneEvent(event); + } + + @Test + public void testConsumeOneMessageEvent() throws Exception { + Event event = Utils.unmarshalFrom(MESSAGE_PAYLOAD, Message.TYPE_REFERENCE); + enqueueAndConsumeOneEvent(event); + } + + @Test + public void testConsumeOneA2AErrorEvent() throws Exception { + Event event = new A2AError() {}; + enqueueAndConsumeOneEvent(event); + } + + @Test + public void testConsumeOneJsonRpcErrorEvent() throws Exception { + Event event = new JSONRPCError(123, "Some Error", null); + enqueueAndConsumeOneEvent(event); + } + + @Test + public void testConsumeOneQueueEmpty() throws A2AServerException { + assertThrows(A2AServerException.class, () -> eventConsumer.consumeOne()); + } + + @Test + public void testConsumeAllMultipleEvents() throws JsonProcessingException { + List events = List.of( + Utils.unmarshalFrom(MINIMAL_TASK, Task.TYPE_REFERENCE), + new TaskArtifactUpdateEvent.Builder() + .taskId("task-123") + .contextId("session-xyz") + .artifact(new Artifact.Builder() + .artifactId("11") + .parts(new TextPart("text")) + .build()) + .build(), + new TaskStatusUpdateEvent.Builder() + .taskId("task-123") + .contextId("session-xyz") + .status(new TaskStatus(TaskState.WORKING)) + .isFinal(true) + .build()); + + for (Event event : events) { + eventQueue.enqueueEvent(event); + } + + Flow.Publisher publisher = eventConsumer.consumeAll(); + final List receivedEvents = new ArrayList<>(); + final AtomicReference error = new AtomicReference<>(); + + publisher.subscribe(new Flow.Subscriber<>() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(Event item) { + receivedEvents.add(item); + subscription.request(1); + + } + + @Override + public void onError(Throwable throwable) { + error.set(throwable); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + assertNull(error.get()); + assertEquals(events.size(), receivedEvents.size()); + for (int i = 0; i < events.size(); i++) { + assertSame(events.get(i), receivedEvents.get(i)); + } + } + + @Test + public void testConsumeUntilMessage() throws Exception { + List events = List.of( + Utils.unmarshalFrom(MINIMAL_TASK, Task.TYPE_REFERENCE), + new TaskArtifactUpdateEvent.Builder() + .taskId("task-123") + .contextId("session-xyz") + .artifact(new Artifact.Builder() + .artifactId("11") + .parts(new TextPart("text")) + .build()) + .build(), + new TaskStatusUpdateEvent.Builder() + .taskId("task-123") + .contextId("session-xyz") + .status(new TaskStatus(TaskState.WORKING)) + .isFinal(true) + .build()); + + for (Event event : events) { + eventQueue.enqueueEvent(event); + } + + Flow.Publisher publisher = eventConsumer.consumeAll(); + final List receivedEvents = new ArrayList<>(); + final AtomicReference error = new AtomicReference<>(); + + publisher.subscribe(new Flow.Subscriber<>() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(Event item) { + receivedEvents.add(item); + subscription.request(1); + + } + + @Override + public void onError(Throwable throwable) { + error.set(throwable); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + assertNull(error.get()); + assertEquals(3, receivedEvents.size()); + for (int i = 0; i < 3; i++) { + assertSame(events.get(i), receivedEvents.get(i)); + } + } + + @Test + public void testConsumeMessageEvents() throws Exception { + Message message = Utils.unmarshalFrom(MESSAGE_PAYLOAD, Message.TYPE_REFERENCE); + Message message2 = new Message.Builder(message).build(); + + List events = List.of(message, message2); + + for (Event event : events) { + eventQueue.enqueueEvent(event); + } + + Flow.Publisher publisher = eventConsumer.consumeAll(); + final List receivedEvents = new ArrayList<>(); + final AtomicReference error = new AtomicReference<>(); + + publisher.subscribe(new Flow.Subscriber<>() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(Event item) { + receivedEvents.add(item); + subscription.request(1); + + } + + @Override + public void onError(Throwable throwable) { + error.set(throwable); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + assertNull(error.get()); + // The stream is closed after the first Message + assertEquals(1, receivedEvents.size()); + assertSame(message, receivedEvents.get(0)); + } + + private void enqueueAndConsumeOneEvent(Event event) throws Exception { + eventQueue.enqueueEvent(event); + Event result = eventConsumer.consumeOne(); + assertSame(event, result); + } +} diff --git a/sdk-server-common/src/test/java/io/a2a/server/events/EventQueueTest.java b/sdk-server-common/src/test/java/io/a2a/server/events/EventQueueTest.java new file mode 100644 index 000000000..9510f8d3f --- /dev/null +++ b/sdk-server-common/src/test/java/io/a2a/server/events/EventQueueTest.java @@ -0,0 +1,116 @@ +package io.a2a.server.events; + +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; + +import java.util.List; + +import io.a2a.spec.Artifact; +import io.a2a.spec.Event; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.Message; +import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskNotFoundError; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.a2a.spec.TextPart; +import io.a2a.util.Utils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class EventQueueTest { + + private EventQueue eventQueue; + + private static final String MINIMAL_TASK = """ + { + "id": "123", + "contextId": "session-xyz", + "status": {"state": "submitted"}, + "kind": "task" + } + """; + + private static final String MESSAGE_PAYLOAD = """ + { + "role": "agent", + "parts": [{"kind": "text", "text": "test message"}], + "messageId": "111", + "kind": "message" + } + """; + + + @BeforeEach + public void init() { + eventQueue = EventQueue.create(); + + } + + @Test + public void testEnqueueAndDequeueEvent() throws Exception { + Event event = Utils.unmarshalFrom(MESSAGE_PAYLOAD, Message.TYPE_REFERENCE); + eventQueue.enqueueEvent(event); + Event dequeuedEvent = eventQueue.dequeueEvent(200); + assertSame(event, dequeuedEvent); + } + + @Test + public void testDequeueEventNoWait() throws Exception { + Event event = Utils.unmarshalFrom(MINIMAL_TASK, Task.TYPE_REFERENCE); + eventQueue.enqueueEvent(event); + Event dequeuedEvent = eventQueue.dequeueEvent(-1); + assertSame(event, dequeuedEvent); + } + + @Test + public void testDequeueEventEmptyQueueNoWait() throws Exception { + Event dequeuedEvent = eventQueue.dequeueEvent(-1); + assertNull(dequeuedEvent); + } + + @Test + public void testDequeueEventWait() throws Exception { + Event event = new TaskStatusUpdateEvent.Builder() + .taskId("task-123") + .contextId("session-xyz") + .status(new TaskStatus(TaskState.WORKING)) + .isFinal(true) + .build(); + + eventQueue.enqueueEvent(event); + Event dequeuedEvent = eventQueue.dequeueEvent(1000); + assertSame(event, dequeuedEvent); + } + + @Test + public void testTaskDone() throws Exception { + Event event = new TaskArtifactUpdateEvent.Builder() + .taskId("task-123") + .contextId("session-xyz") + .artifact(new Artifact.Builder() + .artifactId("11") + .parts(new TextPart("text")) + .build()) + .build(); + eventQueue.enqueueEvent(event); + Event dequeuedEvent = eventQueue.dequeueEvent(1000); + assertSame(event, dequeuedEvent); + eventQueue.taskDone(); + } + + @Test + public void testEnqueueDifferentEventTypes() throws Exception { + List events = List.of( + new TaskNotFoundError(), + new JSONRPCError(111, "rpc error", null)); + + for (Event event : events) { + eventQueue.enqueueEvent(event); + Event dequeuedEvent = eventQueue.dequeueEvent(100); + assertSame(event, dequeuedEvent); + } + } +} diff --git a/sdk-server-common/src/test/java/io/a2a/server/requesthandlers/JSONRPCHandlerTest.java b/sdk-server-common/src/test/java/io/a2a/server/requesthandlers/JSONRPCHandlerTest.java new file mode 100644 index 000000000..01498df89 --- /dev/null +++ b/sdk-server-common/src/test/java/io/a2a/server/requesthandlers/JSONRPCHandlerTest.java @@ -0,0 +1,1341 @@ +package io.a2a.server.requesthandlers; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.Flow; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import jakarta.enterprise.context.Dependent; + +import io.a2a.http.A2AHttpClient; +import io.a2a.http.A2AHttpResponse; +import io.a2a.server.agentexecution.AgentExecutor; +import io.a2a.server.agentexecution.RequestContext; +import io.a2a.server.events.EventConsumer; +import io.a2a.server.events.EventQueue; +import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.tasks.InMemoryPushNotifier; +import io.a2a.server.tasks.InMemoryTaskStore; +import io.a2a.server.tasks.PushNotifier; +import io.a2a.server.tasks.ResultAggregator; +import io.a2a.server.tasks.TaskStore; +import io.a2a.server.tasks.TaskUpdater; +import io.a2a.spec.AgentCapabilities; +import io.a2a.spec.AgentCard; +import io.a2a.spec.Artifact; +import io.a2a.spec.CancelTaskRequest; +import io.a2a.spec.CancelTaskResponse; +import io.a2a.spec.Event; +import io.a2a.spec.GetTaskPushNotificationConfigRequest; +import io.a2a.spec.GetTaskPushNotificationConfigResponse; +import io.a2a.spec.GetTaskRequest; +import io.a2a.spec.GetTaskResponse; +import io.a2a.spec.InternalError; +import io.a2a.spec.InvalidRequestError; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.Message; +import io.a2a.spec.MessageSendParams; +import io.a2a.spec.PushNotificationConfig; +import io.a2a.spec.SendMessageRequest; +import io.a2a.spec.SendMessageResponse; +import io.a2a.spec.SendStreamingMessageRequest; +import io.a2a.spec.SendStreamingMessageResponse; +import io.a2a.spec.SetTaskPushNotificationConfigRequest; +import io.a2a.spec.SetTaskPushNotificationConfigResponse; +import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskIdParams; +import io.a2a.spec.TaskNotFoundError; +import io.a2a.spec.TaskPushNotificationConfig; +import io.a2a.spec.TaskQueryParams; +import io.a2a.spec.TaskResubscriptionRequest; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.a2a.spec.TextPart; +import io.a2a.spec.UnsupportedOperationError; +import io.a2a.util.Utils; +import io.quarkus.arc.profile.IfBuildProfile; +import mutiny.zero.ZeroPublisher; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.mockito.MockedConstruction; +import org.mockito.Mockito; + +public class JSONRPCHandlerTest { + + private static final AgentCard CARD = createAgentCard(true, true, true); + + private static final Task MINIMAL_TASK = new Task.Builder() + .id("task-123") + .contextId("session-xyz") + .status(new TaskStatus(TaskState.SUBMITTED)) + .build(); + + private static final Message MESSAGE = new Message.Builder() + .messageId("111") + .role(Message.Role.AGENT) + .parts(new TextPart("test message")) + .build(); + + AgentExecutor executor; + TaskStore taskStore; + RequestHandler requestHandler; + AgentExecutorMethod agentExecutorExecute; + AgentExecutorMethod agentExecutorCancel; + private InMemoryQueueManager queueManager; + private TestHttpClient httpClient; + + private final Executor internalExecutor = Executors.newCachedThreadPool(); + + + @BeforeEach + public void init() { + executor = new AgentExecutor() { + @Override + public void execute(RequestContext context, EventQueue eventQueue) throws JSONRPCError { + if (agentExecutorExecute != null) { + agentExecutorExecute.invoke(context, eventQueue); + } + } + + @Override + public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPCError { + if (agentExecutorCancel != null) { + agentExecutorCancel.invoke(context, eventQueue); + } + } + }; + + taskStore = new InMemoryTaskStore(); + queueManager = new InMemoryQueueManager(); + httpClient = new TestHttpClient(); + PushNotifier pushNotifier = new InMemoryPushNotifier(httpClient); + + requestHandler = new DefaultRequestHandler(executor, taskStore, queueManager, pushNotifier, internalExecutor); + } + + @AfterEach + public void cleanup() { + agentExecutorExecute = null; + agentExecutorCancel = null; + } + + @Test + public void testOnGetTaskSuccess() throws Exception { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + GetTaskRequest request = new GetTaskRequest("1", new TaskQueryParams(MINIMAL_TASK.getId())); + GetTaskResponse response = handler.onGetTask(request); + assertEquals(request.getId(), response.getId()); + assertSame(MINIMAL_TASK, response.getResult()); + assertNull(response.getError()); + } + + @Test + public void testOnGetTaskNotFound() throws Exception { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + GetTaskRequest request = new GetTaskRequest("1", new TaskQueryParams(MINIMAL_TASK.getId())); + GetTaskResponse response = handler.onGetTask(request); + assertEquals(request.getId(), response.getId()); + assertInstanceOf(TaskNotFoundError.class, response.getError()); + assertNull(response.getResult()); + } + + @Test + public void testOnCancelTaskSuccess() throws Exception { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + + agentExecutorCancel = (context, eventQueue) -> { + // We need to cancel the task or the EventConsumer never finds a 'final' event. + // Looking at the Python implementation, they typically use AgentExecutors that + // don't support cancellation. So my theory is the Agent updates the task to the CANCEL status + Task task = context.getTask(); + TaskUpdater taskUpdater = new TaskUpdater(context, eventQueue); + taskUpdater.cancel(); + }; + + CancelTaskRequest request = new CancelTaskRequest("111", new TaskIdParams(MINIMAL_TASK.getId())); + CancelTaskResponse response = handler.onCancelTask(request); + + assertNull(response.getError()); + assertEquals(request.getId(), response.getId()); + Task task = response.getResult(); + assertEquals(MINIMAL_TASK.getId(), task.getId()); + assertEquals(MINIMAL_TASK.getContextId(), task.getContextId()); + assertEquals(TaskState.CANCELED, task.getStatus().state()); + } + + @Test + public void testOnCancelTaskNotSupported() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + + agentExecutorCancel = (context, eventQueue) -> { + throw new UnsupportedOperationError(); + }; + + CancelTaskRequest request = new CancelTaskRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); + CancelTaskResponse response = handler.onCancelTask(request); + assertEquals(request.getId(), response.getId()); + assertNull(response.getResult()); + assertInstanceOf(UnsupportedOperationError.class, response.getError()); + } + + @Test + public void testOnCancelTaskNotFound() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + CancelTaskRequest request = new CancelTaskRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); + CancelTaskResponse response = handler.onCancelTask(request); + assertEquals(request.getId(), response.getId()); + assertNull(response.getResult()); + assertInstanceOf(TaskNotFoundError.class, response.getError()); + } + + @Test + public void testOnMessageNewMessageSuccess() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + agentExecutorExecute = (context, eventQueue) -> { + eventQueue.enqueueEvent(context.getMessage()); + }; + Message message = new Message.Builder(MESSAGE) + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .build(); + SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); + SendMessageResponse response = handler.onMessageSend(request); + assertNull(response.getError()); + // The Python implementation returns a Task here, but then again they are using hardcoded mocks and + // bypassing the whole EventQueue. + // If we were to send a Task in agentExecutorExecute EventConsumer.consumeAll() would not exit due to + // the Task not having a 'final' state + // + // See testOnMessageNewMessageSuccessMocks() for a test more similar to the Python implementation + assertSame(message, response.getResult()); + } + + @Test + public void testOnMessageNewMessageSuccessMocks() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + + Message message = new Message.Builder(MESSAGE) + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .build(); + + SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); + SendMessageResponse response; + try (MockedConstruction mocked = Mockito.mockConstruction( + EventConsumer.class, + (mock, context) -> {Mockito.doReturn(ZeroPublisher.fromItems(MINIMAL_TASK)).when(mock).consumeAll();})){ + response = handler.onMessageSend(request); + } + assertNull(response.getError()); + assertSame(MINIMAL_TASK, response.getResult()); + } + + @Test + public void testOnMessageNewMessageWithExistingTaskSuccess() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + agentExecutorExecute = (context, eventQueue) -> { + eventQueue.enqueueEvent(context.getMessage()); + }; + Message message = new Message.Builder(MESSAGE) + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .build(); + SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); + SendMessageResponse response = handler.onMessageSend(request); + assertNull(response.getError()); + // The Python implementation returns a Task here, but then again they are using hardcoded mocks and + // bypassing the whole EventQueue. + // If we were to send a Task in agentExecutorExecute EventConsumer.consumeAll() would not exit due to + // the Task not having a 'final' state + // + // See testOnMessageNewMessageWithExistingTaskSuccessMocks() for a test more similar to the Python implementation + assertSame(message, response.getResult()); + } + + @Test + public void testOnMessageNewMessageWithExistingTaskSuccessMocks() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + + Message message = new Message.Builder(MESSAGE) + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .build(); + SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); + SendMessageResponse response; + try (MockedConstruction mocked = Mockito.mockConstruction( + EventConsumer.class, + (mock, context) -> { + Mockito.doReturn(ZeroPublisher.fromItems(MINIMAL_TASK)).when(mock).consumeAll();})){ + response = handler.onMessageSend(request); + } + assertNull(response.getError()); + assertSame(MINIMAL_TASK, response.getResult()); + + } + + @Test + public void testOnMessageError() { + // See testMessageOnErrorMocks() for a test more similar to the Python implementation, using mocks for + // EventConsumer.consumeAll() + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + agentExecutorExecute = (context, eventQueue) -> { + eventQueue.enqueueEvent(new UnsupportedOperationError()); + }; + Message message = new Message.Builder(MESSAGE) + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .build(); + SendMessageRequest request = new SendMessageRequest( + "1", new MessageSendParams(message, null, null)); + SendMessageResponse response = handler.onMessageSend(request); + assertInstanceOf(UnsupportedOperationError.class, response.getError()); + assertNull(response.getResult()); + } + + @Test + public void testOnMessageErrorMocks() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + Message message = new Message.Builder(MESSAGE) + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .build(); + SendMessageRequest request = new SendMessageRequest( + "1", new MessageSendParams(message, null, null)); + SendMessageResponse response; + try (MockedConstruction mocked = Mockito.mockConstruction( + EventConsumer.class, + (mock, context) -> { + Mockito.doReturn(ZeroPublisher.fromItems(new UnsupportedOperationError())).when(mock).consumeAll();})){ + response = handler.onMessageSend(request); + } + + assertInstanceOf(UnsupportedOperationError.class, response.getError()); + assertNull(response.getResult()); + } + + @Test + public void testOnMessageStreamNewMessageSuccess() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + agentExecutorExecute = (context, eventQueue) -> { + eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); + }; + + Message message = new Message.Builder(MESSAGE) + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .build(); + + SendStreamingMessageRequest request = new SendStreamingMessageRequest( + "1", new MessageSendParams(message, null, null)); + Flow.Publisher response = handler.onMessageSendStream(request); + + List results = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(1); + + response.subscribe(new Flow.Subscriber<>() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(SendStreamingMessageResponse item) { + results.add(item.getResult()); + subscription.request(1); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + // The Python implementation has several events emitted since it uses mocks. Also, in the + // implementation, a Message is considered a 'final' Event in EventConsumer.consumeAll() + // so there would be no more Events. + // + // See testOnMessageStreamNewMessageSuccessMocks() for a test more similar to the Python implementation + assertEquals(1, results.size()); + assertSame(message, results.get(0)); + } + + @Test + public void testOnMessageStreamNewMessageSuccessMocks() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + + // This is used to send events from a mock + List events = List.of( + MINIMAL_TASK, + new TaskArtifactUpdateEvent.Builder() + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .artifact(new Artifact.Builder() + .artifactId("art1") + .parts(new TextPart("text")) + .build()) + .build(), + new TaskStatusUpdateEvent.Builder() + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .status(new TaskStatus(TaskState.COMPLETED)) + .build()); + + Message message = new Message.Builder(MESSAGE) + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .build(); + + SendStreamingMessageRequest request = new SendStreamingMessageRequest( + "1", new MessageSendParams(message, null, null)); + Flow.Publisher response; + try (MockedConstruction mocked = Mockito.mockConstruction( + EventConsumer.class, + (mock, context) -> { + Mockito.doReturn(ZeroPublisher.fromIterable(events)).when(mock).consumeAll();})){ + response = handler.onMessageSendStream(request); + } + + List results = new ArrayList<>(); + + response.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(SendStreamingMessageResponse item) { + results.add((Event) item.getResult()); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + + } + + @Override + public void onComplete() { + + } + }); + + assertEquals(events, results); + } + + @Test + public void testOnMessageStreamNewMessageExistingTaskSuccess() throws Exception { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + agentExecutorExecute = (context, eventQueue) -> { + eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); + }; + + Task task = new Task.Builder(MINIMAL_TASK) + .history(new ArrayList<>()) + .build(); + taskStore.save(task); + + Message message = new Message.Builder(MESSAGE) + .taskId(task.getId()) + .contextId(task.getContextId()) + .build(); + + + SendStreamingMessageRequest request = new SendStreamingMessageRequest( + "1", new MessageSendParams(message, null, null)); + Flow.Publisher response = handler.onMessageSendStream(request); + + // This Publisher never completes so we subscribe in a new thread. + // I _think_ that is as expected, and testOnMessageStreamNewMessageSendPushNotificationSuccess seems + // to confirm this + final List results = new ArrayList<>(); + final AtomicReference subscriptionRef = new AtomicReference<>(); + final CountDownLatch latch = new CountDownLatch(1); + + Executors.newSingleThreadExecutor().execute(() -> { + response.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + subscriptionRef.set(subscription); + subscription.request(1); + } + + @Override + public void onNext(SendStreamingMessageResponse item) { + results.add(item.getResult()); + subscriptionRef.get().request(1); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + subscriptionRef.get().cancel(); + } + + @Override + public void onComplete() { + subscriptionRef.get().cancel(); + } + }); + }); + + assertTrue(latch.await(1, TimeUnit.SECONDS)); + subscriptionRef.get().cancel(); + // The Python implementation has several events emitted since it uses mocks. + // + // See testOnMessageStreamNewMessageExistingTaskSuccessMocks() for a test more similar to the Python implementation + Task expected = new Task.Builder(task) + .history(message) + .build(); + assertEquals(1, results.size()); + StreamingEventKind receivedType = results.get(0); + assertInstanceOf(Task.class, receivedType); + Task received = (Task) receivedType; + assertEquals(expected.getId(), received.getId()); + assertEquals(expected.getContextId(), received.getContextId()); + assertEquals(expected.getStatus(), received.getStatus()); + assertEquals(expected.getHistory(), received.getHistory()); + } + + @Test + public void testOnMessageStreamNewMessageExistingTaskSuccessMocks() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + + Task task = new Task.Builder(MINIMAL_TASK) + .history(new ArrayList<>()) + .build(); + taskStore.save(task); + + // This is used to send events from a mock + List events = List.of( + new TaskArtifactUpdateEvent.Builder() + .taskId(task.getId()) + .contextId(task.getContextId()) + .artifact(new Artifact.Builder() + .artifactId("11") + .parts(new TextPart("text")) + .build()) + .build(), + new TaskStatusUpdateEvent.Builder() + .taskId(task.getId()) + .contextId(task.getContextId()) + .status(new TaskStatus(TaskState.WORKING)) + .build()); + + Message message = new Message.Builder(MESSAGE) + .taskId(task.getId()) + .contextId(task.getContextId()) + .build(); + + SendStreamingMessageRequest request = new SendStreamingMessageRequest( + "1", new MessageSendParams(message, null, null)); + Flow.Publisher response; + try (MockedConstruction mocked = Mockito.mockConstruction( + EventConsumer.class, + (mock, context) -> { + Mockito.doReturn(ZeroPublisher.fromIterable(events)).when(mock).consumeAll();})){ + response = handler.onMessageSendStream(request); + } + + List results = new ArrayList<>(); + + // Unlike testOnMessageStreamNewMessageExistingTaskSuccess() the ZeroPublisher.fromIterable() + // used to mock the events completes once it has sent all the items. So no special thread + // handling is needed. + response.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(SendStreamingMessageResponse item) { + results.add((Event) item.getResult()); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + + } + + @Override + public void onComplete() { + + } + }); + + assertEquals(events, results); + } + + + @Test + public void testSetPushNotificationSuccess() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + + TaskPushNotificationConfig taskPushConfig = + new TaskPushNotificationConfig( + MINIMAL_TASK.getId(), new PushNotificationConfig.Builder().url("http://example.com").build()); + SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); + SetTaskPushNotificationConfigResponse response = handler.setPushNotification(request); + assertSame(taskPushConfig, response.getResult()); + } + + @Test + public void testGetPushNotificationSuccess() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + agentExecutorExecute = (context, eventQueue) -> { + eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); + }; + + + TaskPushNotificationConfig taskPushConfig = + new TaskPushNotificationConfig( + MINIMAL_TASK.getId(), new PushNotificationConfig.Builder().url("http://example.com").build()); + + SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); + handler.setPushNotification(request); + + GetTaskPushNotificationConfigRequest getRequest = + new GetTaskPushNotificationConfigRequest("111", new TaskIdParams(MINIMAL_TASK.getId())); + GetTaskPushNotificationConfigResponse getResponse = handler.getPushNotification(getRequest); + + assertEquals(taskPushConfig, getResponse.getResult()); + } + + @Test + public void testOnMessageStreamNewMessageSendPushNotificationSuccess() throws Exception { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + + List events = List.of( + MINIMAL_TASK, + new TaskArtifactUpdateEvent.Builder() + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .artifact(new Artifact.Builder() + .artifactId("11") + .parts(new TextPart("text")) + .build()) + .build(), + new TaskStatusUpdateEvent.Builder() + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .status(new TaskStatus(TaskState.COMPLETED)) + .build()); + + + agentExecutorExecute = (context, eventQueue) -> { + // Hardcode the events to send here + for (Event event : events) { + eventQueue.enqueueEvent(event); + } + }; + + + TaskPushNotificationConfig config = new TaskPushNotificationConfig( + MINIMAL_TASK.getId(), + new PushNotificationConfig.Builder().url("http://example.com").build()); + SetTaskPushNotificationConfigRequest stpnRequest = new SetTaskPushNotificationConfigRequest("1", config); + SetTaskPushNotificationConfigResponse stpnResponse = handler.setPushNotification(stpnRequest); + assertNull(stpnResponse.getError()); + + Message msg = new Message.Builder(MESSAGE) + .taskId(MINIMAL_TASK.getId()) + .build(); + SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(msg, null, null)); + Flow.Publisher response = handler.onMessageSendStream(request); + + final List results = Collections.synchronizedList(new ArrayList<>()); + final AtomicReference subscriptionRef = new AtomicReference<>(); + final CountDownLatch latch = new CountDownLatch(6); + httpClient.latch = latch; + + Executors.newSingleThreadExecutor().execute(() -> { + response.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + subscriptionRef.set(subscription); + subscription.request(1); + latch.countDown(); + } + + @Override + public void onNext(SendStreamingMessageResponse item) { + System.out.println("-> " + item.getResult()); + results.add(item.getResult()); + System.out.println(results); + subscriptionRef.get().request(1); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + subscriptionRef.get().cancel(); + } + + @Override + public void onComplete() { + subscriptionRef.get().cancel(); + } + }); + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + subscriptionRef.get().cancel(); + if (results.size() != 3) { + // TODO - this is very strange. The results array is synchronized, and the latch is counted down + // AFTER adding items to the list. Still, I am seeing intermittently, but frequently that + // the results list only has two items. + long end = System.currentTimeMillis() + 5000; + while (results.size() != 3 && System.currentTimeMillis() < end) { + Thread.sleep(1000); + } + } + assertEquals(3, results.size()); + assertEquals(3, httpClient.tasks.size()); + + Task curr = httpClient.tasks.get(0); + assertEquals(MINIMAL_TASK.getId(), curr.getId()); + assertEquals(MINIMAL_TASK.getContextId(), curr.getContextId()); + assertEquals(MINIMAL_TASK.getStatus().state(), curr.getStatus().state()); + assertEquals(0, curr.getArtifacts() == null ? 0 : curr.getArtifacts().size()); + + curr = httpClient.tasks.get(1); + assertEquals(MINIMAL_TASK.getId(), curr.getId()); + assertEquals(MINIMAL_TASK.getContextId(), curr.getContextId()); + assertEquals(MINIMAL_TASK.getStatus().state(), curr.getStatus().state()); + assertEquals(1, curr.getArtifacts().size()); + assertEquals(1, curr.getArtifacts().get(0).parts().size()); + assertEquals("text", ((TextPart)curr.getArtifacts().get(0).parts().get(0)).getText()); + + curr = httpClient.tasks.get(2); + assertEquals(MINIMAL_TASK.getId(), curr.getId()); + assertEquals(MINIMAL_TASK.getContextId(), curr.getContextId()); + assertEquals(TaskState.COMPLETED, curr.getStatus().state()); + assertEquals(1, curr.getArtifacts().size()); + assertEquals(1, curr.getArtifacts().get(0).parts().size()); + assertEquals("text", ((TextPart)curr.getArtifacts().get(0).parts().get(0)).getText()); + } + + @Test + public void testOnResubscribeExistingTaskSuccess() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + queueManager.createOrTap(MINIMAL_TASK.getId()); + + agentExecutorExecute = (context, eventQueue) -> { + // The only thing hitting the agent is the onMessageSend() and we should use the message + eventQueue.enqueueEvent(context.getMessage()); + //eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); + }; + + TaskResubscriptionRequest request = new TaskResubscriptionRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); + Flow.Publisher response = handler.onResubscribeToTask(request); + + // We need to send some events in order for those to end up in the queue + Message message = new Message.Builder() + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .role(Message.Role.AGENT) + .parts(new TextPart("text")) + .build(); + SendMessageResponse smr = + handler.onMessageSend(new SendMessageRequest("1", new MessageSendParams(message, null, null))); + assertNull(smr.getError()); + + + List results = new ArrayList<>(); + + response.subscribe(new Flow.Subscriber<>() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(SendStreamingMessageResponse item) { + results.add(item.getResult()); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + // The Python implementation has several events emitted since it uses mocks. + // + // See testOnMessageStreamNewMessageExistingTaskSuccessMocks() for a test more similar to the Python implementation + assertEquals(1, results.size()); + } + + + @Test + public void testOnResubscribeExistingTaskSuccessMocks() throws Exception { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + queueManager.createOrTap(MINIMAL_TASK.getId()); + + List events = List.of( + new TaskArtifactUpdateEvent.Builder() + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .artifact(new Artifact.Builder() + .artifactId("11") + .parts(new TextPart("text")) + .build()) + .build(), + new TaskStatusUpdateEvent.Builder() + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .status(new TaskStatus(TaskState.WORKING)) + .build()); + + TaskResubscriptionRequest request = new TaskResubscriptionRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); + Flow.Publisher response; + try (MockedConstruction mocked = Mockito.mockConstruction( + EventConsumer.class, + (mock, context) -> { + Mockito.doReturn(ZeroPublisher.fromIterable(events)).when(mock).consumeAll();})){ + response = handler.onResubscribeToTask(request); + } + + List results = new ArrayList<>(); + + // Unlike testOnResubscribeExistingTaskSuccess() the ZeroPublisher.fromIterable() + // used to mock the events completes once it has sent all the items. So no special thread + // handling is needed. + response.subscribe(new Flow.Subscriber<>() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(SendStreamingMessageResponse item) { + results.add(item.getResult()); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + // The Python implementation has several events emitted since it uses mocks. + // + // See testOnMessageStreamNewMessageExistingTaskSuccessMocks() for a test more similar to the Python implementation + assertEquals(events, results); + } + + @Test + public void testOnResubscribeNoExistingTaskError() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + + TaskResubscriptionRequest request = new TaskResubscriptionRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); + Flow.Publisher response = handler.onResubscribeToTask(request); + + List results = new ArrayList<>(); + AtomicReference error = new AtomicReference<>(); + + response.subscribe(new Flow.Subscriber<>() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(SendStreamingMessageResponse item) { + results.add(item); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + error.set(throwable); + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + assertEquals(1, results.size()); + assertNull(results.get(0).getResult()); + assertInstanceOf(TaskNotFoundError.class, results.get(0).getError()); + } + + @Test + public void testStreamingNotSupportedError() { + AgentCard card = createAgentCard(false, true, true); + JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler); + + SendStreamingMessageRequest request = new SendStreamingMessageRequest.Builder() + .id("1") + .params(new MessageSendParams.Builder() + .message(MESSAGE) + .build()) + .build(); + Flow.Publisher response = handler.onMessageSendStream(request); + + List results = new ArrayList<>(); + AtomicReference error = new AtomicReference<>(); + + response.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(SendStreamingMessageResponse item) { + results.add(item); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + error.set(throwable); + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + assertEquals(1, results.size()); + if (results.get(0).getError() != null && results.get(0).getError() instanceof InvalidRequestError ire) { + assertEquals("Streaming is not supported by the agent", ire.getMessage()); + } else { + fail("Expected a response containing an error"); + } + } + + @Test + public void testStreamingNotSupportedErrorOnResubscribeToTask() { + // This test does not exist in the Python implementation + AgentCard card = createAgentCard(false, true, true); + JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler); + + TaskResubscriptionRequest request = new TaskResubscriptionRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); + Flow.Publisher response = handler.onResubscribeToTask(request); + + List results = new ArrayList<>(); + AtomicReference error = new AtomicReference<>(); + + response.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(SendStreamingMessageResponse item) { + results.add(item); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + error.set(throwable); + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + assertEquals(1, results.size()); + if (results.get(0).getError() != null && results.get(0).getError() instanceof InvalidRequestError ire) { + assertEquals("Streaming is not supported by the agent", ire.getMessage()); + } else { + fail("Expected a response containing an error"); + } + } + + + @Test + public void testPushNotificationsNotSupportedError() { + AgentCard card = createAgentCard(true, false, true); + JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler); + taskStore.save(MINIMAL_TASK); + + TaskPushNotificationConfig config = + new TaskPushNotificationConfig( + MINIMAL_TASK.getId(), + new PushNotificationConfig.Builder() + .url("http://example.com") + .build()); + + SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest.Builder() + .params(config) + .build(); + SetTaskPushNotificationConfigResponse response = handler.setPushNotification(request); + assertInstanceOf(InvalidRequestError.class, response.getError()); + assertEquals("Push notifications are not supported by the agent", response.getError().getMessage()); + } + + @Test + public void testOnGetPushNotificationNoPushNotifier() { + // Create request handler without a push notifier + DefaultRequestHandler requestHandler = + new DefaultRequestHandler(executor, taskStore, queueManager, null, internalExecutor); + AgentCard card = createAgentCard(false, true, false); + JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler); + + taskStore.save(MINIMAL_TASK); + + GetTaskPushNotificationConfigRequest request = + new GetTaskPushNotificationConfigRequest("id", new TaskIdParams(MINIMAL_TASK.getId())); + GetTaskPushNotificationConfigResponse response = handler.getPushNotification(request); + + assertNotNull(response.getError()); + assertInstanceOf(UnsupportedOperationError.class, response.getError()); + assertEquals("This operation is not supported", response.getError().getMessage()); + } + + @Test + public void testOnSetPushNotificationNoPushNotifier() { + // Create request handler without a push notifier + DefaultRequestHandler requestHandler = + new DefaultRequestHandler(executor, taskStore, queueManager, null, internalExecutor); + AgentCard card = createAgentCard(false, true, false); + JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler); + + taskStore.save(MINIMAL_TASK); + + TaskPushNotificationConfig config = + new TaskPushNotificationConfig( + MINIMAL_TASK.getId(), + new PushNotificationConfig.Builder() + .url("http://example.com") + .build()); + + SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest.Builder() + .params(config) + .build(); + SetTaskPushNotificationConfigResponse response = handler.setPushNotification(request); + + assertInstanceOf(UnsupportedOperationError.class, response.getError()); + assertEquals("This operation is not supported", response.getError().getMessage()); + } + + @Test + public void testOnMessageSendInternalError() { + DefaultRequestHandler mocked = Mockito.mock(DefaultRequestHandler.class); + Mockito.doThrow(new InternalError("Internal Error")).when(mocked).onMessageSend(Mockito.any(MessageSendParams.class)); + + JSONRPCHandler handler = new JSONRPCHandler(CARD, mocked); + + SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); + SendMessageResponse response = handler.onMessageSend(request); + + System.out.println(response); + assertInstanceOf(InternalError.class, response.getError()); + } + + @Test + public void testOnMessageStreamInternalError() { + DefaultRequestHandler mocked = Mockito.mock(DefaultRequestHandler.class); + Mockito.doThrow(new InternalError("Internal Error")).when(mocked).onMessageSendStream(Mockito.any(MessageSendParams.class)); + + JSONRPCHandler handler = new JSONRPCHandler(CARD, mocked); + + SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); + Flow.Publisher response = handler.onMessageSendStream(request); + + + List results = new ArrayList<>(); + AtomicReference error = new AtomicReference<>(); + + response.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(SendStreamingMessageResponse item) { + results.add(item); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + error.set(throwable); + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + assertEquals(1, results.size()); + assertInstanceOf(InternalError.class, results.get(0).getError()); + } + + @Test + @Disabled + public void testDefaultRequestHandlerWithCustomComponents() { + // Not much happening in the Python test beyond checking that the DefaultRequestHandler + // constructor sets the fields as expected + } + + @Test + public void testOnMessageSendErrorHandling() { + DefaultRequestHandler requestHandler = + new DefaultRequestHandler(executor, taskStore, queueManager, null, internalExecutor); + AgentCard card = createAgentCard(false, true, false); + JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler); + + taskStore.save(MINIMAL_TASK); + + Message message = new Message.Builder(MESSAGE) + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .build(); + + SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); + SendMessageResponse response; + + try (MockedConstruction mocked = Mockito.mockConstruction( + ResultAggregator.class, + (mock, context) -> + Mockito.doThrow( + new UnsupportedOperationError()) + .when(mock).consumeAndBreakOnInterrupt(Mockito.any(EventConsumer.class)))){ + response = handler.onMessageSend(request); + } + + assertInstanceOf(UnsupportedOperationError.class, response.getError()); + + } + + @Test + public void testOnMessageSendTaskIdMismatch() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + + agentExecutorExecute = ((context, eventQueue) -> { + eventQueue.enqueueEvent(MINIMAL_TASK); + }); + SendMessageRequest request = new SendMessageRequest("1", + new MessageSendParams(MESSAGE, null, null)); + SendMessageResponse response = handler.onMessageSend(request); + assertInstanceOf(InternalError.class, response.getError()); + + } + + @Test + public void testOnMessageStreamTaskIdMismatch() { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); + taskStore.save(MINIMAL_TASK); + + agentExecutorExecute = ((context, eventQueue) -> { + eventQueue.enqueueEvent(MINIMAL_TASK); + }); + + SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); + Flow.Publisher response = handler.onMessageSendStream(request); + + List results = new ArrayList<>(); + AtomicReference error = new AtomicReference<>(); + + response.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(SendStreamingMessageResponse item) { + results.add(item); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + error.set(throwable); + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + assertNull(error.get()); + assertEquals(1, results.size()); + assertInstanceOf(InternalError.class, results.get(0).getError()); + } + + private static AgentCard createAgentCard(boolean streaming, boolean pushNotifications, boolean stateTransitionHistory) { + return new AgentCard.Builder() + .name("test-card") + .description("A test agent card") + .url("http://example.com") + .version("1.0") + .documentationUrl("http://example.com/docs") + .capabilities(new AgentCapabilities.Builder() + .streaming(streaming) + .pushNotifications(pushNotifications) + .stateTransitionHistory(stateTransitionHistory) + .build()) + .defaultInputModes(new ArrayList<>()) + .defaultOutputModes(new ArrayList<>()) + .skills(new ArrayList<>()) + .build(); + } + + private interface AgentExecutorMethod { + void invoke(RequestContext context, EventQueue eventQueue) throws JSONRPCError; + } + + @Dependent + @IfBuildProfile("test") + private static class TestHttpClient implements A2AHttpClient { + final List tasks = Collections.synchronizedList(new ArrayList<>()); + volatile CountDownLatch latch; + + @Override + public GetBuilder createGet() { + return null; + } + + @Override + public PostBuilder createPost() { + return new TestPostBuilder(); + } + + class TestPostBuilder implements A2AHttpClient.PostBuilder { + private volatile String body; + @Override + public PostBuilder body(String body) { + this.body = body; + return this; + } + + @Override + public A2AHttpResponse post() throws IOException, InterruptedException { + tasks.add(Utils.OBJECT_MAPPER.readValue(body, Task.TYPE_REFERENCE)); + try { + return new A2AHttpResponse() { + @Override + public int status() { + return 200; + } + + @Override + public boolean success() { + return true; + } + + @Override + public String body() { + return ""; + } + }; + } finally { + latch.countDown(); + } + } + + @Override + public CompletableFuture postAsyncSSE(Consumer messageConsumer, Consumer errorConsumer, Runnable completeRunnable) throws IOException, InterruptedException { + return null; + } + + @Override + public PostBuilder url(String s) { + return this; + } + + @Override + public PostBuilder addHeader(String name, String value) { + return this; + } + } + } +} diff --git a/sdk-server-common/src/test/java/io/a2a/server/tasks/InMemoryTaskStoreTest.java b/sdk-server-common/src/test/java/io/a2a/server/tasks/InMemoryTaskStoreTest.java new file mode 100644 index 000000000..1ae4174dc --- /dev/null +++ b/sdk-server-common/src/test/java/io/a2a/server/tasks/InMemoryTaskStoreTest.java @@ -0,0 +1,50 @@ +package io.a2a.server.tasks; + +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; + +import io.a2a.spec.Task; +import io.a2a.util.Utils; +import org.junit.jupiter.api.Test; + +public class InMemoryTaskStoreTest { + private static final String TASK_JSON = """ + { + "id": "task-abc", + "contextId" : "session-xyz", + "status": {"state": "submitted"}, + "kind": "task" + }"""; + + @Test + public void testSaveAndGet() throws Exception { + InMemoryTaskStore store = new InMemoryTaskStore(); + Task task = Utils.unmarshalFrom(TASK_JSON, Task.TYPE_REFERENCE); + store.save(task); + Task retrieved = store.get(task.getId()); + assertSame(task, retrieved); + } + + @Test + public void testGetNonExistent() throws Exception { + InMemoryTaskStore store = new InMemoryTaskStore(); + Task retrieved = store.get("nonexistent"); + assertNull(retrieved); + } + + @Test + public void testDelete() throws Exception { + InMemoryTaskStore store = new InMemoryTaskStore(); + Task task = Utils.unmarshalFrom(TASK_JSON, Task.TYPE_REFERENCE); + store.save(task); + store.delete(task.getId()); + Task retrieved = store.get(task.getId()); + assertNull(retrieved); + } + + @Test + public void testDeleteNonExistent() throws Exception { + InMemoryTaskStore store = new InMemoryTaskStore(); + store.delete("non-existent"); + } +} diff --git a/sdk-server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java b/sdk-server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java new file mode 100644 index 000000000..768a52bb2 --- /dev/null +++ b/sdk-server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java @@ -0,0 +1,178 @@ +package io.a2a.server.tasks; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; + +import java.util.Collections; +import java.util.HashMap; + +import io.a2a.spec.A2AServerException; +import io.a2a.spec.Artifact; +import io.a2a.spec.Message; +import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.a2a.spec.TextPart; +import io.a2a.util.Utils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class TaskManagerTest { + private static final String TASK_JSON = """ + { + "id": "task-abc", + "contextId" : "session-xyz", + "status": {"state": "submitted"}, + "kind": "task" + }"""; + + Task minimalTask; + TaskStore taskStore; + TaskManager taskManager; + + @BeforeEach + public void init() throws Exception { + minimalTask = Utils.unmarshalFrom(TASK_JSON, Task.TYPE_REFERENCE); + taskStore = new InMemoryTaskStore(); + taskManager = new TaskManager(minimalTask.getId(), minimalTask.getContextId(), taskStore, null); + } + + @Test + public void testGetTaskExisting() { + Task expectedTask = minimalTask; + taskStore.save(expectedTask); + Task retrieved = taskManager.getTask(); + assertSame(expectedTask, retrieved); + } + + @Test + public void testGetTaskNonExistent() { + Task retrieved = taskManager.getTask(); + assertNull(retrieved); + } + + @Test + public void testSaveTaskEventNewTask() throws A2AServerException { + Task saved = taskManager.saveTaskEvent(minimalTask); + Task retrieved = taskManager.getTask(); + assertSame(minimalTask, retrieved); + assertSame(retrieved, saved); + } + + @Test + public void testSaveTaskEventStatusUpdate() throws A2AServerException { + Task initialTask = minimalTask; + taskStore.save(initialTask); + + TaskStatus newStatus = new TaskStatus( + TaskState.WORKING, + new Message.Builder() + .role(Message.Role.AGENT) + .parts(Collections.singletonList(new TextPart("content"))) + .messageId("messageId") + .build(), + null); + TaskStatusUpdateEvent event = new TaskStatusUpdateEvent( + minimalTask.getId(), + newStatus, + minimalTask.getContextId(), + false, + new HashMap<>()); + + + Task saved = taskManager.saveTaskEvent(event); + Task updated = taskManager.getTask(); + + assertNotSame(initialTask, updated); + assertSame(updated, saved); + + assertEquals(initialTask.getId(), updated.getId()); + assertEquals(initialTask.getContextId(), updated.getContextId()); + // TODO type does not get unmarshalled + //assertEquals(initialTask.getType(), updated.getType()); + assertSame(newStatus, updated.getStatus()); + } + + @Test + public void testSaveTaskEventArtifactUpdate() throws A2AServerException { + Task initialTask = minimalTask; + Artifact newArtifact = new Artifact.Builder() + .artifactId("artifact-id") + .name("artifact-1") + .parts(Collections.singletonList(new TextPart("content"))) + .build(); + TaskArtifactUpdateEvent event = new TaskArtifactUpdateEvent.Builder() + .taskId(minimalTask.getId()) + .contextId(minimalTask.getContextId()) + .artifact(newArtifact) + .build(); + Task saved = taskManager.saveTaskEvent(event); + + Task updatedTask = taskManager.getTask(); + assertSame(updatedTask, saved); + + assertNotSame(initialTask, updatedTask); + assertEquals(initialTask.getId(), updatedTask.getId()); + assertEquals(initialTask.getContextId(), updatedTask.getContextId()); + assertSame(initialTask.getStatus().state(), updatedTask.getStatus().state()); + assertEquals(1, updatedTask.getArtifacts().size()); + assertEquals(newArtifact, updatedTask.getArtifacts().get(0)); + } + + @Test + public void testEnsureTaskExisting() { + // This tests the 'working case' of the internal logic to check a task being updated existas + // We are already testing that + } + + @Test + public void testEnsureTaskNonExistentForStatusUpdate() throws A2AServerException { + // Tests that an update event instantiates a new task and that + TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null); + TaskStatusUpdateEvent event = new TaskStatusUpdateEvent.Builder() + .taskId("new-task") + .contextId("some-context") + .status(new TaskStatus(TaskState.SUBMITTED)) + .isFinal(false) + .build(); + + Task task = taskManagerWithoutId.saveTaskEvent(event); + assertEquals(event.getTaskId(), taskManagerWithoutId.getTaskId()); + assertEquals(event.getContextId(), taskManagerWithoutId.getContextId()); + + Task newTask = taskManagerWithoutId.getTask(); + assertEquals(event.getTaskId(), newTask.getId()); + assertEquals(event.getContextId(), newTask.getContextId()); + assertEquals(TaskState.SUBMITTED, newTask.getStatus().state()); + assertSame(newTask, task); + } + + @Test + public void testSaveTaskEventNewTaskNoTaskId() throws A2AServerException { + TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null); + Task task = new Task.Builder() + .id("new-task-id") + .contextId("some-context") + .status(new TaskStatus(TaskState.WORKING)) + .build(); + + Task saved = taskManagerWithoutId.saveTaskEvent(task); + assertEquals(task.getId(), taskManagerWithoutId.getTaskId()); + assertEquals(task.getContextId(), taskManagerWithoutId.getContextId()); + + Task retrieved = taskManagerWithoutId.getTask(); + assertSame(task, retrieved); + assertSame(retrieved, saved); + } + + @Test + public void testGetTaskNoTaskId() { + TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null); + Task retrieved = taskManagerWithoutId.getTask(); + assertNull(retrieved); + } +} diff --git a/sdk-server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java b/sdk-server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java new file mode 100644 index 000000000..291767b63 --- /dev/null +++ b/sdk-server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java @@ -0,0 +1,173 @@ +package io.a2a.server.tasks; + +import static io.a2a.spec.Message.Role.AGENT; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; + +import java.util.List; +import java.util.Map; + +import io.a2a.server.agentexecution.RequestContext; +import io.a2a.server.events.EventQueue; +import io.a2a.spec.Event; +import io.a2a.spec.Message; +import io.a2a.spec.Part; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.a2a.spec.TextPart; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class TaskUpdaterTest { + public static final String TEST_TASK_ID = "test-task-id"; + public static final String TEST_TASK_CONTEXT_ID = "test-task-context-id"; + + private static final Message SAMPLE_MESSAGE = new Message.Builder() + .taskId(TEST_TASK_ID) + .contextId(TEST_TASK_CONTEXT_ID) + .parts(new TextPart("Test message")) + .role(AGENT) + .build(); + + private static final List> SAMPLE_PARTS = List.of(new TextPart("Test message")); + + EventQueue eventQueue; + private TaskUpdater taskUpdater; + + + + @BeforeEach + public void init() { + eventQueue = EventQueue.create(); + RequestContext context = new RequestContext.Builder() + .setTaskId(TEST_TASK_ID) + .setContextId(TEST_TASK_CONTEXT_ID) + .build(); + taskUpdater = new TaskUpdater(context, eventQueue); + } + + @Test + public void testAddArtifactWithCustomIdAndName() throws Exception { + taskUpdater.addArtifact(SAMPLE_PARTS, "custom-artifact-id", "Custom Artifact", null); + Event event = eventQueue.dequeueEvent(0); + assertNotNull(event); + assertInstanceOf(TaskArtifactUpdateEvent.class, event); + + TaskArtifactUpdateEvent taue = (TaskArtifactUpdateEvent) event; + assertEquals(TEST_TASK_ID, taue.getTaskId()); + assertEquals(TEST_TASK_CONTEXT_ID, taue.getContextId()); + assertEquals("custom-artifact-id", taue.getArtifact().artifactId()); + assertEquals("Custom Artifact", taue.getArtifact().name()); + assertSame(SAMPLE_PARTS, taue.getArtifact().parts()); + + + assertNull(eventQueue.dequeueEvent(0)); + } + + @Test + public void testCompleteWithoutMessage() throws Exception { + taskUpdater.complete(); + checkTaskStatusUpdateEventOnQueue(true, TaskState.COMPLETED, null); + } + + @Test + public void testCompleteWithMessage() throws Exception { + taskUpdater.complete(SAMPLE_MESSAGE); + checkTaskStatusUpdateEventOnQueue(true, TaskState.COMPLETED, SAMPLE_MESSAGE); + } + + @Test + public void testSubmitWithoutMessage() throws Exception { + taskUpdater.submit(); + checkTaskStatusUpdateEventOnQueue(false, TaskState.SUBMITTED, null); + } + + @Test + public void testSubmitWithMessage() throws Exception { + taskUpdater.submit(SAMPLE_MESSAGE); + checkTaskStatusUpdateEventOnQueue(false, TaskState.SUBMITTED, SAMPLE_MESSAGE); + } + + @Test + public void testStartWorkWithoutMessage() throws Exception { + taskUpdater.startWork(); + checkTaskStatusUpdateEventOnQueue(false, TaskState.WORKING, null); + } + + @Test + public void testStartWorkWithMessage() throws Exception { + taskUpdater.startWork(SAMPLE_MESSAGE); + checkTaskStatusUpdateEventOnQueue(false, TaskState.WORKING, SAMPLE_MESSAGE); + } + + @Test + public void testFailedWithoutMessage() throws Exception { + taskUpdater.fail(); + checkTaskStatusUpdateEventOnQueue(true, TaskState.FAILED, null); + } + + @Test + public void testFailedWithMessage() throws Exception { + taskUpdater.fail(SAMPLE_MESSAGE); + checkTaskStatusUpdateEventOnQueue(true, TaskState.FAILED, SAMPLE_MESSAGE); + } + + @Test + public void testCanceledWithoutMessage() throws Exception { + taskUpdater.cancel(); + checkTaskStatusUpdateEventOnQueue(true, TaskState.CANCELED, null); + } + + @Test + public void testCanceledWithMessage() throws Exception { + taskUpdater.cancel(SAMPLE_MESSAGE); + checkTaskStatusUpdateEventOnQueue(true, TaskState.CANCELED, SAMPLE_MESSAGE); + } + + @Test + public void testNewAgentMessage() throws Exception { + Message message = taskUpdater.newAgentMessage(SAMPLE_PARTS, null); + + assertEquals(AGENT, message.getRole()); + assertEquals(TEST_TASK_ID, message.getTaskId()); + assertEquals(TEST_TASK_CONTEXT_ID, message.getContextId()); + assertNotNull(message.getMessageId()); + assertEquals(SAMPLE_PARTS, message.getParts()); + assertNull(message.getMetadata()); + } + + @Test + public void testNewAgentMessageWithMetadata() throws Exception { + Map metadata = Map.of("key", "value"); + Message message = taskUpdater.newAgentMessage(SAMPLE_PARTS, metadata); + + assertEquals(AGENT, message.getRole()); + assertEquals(TEST_TASK_ID, message.getTaskId()); + assertEquals(TEST_TASK_CONTEXT_ID, message.getContextId()); + assertNotNull(message.getMessageId()); + assertEquals(SAMPLE_PARTS, message.getParts()); + assertEquals(metadata, message.getMetadata()); + } + + private TaskStatusUpdateEvent checkTaskStatusUpdateEventOnQueue(boolean isFinal, TaskState state, Message statusMessage) throws Exception { + Event event = eventQueue.dequeueEvent(0); + + assertNotNull(event); + assertInstanceOf(TaskStatusUpdateEvent.class, event); + + TaskStatusUpdateEvent tsue = (TaskStatusUpdateEvent) event; + assertEquals(TEST_TASK_ID, tsue.getTaskId()); + assertEquals(TEST_TASK_CONTEXT_ID, tsue.getContextId()); + assertEquals(isFinal, tsue.isFinal()); + assertEquals(state, tsue.getStatus().state()); + assertEquals(statusMessage, tsue.getStatus().message()); + + assertNull(eventQueue.dequeueEvent(0)); + + return tsue; + } +} diff --git a/sdk-server-common/src/test/java/io/a2a/server/util/async/AsyncUtilsTest.java b/sdk-server-common/src/test/java/io/a2a/server/util/async/AsyncUtilsTest.java new file mode 100644 index 000000000..2bbd0a9fc --- /dev/null +++ b/sdk-server-common/src/test/java/io/a2a/server/util/async/AsyncUtilsTest.java @@ -0,0 +1,696 @@ +package io.a2a.server.util.async; + +import static io.a2a.server.util.async.AsyncUtils.consumer; +import static io.a2a.server.util.async.AsyncUtils.convertingProcessor; +import static io.a2a.server.util.async.AsyncUtils.createTubeConfig; +import static io.a2a.server.util.async.AsyncUtils.processor; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Flow; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +import mutiny.zero.ZeroPublisher; +import org.junit.jupiter.api.Test; + +public class AsyncUtilsTest { + + @Test + public void testConsumer() throws Exception { + List toSend = List.of("A", "B", "C"); + Flow.Publisher publisher = ZeroPublisher.fromIterable(toSend); + + List received = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(3); + + AtomicReference error = new AtomicReference<>(); + consumer(createTubeConfig(), + publisher, + s -> { + received.add(s); + latch.countDown(); + return true; + }, + error::set); + + + latch.await(2, TimeUnit.SECONDS); + assertEquals(toSend, received); + assertNull(error.get()); + } + + @Test + public void testCancelConsumer() throws Exception { + List toSend = List.of("A", "B", "C", "D"); + Flow.Publisher publisher = ZeroPublisher.fromIterable(toSend); + + List received = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(3); + + AtomicReference error = new AtomicReference<>(); + consumer(createTubeConfig(), + publisher, + s -> { + latch.countDown(); + if (s.equals("C")) { + return false; + } + received.add(s); + return true; + }, + error::set); + + Thread.sleep(500); + assertEquals(toSend.subList(0, 2), received); + assertNull(error.get()); + } + + @Test + public void testErrorConsumer() throws Exception { + List toSend = List.of("A", "B", "C", "D"); + Flow.Publisher publisher = ZeroPublisher.fromIterable(toSend); + + List received = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(3); + + AtomicReference error = new AtomicReference<>(); + consumer(createTubeConfig(), + publisher, + s -> { + latch.countDown(); + if (s.equals("C")) { + throw new IllegalStateException(); + } + received.add(s); + return true; + }, + error::set); + + Thread.sleep(500); + assertEquals(toSend.subList(0, 2), received); + assertInstanceOf(IllegalStateException.class, error.get()); + } + + @Test + public void testProcessor() throws Exception { + List toSend = List.of("A", "B", "C"); + Flow.Publisher publisher = ZeroPublisher.fromIterable(toSend); + + List received = new ArrayList<>(); + List processed = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(3); + + Flow.Publisher processedPublisher = + processor(createTubeConfig(), publisher, (errorConsumer, s) -> { + processed.add(s); + latch.countDown(); + return true; + }); + + processedPublisher.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(String item) { + subscription.request(1); + received.add(item); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + latch.await(2, TimeUnit.SECONDS); + assertEquals(toSend, received); + assertEquals(toSend, processed); + } + + @Test + public void testErrorProcessor() throws Exception { + List toSend = List.of("A", "B", "C", "D"); + Flow.Publisher publisher = ZeroPublisher.fromIterable(toSend); + + AtomicReference error = new AtomicReference<>(); + List received = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(3); + + + Flow.Publisher processedPublisher = + processor(createTubeConfig(), publisher, (errorConsumer, s) -> { + latch.countDown(); + if (s.equals("C")) { + errorConsumer.accept(new IllegalStateException()); + } + return true; + }); + + processedPublisher.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(String item) { + subscription.request(1); + received.add(item); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + error.set(throwable); + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + latch.await(2, TimeUnit.SECONDS); + Thread.sleep(500); + assertEquals(toSend.subList(0, 2), received); + assertNotNull(error.get()); + } + + @Test + public void testUncaughtErrorProcessor() throws Exception { + List toSend = List.of("A", "B", "C", "D"); + Flow.Publisher publisher = ZeroPublisher.fromIterable(toSend); + + AtomicReference error = new AtomicReference<>(); + List received = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(3); + + + Flow.Publisher processedPublisher = + processor(createTubeConfig(), publisher, (errorConsumer, s) -> { + latch.countDown(); + if (s.equals("C")) { + throw new IllegalStateException(); + } + return true; + }); + + processedPublisher.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(String item) { + subscription.request(1); + received.add(item); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + error.set(throwable); + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + latch.await(2, TimeUnit.SECONDS); + Thread.sleep(500); + assertEquals(toSend.subList(0, 2), received); + assertNotNull(error.get()); + } + + @Test + public void testConvertingProcessor() throws Exception { + List toSend = List.of(1, 2, 3); + Flow.Publisher publisher = ZeroPublisher.fromIterable(toSend); + + List received = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(3); + + Flow.Publisher convertingPublisher = + convertingProcessor(publisher, String::valueOf); + + convertingPublisher.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(String item) { + subscription.request(1); + received.add(item); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + latch.await(2, TimeUnit.SECONDS); + assertEquals(toSend.stream().map(String::valueOf).toList(), received); + } + + @Test + public void testChainedConvertingProcessors() throws Exception { + List toSend = List.of(1, 2, 3); + Flow.Publisher publisher = ZeroPublisher.fromIterable(toSend); + + List received = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(3); + + Flow.Publisher convertingPublisher = + convertingProcessor(publisher, String::valueOf); + Flow.Publisher convertingPublisher2 = + convertingProcessor(convertingPublisher, Long::valueOf); + + convertingPublisher2.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(Long item) { + subscription.request(1); + received.add(item); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + latch.await(2, TimeUnit.SECONDS); + assertEquals(toSend.stream().map(Long::valueOf).toList(), received); + } + + @Test + public void testErrorConvertingProcessor() throws Exception { + List toSend = List.of(1, 2, 3, 4); + Flow.Publisher publisher = ZeroPublisher.fromIterable(toSend); + + List received = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(2); + + Flow.Publisher convertingPublisher = + convertingProcessor(publisher, i -> { + if (i == 3) { + throw new IllegalStateException(); + } + return String.valueOf(i); + }); + + convertingPublisher.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(String item) { + subscription.request(1); + received.add(item); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + latch.await(2, TimeUnit.SECONDS); + assertEquals(toSend.stream().map(String::valueOf).toList().subList(0, 2), received); + } + + @Test + public void testConvertingAndProcessingProcessor() throws Exception { + List toSend = List.of(1, 2, 3); + Flow.Publisher publisher = ZeroPublisher.fromIterable(toSend); + + List received = new ArrayList<>(); + List processed = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(3); + + Flow.Publisher processedPublisher = + processor(createTubeConfig(), publisher, (errorConsumer, i) -> { + processed.add(i); + return true; + }); + + Flow.Publisher convertingPublisher = + convertingProcessor(processedPublisher, String::valueOf); + + convertingPublisher.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(String item) { + subscription.request(1); + received.add(item); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + latch.await(2, TimeUnit.SECONDS); + assertEquals(toSend, processed); + assertEquals(toSend.stream().map(String::valueOf).toList(), received); + } + + @Test + public void testCancelProcessor() throws Exception { + List toSend = List.of("A", "B", "C", "D"); + Flow.Publisher publisher = ZeroPublisher.fromIterable(toSend); + + List received = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(3); + + Flow.Publisher processedPublisher = + processor(createTubeConfig(), publisher, (errorConsumer, s) -> { + latch.countDown(); + if (s.equals("C")) { + return false; + } + return true; + }); + + processedPublisher.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(String item) { + subscription.request(1); + received.add(item); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + latch.await(2, TimeUnit.SECONDS); + Thread.sleep(500); + assertEquals(toSend.subList(0, 2), received); + } + + @Test + public void testMutinyZeroErrorPropagationSanityTest() { + Flow.Publisher source = ZeroPublisher.fromItems("a", "b", "c"); + + Flow.Publisher processor = ZeroPublisher.create(createTubeConfig(), tube -> { + source.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(String item) { + if (item.equals("c")) { + onError(new IllegalStateException()); + } + tube.send(item); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + tube.fail(throwable); + subscription.cancel(); + } + + @Override + public void onComplete() { + tube.complete(); + } + }); + }); + + Flow.Publisher processor2 = ZeroPublisher.create(createTubeConfig(), tube -> { + processor.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(String item) { + tube.send(item); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + tube.fail(throwable); + subscription.cancel(); + } + + @Override + public void onComplete() { + tube.complete(); + } + }); + }); + + List results = new ArrayList<>(); + + processor2.subscribe(new Flow.Subscriber() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(String item) { + results.add(item); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + results.add(throwable); + subscription.cancel(); + } + + @Override + public void onComplete() { + } + }); + + assertEquals(3, results.size()); + assertEquals("a", results.get(0)); + assertEquals("b", results.get(1)); + assertInstanceOf(IllegalStateException.class, results.get(2)); + } + + @Test + public void testAsyncUtilsErrorPropagation() { + Flow.Publisher source = ZeroPublisher.fromItems("a", "b", "c"); + + Flow.Publisher processor = processor(createTubeConfig(), source, new BiFunction, String, Boolean>() { + @Override + public Boolean apply(Consumer throwableConsumer, String item) { + System.out.println("-> (1) " + item); + if (item.equals("c")) { + throw new IllegalStateException(); + } + return true; + } + }); + + Flow.Publisher processor2 = processor(createTubeConfig(), processor, new BiFunction, String, Boolean>() { + @Override + public Boolean apply(Consumer throwableConsumer, String s) { + return true; + } + }); + + Flow.Publisher> processor3 = convertingProcessor(processor2, List::of); + + List results = new ArrayList<>(); + AtomicReference error = new AtomicReference<>(); + + consumer(createTubeConfig(), + processor3, + results::add, + t -> { + results.add(t); + error.set(t); + }); + + assertEquals(3, results.size()); + assertEquals(List.of("a"), results.get(0)); + assertEquals(List.of("b"), results.get(1)); + assertInstanceOf(IllegalStateException.class, results.get(2)); + assertInstanceOf(IllegalStateException.class, error.get()); + } + + @Test + public void testMutinyZeroEventPropagationSanity() throws Exception { + Flow.Publisher source = ZeroPublisher.fromItems("one", "two", "three"); + + CountDownLatch latch = new CountDownLatch(3); + + final List results = Collections.synchronizedList(new ArrayList<>()); + + source.subscribe(new Flow.Subscriber<>() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(String item) { + results.add(item); + subscription.request(1); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + source.subscribe(new Flow.Subscriber<>() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(String item) { + results.add(item); + subscription.request(1); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + System.out.println("---hi"); + latch.await(2, TimeUnit.SECONDS); + assertEquals(6, results.size()); + } + +} diff --git a/tck/pom.xml b/tck/pom.xml new file mode 100644 index 000000000..b89071fc1 --- /dev/null +++ b/tck/pom.xml @@ -0,0 +1,58 @@ + + + 4.0.0 + + + io.a2a.sdk + a2a-java-sdk-parent + 0.2.4-SNAPSHOT + + + a2a-tck-server + + Java SDK A2A TCK Server + Server example to use with the A2A TCK + + + + io.a2a.sdk + a2a-java-sdk-server-quarkus + ${project.version} + + + io.quarkus + quarkus-rest-jackson + provided + + + jakarta.enterprise + jakarta.enterprise.cdi-api + provided + + + jakarta.ws.rs + jakarta.ws.rs-api + + + + + + + io.quarkus + quarkus-maven-plugin + true + + + + build + generate-code + generate-code-tests + + + + + + + \ No newline at end of file diff --git a/tck/src/main/java/io/a2a/tck/server/AgentCardProducer.java b/tck/src/main/java/io/a2a/tck/server/AgentCardProducer.java new file mode 100644 index 000000000..7abf29a12 --- /dev/null +++ b/tck/src/main/java/io/a2a/tck/server/AgentCardProducer.java @@ -0,0 +1,43 @@ +package io.a2a.tck.server; + +import java.util.Collections; +import java.util.List; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Produces; + +import io.a2a.server.PublicAgentCard; +import io.a2a.spec.AgentCapabilities; +import io.a2a.spec.AgentCard; +import io.a2a.spec.AgentSkill; + +@ApplicationScoped +public class AgentCardProducer { + + @Produces + @PublicAgentCard + public AgentCard agentCard() { + return new AgentCard.Builder() + .name("Hello World Agent") + .description("Just a hello world agent") + .url("http://localhost:9999") + .version("1.0.0") + .documentationUrl("http://example.com/docs") + .capabilities(new AgentCapabilities.Builder() + .streaming(true) + .pushNotifications(true) + .stateTransitionHistory(true) + .build()) + .defaultInputModes(Collections.singletonList("text")) + .defaultOutputModes(Collections.singletonList("text")) + .skills(Collections.singletonList(new AgentSkill.Builder() + .id("hello_world") + .name("Returns hello world") + .description("just returns hello world") + .tags(Collections.singletonList("hello world")) + .examples(List.of("hi", "hello world")) + .build())) + .build(); + } +} + diff --git a/tck/src/main/java/io/a2a/tck/server/AgentExecutorProducer.java b/tck/src/main/java/io/a2a/tck/server/AgentExecutorProducer.java new file mode 100644 index 000000000..592546591 --- /dev/null +++ b/tck/src/main/java/io/a2a/tck/server/AgentExecutorProducer.java @@ -0,0 +1,91 @@ +package io.a2a.tck.server; + +import jakarta.annotation.PreDestroy; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Produces; + +import io.a2a.server.agentexecution.AgentExecutor; +import io.a2a.server.agentexecution.RequestContext; +import io.a2a.server.events.EventQueue; +import io.a2a.server.tasks.TaskUpdater; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.Task; +import io.a2a.spec.TaskNotCancelableError; +import io.a2a.spec.TaskNotFoundError; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; + +@ApplicationScoped +public class AgentExecutorProducer { + + @Produces + public AgentExecutor agentExecutor() { + return new FireAndForgetAgentExecutor(); + } + + private static class FireAndForgetAgentExecutor implements AgentExecutor { + @Override + public void execute(RequestContext context, EventQueue eventQueue) throws JSONRPCError { + Task task = context.getTask(); + + if (context.getMessage().getTaskId() != null && task == null && context.getMessage().getTaskId().startsWith("non-existent")) { + throw new TaskNotFoundError(); + } + + if (task == null) { + task = new Task.Builder() + .id(context.getTaskId()) + .contextId(context.getContextId()) + .status(new TaskStatus(TaskState.SUBMITTED)) + .history(context.getMessage()) + .build(); + eventQueue.enqueueEvent(task); + } + + TaskUpdater updater = new TaskUpdater(context, eventQueue); + + // Immediately set to WORKING state + updater.startWork(); + System.out.println("====> task set to WORKING, starting background execution"); + + // Method returns immediately - task continues in background + System.out.println("====> execute() method returning immediately, task running in background"); + } + + @Override + public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPCError { + System.out.println("====> task cancel request received"); + Task task = context.getTask(); + + if (task.getStatus().state() == TaskState.CANCELED) { + System.out.println("====> task already canceled"); + throw new TaskNotCancelableError(); + } + + if (task.getStatus().state() == TaskState.COMPLETED) { + System.out.println("====> task already completed"); + throw new TaskNotCancelableError(); + } + + TaskUpdater updater = new TaskUpdater(context, eventQueue); + updater.cancel(); + eventQueue.enqueueEvent(new TaskStatusUpdateEvent.Builder() + .taskId(task.getId()) + .contextId(task.getContextId()) + .status(new TaskStatus(TaskState.CANCELED)) + .isFinal(true) + .build()); + + System.out.println("====> task canceled"); + } + + /** + * Cleanup method for proper resource management + */ + @PreDestroy + public void cleanup() { + System.out.println("====> shutting down task executor"); + } + } +} \ No newline at end of file diff --git a/tck/src/main/resources/application.properties b/tck/src/main/resources/application.properties new file mode 100644 index 000000000..a2452b339 --- /dev/null +++ b/tck/src/main/resources/application.properties @@ -0,0 +1 @@ +%dev.quarkus.http.port=9999 \ No newline at end of file diff --git a/tests/server-common/pom.xml b/tests/server-common/pom.xml new file mode 100644 index 000000000..866b379c0 --- /dev/null +++ b/tests/server-common/pom.xml @@ -0,0 +1,75 @@ + + + 4.0.0 + + + io.a2a.sdk + a2a-java-sdk-parent + 0.2.4-SNAPSHOT + ../../pom.xml + + a2a-java-sdk-tests-server-common + + jar + + Java A2A SDK Server Tests Common + Java SDK for the Agent2Agent Protocol (A2A) - SDK - Server Tests Common + + + + ${project.groupId} + a2a-java-sdk-core + ${project.version} + + + ${project.groupId} + a2a-java-sdk-server-common + ${project.version} + + + jakarta.ws.rs + jakarta.ws.rs-api + test + + + org.junit.jupiter + junit-jupiter-api + test + + + io.rest-assured + rest-assured + test + + + io.quarkus + quarkus-arc + test + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + + org.apache.maven.plugins + maven-jar-plugin + + + + test-jar + + + + + + + \ No newline at end of file diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java new file mode 100644 index 000000000..ac7bf5c5d --- /dev/null +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java @@ -0,0 +1,878 @@ +package io.a2a.server.apps.common; + +import static io.restassured.RestAssured.given; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.wildfly.common.Assert.assertNotNull; +import static org.wildfly.common.Assert.assertTrue; + +import java.io.EOFException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; + +import jakarta.ws.rs.core.MediaType; + +import com.fasterxml.jackson.core.JsonProcessingException; +import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.tasks.TaskStore; +import io.a2a.spec.AgentCard; +import io.a2a.spec.Artifact; +import io.a2a.spec.CancelTaskRequest; +import io.a2a.spec.CancelTaskResponse; +import io.a2a.spec.Event; +import io.a2a.spec.GetTaskPushNotificationConfigRequest; +import io.a2a.spec.GetTaskPushNotificationConfigResponse; +import io.a2a.spec.GetTaskRequest; +import io.a2a.spec.GetTaskResponse; +import io.a2a.spec.InvalidParamsError; +import io.a2a.spec.InvalidRequestError; +import io.a2a.spec.JSONParseError; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.JSONRPCErrorResponse; +import io.a2a.spec.Message; +import io.a2a.spec.MessageSendParams; +import io.a2a.spec.MethodNotFoundError; +import io.a2a.spec.Part; +import io.a2a.spec.PushNotificationConfig; +import io.a2a.spec.SendMessageRequest; +import io.a2a.spec.SendMessageResponse; +import io.a2a.spec.SendStreamingMessageRequest; +import io.a2a.spec.SendStreamingMessageResponse; +import io.a2a.spec.SetTaskPushNotificationConfigRequest; +import io.a2a.spec.SetTaskPushNotificationConfigResponse; +import io.a2a.spec.StreamingJSONRPCRequest; +import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskIdParams; +import io.a2a.spec.TaskNotFoundError; +import io.a2a.spec.TaskPushNotificationConfig; +import io.a2a.spec.TaskQueryParams; +import io.a2a.spec.TaskResubscriptionRequest; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.a2a.spec.TextPart; +import io.a2a.spec.UnsupportedOperationError; +import io.a2a.util.Utils; +import io.restassured.RestAssured; +import io.restassured.specification.RequestSpecification; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public abstract class AbstractA2AServerTest { + + private static final Task MINIMAL_TASK = new Task.Builder() + .id("task-123") + .contextId("session-xyz") + .status(new TaskStatus(TaskState.SUBMITTED)) + .build(); + + private static final Task CANCEL_TASK = new Task.Builder() + .id("cancel-task-123") + .contextId("session-xyz") + .status(new TaskStatus(TaskState.SUBMITTED)) + .build(); + + private static final Task CANCEL_TASK_NOT_SUPPORTED = new Task.Builder() + .id("cancel-task-not-supported-123") + .contextId("session-xyz") + .status(new TaskStatus(TaskState.SUBMITTED)) + .build(); + + private static final Task SEND_MESSAGE_NOT_SUPPORTED = new Task.Builder() + .id("task-not-supported-123") + .contextId("session-xyz") + .status(new TaskStatus(TaskState.SUBMITTED)) + .build(); + + private static final Message MESSAGE = new Message.Builder() + .messageId("111") + .role(Message.Role.AGENT) + .parts(new TextPart("test message")) + .build(); + + @Test + public void testGetTaskSuccess() { + testGetTask(); + } + + private void testGetTask() { + testGetTask(null); + } + + private void testGetTask(String mediaType) { + getTaskStore().save(MINIMAL_TASK); + try { + GetTaskRequest request = new GetTaskRequest("1", new TaskQueryParams(MINIMAL_TASK.getId())); + RequestSpecification requestSpecification = RestAssured.given() + .contentType(MediaType.APPLICATION_JSON) + .body(request); + if (mediaType != null) { + requestSpecification = requestSpecification.accept(mediaType); + } + GetTaskResponse response = requestSpecification + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(GetTaskResponse.class); + assertEquals("1", response.getId()); + assertEquals("task-123", response.getResult().getId()); + assertEquals("session-xyz", response.getResult().getContextId()); + assertEquals(TaskState.SUBMITTED, response.getResult().getStatus().state()); + assertNull(response.getError()); + } catch (Exception e) { + } finally { + getTaskStore().delete(MINIMAL_TASK.getId()); + } + } + + @Test + public void testGetTaskNotFound() { + assertTrue(getTaskStore().get("non-existent-task") == null); + GetTaskRequest request = new GetTaskRequest("1", new TaskQueryParams("non-existent-task")); + GetTaskResponse response = given() + .contentType(MediaType.APPLICATION_JSON) + .body(request) + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(GetTaskResponse.class); + assertEquals("1", response.getId()); + // this should be an instance of TaskNotFoundError, see https://github.com/a2aproject/a2a-java/issues/23 + assertInstanceOf(JSONRPCError.class, response.getError()); + assertEquals(new TaskNotFoundError().getCode(), response.getError().getCode()); + assertNull(response.getResult()); + } + + @Test + public void testCancelTaskSuccess() { + getTaskStore().save(CANCEL_TASK); + try { + CancelTaskRequest request = new CancelTaskRequest("1", new TaskIdParams(CANCEL_TASK.getId())); + CancelTaskResponse response = given() + .contentType(MediaType.APPLICATION_JSON) + .body(request) + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(CancelTaskResponse.class); + assertNull(response.getError()); + assertEquals(request.getId(), response.getId()); + Task task = response.getResult(); + assertEquals(CANCEL_TASK.getId(), task.getId()); + assertEquals(CANCEL_TASK.getContextId(), task.getContextId()); + assertEquals(TaskState.CANCELED, task.getStatus().state()); + } catch (Exception e) { + } finally { + getTaskStore().delete(CANCEL_TASK.getId()); + } + } + + @Test + public void testCancelTaskNotSupported() { + getTaskStore().save(CANCEL_TASK_NOT_SUPPORTED); + try { + CancelTaskRequest request = new CancelTaskRequest("1", new TaskIdParams(CANCEL_TASK_NOT_SUPPORTED.getId())); + CancelTaskResponse response = given() + .contentType(MediaType.APPLICATION_JSON) + .body(request) + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(CancelTaskResponse.class); + assertEquals(request.getId(), response.getId()); + assertNull(response.getResult()); + // this should be an instance of UnsupportedOperationError, see https://github.com/a2aproject/a2a-java/issues/23 + assertInstanceOf(JSONRPCError.class, response.getError()); + assertEquals(new UnsupportedOperationError().getCode(), response.getError().getCode()); + } catch (Exception e) { + } finally { + getTaskStore().delete(CANCEL_TASK_NOT_SUPPORTED.getId()); + } + } + + @Test + public void testCancelTaskNotFound() { + CancelTaskRequest request = new CancelTaskRequest("1", new TaskIdParams("non-existent-task")); + CancelTaskResponse response = given() + .contentType(MediaType.APPLICATION_JSON) + .body(request) + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(CancelTaskResponse.class); + assertEquals(request.getId(), response.getId()); + assertNull(response.getResult()); + // this should be an instance of UnsupportedOperationError, see https://github.com/a2aproject/a2a-java/issues/23 + assertInstanceOf(JSONRPCError.class, response.getError()); + assertEquals(new TaskNotFoundError().getCode(), response.getError().getCode()); + } + + @Test + public void testSendMessageNewMessageSuccess() { + assertTrue(getTaskStore().get(MINIMAL_TASK.getId()) == null); + Message message = new Message.Builder(MESSAGE) + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .build(); + SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); + SendMessageResponse response = given() + .contentType(MediaType.APPLICATION_JSON) + .body(request) + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(SendMessageResponse.class); + assertNull(response.getError()); + Message messageResponse = (Message) response.getResult(); + assertEquals(MESSAGE.getMessageId(), messageResponse.getMessageId()); + assertEquals(MESSAGE.getRole(), messageResponse.getRole()); + Part part = messageResponse.getParts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("test message", ((TextPart) part).getText()); + } + + @Test + public void testSendMessageExistingTaskSuccess() { + getTaskStore().save(MINIMAL_TASK); + try { + Message message = new Message.Builder(MESSAGE) + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .build(); + SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); + SendMessageResponse response = given() + .contentType(MediaType.APPLICATION_JSON) + .body(request) + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(SendMessageResponse.class); + assertNull(response.getError()); + Message messageResponse = (Message) response.getResult(); + assertEquals(MESSAGE.getMessageId(), messageResponse.getMessageId()); + assertEquals(MESSAGE.getRole(), messageResponse.getRole()); + Part part = messageResponse.getParts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("test message", ((TextPart) part).getText()); + } catch (Exception e) { + } finally { + getTaskStore().delete(MINIMAL_TASK.getId()); + } + } + + @Test + public void testSetPushNotificationSuccess() { + getTaskStore().save(MINIMAL_TASK); + try { + TaskPushNotificationConfig taskPushConfig = + new TaskPushNotificationConfig( + MINIMAL_TASK.getId(), new PushNotificationConfig.Builder().url("http://example.com").build()); + SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); + SetTaskPushNotificationConfigResponse response = given() + .contentType(MediaType.APPLICATION_JSON) + .body(request) + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(SetTaskPushNotificationConfigResponse.class); + assertNull(response.getError()); + assertEquals(request.getId(), response.getId()); + TaskPushNotificationConfig config = response.getResult(); + assertEquals(MINIMAL_TASK.getId(), config.taskId()); + assertEquals("http://example.com", config.pushNotificationConfig().url()); + } catch (Exception e) { + } finally { + getTaskStore().delete(MINIMAL_TASK.getId()); + } + } + + @Test + public void testGetPushNotificationSuccess() { + getTaskStore().save(MINIMAL_TASK); + try { + TaskPushNotificationConfig taskPushConfig = + new TaskPushNotificationConfig( + MINIMAL_TASK.getId(), new PushNotificationConfig.Builder().url("http://example.com").build()); + + SetTaskPushNotificationConfigRequest setTaskPushNotificationRequest = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); + SetTaskPushNotificationConfigResponse setTaskPushNotificationResponse = given() + .contentType(MediaType.APPLICATION_JSON) + .body(setTaskPushNotificationRequest) + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(SetTaskPushNotificationConfigResponse.class); + assertNotNull(setTaskPushNotificationResponse); + + GetTaskPushNotificationConfigRequest request = + new GetTaskPushNotificationConfigRequest("111", new TaskIdParams(MINIMAL_TASK.getId())); + GetTaskPushNotificationConfigResponse response = given() + .contentType(MediaType.APPLICATION_JSON) + .body(request) + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(GetTaskPushNotificationConfigResponse.class); + assertNull(response.getError()); + assertEquals(request.getId(), response.getId()); + TaskPushNotificationConfig config = response.getResult(); + assertEquals(MINIMAL_TASK.getId(), config.taskId()); + assertEquals("http://example.com", config.pushNotificationConfig().url()); + } catch (Exception e) { + } finally { + getTaskStore().delete(MINIMAL_TASK.getId()); + } + } + + @Test + public void testError() { + Message message = new Message.Builder(MESSAGE) + .taskId(SEND_MESSAGE_NOT_SUPPORTED.getId()) + .contextId(SEND_MESSAGE_NOT_SUPPORTED.getContextId()) + .build(); + SendMessageRequest request = new SendMessageRequest( + "1", new MessageSendParams(message, null, null)); + SendMessageResponse response = given() + .contentType(MediaType.APPLICATION_JSON) + .body(request) + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(SendMessageResponse.class); + assertEquals(request.getId(), response.getId()); + assertNull(response.getResult()); + // this should be an instance of UnsupportedOperationError, see https://github.com/a2aproject/a2a-java/issues/23 + assertInstanceOf(JSONRPCError.class, response.getError()); + assertEquals(new UnsupportedOperationError().getCode(), response.getError().getCode()); + } + + @Test + public void testGetAgentCard() { + AgentCard agentCard = given() + .contentType(MediaType.APPLICATION_JSON) + .when() + .get("/.well-known/agent.json") + .then() + .statusCode(200) + .extract() + .as(AgentCard.class); + assertNotNull(agentCard); + assertEquals("test-card", agentCard.name()); + assertEquals("A test agent card", agentCard.description()); + assertEquals("http://localhost:8081", agentCard.url()); + assertEquals("1.0", agentCard.version()); + assertEquals("http://example.com/docs", agentCard.documentationUrl()); + assertTrue(agentCard.capabilities().pushNotifications()); + assertTrue(agentCard.capabilities().streaming()); + assertTrue(agentCard.capabilities().stateTransitionHistory()); + assertTrue(agentCard.skills().isEmpty()); + } + + @Test + public void testGetExtendAgentCardNotSupported() { + given() + .contentType(MediaType.APPLICATION_JSON) + .when() + .get("/agent/authenticatedExtendedCard") + .then() + .statusCode(404) + .body("error", equalTo("Extended agent card not supported or not enabled.")); + } + + @Test + public void testMalformedJSONRPCRequest() { + // missing closing bracket + String malformedRequest = "{\"jsonrpc\": \"2.0\", \"method\": \"message/send\", \"params\": {\"foo\": \"bar\"}"; + JSONRPCErrorResponse response = given() + .contentType(MediaType.APPLICATION_JSON) + .body(malformedRequest) + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(JSONRPCErrorResponse.class); + assertNotNull(response.getError()); + assertEquals(new JSONParseError().getCode(), response.getError().getCode()); + } + + @Test + public void testInvalidParamsJSONRPCRequest() { + String invalidParamsRequest = """ + {"jsonrpc": "2.0", "method": "message/send", "params": "not_a_dict", "id": "1"} + """; + testInvalidParams(invalidParamsRequest); + + invalidParamsRequest = """ + {"jsonrpc": "2.0", "method": "message/send", "params": {"message": {"parts": "invalid"}}, "id": "1"} + """; + testInvalidParams(invalidParamsRequest); + } + + private void testInvalidParams(String invalidParamsRequest) { + JSONRPCErrorResponse response = given() + .contentType(MediaType.APPLICATION_JSON) + .body(invalidParamsRequest) + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(JSONRPCErrorResponse.class); + assertNotNull(response.getError()); + assertEquals(new InvalidParamsError().getCode(), response.getError().getCode()); + assertEquals("1", response.getId()); + } + + @Test + public void testInvalidJSONRPCRequestMissingJsonrpc() { + String invalidRequest = """ + { + "method": "message/send", + "params": {} + } + """; + JSONRPCErrorResponse response = given() + .contentType(MediaType.APPLICATION_JSON) + .body(invalidRequest) + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(JSONRPCErrorResponse.class); + assertNotNull(response.getError()); + assertEquals(new InvalidRequestError().getCode(), response.getError().getCode()); + } + + @Test + public void testInvalidJSONRPCRequestMissingMethod() { + String invalidRequest = """ + {"jsonrpc": "2.0", "params": {}} + """; + JSONRPCErrorResponse response = given() + .contentType(MediaType.APPLICATION_JSON) + .body(invalidRequest) + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(JSONRPCErrorResponse.class); + assertNotNull(response.getError()); + assertEquals(new InvalidRequestError().getCode(), response.getError().getCode()); + } + + @Test + public void testInvalidJSONRPCRequestInvalidId() { + String invalidRequest = """ + {"jsonrpc": "2.0", "method": "message/send", "params": {}, "id": {"bad": "type"}} + """; + JSONRPCErrorResponse response = given() + .contentType(MediaType.APPLICATION_JSON) + .body(invalidRequest) + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(JSONRPCErrorResponse.class); + assertNotNull(response.getError()); + assertEquals(new InvalidRequestError().getCode(), response.getError().getCode()); + } + + @Test + public void testInvalidJSONRPCRequestNonExistentMethod() { + String invalidRequest = """ + {"jsonrpc": "2.0", "method" : "nonexistent/method", "params": {}} + """; + JSONRPCErrorResponse response = given() + .contentType(MediaType.APPLICATION_JSON) + .body(invalidRequest) + .when() + .post("/") + .then() + .statusCode(200) + .extract() + .as(JSONRPCErrorResponse.class); + assertNotNull(response.getError()); + assertEquals(new MethodNotFoundError().getCode(), response.getError().getCode()); + } + + @Test + public void testNonStreamingMethodWithAcceptHeader() { + testGetTask(MediaType.APPLICATION_JSON); + } + + + @Test + public void testSendMessageStreamExistingTaskSuccess() { + getTaskStore().save(MINIMAL_TASK); + try { + Message message = new Message.Builder(MESSAGE) + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .build(); + SendStreamingMessageRequest request = new SendStreamingMessageRequest( + "1", new MessageSendParams(message, null, null)); + + CompletableFuture>> responseFuture = initialiseStreamingRequest(request, null); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference errorRef = new AtomicReference<>(); + + responseFuture.thenAccept(response -> { + if (response.statusCode() != 200) { + //errorRef.set(new IllegalStateException("Status code was " + response.statusCode())); + throw new IllegalStateException("Status code was " + response.statusCode()); + } + response.body().forEach(line -> { + try { + SendStreamingMessageResponse jsonResponse = extractJsonResponseFromSseLine(line); + if (jsonResponse != null) { + assertNull(jsonResponse.getError()); + Message messageResponse = (Message) jsonResponse.getResult(); + assertEquals(MESSAGE.getMessageId(), messageResponse.getMessageId()); + assertEquals(MESSAGE.getRole(), messageResponse.getRole()); + Part part = messageResponse.getParts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("test message", ((TextPart) part).getText()); + latch.countDown(); + } + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + }); + }).exceptionally(t -> { + if (!isStreamClosedError(t)) { + errorRef.set(t); + } + latch.countDown(); + return null; + }); + + boolean dataRead = latch.await(20, TimeUnit.SECONDS); + Assertions.assertTrue(dataRead); + Assertions.assertNull(errorRef.get()); + } catch (Exception e) { + } finally { + getTaskStore().delete(MINIMAL_TASK.getId()); + } + } + + @Test + public void testResubscribeExistingTaskSuccess() throws Exception { + ExecutorService executorService = Executors.newSingleThreadExecutor(); + getTaskStore().save(MINIMAL_TASK); + + try { + // attempting to send a streaming message instead of explicitly calling queueManager#createOrTap + // does not work because after the message is sent, the queue becomes null but task resubscription + // requires the queue to still be active + getQueueManager().createOrTap(MINIMAL_TASK.getId()); + + CountDownLatch taskResubscriptionRequestSent = new CountDownLatch(1); + CountDownLatch taskResubscriptionResponseReceived = new CountDownLatch(2); + AtomicReference firstResponse = new AtomicReference<>(); + AtomicReference secondResponse = new AtomicReference<>(); + + // resubscribe to the task, requires the task and its queue to still be active + TaskResubscriptionRequest taskResubscriptionRequest = new TaskResubscriptionRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); + + // Count down the latch when the MultiSseSupport on the server has started subscribing + setStreamingSubscribedRunnable(taskResubscriptionRequestSent::countDown); + + CompletableFuture>> responseFuture = initialiseStreamingRequest(taskResubscriptionRequest, null); + + AtomicReference errorRef = new AtomicReference<>(); + + responseFuture.thenAccept(response -> { + + if (response.statusCode() != 200) { + //errorRef.set(new IllegalStateException("Status code was " + response.statusCode())); + throw new IllegalStateException("Status code was " + response.statusCode()); + } + try { + response.body().forEach(line -> { + try { + SendStreamingMessageResponse jsonResponse = extractJsonResponseFromSseLine(line); + if (jsonResponse != null) { + SendStreamingMessageResponse sendStreamingMessageResponse = Utils.OBJECT_MAPPER.readValue(line.substring("data: ".length()).trim(), SendStreamingMessageResponse.class); + if (taskResubscriptionResponseReceived.getCount() == 2) { + firstResponse.set(sendStreamingMessageResponse); + } else { + secondResponse.set(sendStreamingMessageResponse); + } + taskResubscriptionResponseReceived.countDown(); + if (taskResubscriptionResponseReceived.getCount() == 0) { + throw new BreakException(); + } + } + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + }); + } catch (BreakException e) { + } + }).exceptionally(t -> { + if (!isStreamClosedError(t)) { + errorRef.set(t); + } + return null; + }); + + try { + taskResubscriptionRequestSent.await(); + List events = List.of( + new TaskArtifactUpdateEvent.Builder() + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .artifact(new Artifact.Builder() + .artifactId("11") + .parts(new TextPart("text")) + .build()) + .build(), + new TaskStatusUpdateEvent.Builder() + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .status(new TaskStatus(TaskState.COMPLETED)) + .isFinal(true) + .build()); + + for (Event event : events) { + getQueueManager().get(MINIMAL_TASK.getId()).enqueueEvent(event); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // wait for the client to receive the responses + taskResubscriptionResponseReceived.await(); + + assertNotNull(firstResponse.get()); + SendStreamingMessageResponse sendStreamingMessageResponse = firstResponse.get(); + assertNull(sendStreamingMessageResponse.getError()); + TaskArtifactUpdateEvent taskArtifactUpdateEvent = (TaskArtifactUpdateEvent) sendStreamingMessageResponse.getResult(); + assertEquals(MINIMAL_TASK.getId(), taskArtifactUpdateEvent.getTaskId()); + assertEquals(MINIMAL_TASK.getContextId(), taskArtifactUpdateEvent.getContextId()); + Part part = taskArtifactUpdateEvent.getArtifact().parts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("text", ((TextPart) part).getText()); + + assertNotNull(secondResponse.get()); + sendStreamingMessageResponse = secondResponse.get(); + assertNull(sendStreamingMessageResponse.getError()); + TaskStatusUpdateEvent taskStatusUpdateEvent = (TaskStatusUpdateEvent) sendStreamingMessageResponse.getResult(); + assertEquals(MINIMAL_TASK.getId(), taskStatusUpdateEvent.getTaskId()); + assertEquals(MINIMAL_TASK.getContextId(), taskStatusUpdateEvent.getContextId()); + assertEquals(TaskState.COMPLETED, taskStatusUpdateEvent.getStatus().state()); + assertNotNull(taskStatusUpdateEvent.getStatus().timestamp()); + } finally { + setStreamingSubscribedRunnable(null); + getTaskStore().delete(MINIMAL_TASK.getId()); + executorService.shutdown(); + if (!executorService.awaitTermination(10, TimeUnit.SECONDS)) { + executorService.shutdownNow(); + } + } + } + + @Test + public void testResubscribeNoExistingTaskError() throws Exception { + TaskResubscriptionRequest request = new TaskResubscriptionRequest("1", new TaskIdParams("non-existent-task")); + + CompletableFuture>> responseFuture = initialiseStreamingRequest(request, null); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference errorRef = new AtomicReference<>(); + + responseFuture.thenAccept(response -> { + if (response.statusCode() != 200) { + //errorRef.set(new IllegalStateException("Status code was " + response.statusCode())); + throw new IllegalStateException("Status code was " + response.statusCode()); + } + response.body().forEach(line -> { + try { + SendStreamingMessageResponse jsonResponse = extractJsonResponseFromSseLine(line); + if (jsonResponse != null) { + assertEquals(request.getId(), jsonResponse.getId()); + assertNull(jsonResponse.getResult()); + // this should be an instance of TaskNotFoundError, see https://github.com/a2aproject/a2a-java/issues/23 + assertInstanceOf(JSONRPCError.class, jsonResponse.getError()); + assertEquals(new TaskNotFoundError().getCode(), jsonResponse.getError().getCode()); + latch.countDown(); + } + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + }); + }).exceptionally(t -> { + if (!isStreamClosedError(t)) { + errorRef.set(t); + } + latch.countDown(); + return null; + }); + + boolean dataRead = latch.await(20, TimeUnit.SECONDS); + Assertions.assertTrue(dataRead); + Assertions.assertNull(errorRef.get()); + } + + @Test + public void testStreamingMethodWithAcceptHeader() throws Exception { + testSendStreamingMessage(MediaType.SERVER_SENT_EVENTS); + } + + @Test + public void testSendMessageStreamNewMessageSuccess() throws Exception { + testSendStreamingMessage(null); + } + + private void testSendStreamingMessage(String mediaType) throws Exception { + Message message = new Message.Builder(MESSAGE) + .taskId(MINIMAL_TASK.getId()) + .contextId(MINIMAL_TASK.getContextId()) + .build(); + SendStreamingMessageRequest request = new SendStreamingMessageRequest( + "1", new MessageSendParams(message, null, null)); + + CompletableFuture>> responseFuture = initialiseStreamingRequest(request, mediaType); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference errorRef = new AtomicReference<>(); + + responseFuture.thenAccept(response -> { + if (response.statusCode() != 200) { + //errorRef.set(new IllegalStateException("Status code was " + response.statusCode())); + throw new IllegalStateException("Status code was " + response.statusCode()); + } + response.body().forEach(line -> { + try { + SendStreamingMessageResponse jsonResponse = extractJsonResponseFromSseLine(line); + if (jsonResponse != null) { + assertNull(jsonResponse.getError()); + Message messageResponse = (Message) jsonResponse.getResult(); + assertEquals(MESSAGE.getMessageId(), messageResponse.getMessageId()); + assertEquals(MESSAGE.getRole(), messageResponse.getRole()); + Part part = messageResponse.getParts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("test message", ((TextPart) part).getText()); + latch.countDown(); + } + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + }); + }).exceptionally(t -> { + if (!isStreamClosedError(t)) { + errorRef.set(t); + } + latch.countDown(); + return null; + }); + + + boolean dataRead = latch.await(20, TimeUnit.SECONDS); + Assertions.assertTrue(dataRead); + Assertions.assertNull(errorRef.get()); + + } + + private SendStreamingMessageResponse extractJsonResponseFromSseLine(String line) throws JsonProcessingException { + line = extractSseData(line); + if (line != null) { + return Utils.OBJECT_MAPPER.readValue(line, SendStreamingMessageResponse.class); + } + return null; + } + + private static String extractSseData(String line) { + if (line.startsWith("data:")) { + line = line.substring(5).trim(); + return line; + } + return null; + } + + private boolean isStreamClosedError(Throwable throwable) { + // Unwrap the CompletionException + Throwable cause = throwable; + + while (cause != null) { + if (cause instanceof EOFException) { + return true; + } + cause = cause.getCause(); + } + return false; + } + + private CompletableFuture>> initialiseStreamingRequest( + StreamingJSONRPCRequest request, String mediaType) throws Exception { + + // Create the client + HttpClient client = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .build(); + + // Create the request + HttpRequest.Builder builder = HttpRequest.newBuilder() + .uri(URI.create("http://localhost:8081/")) + .POST(HttpRequest.BodyPublishers.ofString(Utils.OBJECT_MAPPER.writeValueAsString(request))) + .header("Content-Type", "application/json"); + if (mediaType != null) { + builder.header("Accept", mediaType); + } + HttpRequest httpRequest = builder.build(); + + + // Send request async and return the CompletableFuture + return client.sendAsync(httpRequest, HttpResponse.BodyHandlers.ofLines()); + } + + protected abstract TaskStore getTaskStore(); + + protected abstract InMemoryQueueManager getQueueManager(); + + protected abstract void setStreamingSubscribedRunnable(Runnable runnable); + + private static class BreakException extends RuntimeException { + + } +} diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentCardProducer.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentCardProducer.java new file mode 100644 index 000000000..f68f44967 --- /dev/null +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentCardProducer.java @@ -0,0 +1,38 @@ +package io.a2a.server.apps.common; + +import java.util.ArrayList; +import java.util.Collections; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Produces; + +import io.a2a.server.PublicAgentCard; +import io.a2a.spec.AgentCapabilities; +import io.a2a.spec.AgentCard; +import io.quarkus.arc.profile.IfBuildProfile; + +@ApplicationScoped +@IfBuildProfile("test") +public class AgentCardProducer { + + @Produces + @PublicAgentCard + public AgentCard agentCard() { + return new AgentCard.Builder() + .name("test-card") + .description("A test agent card") + .url("http://localhost:8081") + .version("1.0") + .documentationUrl("http://example.com/docs") + .capabilities(new AgentCapabilities.Builder() + .streaming(true) + .pushNotifications(true) + .stateTransitionHistory(true) + .build()) + .defaultInputModes(Collections.singletonList("text")) + .defaultOutputModes(Collections.singletonList("text")) + .skills(new ArrayList<>()) + .build(); + } +} + diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentExecutorProducer.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentExecutorProducer.java new file mode 100644 index 000000000..9576da4a2 --- /dev/null +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentExecutorProducer.java @@ -0,0 +1,40 @@ +package io.a2a.server.apps.common; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Produces; + +import io.a2a.server.agentexecution.AgentExecutor; +import io.a2a.server.agentexecution.RequestContext; +import io.a2a.server.events.EventQueue; +import io.a2a.server.tasks.TaskUpdater; +import io.a2a.spec.JSONRPCError; +import io.a2a.spec.UnsupportedOperationError; +import io.quarkus.arc.profile.IfBuildProfile; + +@ApplicationScoped +@IfBuildProfile("test") +public class AgentExecutorProducer { + + @Produces + public AgentExecutor agentExecutor() { + return new AgentExecutor() { + @Override + public void execute(RequestContext context, EventQueue eventQueue) throws JSONRPCError { + if (context.getTaskId().equals("task-not-supported-123")) { + eventQueue.enqueueEvent(new UnsupportedOperationError()); + } + eventQueue.enqueueEvent(context.getMessage() != null ? context.getMessage() : context.getTask()); + } + + @Override + public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPCError { + if (context.getTask().getId().equals("cancel-task-123")) { + TaskUpdater taskUpdater = new TaskUpdater(context, eventQueue); + taskUpdater.cancel(); + } else if (context.getTask().getId().equals("cancel-task-not-supported-123")) { + throw new UnsupportedOperationError(); + } + } + }; + } +} diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java new file mode 100644 index 000000000..c5deef68f --- /dev/null +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java @@ -0,0 +1,83 @@ +package io.a2a.server.apps.common; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.function.Consumer; + +import jakarta.enterprise.context.Dependent; +import jakarta.enterprise.inject.Alternative; + +import io.a2a.http.A2AHttpClient; +import io.a2a.http.A2AHttpResponse; +import io.a2a.spec.Task; +import io.a2a.util.Utils; + +@Dependent +@Alternative +public class TestHttpClient implements A2AHttpClient { + final List tasks = Collections.synchronizedList(new ArrayList<>()); + volatile CountDownLatch latch; + + @Override + public GetBuilder createGet() { + return null; + } + + @Override + public PostBuilder createPost() { + return new TestPostBuilder(); + } + + class TestPostBuilder implements A2AHttpClient.PostBuilder { + private volatile String body; + @Override + public PostBuilder body(String body) { + this.body = body; + return this; + } + + @Override + public A2AHttpResponse post() throws IOException, InterruptedException { + tasks.add(Utils.OBJECT_MAPPER.readValue(body, Task.TYPE_REFERENCE)); + try { + return new A2AHttpResponse() { + @Override + public int status() { + return 200; + } + + @Override + public boolean success() { + return true; + } + + @Override + public String body() { + return ""; + } + }; + } finally { + latch.countDown(); + } + } + + @Override + public CompletableFuture postAsyncSSE(Consumer messageConsumer, Consumer errorConsumer, Runnable completeRunnable) throws IOException, InterruptedException { + return null; + } + + @Override + public PostBuilder url(String s) { + return this; + } + + @Override + public PostBuilder addHeader(String name, String value) { + return this; + } + } +} \ No newline at end of file diff --git a/tests/server-common/src/test/resources/META-INF/beans.xml b/tests/server-common/src/test/resources/META-INF/beans.xml new file mode 100644 index 000000000..e69de29bb