Skip to content

Commit 82a6b7c

Browse files
holtskinnerpstephengoogleswapydapygemini-code-assist[bot]kthota-g
authored
feat: Add RESTful API Serving (#348)
refactor: Update client code to support multi-transport refactor: Refactor client into BaseClient + ClientTransport #363 --------- Co-authored-by: pstephengoogle <[email protected]> Co-authored-by: swapydapy <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: kthota-g <[email protected]> Co-authored-by: Aneesh Garg <[email protected]> Co-authored-by: Mike Smith <[email protected]>
1 parent 9ad8b96 commit 82a6b7c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+5459
-1904
lines changed

.github/actions/spelling/allow.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ lifecycles
5353
linting
5454
Llm
5555
lstrips
56+
mikeas
5657
mockurl
5758
notif
5859
oauthoidc
@@ -67,6 +68,7 @@ pyi
6768
pypistats
6869
pyupgrade
6970
pyversions
71+
redef
7072
respx
7173
resub
7274
RUF
@@ -76,5 +78,6 @@ sse
7678
tagwords
7779
taskupdate
7880
testuuid
81+
Tful
7982
typeerror
8083
vulnz

.github/workflows/unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,6 @@ jobs:
5757
- name: Install dependencies
5858
run: uv sync --dev --extra sql --extra encryption --extra grpc --extra telemetry
5959
- name: Run tests and check coverage
60-
run: uv run pytest --cov=a2a --cov-report=xml --cov-fail-under=89
60+
run: uv run pytest --cov=a2a --cov-report term --cov-fail-under=89
6161
- name: Show coverage summary in log
6262
run: uv run coverage report

.vscode/launch.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@
4747
"-s"
4848
],
4949
"console": "integratedTerminal",
50-
"justMyCode": true
50+
"justMyCode": true,
51+
"python": "${workspaceFolder}/.venv/bin/python",
5152
}
5253
]
53-
}
54+
}

Gemini.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
**A2A specification:** https://a2a-protocol.org/latest/specification/
2+
3+
## Project frameworks
4+
- uv as package manager
5+
6+
## How to run all tests
7+
1. If dependencies are not installed install them using following command
8+
```
9+
uv sync --all-extras
10+
```
11+
12+
2. Run tests
13+
```
14+
uv run pytest
15+
```
16+
17+
## Other instructions
18+
1. Whenever writing python code, write types as well.
19+
2. After making the changes run ruff to check and fix the formatting issues
20+
```
21+
uv run ruff check --fix
22+
```
23+
3. Run mypy type checkers to check for type errors
24+
```
25+
uv run mypy
26+
```
27+
4. Run the unit tests to make sure that none of the unit tests are broken.

error_handlers.py

Whitespace-only changes.

pyproject.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ authors = [{ name = "Google LLC", email = "[email protected]" }]
88
requires-python = ">=3.10"
99
keywords = ["A2A", "A2A SDK", "A2A Protocol", "Agent2Agent", "Agent 2 Agent"]
1010
dependencies = [
11-
"fastapi>=0.115.2",
11+
"fastapi>=0.116.1",
1212
"httpx>=0.28.1",
1313
"httpx-sse>=0.4.0",
1414
"pydantic>=2.11.3",
1515
"sse-starlette",
16-
"starlette"
16+
"starlette",
17+
"protobuf==5.29.5",
18+
"google-api-core>=1.26.0",
1719
]
1820

1921
classifiers = [
@@ -35,7 +37,7 @@ mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"]
3537
sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"]
3638
sql = ["sqlalchemy[asyncio,postgresql-asyncpg,aiomysql,aiosqlite]>=2.0.0"]
3739
encryption = ["cryptography>=43.0.0"]
38-
grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio_reflection>=1.7.0", "protobuf==5.29.5", "google-api-core>=1.26.0"]
40+
grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio_reflection>=1.7.0"]
3941
telemetry = ["opentelemetry-api>=1.33.0", "opentelemetry-sdk>=1.33.0"]
4042

4143
[project.urls]
@@ -90,6 +92,7 @@ dev = [
9092
"pyupgrade",
9193
"autoflake",
9294
"no_implicit_optional",
95+
"trio",
9396
]
9497

9598
[[tool.uv.index]]

src/a2a/client/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,24 @@
77
CredentialService,
88
InMemoryContextCredentialStore,
99
)
10-
from a2a.client.client import A2ACardResolver, A2AClient
10+
from a2a.client.card_resolver import A2ACardResolver
11+
from a2a.client.client import Client, ClientConfig, ClientEvent, Consumer
12+
from a2a.client.client_factory import ClientFactory, minimal_agent_card
1113
from a2a.client.errors import (
1214
A2AClientError,
1315
A2AClientHTTPError,
1416
A2AClientJSONError,
1517
A2AClientTimeoutError,
1618
)
1719
from a2a.client.helpers import create_text_message_object
20+
from a2a.client.legacy import A2AClient
1821
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
1922

2023

2124
logger = logging.getLogger(__name__)
2225

2326
try:
24-
from a2a.client.grpc_client import A2AGrpcClient # type: ignore
27+
from a2a.client.legacy_grpc import A2AGrpcClient # type: ignore
2528
except ImportError as e:
2629
_original_error = e
2730
logger.debug(
@@ -48,9 +51,15 @@ def __init__(self, *args, **kwargs):
4851
'A2AClientTimeoutError',
4952
'A2AGrpcClient',
5053
'AuthInterceptor',
54+
'Client',
5155
'ClientCallContext',
5256
'ClientCallInterceptor',
57+
'ClientConfig',
58+
'ClientEvent',
59+
'ClientFactory',
60+
'Consumer',
5361
'CredentialService',
5462
'InMemoryContextCredentialStore',
5563
'create_text_message_object',
64+
'minimal_agent_card',
5665
]

src/a2a/client/base_client.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
from collections.abc import AsyncIterator
2+
3+
from a2a.client.client import (
4+
Client,
5+
ClientCallContext,
6+
ClientConfig,
7+
ClientEvent,
8+
Consumer,
9+
)
10+
from a2a.client.client_task_manager import ClientTaskManager
11+
from a2a.client.errors import A2AClientInvalidStateError
12+
from a2a.client.middleware import ClientCallInterceptor
13+
from a2a.client.transports.base import ClientTransport
14+
from a2a.types import (
15+
AgentCard,
16+
GetTaskPushNotificationConfigParams,
17+
Message,
18+
MessageSendConfiguration,
19+
MessageSendParams,
20+
Task,
21+
TaskArtifactUpdateEvent,
22+
TaskIdParams,
23+
TaskPushNotificationConfig,
24+
TaskQueryParams,
25+
TaskStatusUpdateEvent,
26+
)
27+
28+
29+
class BaseClient(Client):
30+
"""Base implementation of the A2A client, containing transport-independent logic."""
31+
32+
def __init__(
33+
self,
34+
card: AgentCard,
35+
config: ClientConfig,
36+
transport: ClientTransport,
37+
consumers: list[Consumer],
38+
middleware: list[ClientCallInterceptor],
39+
):
40+
super().__init__(consumers, middleware)
41+
self._card = card
42+
self._config = config
43+
self._transport = transport
44+
45+
async def send_message(
46+
self,
47+
request: Message,
48+
*,
49+
context: ClientCallContext | None = None,
50+
) -> AsyncIterator[ClientEvent | Message]:
51+
"""Sends a message to the agent.
52+
53+
This method handles both streaming and non-streaming (polling) interactions
54+
based on the client configuration and agent capabilities. It will yield
55+
events as they are received from the agent.
56+
57+
Args:
58+
request: The message to send to the agent.
59+
context: The client call context.
60+
61+
Yields:
62+
An async iterator of `ClientEvent` or a final `Message` response.
63+
"""
64+
config = MessageSendConfiguration(
65+
accepted_output_modes=self._config.accepted_output_modes,
66+
blocking=not self._config.polling,
67+
push_notification_config=(
68+
self._config.push_notification_configs[0]
69+
if self._config.push_notification_configs
70+
else None
71+
),
72+
)
73+
params = MessageSendParams(message=request, configuration=config)
74+
75+
if not self._config.streaming or not self._card.capabilities.streaming:
76+
response = await self._transport.send_message(
77+
params, context=context
78+
)
79+
result = (
80+
(response, None) if isinstance(response, Task) else response
81+
)
82+
await self.consume(result, self._card)
83+
yield result
84+
return
85+
86+
tracker = ClientTaskManager()
87+
stream = self._transport.send_message_streaming(params, context=context)
88+
89+
first_event = await anext(stream)
90+
# The response from a server may be either exactly one Message or a
91+
# series of Task updates. Separate out the first message for special
92+
# case handling, which allows us to simplify further stream processing.
93+
if isinstance(first_event, Message):
94+
await self.consume(first_event, self._card)
95+
yield first_event
96+
return
97+
98+
yield await self._process_response(tracker, first_event)
99+
100+
async for event in stream:
101+
yield await self._process_response(tracker, event)
102+
103+
async def _process_response(
104+
self,
105+
tracker: ClientTaskManager,
106+
event: Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent,
107+
) -> ClientEvent:
108+
if isinstance(event, Message):
109+
raise A2AClientInvalidStateError(
110+
'received a streamed Message from server after first response; this is not supported'
111+
)
112+
await tracker.process(event)
113+
task = tracker.get_task_or_raise()
114+
update = None if isinstance(event, Task) else event
115+
client_event = (task, update)
116+
await self.consume(client_event, self._card)
117+
return client_event
118+
119+
async def get_task(
120+
self,
121+
request: TaskQueryParams,
122+
*,
123+
context: ClientCallContext | None = None,
124+
) -> Task:
125+
"""Retrieves the current state and history of a specific task.
126+
127+
Args:
128+
request: The `TaskQueryParams` object specifying the task ID.
129+
context: The client call context.
130+
131+
Returns:
132+
A `Task` object representing the current state of the task.
133+
"""
134+
return await self._transport.get_task(request, context=context)
135+
136+
async def cancel_task(
137+
self,
138+
request: TaskIdParams,
139+
*,
140+
context: ClientCallContext | None = None,
141+
) -> Task:
142+
"""Requests the agent to cancel a specific task.
143+
144+
Args:
145+
request: The `TaskIdParams` object specifying the task ID.
146+
context: The client call context.
147+
148+
Returns:
149+
A `Task` object containing the updated task status.
150+
"""
151+
return await self._transport.cancel_task(request, context=context)
152+
153+
async def set_task_callback(
154+
self,
155+
request: TaskPushNotificationConfig,
156+
*,
157+
context: ClientCallContext | None = None,
158+
) -> TaskPushNotificationConfig:
159+
"""Sets or updates the push notification configuration for a specific task.
160+
161+
Args:
162+
request: The `TaskPushNotificationConfig` object with the new configuration.
163+
context: The client call context.
164+
165+
Returns:
166+
The created or updated `TaskPushNotificationConfig` object.
167+
"""
168+
return await self._transport.set_task_callback(request, context=context)
169+
170+
async def get_task_callback(
171+
self,
172+
request: GetTaskPushNotificationConfigParams,
173+
*,
174+
context: ClientCallContext | None = None,
175+
) -> TaskPushNotificationConfig:
176+
"""Retrieves the push notification configuration for a specific task.
177+
178+
Args:
179+
request: The `GetTaskPushNotificationConfigParams` object specifying the task.
180+
context: The client call context.
181+
182+
Returns:
183+
A `TaskPushNotificationConfig` object containing the configuration.
184+
"""
185+
return await self._transport.get_task_callback(request, context=context)
186+
187+
async def resubscribe(
188+
self,
189+
request: TaskIdParams,
190+
*,
191+
context: ClientCallContext | None = None,
192+
) -> AsyncIterator[ClientEvent]:
193+
"""Resubscribes to a task's event stream.
194+
195+
This is only available if both the client and server support streaming.
196+
197+
Args:
198+
request: Parameters to identify the task to resubscribe to.
199+
context: The client call context.
200+
201+
Yields:
202+
An async iterator of `ClientEvent` objects.
203+
204+
Raises:
205+
NotImplementedError: If streaming is not supported by the client or server.
206+
"""
207+
if not self._config.streaming or not self._card.capabilities.streaming:
208+
raise NotImplementedError(
209+
'client and/or server do not support resubscription.'
210+
)
211+
212+
tracker = ClientTaskManager()
213+
# Note: resubscribe can only be called on an existing task. As such,
214+
# we should never see Message updates, despite the typing of the service
215+
# definition indicating it may be possible.
216+
async for event in self._transport.resubscribe(
217+
request, context=context
218+
):
219+
yield await self._process_response(tracker, event)
220+
221+
async def get_card(
222+
self, *, context: ClientCallContext | None = None
223+
) -> AgentCard:
224+
"""Retrieves the agent's card.
225+
226+
This will fetch the authenticated card if necessary and update the
227+
client's internal state with the new card.
228+
229+
Args:
230+
context: The client call context.
231+
232+
Returns:
233+
The `AgentCard` for the agent.
234+
"""
235+
card = await self._transport.get_card(context=context)
236+
self._card = card
237+
return card
238+
239+
async def close(self) -> None:
240+
"""Closes the underlying transport."""
241+
await self._transport.close()

0 commit comments

Comments
 (0)