Skip to content

Commit b772090

Browse files
committed
feat: Support for Database based Push Config Store
1 parent 9d6cb68 commit b772090

File tree

4 files changed

+316
-2
lines changed

4 files changed

+316
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"]
4141
mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"]
4242
sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"]
4343
sql = ["sqlalchemy[asyncio,postgresql-asyncpg,aiomysql,aiosqlite]>=2.0.0"]
44+
encryption = ["cryptography>=43.0.0"]
4445

4546
[project.urls]
4647
homepage = "https://a2aproject.github.io/A2A/"

src/a2a/server/models.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def override(func): # noqa: ANN001, ANN201
1616

1717

1818
try:
19-
from sqlalchemy import JSON, Dialect, String
19+
from sqlalchemy import JSON, Dialect, LargeBinary, String
2020
from sqlalchemy.orm import (
2121
DeclarativeBase,
2222
Mapped,
@@ -208,3 +208,58 @@ class TaskModel(TaskMixin, Base):
208208
"""Default task model with standard table name."""
209209

210210
__tablename__ = 'tasks'
211+
212+
213+
# PushNotificationConfigMixin that can be used with any table name
214+
class PushNotificationConfigMixin:
215+
"""Mixin providing standard push notification config columns."""
216+
217+
task_id: Mapped[str] = mapped_column(String(36), primary_key=True)
218+
config_id: Mapped[str] = mapped_column(String(255), primary_key=True)
219+
config_data: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
220+
221+
@override
222+
def __repr__(self) -> str:
223+
"""Return a string representation of the push notification config."""
224+
repr_template = '<{CLS}(task_id="{TID}", config_id="{CID}")>'
225+
return repr_template.format(
226+
CLS=self.__class__.__name__,
227+
TID=self.task_id,
228+
CID=self.config_id,
229+
)
230+
231+
232+
def create_push_notification_config_model(
233+
table_name: str = 'push_notification_configs',
234+
base: type[DeclarativeBase] = Base,
235+
) -> type:
236+
"""Create a PushNotificationConfigModel class with a configurable table name."""
237+
238+
class PushNotificationConfigModel(PushNotificationConfigMixin, base):
239+
__tablename__ = table_name
240+
241+
@override
242+
def __repr__(self) -> str:
243+
"""Return a string representation of the push notification config."""
244+
repr_template = '<PushNotificationConfigModel[{TABLE}](task_id="{TID}", config_id="{CID}")>'
245+
return repr_template.format(
246+
TABLE=table_name,
247+
TID=self.task_id,
248+
CID=self.config_id,
249+
)
250+
251+
PushNotificationConfigModel.__name__ = (
252+
f'PushNotificationConfigModel_{table_name}'
253+
)
254+
PushNotificationConfigModel.__qualname__ = (
255+
f'PushNotificationConfigModel_{table_name}'
256+
)
257+
258+
return PushNotificationConfigModel
259+
260+
261+
# Default PushNotificationConfigModel for backward compatibility
262+
class PushNotificationConfigModel(PushNotificationConfigMixin, Base):
263+
"""Default push notification config model with standard table name."""
264+
265+
__tablename__ = 'push_notification_configs'

src/a2a/server/tasks/base_push_notification_sender.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,13 @@ async def _dispatch_notification(
5252
) -> bool:
5353
url = push_info.url
5454
try:
55+
headers = None
56+
if push_info.token:
57+
headers = {'X-A2A-Notification-Token': push_info.token}
5558
response = await self._client.post(
56-
url, json=task.model_dump(mode='json', exclude_none=True)
59+
url,
60+
json=task.model_dump(mode='json', exclude_none=True),
61+
headers=headers,
5762
)
5863
response.raise_for_status()
5964
logger.info(
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
import json
2+
import logging
3+
4+
from typing import TYPE_CHECKING
5+
6+
7+
try:
8+
from sqlalchemy import (
9+
delete,
10+
select,
11+
)
12+
from sqlalchemy.ext.asyncio import (
13+
AsyncEngine,
14+
AsyncSession,
15+
async_sessionmaker,
16+
)
17+
except ImportError as e:
18+
raise ImportError(
19+
'DatabasePushNotificationConfigStore requires SQLAlchemy and a database driver. '
20+
'Install with one of: '
21+
"'pip install a2a-sdk[postgresql]', "
22+
"'pip install a2a-sdk[mysql]', "
23+
"'pip install a2a-sdk[sqlite]', "
24+
"or 'pip install a2a-sdk[sql]'"
25+
) from e
26+
27+
from a2a.server.models import (
28+
Base,
29+
PushNotificationConfigModel,
30+
create_push_notification_config_model,
31+
)
32+
from a2a.server.tasks.push_notification_config_store import (
33+
PushNotificationConfigStore,
34+
)
35+
from a2a.types import PushNotificationConfig
36+
37+
38+
if TYPE_CHECKING:
39+
from cryptography.fernet import Fernet
40+
41+
42+
logger = logging.getLogger(__name__)
43+
44+
45+
class DatabasePushNotificationConfigStore(PushNotificationConfigStore):
46+
"""SQLAlchemy-based implementation of PushNotificationConfigStore.
47+
48+
Stores push notification configurations in a database supported by SQLAlchemy.
49+
"""
50+
51+
engine: AsyncEngine
52+
async_session_maker: async_sessionmaker[AsyncSession]
53+
create_table: bool
54+
_initialized: bool
55+
config_model: type[PushNotificationConfigModel]
56+
_fernet: 'Fernet | None'
57+
58+
def __init__(
59+
self,
60+
engine: AsyncEngine,
61+
create_table: bool = True,
62+
table_name: str = 'push_notification_configs',
63+
encryption_key: str | bytes | None = None,
64+
) -> None:
65+
"""Initializes the DatabasePushNotificationConfigStore.
66+
67+
Args:
68+
engine: An existing SQLAlchemy AsyncEngine to be used by the store.
69+
create_table: If true, create the table on initialization.
70+
table_name: Name of the database table. Defaults to 'push_notification_configs'.
71+
encryption_key: A key for encrypting sensitive configuration data.
72+
If provided, `config_data` will be encrypted in the database.
73+
The key must be a URL-safe base64-encoded 32-byte key.
74+
"""
75+
logger.debug(
76+
f'Initializing DatabasePushNotificationConfigStore with existing engine, table: {table_name}'
77+
)
78+
self.engine = engine
79+
self.async_session_maker = async_sessionmaker(
80+
self.engine, expire_on_commit=False
81+
)
82+
self.create_table = create_table
83+
self._initialized = False
84+
self.config_model = (
85+
PushNotificationConfigModel
86+
if table_name == 'push_notification_configs'
87+
else create_push_notification_config_model(table_name)
88+
)
89+
self._fernet = None
90+
91+
if encryption_key:
92+
try:
93+
from cryptography.fernet import Fernet # noqa: PLC0415
94+
except ImportError as e:
95+
raise ImportError(
96+
"DatabasePushNotificationConfigStore with encryption requires the 'cryptography' "
97+
'library. Install with: '
98+
"'pip install a2a-sdk[encryption]'"
99+
) from e
100+
101+
if isinstance(encryption_key, str):
102+
encryption_key = encryption_key.encode('utf-8')
103+
self._fernet = Fernet(encryption_key)
104+
logger.debug(
105+
'Encryption enabled for push notification config store.'
106+
)
107+
108+
async def initialize(self) -> None:
109+
"""Initialize the database and create the table if needed."""
110+
if self._initialized:
111+
return
112+
113+
logger.debug(
114+
'Initializing database schema for push notification configs...'
115+
)
116+
if self.create_table:
117+
async with self.engine.begin() as conn:
118+
await conn.run_sync(Base.metadata.create_all)
119+
self._initialized = True
120+
logger.debug(
121+
'Database schema for push notification configs initialized.'
122+
)
123+
124+
async def _ensure_initialized(self) -> None:
125+
"""Ensure the database connection is initialized."""
126+
if not self._initialized:
127+
await self.initialize()
128+
129+
def _to_orm(
130+
self, task_id: str, config: PushNotificationConfig
131+
) -> PushNotificationConfigModel:
132+
"""Maps a Pydantic PushNotificationConfig to a SQLAlchemy model instance.
133+
134+
The config data is serialized to JSON bytes, and encrypted if a key is configured.
135+
"""
136+
json_payload = config.model_dump_json().encode('utf-8')
137+
138+
if self._fernet:
139+
data_to_store = self._fernet.encrypt(json_payload)
140+
else:
141+
data_to_store = json_payload
142+
143+
return self.config_model(
144+
task_id=task_id,
145+
config_id=config.id,
146+
config_data=data_to_store,
147+
)
148+
149+
def _from_orm(
150+
self, model_instance: PushNotificationConfigModel
151+
) -> PushNotificationConfig:
152+
"""Maps a SQLAlchemy model instance to a Pydantic PushNotificationConfig.
153+
154+
Handles decryption if a key is configured.
155+
"""
156+
payload = model_instance.config_data
157+
158+
if self._fernet:
159+
from cryptography.fernet import InvalidToken # noqa: PLC0415
160+
161+
try:
162+
decrypted_payload = self._fernet.decrypt(payload)
163+
return PushNotificationConfig.model_validate_json(
164+
decrypted_payload
165+
)
166+
except InvalidToken:
167+
# This could be unencrypted data if encryption was enabled after data was stored.
168+
# We'll fall through and try to parse it as plain JSON.
169+
logger.debug(
170+
'Could not decrypt config for task %s, config %s. '
171+
'Attempting to parse as unencrypted JSON.',
172+
model_instance.task_id,
173+
model_instance.config_id,
174+
)
175+
176+
# If no fernet or if decryption failed, try to parse as plain JSON.
177+
try:
178+
return PushNotificationConfig.model_validate_json(payload)
179+
except json.JSONDecodeError as e:
180+
if self._fernet:
181+
raise ValueError(
182+
'Failed to decrypt data; incorrect key or corrupted data.'
183+
) from e
184+
raise ValueError(
185+
'Failed to parse data; it may be encrypted but no key is configured.'
186+
) from e
187+
188+
async def set_info(
189+
self, task_id: str, notification_config: PushNotificationConfig
190+
) -> None:
191+
"""Sets or updates the push notification configuration for a task."""
192+
await self._ensure_initialized()
193+
194+
config_to_save = notification_config.model_copy()
195+
if config_to_save.id is None:
196+
config_to_save.id = task_id
197+
198+
db_config = self._to_orm(task_id, config_to_save)
199+
async with self.async_session_maker.begin() as session:
200+
await session.merge(db_config)
201+
logger.debug(
202+
f'Push notification config for task {task_id} with config id {config_to_save.id} saved/updated.'
203+
)
204+
205+
async def get_info(self, task_id: str) -> list[PushNotificationConfig]:
206+
"""Retrieves all push notification configurations for a task."""
207+
await self._ensure_initialized()
208+
async with self.async_session_maker() as session:
209+
stmt = select(self.config_model).where(
210+
self.config_model.task_id == task_id
211+
)
212+
result = await session.execute(stmt)
213+
models = result.scalars().all()
214+
215+
configs = []
216+
for model in models:
217+
try:
218+
configs.append(self._from_orm(model))
219+
except ValueError as e:
220+
logger.error(
221+
'Could not deserialize push notification config for task %s, config %s: %s',
222+
model.task_id,
223+
model.config_id,
224+
e,
225+
)
226+
return configs
227+
228+
async def delete_info(
229+
self, task_id: str, config_id: str | None = None
230+
) -> None:
231+
"""Deletes push notification configurations for a task.
232+
233+
If config_id is provided, only that specific configuration is deleted.
234+
If config_id is None, all configurations for the task are deleted.
235+
"""
236+
await self._ensure_initialized()
237+
async with self.async_session_maker.begin() as session:
238+
stmt = delete(self.config_model).where(
239+
self.config_model.task_id == task_id
240+
)
241+
if config_id is not None:
242+
stmt = stmt.where(self.config_model.config_id == config_id)
243+
244+
result = await session.execute(stmt)
245+
246+
if result.rowcount > 0:
247+
logger.info(
248+
f'Deleted {result.rowcount} push notification config(s) for task {task_id}.'
249+
)
250+
else:
251+
logger.warning(
252+
f'Attempted to delete non-existent push notification config for task {task_id} with config_id: {config_id}'
253+
)

0 commit comments

Comments
 (0)