@@ -11,11 +11,11 @@ def override(func): # noqa: D103
1111
1212from pydantic import BaseModel
1313
14- from a2a .types import Artifact , Message , TaskStatus
15-
14+ from a2a .types import Artifact , Message , TaskStatus , PushNotificationAuthenticationInfo
1615
1716try :
18- from sqlalchemy import JSON , Dialect , String
17+ from cryptography .fernet import Fernet
18+ from sqlalchemy import JSON , Dialect , String , LargeBinary
1919 from sqlalchemy .orm import (
2020 DeclarativeBase ,
2121 Mapped ,
@@ -36,6 +36,12 @@ def override(func): # noqa: D103
3636
3737T = TypeVar ('T' , bound = BaseModel )
3838
39+ _ENCRYPTION_KEY : bytes | None = None
40+
41+ def set_model_encryption_key (key : bytes ) -> None :
42+ """Sets the encryption key used for encrypting model data in the database."""
43+ global _ENCRYPTION_KEY
44+ _ENCRYPTION_KEY = key
3945
4046class PydanticType (TypeDecorator [T ], Generic [T ]):
4147 """SQLAlchemy type that handles Pydantic model serialization."""
@@ -67,6 +73,42 @@ def process_result_value(
6773 return None
6874 return self .pydantic_type .model_validate (value )
6975
76+ class EncryptedPydanticType (TypeDecorator [T ], Generic [T ]):
77+ """SQLAlchemy type that handles Pydantic model serialization with encryption."""
78+
79+ impl = LargeBinary # Store encrypted data as binary
80+ cache_ok = True
81+
82+ def __init__ (self , pydantic_type : type [T ], ** kwargs : Any ):
83+ super ().__init__ (** kwargs )
84+ if _ENCRYPTION_KEY is None :
85+ raise RuntimeError (
86+ "Encryption key not set for models. "
87+ "Call a2a.server.models.set_model_encryption_key(key) before model definition or usage."
88+ )
89+ self .pydantic_type = pydantic_type
90+ self .fernet = Fernet (_ENCRYPTION_KEY )
91+
92+ @override
93+ def process_bind_param (
94+ self , value : T | None , dialect : Dialect
95+ ) -> bytes | None :
96+ if value is None :
97+ return None
98+ # Pydantic model to JSON string, then encode to bytes for encryption
99+ json_string = value .model_dump_json ()
100+ encrypted_data = self .fernet .encrypt (json_string .encode ('utf-8' ))
101+ return encrypted_data
102+
103+ @override
104+ def process_result_value (
105+ self , value : bytes | None , dialect : Dialect
106+ ) -> T | None :
107+ if value is None :
108+ return None
109+ # Decrypt bytes to JSON string, then parse with Pydantic
110+ decrypted_json_string = self .fernet .decrypt (value ).decode ('utf-8' )
111+ return self .pydantic_type .model_validate_json (decrypted_json_string )
70112
71113class PydanticListType (TypeDecorator [list [T ]], Generic [T ]):
72114 """SQLAlchemy type that handles lists of Pydantic models."""
@@ -166,7 +208,7 @@ def create_task_model(
166208 TaskModel = create_task_model('tasks', MyBase)
167209 """
168210
169- class TaskModel (TaskMixin , base ):
211+ class TaskModel (TaskMixin , base ): # type: ignore
170212 __tablename__ = table_name
171213
172214 @override
@@ -192,3 +234,48 @@ class TaskModel(TaskMixin, Base):
192234 """Default task model with standard table name."""
193235
194236 __tablename__ = 'tasks'
237+
238+
239+ class PushNotificationConfigMixin :
240+ """Mixin providing standard columns for push notification configuration."""
241+ id : Mapped [str ] = mapped_column (String , primary_key = True , index = True )
242+ task_id : Mapped [str ] = mapped_column (String , nullable = False )
243+ url : Mapped [str ] = mapped_column (String , nullable = False )
244+ token : Mapped [str | None ] = mapped_column (String , nullable = True )
245+ authentication : Mapped [PushNotificationAuthenticationInfo | None ] = mapped_column (
246+ EncryptedPydanticType (PushNotificationAuthenticationInfo ), nullable = True
247+ )
248+
249+ @override
250+ def __repr__ (self ) -> str :
251+ """Return a string representation of the push notification config."""
252+ repr_template = (
253+ '<{CLS}(id="{ID}", task_id="{TASK_ID}", url="{URL}")>'
254+ )
255+ return repr_template .format (
256+ CLS = self .__class__ .__name__ ,
257+ ID = self .id ,
258+ TASK_ID = self .task_id ,
259+ URL = self .url ,
260+ )
261+
262+ def create_push_notification_config_model (
263+ table_name : str = 'push_notification_configs' ,
264+ base : type [DeclarativeBase ] = Base
265+ ) -> type :
266+ """Create a PushNotificationConfigModel class with a configurable table name.
267+
268+ Args:
269+ table_name: Name of the database table. Defaults to 'push_notification_configs'.
270+ base_cls: Base declarative class to use. Defaults to the SDK's Base class.
271+
272+ Returns:
273+ PushNotificationConfigModel class with the specified table name.
274+ """
275+
276+ class PushNotificationConfigModel (PushNotificationConfigMixin , base ): # type: ignore
277+ __tablename__ = table_name
278+
279+ PushNotificationConfigModel .__name__ = f'PushNotificationConfigModel_{ table_name } '
280+ PushNotificationConfigModel .__qualname__ = f'PushNotificationConfigModel_{ table_name } '
281+ return PushNotificationConfigModel
0 commit comments