Skip to content

Commit 0e77460

Browse files
committed
Merge branch 'fix/logging-and-string-formatting' of github.com:ahoblitz/a2a-python into fix/logging-and-string-formatting
2 parents c7b39d5 + 7d0c034 commit 0e77460

File tree

13 files changed

+536
-130
lines changed

13 files changed

+536
-130
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ __pycache__
99
test_venv/
1010
coverage.xml
1111
.nox
12-
spec.json
12+
spec.json

.ruff.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ ignore = [
3030
"ANN003",
3131
"ANN401",
3232
"TRY003",
33-
"G004",
3433
"TRY201",
3534
"FIX002",
3635
]

src/a2a/grpc/a2a_pb2.py

Lines changed: 83 additions & 83 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/a2a/grpc/a2a_pb2.pyi

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class SendMessageConfiguration(_message.Message):
5555
push_notification: PushNotificationConfig
5656
history_length: int
5757
blocking: bool
58-
def __init__(self, accepted_output_modes: _Optional[_Iterable[str]] = ..., push_notification: _Optional[_Union[PushNotificationConfig, _Mapping]] = ..., history_length: _Optional[int] = ..., blocking: bool = ...) -> None: ...
58+
def __init__(self, accepted_output_modes: _Optional[_Iterable[str]] = ..., push_notification: _Optional[_Union[PushNotificationConfig, _Mapping]] = ..., history_length: _Optional[int] = ..., blocking: _Optional[bool] = ...) -> None: ...
5959

6060
class Task(_message.Message):
6161
__slots__ = ("id", "context_id", "status", "artifacts", "history", "metadata")
@@ -157,7 +157,7 @@ class TaskStatusUpdateEvent(_message.Message):
157157
status: TaskStatus
158158
final: bool
159159
metadata: _struct_pb2.Struct
160-
def __init__(self, task_id: _Optional[str] = ..., context_id: _Optional[str] = ..., status: _Optional[_Union[TaskStatus, _Mapping]] = ..., final: bool = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
160+
def __init__(self, task_id: _Optional[str] = ..., context_id: _Optional[str] = ..., status: _Optional[_Union[TaskStatus, _Mapping]] = ..., final: _Optional[bool] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
161161

162162
class TaskArtifactUpdateEvent(_message.Message):
163163
__slots__ = ("task_id", "context_id", "artifact", "append", "last_chunk", "metadata")
@@ -173,7 +173,7 @@ class TaskArtifactUpdateEvent(_message.Message):
173173
append: bool
174174
last_chunk: bool
175175
metadata: _struct_pb2.Struct
176-
def __init__(self, task_id: _Optional[str] = ..., context_id: _Optional[str] = ..., artifact: _Optional[_Union[Artifact, _Mapping]] = ..., append: bool = ..., last_chunk: bool = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
176+
def __init__(self, task_id: _Optional[str] = ..., context_id: _Optional[str] = ..., artifact: _Optional[_Union[Artifact, _Mapping]] = ..., append: _Optional[bool] = ..., last_chunk: _Optional[bool] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
177177

178178
class PushNotificationConfig(_message.Message):
179179
__slots__ = ("id", "url", "token", "authentication")
@@ -204,7 +204,7 @@ class AgentInterface(_message.Message):
204204
def __init__(self, url: _Optional[str] = ..., transport: _Optional[str] = ...) -> None: ...
205205

206206
class AgentCard(_message.Message):
207-
__slots__ = ("protocol_version", "name", "description", "url", "preferred_transport", "additional_interfaces", "provider", "version", "documentation_url", "capabilities", "security_schemes", "security", "default_input_modes", "default_output_modes", "skills", "supports_authenticated_extended_card", "signatures")
207+
__slots__ = ("protocol_version", "name", "description", "url", "preferred_transport", "additional_interfaces", "provider", "version", "documentation_url", "capabilities", "security_schemes", "security", "default_input_modes", "default_output_modes", "skills", "supports_authenticated_extended_card", "signatures", "icon_url")
208208
class SecuritySchemesEntry(_message.Message):
209209
__slots__ = ("key", "value")
210210
KEY_FIELD_NUMBER: _ClassVar[int]
@@ -229,6 +229,7 @@ class AgentCard(_message.Message):
229229
SKILLS_FIELD_NUMBER: _ClassVar[int]
230230
SUPPORTS_AUTHENTICATED_EXTENDED_CARD_FIELD_NUMBER: _ClassVar[int]
231231
SIGNATURES_FIELD_NUMBER: _ClassVar[int]
232+
ICON_URL_FIELD_NUMBER: _ClassVar[int]
232233
protocol_version: str
233234
name: str
234235
description: str
@@ -246,7 +247,8 @@ class AgentCard(_message.Message):
246247
skills: _containers.RepeatedCompositeFieldContainer[AgentSkill]
247248
supports_authenticated_extended_card: bool
248249
signatures: _containers.RepeatedCompositeFieldContainer[AgentCardSignature]
249-
def __init__(self, protocol_version: _Optional[str] = ..., name: _Optional[str] = ..., description: _Optional[str] = ..., url: _Optional[str] = ..., preferred_transport: _Optional[str] = ..., additional_interfaces: _Optional[_Iterable[_Union[AgentInterface, _Mapping]]] = ..., provider: _Optional[_Union[AgentProvider, _Mapping]] = ..., version: _Optional[str] = ..., documentation_url: _Optional[str] = ..., capabilities: _Optional[_Union[AgentCapabilities, _Mapping]] = ..., security_schemes: _Optional[_Mapping[str, SecurityScheme]] = ..., security: _Optional[_Iterable[_Union[Security, _Mapping]]] = ..., default_input_modes: _Optional[_Iterable[str]] = ..., default_output_modes: _Optional[_Iterable[str]] = ..., skills: _Optional[_Iterable[_Union[AgentSkill, _Mapping]]] = ..., supports_authenticated_extended_card: bool = ..., signatures: _Optional[_Iterable[_Union[AgentCardSignature, _Mapping]]] = ...) -> None: ...
250+
icon_url: str
251+
def __init__(self, protocol_version: _Optional[str] = ..., name: _Optional[str] = ..., description: _Optional[str] = ..., url: _Optional[str] = ..., preferred_transport: _Optional[str] = ..., additional_interfaces: _Optional[_Iterable[_Union[AgentInterface, _Mapping]]] = ..., provider: _Optional[_Union[AgentProvider, _Mapping]] = ..., version: _Optional[str] = ..., documentation_url: _Optional[str] = ..., capabilities: _Optional[_Union[AgentCapabilities, _Mapping]] = ..., security_schemes: _Optional[_Mapping[str, SecurityScheme]] = ..., security: _Optional[_Iterable[_Union[Security, _Mapping]]] = ..., default_input_modes: _Optional[_Iterable[str]] = ..., default_output_modes: _Optional[_Iterable[str]] = ..., skills: _Optional[_Iterable[_Union[AgentSkill, _Mapping]]] = ..., supports_authenticated_extended_card: _Optional[bool] = ..., signatures: _Optional[_Iterable[_Union[AgentCardSignature, _Mapping]]] = ..., icon_url: _Optional[str] = ...) -> None: ...
250252

251253
class AgentProvider(_message.Message):
252254
__slots__ = ("url", "organization")
@@ -264,7 +266,7 @@ class AgentCapabilities(_message.Message):
264266
streaming: bool
265267
push_notifications: bool
266268
extensions: _containers.RepeatedCompositeFieldContainer[AgentExtension]
267-
def __init__(self, streaming: bool = ..., push_notifications: bool = ..., extensions: _Optional[_Iterable[_Union[AgentExtension, _Mapping]]] = ...) -> None: ...
269+
def __init__(self, streaming: _Optional[bool] = ..., push_notifications: _Optional[bool] = ..., extensions: _Optional[_Iterable[_Union[AgentExtension, _Mapping]]] = ...) -> None: ...
268270

269271
class AgentExtension(_message.Message):
270272
__slots__ = ("uri", "description", "required", "params")
@@ -276,7 +278,7 @@ class AgentExtension(_message.Message):
276278
description: str
277279
required: bool
278280
params: _struct_pb2.Struct
279-
def __init__(self, uri: _Optional[str] = ..., description: _Optional[str] = ..., required: bool = ..., params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
281+
def __init__(self, uri: _Optional[str] = ..., description: _Optional[str] = ..., required: _Optional[bool] = ..., params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
280282

281283
class AgentSkill(_message.Message):
282284
__slots__ = ("id", "name", "description", "tags", "examples", "input_modes", "output_modes", "security")
@@ -486,12 +488,14 @@ class SendMessageRequest(_message.Message):
486488
def __init__(self, request: _Optional[_Union[Message, _Mapping]] = ..., configuration: _Optional[_Union[SendMessageConfiguration, _Mapping]] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
487489

488490
class GetTaskRequest(_message.Message):
489-
__slots__ = ("name", "history_length")
491+
__slots__ = ("name", "history_length", "metadata")
490492
NAME_FIELD_NUMBER: _ClassVar[int]
491493
HISTORY_LENGTH_FIELD_NUMBER: _ClassVar[int]
494+
METADATA_FIELD_NUMBER: _ClassVar[int]
492495
name: str
493496
history_length: int
494-
def __init__(self, name: _Optional[str] = ..., history_length: _Optional[int] = ...) -> None: ...
497+
metadata: _struct_pb2.Struct
498+
def __init__(self, name: _Optional[str] = ..., history_length: _Optional[int] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
495499

496500
class CancelTaskRequest(_message.Message):
497501
__slots__ = ("name",)

src/a2a/server/events/event_consumer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def consume_all(self) -> AsyncGenerator[Event]:
125125
# other part is waiting for an event or a closed queue.
126126
if is_final_event:
127127
logger.debug('Stopping event consumption in consume_all.')
128-
await self.queue.close()
128+
await self.queue.close(True)
129129
yield event
130130
break
131131
yield event
@@ -135,7 +135,7 @@ async def consume_all(self) -> AsyncGenerator[Event]:
135135
except asyncio.TimeoutError: # pyright: ignore [reportUnusedExcept]
136136
# This class was made an alias of build-in TimeoutError after 3.11
137137
continue
138-
except QueueClosed:
138+
except (QueueClosed, asyncio.QueueEmpty):
139139
# Confirm that the queue is closed, e.g. we aren't on
140140
# python 3.12 and get a queue empty error on an open queue
141141
if self.queue.is_closed():

src/a2a/server/events/event_queue.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ async def dequeue_event(self, no_wait: bool = False) -> Event:
9090
asyncio.QueueShutDown: If the queue has been closed and is empty.
9191
"""
9292
async with self._lock:
93-
if self._is_closed and self.queue.empty():
93+
if (
94+
sys.version_info < (3, 13)
95+
and self._is_closed
96+
and self.queue.empty()
97+
):
98+
# On 3.13+, skip early raise; await self.queue.get() will raise QueueShutDown after shutdown()
9499
logger.warning('Queue is closed. Event will not be dequeued.')
95100
raise asyncio.QueueEmpty('Queue is closed.')
96101

@@ -127,25 +132,38 @@ def tap(self) -> 'EventQueue':
127132
self._children.append(queue)
128133
return queue
129134

130-
async def close(self) -> None:
131-
"""Closes the queue for future push events.
135+
async def close(self, immediate: bool = False) -> None:
136+
"""Closes the queue for future push events and also closes all child queues.
137+
138+
Once closed, no new events can be enqueued. For Python 3.13+, this will trigger
139+
`asyncio.QueueShutDown` when the queue is empty and a consumer tries to dequeue.
140+
For lower versions, the queue will be marked as closed and optionally cleared.
141+
142+
Args:
143+
immediate (bool):
144+
- True: Immediately closes the queue and clears all unprocessed events without waiting for them to be consumed. This is suitable for scenarios where you need to forcefully interrupt and quickly release resources.
145+
- False (default): Gracefully closes the queue, waiting for all queued events to be processed (i.e., the queue is drained) before closing. This is suitable when you want to ensure all events are handled.
132146
133-
Once closed, `dequeue_event` will eventually raise `asyncio.QueueShutDown`
134-
when the queue is empty. Also closes all child queues.
135147
"""
136148
logger.debug('Closing EventQueue.')
137149
async with self._lock:
138150
# If already closed, just return.
139-
if self._is_closed:
151+
if self._is_closed and not immediate:
140152
return
141-
self._is_closed = True
153+
if not self._is_closed:
154+
self._is_closed = True
142155
# If using python 3.13 or higher, use the shutdown method
143156
if sys.version_info >= (3, 13):
144-
self.queue.shutdown()
157+
self.queue.shutdown(immediate)
145158
for child in self._children:
146-
await child.close()
159+
await child.close(immediate)
147160
# Otherwise, join the queue
148161
else:
162+
if immediate:
163+
await self.clear_events(True)
164+
for child in self._children:
165+
await child.close(immediate)
166+
return
149167
tasks = [asyncio.create_task(self.queue.join())]
150168
tasks.extend(
151169
asyncio.create_task(child.close()) for child in self._children
@@ -155,3 +173,53 @@ async def close(self) -> None:
155173
def is_closed(self) -> bool:
156174
"""Checks if the queue is closed."""
157175
return self._is_closed
176+
177+
async def clear_events(self, clear_child_queues: bool = True) -> None:
178+
"""Clears all events from the current queue and optionally all child queues.
179+
180+
This method removes all pending events from the queue without processing them.
181+
Child queues can be optionally cleared based on the clear_child_queues parameter.
182+
183+
Args:
184+
clear_child_queues: If True (default), clear all child queues as well.
185+
If False, only clear the current queue, leaving child queues untouched.
186+
"""
187+
logger.debug('Clearing all events from EventQueue and child queues.')
188+
189+
# Clear all events from the queue, even if closed
190+
cleared_count = 0
191+
async with self._lock:
192+
try:
193+
while True:
194+
event = self.queue.get_nowait()
195+
logger.debug(
196+
f'Discarding unprocessed event of type: {type(event)}, content: {event}'
197+
)
198+
self.queue.task_done()
199+
cleared_count += 1
200+
except asyncio.QueueEmpty:
201+
pass
202+
except Exception as e:
203+
# Handle Python 3.13+ QueueShutDown
204+
if (
205+
sys.version_info >= (3, 13)
206+
and type(e).__name__ == 'QueueShutDown'
207+
):
208+
pass
209+
else:
210+
raise
211+
212+
if cleared_count > 0:
213+
logger.debug(
214+
f'Cleared {cleared_count} unprocessed events from EventQueue.'
215+
)
216+
217+
# Clear all child queues (lock released before awaiting child tasks)
218+
if clear_child_queues and self._children:
219+
child_tasks = [
220+
asyncio.create_task(child.clear_events())
221+
for child in self._children
222+
]
223+
224+
if child_tasks:
225+
await asyncio.gather(*child_tasks, return_exceptions=True)

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,11 +288,19 @@ async def on_message_send(
288288

289289
interrupted_or_non_blocking = False
290290
try:
291+
# Create async callback for push notifications
292+
async def push_notification_callback() -> None:
293+
await self._send_push_notification_if_needed(
294+
task_id, result_aggregator
295+
)
296+
291297
(
292298
result,
293299
interrupted_or_non_blocking,
294300
) = await result_aggregator.consume_and_break_on_interrupt(
295-
consumer, blocking=blocking
301+
consumer,
302+
blocking=blocking,
303+
event_callback=push_notification_callback,
296304
)
297305
if not result:
298306
raise ServerError(error=InternalError()) # noqa: TRY301

src/a2a/server/tasks/result_aggregator.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import logging
33

4-
from collections.abc import AsyncGenerator, AsyncIterator
4+
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
55

66
from a2a.server.events import Event, EventConsumer
77
from a2a.server.tasks.task_manager import TaskManager
@@ -24,7 +24,10 @@ class ResultAggregator:
2424
Task object and emit that Task object.
2525
"""
2626

27-
def __init__(self, task_manager: TaskManager):
27+
def __init__(
28+
self,
29+
task_manager: TaskManager,
30+
) -> None:
2831
"""Initializes the ResultAggregator.
2932
3033
Args:
@@ -92,7 +95,10 @@ async def consume_all(
9295
return await self.task_manager.get_task()
9396

9497
async def consume_and_break_on_interrupt(
95-
self, consumer: EventConsumer, blocking: bool = True
98+
self,
99+
consumer: EventConsumer,
100+
blocking: bool = True,
101+
event_callback: Callable[[], Awaitable[None]] | None = None,
96102
) -> tuple[Task | Message | None, bool]:
97103
"""Processes the event stream until completion or an interruptable state is encountered.
98104
@@ -105,6 +111,9 @@ async def consume_and_break_on_interrupt(
105111
consumer: The `EventConsumer` to read events from.
106112
blocking: If `False`, the method returns as soon as a task/message
107113
is available. If `True`, it waits for a terminal state.
114+
event_callback: Optional async callback function to be called after each event
115+
is processed in the background continuation.
116+
Mainly used for push notifications currently.
108117
109118
Returns:
110119
A tuple containing:
@@ -150,13 +159,17 @@ async def consume_and_break_on_interrupt(
150159
if should_interrupt:
151160
# Continue consuming the rest of the events in the background.
152161
# TODO: We should track all outstanding tasks to ensure they eventually complete.
153-
asyncio.create_task(self._continue_consuming(event_stream)) # noqa: RUF006
162+
asyncio.create_task( # noqa: RUF006
163+
self._continue_consuming(event_stream, event_callback)
164+
)
154165
interrupted = True
155166
break
156167
return await self.task_manager.get_task(), interrupted
157168

158169
async def _continue_consuming(
159-
self, event_stream: AsyncIterator[Event]
170+
self,
171+
event_stream: AsyncIterator[Event],
172+
event_callback: Callable[[], Awaitable[None]] | None = None,
160173
) -> None:
161174
"""Continues processing an event stream in a background task.
162175
@@ -165,6 +178,9 @@ async def _continue_consuming(
165178
166179
Args:
167180
event_stream: The remaining `AsyncIterator` of events from the consumer.
181+
event_callback: Optional async callback function to be called after each event is processed.
168182
"""
169183
async for event in event_stream:
170184
await self.task_manager.process(event)
185+
if event_callback:
186+
await event_callback()

0 commit comments

Comments
 (0)