Skip to content

Commit 1958747

Browse files
committed
feat: Push notification changes
1 parent 8c6560f commit 1958747

File tree

10 files changed

+237
-75
lines changed

10 files changed

+237
-75
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import logging
3-
import uuid
43

54
from collections.abc import AsyncGenerator
65
from typing import cast
@@ -21,13 +20,13 @@
2120
)
2221
from a2a.server.request_handlers.request_handler import RequestHandler
2322
from a2a.server.tasks import (
24-
PushNotifier,
2523
ResultAggregator,
24+
PushNotificationConfigStore,
25+
PushNotificationSender,
2626
TaskManager,
2727
TaskStore,
2828
)
2929
from a2a.types import (
30-
GetTaskPushNotificationConfigParams,
3130
InternalError,
3231
Message,
3332
MessageSendConfiguration,
@@ -39,6 +38,7 @@
3938
TaskPushNotificationConfig,
4039
TaskQueryParams,
4140
UnsupportedOperationError,
41+
GetTaskPushNotificationConfigParams,
4242
)
4343
from a2a.utils.errors import ServerError
4444
from a2a.utils.telemetry import SpanKind, trace_class
@@ -63,7 +63,8 @@ def __init__(
6363
agent_executor: AgentExecutor,
6464
task_store: TaskStore,
6565
queue_manager: QueueManager | None = None,
66-
push_notifier: PushNotifier | None = None,
66+
push_config_store: PushNotificationConfigStore | None = None,
67+
push_sender: PushNotificationSender | None = None,
6768
request_context_builder: RequestContextBuilder | None = None,
6869
) -> None:
6970
"""Initializes the DefaultRequestHandler.
@@ -72,14 +73,16 @@ def __init__(
7273
agent_executor: The `AgentExecutor` instance to run agent logic.
7374
task_store: The `TaskStore` instance to manage task persistence.
7475
queue_manager: The `QueueManager` instance to manage event queues. Defaults to `InMemoryQueueManager`.
75-
push_notifier: The `PushNotifier` instance for sending push notifications. Defaults to None.
76+
push_config_store: The `PushNotificationConfigStore` instance for managing push notification configurations. Defaults to None.
77+
push_sender: The `PushNotificationSender` instance for sending push notifications. Defaults to None.
7678
request_context_builder: The `RequestContextBuilder` instance used
7779
to build request contexts. Defaults to `SimpleRequestContextBuilder`.
7880
"""
7981
self.agent_executor = agent_executor
8082
self.task_store = task_store
8183
self._queue_manager = queue_manager or InMemoryQueueManager()
82-
self._push_notifier = push_notifier
84+
self._push_config_store = push_config_store
85+
self._push_sender = push_sender
8386
self._request_context_builder = (
8487
request_context_builder
8588
or SimpleRequestContextBuilder(
@@ -180,15 +183,15 @@ async def on_message_send(
180183
if task:
181184
task = task_manager.update_with_message(params.message, task)
182185
if self.should_add_push_info(params):
183-
assert isinstance(self._push_notifier, PushNotifier)
186+
assert self._push_config_store is not None
184187
assert isinstance(
185188
params.configuration, MessageSendConfiguration
186189
)
187190
assert isinstance(
188191
params.configuration.pushNotificationConfig,
189192
PushNotificationConfig,
190193
)
191-
await self._push_notifier.set_info(
194+
await self._push_config_store.set_info(
192195
task.id, params.configuration.pushNotificationConfig
193196
)
194197
request_context = await self._request_context_builder.build(
@@ -199,7 +202,7 @@ async def on_message_send(
199202
context=context,
200203
)
201204

202-
task_id = cast('str', request_context.task_id)
205+
task_id = cast(str, request_context.task_id)
203206
# Always assign a task ID. We may not actually upgrade to a task, but
204207
# dictating the task ID at this layer is useful for tracking running
205208
# agents.
@@ -237,7 +240,7 @@ async def on_message_send(
237240
finally:
238241
if interrupted:
239242
# TODO: Track this disconnected cleanup task.
240-
asyncio.create_task( # noqa: RUF006
243+
asyncio.create_task(
241244
self._cleanup_producer(producer_task, task_id)
242245
)
243246
else:
@@ -267,15 +270,15 @@ async def on_message_send_stream(
267270
task = task_manager.update_with_message(params.message, task)
268271

269272
if self.should_add_push_info(params):
270-
assert isinstance(self._push_notifier, PushNotifier)
273+
assert self._push_config_store is not None
271274
assert isinstance(
272275
params.configuration, MessageSendConfiguration
273276
)
274277
assert isinstance(
275278
params.configuration.pushNotificationConfig,
276279
PushNotificationConfig,
277280
)
278-
await self._push_notifier.set_info(
281+
await self._push_config_store.set_info(
279282
task.id, params.configuration.pushNotificationConfig
280283
)
281284
else:
@@ -289,7 +292,7 @@ async def on_message_send_stream(
289292
context=context,
290293
)
291294

292-
task_id = cast('str', request_context.task_id)
295+
task_id = cast(str, request_context.task_id)
293296
queue = await self._queue_manager.create_or_tap(task_id)
294297
producer_task = asyncio.create_task(
295298
self._run_event_stream(
@@ -315,19 +318,19 @@ async def on_message_send_stream(
315318
)
316319

317320
if (
318-
self._push_notifier
321+
self._push_config_store # Check if store is available for config
319322
and params.configuration
320323
and params.configuration.pushNotificationConfig
321324
):
322-
await self._push_notifier.set_info(
325+
await self._push_config_store.set_info(
323326
task_id,
324327
params.configuration.pushNotificationConfig,
325328
)
326329

327-
if self._push_notifier and task_id:
330+
if self._push_sender and task_id: # Check if sender is available
328331
latest_task = await result_aggregator.current_result
329332
if isinstance(latest_task, Task):
330-
await self._push_notifier.send_notification(latest_task)
333+
await self._push_sender.send_notification(latest_task)
331334
yield event
332335
finally:
333336
await self._cleanup_producer(producer_task, task_id)
@@ -359,16 +362,14 @@ async def on_set_task_push_notification_config(
359362
360363
Requires a `PushNotifier` to be configured.
361364
"""
362-
if not self._push_notifier:
365+
if not self._push_config_store:
363366
raise ServerError(error=UnsupportedOperationError())
364367

365368
task: Task | None = await self.task_store.get(params.taskId)
366369
if not task:
367370
raise ServerError(error=TaskNotFoundError())
368371

369-
# Generate a unique id for the notification
370-
params.pushNotificationConfig.id = str(uuid.uuid4())
371-
await self._push_notifier.set_info(
372+
await self._push_config_store.set_info(
372373
params.taskId,
373374
params.pushNotificationConfig,
374375
)
@@ -384,14 +385,14 @@ async def on_get_task_push_notification_config(
384385
385386
Requires a `PushNotifier` to be configured.
386387
"""
387-
if not self._push_notifier:
388+
if not self._push_config_store:
388389
raise ServerError(error=UnsupportedOperationError())
389390

390391
task: Task | None = await self.task_store.get(params.id)
391392
if not task:
392393
raise ServerError(error=TaskNotFoundError())
393394

394-
push_notification_config = await self._push_notifier.get_info(params.id)
395+
push_notification_config = await self._push_config_store.get_info(params.id)
395396
if not push_notification_config:
396397
raise ServerError(error=InternalError())
397398

@@ -431,9 +432,8 @@ async def on_resubscribe_to_task(
431432
yield event
432433

433434
def should_add_push_info(self, params: MessageSendParams) -> bool:
434-
"""Determines if push notification info should be set for a task."""
435435
return bool(
436-
self._push_notifier
436+
self._push_config_store
437437
and params.configuration
438438
and params.configuration.pushNotificationConfig
439439
)

src/a2a/server/request_handlers/jsonrpc_handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
TaskPushNotificationConfig,
3535
TaskResubscriptionRequest,
3636
TaskStatusUpdateEvent,
37+
GetTaskPushNotificationConfigParams
3738
)
3839
from a2a.utils.errors import ServerError
3940
from a2a.utils.helpers import validate

src/a2a/server/tasks/__init__.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
"""Components for managing tasks within the A2A server."""
22

3-
from a2a.server.tasks.inmemory_push_notifier import InMemoryPushNotifier
3+
from a2a.server.tasks.base_push_notification_sender import BasePushNotificationSender
4+
from a2a.server.tasks.inmemory_push_notification_config_store import InMemoryPushNotificationConfigStore
45
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
5-
from a2a.server.tasks.push_notifier import PushNotifier
6+
from a2a.server.tasks.push_notification_config_store import PushNotificationConfigStore
7+
from a2a.server.tasks.push_notification_sender import PushNotificationSender
68
from a2a.server.tasks.result_aggregator import ResultAggregator
79
from a2a.server.tasks.task_manager import TaskManager
810
from a2a.server.tasks.task_store import TaskStore
911
from a2a.server.tasks.task_updater import TaskUpdater
1012

11-
1213
__all__ = [
13-
'InMemoryPushNotifier',
14+
'BasePushNotificationSender',
15+
'InMemoryPushNotificationConfigStore',
1416
'InMemoryTaskStore',
15-
'PushNotifier',
17+
'PushNotificationConfigStore',
18+
'PushNotificationSender',
1619
'ResultAggregator',
1720
'TaskManager',
1821
'TaskStore',
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import logging
2+
3+
import httpx
4+
5+
from a2a.server.tasks.push_notification_config_store import PushNotificationConfigStore
6+
from a2a.server.tasks.push_notification_sender import PushNotificationSender
7+
from a2a.types import Task
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class BasePushNotificationSender(PushNotificationSender):
13+
"""Base implementation of PushNotificationSender interface.
14+
"""
15+
16+
def __init__(self, httpx_client: httpx.AsyncClient, config_store: PushNotificationConfigStore) -> None:
17+
"""Initializes the BasePushNotificationSender.
18+
19+
Args:
20+
httpx_client: An async HTTP client instance to send notifications.
21+
config_store: A PushNotificationConfigStore instance to retrieve configurations.
22+
"""
23+
self._client = httpx_client
24+
self._config_store = config_store
25+
26+
async def send_notification(self, task: Task) -> None:
27+
"""Sends a push notification for a task if configuration exists."""
28+
push_info = await self._config_store.get_info(task.id)
29+
if not push_info:
30+
return
31+
url = push_info.url
32+
33+
try:
34+
response = await self._client.post(
35+
url, json=task.model_dump(mode='json', exclude_none=True)
36+
)
37+
response.raise_for_status()
38+
logger.info(f'Push-notification sent for URL: {url}')
39+
except Exception as e:
40+
logger.error(f'Error sending push-notification: {e}')
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import asyncio
2+
import logging
3+
4+
from a2a.server.tasks.push_notification_config_store import PushNotificationConfigStore
5+
from a2a.types import PushNotificationConfig
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
class InMemoryPushNotificationConfigStore(PushNotificationConfigStore):
11+
"""In-memory implementation of PushNotificationConfigStore interface.
12+
13+
Stores push notification configurations in memory and uses an httpx client
14+
to send notifications.
15+
"""
16+
def __init__(self) -> None:
17+
"""Initializes the InMemoryPushNotifier.
18+
19+
Args:
20+
httpx_client: An async HTTP client instance to send notifications.
21+
"""
22+
self.lock = asyncio.Lock()
23+
self._push_notification_infos: dict[str, PushNotificationConfig] = {}
24+
25+
async def set_info(
26+
self, task_id: str, notification_config: PushNotificationConfig
27+
):
28+
"""Sets or updates the push notification configuration for a task in memory."""
29+
async with self.lock:
30+
self._push_notification_infos[task_id] = notification_config
31+
32+
async def get_info(self, task_id: str) -> PushNotificationConfig | None:
33+
"""Retrieves the push notification configuration for a task from memory."""
34+
async with self.lock:
35+
return self._push_notification_infos.get(task_id)
36+
37+
38+
39+
async def delete_info(self, task_id: str):
40+
"""Deletes the push notification configuration for a task from memory."""
41+
async with self.lock:
42+
if task_id in self._push_notification_infos:
43+
del self._push_notification_infos[task_id]
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from abc import ABC, abstractmethod
2+
3+
from a2a.types import PushNotificationConfig
4+
5+
6+
class PushNotificationConfigStore(ABC):
7+
"""Interface for storing and retrieving push notification configurations for tasks."""
8+
9+
@abstractmethod
10+
async def set_info(self, task_id: str, notification_config: PushNotificationConfig):
11+
"""Sets or updates the push notification configuration for a task."""
12+
13+
@abstractmethod
14+
async def get_info(self, task_id: str) -> PushNotificationConfig | None:
15+
"""Retrieves the push notification configuration for a task."""
16+
17+
@abstractmethod
18+
async def delete_info(self, task_id: str):
19+
"""Deletes the push notification configuration for a task."""
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from abc import ABC, abstractmethod
2+
3+
from a2a.types import Task
4+
5+
6+
class PushNotificationSender(ABC):
7+
"""Interface for sending push notifications for tasks."""
8+
9+
@abstractmethod
10+
async def send_notification(self, task: Task) -> None:
11+
"""Sends a push notification containing the latest task state."""

0 commit comments

Comments
 (0)