Skip to content

Commit d9f0ff6

Browse files
committed
PushNotificationConfig DB backed model
1 parent 88a9f38 commit d9f0ff6

File tree

3 files changed

+828
-616
lines changed

3 files changed

+828
-616
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,22 @@ classifiers = [
3232

3333
[project.optional-dependencies]
3434
postgresql = [
35+
"cryptography>=45.0.3",
3536
"sqlalchemy>=2.0.0",
3637
"asyncpg>=0.30.0",
3738
]
3839
mysql = [
40+
"cryptography>=45.0.3",
3941
"sqlalchemy>=2.0.0",
4042
"aiomysql>=0.2.0",
4143
]
4244
sqlite = [
45+
"cryptography>=45.0.3",
4346
"sqlalchemy>=2.0.0",
4447
"aiosqlite>=0.19.0",
4548
]
4649
sql = [
50+
"cryptography>=45.0.3",
4751
"sqlalchemy>=2.0.0",
4852
"asyncpg>=0.30.0",
4953
"aiomysql>=0.2.0",

src/a2a/server/models.py

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ def override(func): # noqa: D103
1111

1212
from pydantic import BaseModel
1313

14-
from a2a.types import Artifact, Message, TaskStatus
15-
14+
from a2a.types import Artifact, Message, TaskStatus, PushNotificationAuthenticationInfo
1615

1716
try:
18-
from sqlalchemy import JSON, Dialect, String
17+
from cryptography.fernet import Fernet
18+
from sqlalchemy import JSON, Dialect, String, LargeBinary
1919
from sqlalchemy.orm import (
2020
DeclarativeBase,
2121
Mapped,
@@ -36,6 +36,12 @@ def override(func): # noqa: D103
3636

3737
T = TypeVar('T', bound=BaseModel)
3838

39+
_ENCRYPTION_KEY: bytes | None = None
40+
41+
def set_model_encryption_key(key: bytes) -> None:
42+
"""Sets the encryption key used for encrypting model data in the database."""
43+
global _ENCRYPTION_KEY
44+
_ENCRYPTION_KEY = key
3945

4046
class PydanticType(TypeDecorator[T], Generic[T]):
4147
"""SQLAlchemy type that handles Pydantic model serialization."""
@@ -67,6 +73,42 @@ def process_result_value(
6773
return None
6874
return self.pydantic_type.model_validate(value)
6975

76+
class EncryptedPydanticType(TypeDecorator[T], Generic[T]):
77+
"""SQLAlchemy type that handles Pydantic model serialization with encryption."""
78+
79+
impl = LargeBinary # Store encrypted data as binary
80+
cache_ok = True
81+
82+
def __init__(self, pydantic_type: type[T], **kwargs: Any):
83+
super().__init__(**kwargs)
84+
if _ENCRYPTION_KEY is None:
85+
raise RuntimeError(
86+
"Encryption key not set for models. "
87+
"Call a2a.server.models.set_model_encryption_key(key) before model definition or usage."
88+
)
89+
self.pydantic_type = pydantic_type
90+
self.fernet = Fernet(_ENCRYPTION_KEY)
91+
92+
@override
93+
def process_bind_param(
94+
self, value: T | None, dialect: Dialect
95+
) -> bytes | None:
96+
if value is None:
97+
return None
98+
# Pydantic model to JSON string, then encode to bytes for encryption
99+
json_string = value.model_dump_json()
100+
encrypted_data = self.fernet.encrypt(json_string.encode('utf-8'))
101+
return encrypted_data
102+
103+
@override
104+
def process_result_value(
105+
self, value: bytes | None, dialect: Dialect
106+
) -> T | None:
107+
if value is None:
108+
return None
109+
# Decrypt bytes to JSON string, then parse with Pydantic
110+
decrypted_json_string = self.fernet.decrypt(value).decode('utf-8')
111+
return self.pydantic_type.model_validate_json(decrypted_json_string)
70112

71113
class PydanticListType(TypeDecorator[list[T]], Generic[T]):
72114
"""SQLAlchemy type that handles lists of Pydantic models."""
@@ -166,7 +208,7 @@ def create_task_model(
166208
TaskModel = create_task_model('tasks', MyBase)
167209
"""
168210

169-
class TaskModel(TaskMixin, base):
211+
class TaskModel(TaskMixin, base): # type: ignore
170212
__tablename__ = table_name
171213

172214
@override
@@ -192,3 +234,48 @@ class TaskModel(TaskMixin, Base):
192234
"""Default task model with standard table name."""
193235

194236
__tablename__ = 'tasks'
237+
238+
239+
class PushNotificationConfigMixin:
240+
"""Mixin providing standard columns for push notification configuration."""
241+
id: Mapped[str] = mapped_column(String, primary_key=True, index=True)
242+
task_id: Mapped[str] = mapped_column(String, nullable=False)
243+
url: Mapped[str] = mapped_column(String, nullable=False)
244+
token: Mapped[str | None] = mapped_column(String, nullable=True)
245+
authentication: Mapped[PushNotificationAuthenticationInfo | None] = mapped_column(
246+
EncryptedPydanticType(PushNotificationAuthenticationInfo), nullable=True
247+
)
248+
249+
@override
250+
def __repr__(self) -> str:
251+
"""Return a string representation of the push notification config."""
252+
repr_template = (
253+
'<{CLS}(id="{ID}", task_id="{TASK_ID}", url="{URL}")>'
254+
)
255+
return repr_template.format(
256+
CLS=self.__class__.__name__,
257+
ID=self.id,
258+
TASK_ID=self.task_id,
259+
URL=self.url,
260+
)
261+
262+
def create_push_notification_config_model(
263+
table_name: str = 'push_notification_configs',
264+
base: type[DeclarativeBase] = Base
265+
) -> type:
266+
"""Create a PushNotificationConfigModel class with a configurable table name.
267+
268+
Args:
269+
table_name: Name of the database table. Defaults to 'push_notification_configs'.
270+
base_cls: Base declarative class to use. Defaults to the SDK's Base class.
271+
272+
Returns:
273+
PushNotificationConfigModel class with the specified table name.
274+
"""
275+
276+
class PushNotificationConfigModel(PushNotificationConfigMixin, base): # type: ignore
277+
__tablename__ = table_name
278+
279+
PushNotificationConfigModel.__name__ = f'PushNotificationConfigModel_{table_name}'
280+
PushNotificationConfigModel.__qualname__ = f'PushNotificationConfigModel_{table_name}'
281+
return PushNotificationConfigModel

0 commit comments

Comments
 (0)