Skip to content

Commit f63189f

Browse files
committed
Miscellaneous fixes
Change-Id: Ic545b230c7dd153e105c855b68c43fab4a010fa3
1 parent 9b0f264 commit f63189f

File tree

20 files changed

+395
-780
lines changed

20 files changed

+395
-780
lines changed
Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,18 @@
1-
import asyncio
2-
3-
from collections.abc import AsyncGenerator
4-
from typing import Any
5-
from uuid import uuid4
6-
71
from typing_extensions import override
2+
83
from a2a.server.agent_execution import AgentExecutor, RequestContext
9-
from a2a.server.events import Event, EventQueue
10-
from a2a.utils import new_agent_text_message
4+
from a2a.server.events import EventQueue
115
from a2a.types import (
12-
Message,
13-
Part,
14-
Role,
15-
SendMessageRequest,
16-
SendStreamingMessageRequest,
176
Task,
18-
TextPart,
197
)
8+
from a2a.utils import new_agent_text_message
9+
2010

2111
class HelloWorldAgent:
2212
"""Hello World Agent."""
2313

2414
async def invoke(self) -> str:
25-
return 'Hello World'
15+
return 'Hello World'
2616

2717

2818
class HelloWorldAgentExecutor(AgentExecutor):
@@ -41,5 +31,7 @@ async def execute(
4131
event_queue.enqueue_event(new_agent_text_message(result))
4232

4333
@override
44-
async def cancel(self, request: RequestContext, event_queue: EventQueue) -> Task | None:
45-
raise Exception("cancel not supported")
34+
async def cancel(
35+
self, request: RequestContext, event_queue: EventQueue
36+
) -> Task | None:
37+
raise Exception('cancel not supported')
Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,16 @@
11
from abc import ABC, abstractmethod
2+
23
from a2a.server.agent_execution.context import RequestContext
3-
from a2a.server.events.event_queue import Event, EventQueue
4-
from a2a.types import (
5-
Task,
6-
TaskStatusUpdateEvent,
7-
TaskArtifactUpdateEvent,
8-
Message
9-
)
4+
from a2a.server.events.event_queue import EventQueue
105

11-
class AgentExecutor(ABC):
126

13-
@abstractmethod
14-
async def execute(
15-
self, context: RequestContext, event_queue: EventQueue
16-
):
17-
pass
7+
class AgentExecutor(ABC):
8+
"""Agent Executor interface."""
189

19-
@abstractmethod
20-
async def cancel(
21-
self, context: RequestContext, event_queue: EventQueue):
22-
pass
10+
@abstractmethod
11+
async def execute(self, context: RequestContext, event_queue: EventQueue):
12+
pass
2313

14+
@abstractmethod
15+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
16+
pass
Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,38 @@
11
import uuid
2-
from a2a.utils.errors import ServerError
2+
33
from a2a.types import (
4-
Task,
4+
InvalidParamsError,
55
Message,
66
MessageSendParams,
7-
TextPart
7+
Task,
8+
TextPart,
89
)
10+
from a2a.utils.errors import ServerError
11+
912

1013
class RequestContext:
14+
"""Request Context."""
1115

1216
def __init__(
1317
self,
1418
request: MessageSendParams | None = None,
1519
task_id: str | None = None,
1620
context_id: str | None = None,
1721
task: Task | None = None,
18-
related_tasks: list[Task] = []
22+
related_tasks: list[Task] = [],
1923
):
2024
self._params = request
2125
self._task_id = task_id
2226
self._context_id = context_id
2327
self._current_task = task
24-
self._related_tasks: related_tasks
28+
self._related_tasks = related_tasks
2529
# If the task id and context id were provided, make sure they
2630
# match the request. Otherwise, create them
2731
if self._params:
2832
if task_id:
2933
self._params.message.taskId = task_id
3034
if task and task.id != task_id:
31-
raise ServerError(
32-
InvalidParamsError(message='bad task id')
33-
)
35+
raise ServerError(InvalidParamsError(message='bad task id'))
3436
else:
3537
self._check_or_generate_task_id()
3638
if context_id:
@@ -43,6 +45,9 @@ def __init__(
4345
self._check_or_generate_context_id()
4446

4547
def get_user_input(self, delimiter='\n') -> str:
48+
if not self._params:
49+
return ''
50+
4651
parts = []
4752
for part in self._params.message.parts:
4853
if isinstance(part.root, TextPart):
@@ -61,29 +66,35 @@ def related_tasks(self) -> list[Task]:
6166
return self._related_tasks
6267

6368
@property
64-
def current_task(self) -> Task:
69+
def current_task(self) -> Task | None:
6570
return self._current_task
6671

6772
@current_task.setter
68-
def current_task(self, task: Task):
73+
def current_task(self, task: Task) -> None:
6974
self._current_task = task
7075

7176
@property
72-
def task_id(self) -> str:
77+
def task_id(self) -> str | None:
7378
return self._task_id
7479

7580
@property
76-
def context_id(self) -> str:
81+
def context_id(self) -> str | None:
7782
return self._context_id
7883

79-
def _check_or_generate_task_id(self):
80-
if not self._task_id and not self._params.message.taskId:
81-
self._params.message.taskId = str(uuid.uuid4())
82-
if self._params.message.taskId:
83-
self._task_id = self._params.message.taskId
84-
85-
def _check_or_generate_context_id(self):
86-
if not self._context_id and not self._params.message.contextId:
87-
self._params.message.taskId = str(uuid.uuid4())
88-
if self._params.message.contextId:
89-
self._context_id = self._params.message.contextId
84+
def _check_or_generate_task_id(self) -> None:
85+
if not self._params:
86+
return
87+
88+
if not self._task_id and not self._params.message.taskId:
89+
self._params.message.taskId = str(uuid.uuid4())
90+
if self._params.message.taskId:
91+
self._task_id = self._params.message.taskId
92+
93+
def _check_or_generate_context_id(self) -> None:
94+
if not self._params:
95+
return
96+
97+
if not self._context_id and not self._params.message.contextId:
98+
self._params.message.contextId = str(uuid.uuid4())
99+
if self._params.message.contextId:
100+
self._context_id = self._params.message.contextId

src/a2a/server/apps/http_app.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from abc import ABC, abstractmethod
2+
from typing import Any
3+
24
from starlette.applications import Starlette
35

6+
47
class HttpApp(ABC):
8+
"""A2A Server application interface."""
59

6-
@abstractmethod
7-
def build(self, **kwargs) -> Starlette:
8-
pass
10+
@abstractmethod
11+
def build(self, **kwargs: Any) -> Starlette:
12+
pass

src/a2a/server/apps/starlette_app.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
from starlette.responses import JSONResponse, Response
1313
from starlette.routing import Route
1414

15-
from a2a.utils.errors import MethodNotImplementedError
16-
from a2a.server.request_handlers.jsonrpc_handler import RequestHandler, JSONRPCHandler
17-
from a2a.server.events.event_queue import Event
15+
from a2a.server.request_handlers.jsonrpc_handler import (
16+
JSONRPCHandler,
17+
RequestHandler,
18+
)
1819
from a2a.types import (
1920
A2AError,
2021
A2ARequest,
@@ -35,6 +36,7 @@
3536
TaskResubscriptionRequest,
3637
UnsupportedOperationError,
3738
)
39+
from a2a.utils.errors import MethodNotImplementedError
3840

3941

4042
logger = logging.getLogger(__name__)
@@ -47,17 +49,17 @@ class A2AStarletteApplication:
4749
handler methods, and manages response generation including Server-Sent Events (SSE).
4850
"""
4951

50-
def __init__(
51-
self, agent_card: AgentCard, http_handler: RequestHandler
52-
):
52+
def __init__(self, agent_card: AgentCard, http_handler: RequestHandler):
5353
"""Initializes the A2AApplication.
5454
5555
Args:
5656
agent_card: The AgentCard describing the agent's capabilities.
5757
http_handler: The handler instance responsible for processing A2A requests via http.
5858
"""
5959
self.agent_card = agent_card
60-
self.handler = JSONRPCHandler(agent_card=agent_card, request_handler=http_handler)
60+
self.handler = JSONRPCHandler(
61+
agent_card=agent_card, request_handler=http_handler
62+
)
6163

6264
def _generate_error_response(
6365
self, request_id: str | int | None, error: JSONRPCError | A2AError
@@ -154,7 +156,7 @@ async def _process_streaming_request(
154156
):
155157
handler_result = self.handler.on_message_send_stream(request_obj)
156158
elif isinstance(request_obj, TaskResubscriptionRequest):
157-
handler_result = self.handler.on_resubscribe(request_obj)
159+
handler_result = self.handler.on_resubscribe_to_task(request_obj)
158160

159161
return self._create_response(handler_result)
160162

@@ -173,15 +175,17 @@ async def _process_non_streaming_request(
173175
case SendMessageRequest():
174176
handler_result = await self.handler.on_message_send(request_obj)
175177
case CancelTaskRequest():
176-
handler_result = await self.handler.on_cancel(request_obj)
178+
handler_result = await self.handler.on_cancel_task(request_obj)
177179
case GetTaskRequest():
178180
handler_result = await self.handler.on_get_task(request_obj)
179181
case SetTaskPushNotificationConfigRequest():
180-
handler_result = await self.handler.on_set_push_notification(
181-
request_obj)
182+
handler_result = await self.handler.set_push_notification(
183+
request_obj
184+
)
182185
case GetTaskPushNotificationConfigRequest():
183-
handler_result = await self.handler.on_get_push_notification(
184-
request_obj)
186+
handler_result = await self.handler.get_push_notification(
187+
request_obj
188+
)
185189
case _:
186190
logger.error(
187191
f'Unhandled validated request type: {type(request_obj)}'
@@ -218,7 +222,7 @@ def _create_response(
218222
if isinstance(handler_result, AsyncGenerator):
219223
# Result is a stream of SendStreamingMessageResponse objects
220224
async def event_generator(
221-
stream: AsyncGenerator[Event, None],
225+
stream: AsyncGenerator[SendStreamingMessageResponse, None],
222226
) -> AsyncGenerator[dict[str, str], None]:
223227
async for item in stream:
224228
yield {'data': item.root.model_dump_json(exclude_none=True)}

src/a2a/server/events/event_consumer.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,16 @@
22
import logging
33

44
from collections.abc import AsyncGenerator
5-
from a2a.utils.errors import ServerError
5+
66
from a2a.server.events.event_queue import Event, EventQueue
77
from a2a.types import (
8-
A2AError,
98
InternalError,
10-
JSONRPCError,
119
Message,
1210
Task,
13-
TaskArtifactUpdateEvent,
14-
TaskStatusUpdateEvent,
1511
TaskState,
12+
TaskStatusUpdateEvent,
1613
)
14+
from a2a.utils.errors import ServerError
1715

1816

1917
logger = logging.getLogger(__name__)
@@ -35,13 +33,14 @@ async def consume_one(self) -> Event:
3533
logger.warning('Event queue was empty in consume_one.')
3634
raise ServerError(
3735
InternalError(message='Agent did not return any response')
38-
)
36+
) from None
37+
38+
logger.debug(f'Dequeued event of type: {type(event)} in consume_one.')
3939

4040
self.queue.task_done()
4141

4242
return event
4343

44-
4544
async def consume_all(self) -> AsyncGenerator[Event]:
4645
"""Consume all the generated streaming events from the agent."""
4746
logger.debug('Starting to consume all events from the queue.')
@@ -53,17 +52,22 @@ async def consume_all(self) -> AsyncGenerator[Event]:
5352
)
5453
yield event
5554
self.queue.task_done()
56-
logger.debug('Marked task as done in event queue in consume_all')
57-
55+
logger.debug(
56+
'Marked task as done in event queue in consume_all'
57+
)
5858

5959
is_final_event = (
60-
(isinstance(event, TaskStatusUpdateEvent) and event.final) or
61-
isinstance(event, Message) or
62-
(isinstance(event, Task) and
63-
event.status.state in (
64-
TaskState.completed, TaskState.canceled, TaskState.failed
65-
)
66-
)
60+
(isinstance(event, TaskStatusUpdateEvent) and event.final)
61+
or isinstance(event, Message)
62+
or (
63+
isinstance(event, Task)
64+
and event.status.state
65+
in (
66+
TaskState.completed,
67+
TaskState.canceled,
68+
TaskState.failed,
69+
)
70+
)
6771
)
6872

6973
if is_final_event:

src/a2a/server/events/event_queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import logging
3+
34
from typing import Any
4-
from pydantic import RootModel
55

66
from a2a.types import (
77
A2AError,

0 commit comments

Comments
 (0)