Skip to content

Commit 088d25c

Browse files
authored
Merge branch 'main' into md-auth
2 parents 8700c28 + 39307f1 commit 088d25c

File tree

11 files changed

+1056
-714
lines changed

11 files changed

+1056
-714
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ pip install a2a-sdk
5959
uv run test_client.py
6060
```
6161

62+
3. You can validate your agent using the agent inspector. Follow the instructions at the [a2a-inspector](https://github.com/google-a2a/a2a-inspector) repo.
63+
6264
You can also find more Python samples [here](https://github.com/google-a2a/a2a-samples/tree/main/samples/python) and JavaScript samples [here](https://github.com/google-a2a/a2a-samples/tree/main/samples/js).
6365

6466
## License

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ authors = [{ name = "Google LLC", email = "[email protected]" }]
88
requires-python = ">=3.10"
99
keywords = ["A2A", "A2A SDK", "A2A Protocol", "Agent2Agent"]
1010
dependencies = [
11-
"fastapi>=0.115.12",
11+
"fastapi>=0.115.2",
1212
"httpx>=0.28.1",
1313
"httpx-sse>=0.4.0",
1414
"google-api-core>=1.26.0",
1515
"opentelemetry-api>=1.33.0",
1616
"opentelemetry-sdk>=1.33.0",
1717
"pydantic>=2.11.3",
18-
"sse-starlette>=2.3.3",
19-
"starlette>=0.46.2",
18+
"sse-starlette",
19+
"starlette",
2020
"grpcio>=1.60",
2121
"grpcio-tools>=1.60",
2222
"grpcio_reflection>=1.7.0",

src/a2a/server/apps/jsonrpc/fastapi_app.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,25 +49,21 @@ def __init__(
4949
context_builder=context_builder,
5050
)
5151

52-
def build(
52+
def add_routes_to_app(
5353
self,
54+
app: FastAPI,
5455
agent_card_url: str = '/.well-known/agent.json',
5556
rpc_url: str = '/',
5657
extended_agent_card_url: str = '/agent/authenticatedExtendedCard',
57-
**kwargs: Any,
58-
) -> FastAPI:
59-
"""Builds and returns the FastAPI application instance.
58+
) -> None:
59+
"""Adds the routes to the FastAPI application.
6060
6161
Args:
62+
app: The FastAPI application to add the routes to.
6263
agent_card_url: The URL for the agent card endpoint.
6364
rpc_url: The URL for the A2A JSON-RPC endpoint.
6465
extended_agent_card_url: The URL for the authenticated extended agent card endpoint.
65-
**kwargs: Additional keyword arguments to pass to the FastAPI constructor.
66-
67-
Returns:
68-
A configured FastAPI application instance.
6966
"""
70-
app = FastAPI(**kwargs)
7167

7268
@app.post(rpc_url)
7369
async def handle_a2a_request(request: Request) -> Response:
@@ -85,4 +81,28 @@ async def get_extended_agent_card(request: Request) -> Response:
8581
request
8682
)
8783

84+
def build(
85+
self,
86+
agent_card_url: str = '/.well-known/agent.json',
87+
rpc_url: str = '/',
88+
extended_agent_card_url: str = '/agent/authenticatedExtendedCard',
89+
**kwargs: Any,
90+
) -> FastAPI:
91+
"""Builds and returns the FastAPI application instance.
92+
93+
Args:
94+
agent_card_url: The URL for the agent card endpoint.
95+
rpc_url: The URL for the A2A JSON-RPC endpoint.
96+
extended_agent_card_url: The URL for the authenticated extended agent card endpoint.
97+
**kwargs: Additional keyword arguments to pass to the FastAPI constructor.
98+
99+
Returns:
100+
A configured FastAPI application instance.
101+
"""
102+
app = FastAPI(**kwargs)
103+
104+
self.add_routes_to_app(
105+
app, agent_card_url, rpc_url, extended_agent_card_url
106+
)
107+
88108
return app

src/a2a/server/apps/jsonrpc/starlette_app.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,28 @@ def routes(
9292
)
9393
return app_routes
9494

95+
def add_routes_to_app(
96+
self,
97+
app: Starlette,
98+
agent_card_url: str = '/.well-known/agent.json',
99+
rpc_url: str = '/',
100+
extended_agent_card_url: str = '/agent/authenticatedExtendedCard',
101+
) -> None:
102+
"""Adds the routes to the Starlette application.
103+
104+
Args:
105+
app: The Starlette application to add the routes to.
106+
agent_card_url: The URL path for the agent card endpoint.
107+
rpc_url: The URL path for the A2A JSON-RPC endpoint (POST requests).
108+
extended_agent_card_url: The URL for the authenticated extended agent card endpoint.
109+
"""
110+
routes = self.routes(
111+
agent_card_url=agent_card_url,
112+
rpc_url=rpc_url,
113+
extended_agent_card_url=extended_agent_card_url,
114+
)
115+
app.routes.extend(routes)
116+
95117
def build(
96118
self,
97119
agent_card_url: str = '/.well-known/agent.json',
@@ -110,14 +132,10 @@ def build(
110132
Returns:
111133
A configured Starlette application instance.
112134
"""
113-
app_routes = self.routes(
114-
agent_card_url=agent_card_url,
115-
rpc_url=rpc_url,
116-
extended_agent_card_url=extended_agent_card_url,
135+
app = Starlette(**kwargs)
136+
137+
self.add_routes_to_app(
138+
app, agent_card_url, rpc_url, extended_agent_card_url
117139
)
118-
if 'routes' in kwargs:
119-
kwargs['routes'].extend(app_routes)
120-
else:
121-
kwargs['routes'] = app_routes
122140

123-
return Starlette(**kwargs)
141+
return app

src/a2a/server/events/event_consumer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ async def consume_all(self) -> AsyncGenerator[Event]:
130130
except TimeoutError:
131131
# continue polling until there is a final event
132132
continue
133+
except asyncio.TimeoutError:
134+
# This class was made an alias of build-in TimeoutError after 3.11
135+
continue
133136
except QueueClosed:
134137
# Confirm that the queue is closed, e.g. we aren't on
135138
# python 3.12 and get a queue empty error on an open queue

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from a2a.types import (
3030
GetTaskPushNotificationConfigParams,
3131
InternalError,
32+
InvalidParamsError,
3233
Message,
3334
MessageSendConfiguration,
3435
MessageSendParams,
@@ -38,6 +39,7 @@
3839
TaskNotFoundError,
3940
TaskPushNotificationConfig,
4041
TaskQueryParams,
42+
TaskState,
4143
UnsupportedOperationError,
4244
)
4345
from a2a.utils.errors import ServerError
@@ -46,6 +48,12 @@
4648

4749
logger = logging.getLogger(__name__)
4850

51+
TERMINAL_TASK_STATES = {
52+
TaskState.completed,
53+
TaskState.canceled,
54+
TaskState.failed,
55+
TaskState.rejected,
56+
}
4957

5058
@trace_class(kind=SpanKind.SERVER)
5159
class DefaultRequestHandler(RequestHandler):
@@ -178,6 +186,13 @@ async def on_message_send(
178186
)
179187
task: Task | None = await task_manager.get_task()
180188
if task:
189+
if task.status.state in TERMINAL_TASK_STATES:
190+
raise ServerError(
191+
error=InvalidParamsError(
192+
message=f'Task {task.id} is in terminal state: {task.status.state}'
193+
)
194+
)
195+
181196
task = task_manager.update_with_message(params.message, task)
182197
if self.should_add_push_info(params):
183198
assert isinstance(self._push_notifier, PushNotifier)
@@ -264,8 +279,14 @@ async def on_message_send_stream(
264279
task: Task | None = await task_manager.get_task()
265280

266281
if task:
267-
task = task_manager.update_with_message(params.message, task)
282+
if task.status.state in TERMINAL_TASK_STATES:
283+
raise ServerError(
284+
error=InvalidParamsError(
285+
message=f'Task {task.id} is in terminal state: {task.status.state}'
286+
)
287+
)
268288

289+
task = task_manager.update_with_message(params.message, task)
269290
if self.should_add_push_info(params):
270291
assert isinstance(self._push_notifier, PushNotifier)
271292
assert isinstance(
@@ -413,6 +434,13 @@ async def on_resubscribe_to_task(
413434
if not task:
414435
raise ServerError(error=TaskNotFoundError())
415436

437+
if task.status.state in TERMINAL_TASK_STATES:
438+
raise ServerError(
439+
error=InvalidParamsError(
440+
message=f'Task {task.id} is in terminal state: {task.status.state}'
441+
)
442+
)
443+
416444
task_manager = TaskManager(
417445
task_id=task.id,
418446
context_id=task.contextId,

src/a2a/server/tasks/task_updater.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,14 @@ async def update_status(
6565
)
6666
)
6767

68-
async def add_artifact(
68+
async def add_artifact( # noqa: PLR0913
6969
self,
7070
parts: list[Part],
7171
artifact_id: str | None = None,
7272
name: str | None = None,
7373
metadata: dict[str, Any] | None = None,
74+
append: bool | None = None,
75+
last_chunk: bool | None = None,
7476
) -> None:
7577
"""Adds an artifact chunk to the task and publishes a `TaskArtifactUpdateEvent`.
7678
@@ -79,6 +81,8 @@ async def add_artifact(
7981
artifact_id: The ID of the artifact. A new UUID is generated if not provided.
8082
name: Optional name for the artifact.
8183
metadata: Optional metadata for the artifact.
84+
append: Optional boolean indicating if this chunk appends to a previous one.
85+
last_chunk: Optional boolean indicating if this is the last chunk.
8286
"""
8387
if not artifact_id:
8488
artifact_id = str(uuid.uuid4())
@@ -93,6 +97,8 @@ async def add_artifact(
9397
parts=parts,
9498
metadata=metadata,
9599
),
100+
append=append,
101+
lastChunk=last_chunk
96102
)
97103
)
98104

@@ -128,6 +134,30 @@ async def start_work(self, message: Message | None = None) -> None:
128134
message=message,
129135
)
130136

137+
async def cancel(self, message: Message | None = None) -> None:
138+
"""Marks the task as cancelled and publishes a finalstatus update."""
139+
await self.update_status(
140+
TaskState.canceled, message=message, final=True
141+
)
142+
143+
async def requires_input(
144+
self, message: Message | None = None, final: bool = False
145+
) -> None:
146+
"""Marks the task as input required and publishes a status update."""
147+
await self.update_status(
148+
TaskState.input_required,
149+
message=message,
150+
final=final,
151+
)
152+
153+
async def requires_auth(
154+
self, message: Message | None = None, final: bool = False
155+
) -> None:
156+
"""Marks the task as auth required and publishes a status update."""
157+
await self.update_status(
158+
TaskState.auth_required, message=message, final=final
159+
)
160+
131161
def new_agent_message(
132162
self,
133163
parts: list[Part],

src/a2a/utils/proto_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
# Regexp patterns for matching
1717
_TASK_NAME_MATCH = r'tasks/(\w+)'
18-
_TASK_PUSH_CONFIG_NAME_MATCH = r'tasks/(\w+)/pushNotifications/(\w+)'
18+
_TASK_PUSH_CONFIG_NAME_MATCH = r'tasks/(\w+)/pushNotificationConfigs/(\w+)'
1919

2020

2121
class ToProto:
@@ -252,7 +252,7 @@ def task_push_notification_config(
252252
cls, config: types.TaskPushNotificationConfig
253253
) -> a2a_pb2.TaskPushNotificationConfig:
254254
return a2a_pb2.TaskPushNotificationConfig(
255-
name=f'tasks/{config.taskId}/pushNotifications/{config.taskId}',
255+
name=f'tasks/{config.taskId}/pushNotificationConfigs/{config.taskId}',
256256
push_notification_config=cls.push_notification_config(
257257
config.pushNotificationConfig,
258258
),

0 commit comments

Comments
 (0)