Skip to content

Commit 7871379

Browse files
committed
Fix for custom table names
1 parent d743fc6 commit 7871379

File tree

3 files changed

+162
-33
lines changed

3 files changed

+162
-33
lines changed

src/a2a/server/tasks/database_push_notification_config_store.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33

44
from typing import TYPE_CHECKING
55

6+
from pydantic import ValidationError
7+
68

79
try:
810
from sqlalchemy import (
11+
Table,
912
delete,
1013
select,
1114
)
@@ -14,6 +17,7 @@
1417
AsyncSession,
1518
async_sessionmaker,
1619
)
20+
from sqlalchemy.orm import class_mapper
1721
except ImportError as e:
1822
raise ImportError(
1923
'DatabasePushNotificationConfigStore requires SQLAlchemy and a database driver. '
@@ -115,7 +119,13 @@ async def initialize(self) -> None:
115119
)
116120
if self.create_table:
117121
async with self.engine.begin() as conn:
118-
await conn.run_sync(Base.metadata.create_all)
122+
mapper = class_mapper(self.config_model)
123+
tables_to_create = [
124+
table for table in mapper.tables if isinstance(table, Table)
125+
]
126+
await conn.run_sync(
127+
Base.metadata.create_all, tables=tables_to_create
128+
)
119129
self._initialized = True
120130
logger.debug(
121131
'Database schema for push notification configs initialized.'
@@ -151,7 +161,7 @@ def _from_orm(
151161
) -> PushNotificationConfig:
152162
"""Maps a SQLAlchemy model instance to a Pydantic PushNotificationConfig.
153163
154-
Handles decryption if a key is configured.
164+
Handles decryption if a key is configured, with a fallback to plain JSON.
155165
"""
156166
payload = model_instance.config_data
157167

@@ -163,26 +173,51 @@ def _from_orm(
163173
return PushNotificationConfig.model_validate_json(
164174
decrypted_payload
165175
)
176+
except (json.JSONDecodeError, ValidationError) as e:
177+
logger.error(
178+
'Failed to parse decrypted push notification config for task %s, config %s. '
179+
'Data is corrupted or not valid JSON after decryption.',
180+
model_instance.task_id,
181+
model_instance.config_id,
182+
)
183+
raise ValueError(
184+
'Failed to parse decrypted push notification config data'
185+
) from e
166186
except InvalidToken:
167-
# This could be unencrypted data if encryption was enabled after data was stored.
168-
# We'll fall through and try to parse it as plain JSON.
169-
logger.debug(
170-
'Could not decrypt config for task %s, config %s. '
171-
'Attempting to parse as unencrypted JSON.',
187+
# Decryption failed. This could be because the data is not encrypted.
188+
# We'll log a warning and try to parse it as plain JSON as a fallback.
189+
logger.warning(
190+
'Failed to decrypt push notification config for task %s, config %s. '
191+
'Attempting to parse as unencrypted JSON. '
192+
'This may indicate an incorrect encryption key or unencrypted data in the database.',
172193
model_instance.task_id,
173194
model_instance.config_id,
174195
)
196+
# Fall through to the unencrypted parsing logic below.
175197

176-
# If no fernet or if decryption failed, try to parse as plain JSON.
198+
# Try to parse as plain JSON.
177199
try:
178200
return PushNotificationConfig.model_validate_json(payload)
179-
except json.JSONDecodeError as e:
201+
except (json.JSONDecodeError, ValidationError) as e:
180202
if self._fernet:
181-
raise ValueError(
182-
'Failed to decrypt data; incorrect key or corrupted data.'
183-
) from e
203+
logger.error(
204+
'Failed to parse push notification config for task %s, config %s. '
205+
'Decryption failed and the data is not valid JSON. '
206+
'This likely indicates the data is corrupted or encrypted with a different key.',
207+
model_instance.task_id,
208+
model_instance.config_id,
209+
)
210+
else:
211+
# if no key is configured and the payload is not valid JSON.
212+
logger.error(
213+
'Failed to parse push notification config for task %s, config %s. '
214+
'Data is not valid JSON and no encryption key is configured.',
215+
model_instance.task_id,
216+
model_instance.config_id,
217+
)
184218
raise ValueError(
185-
'Failed to parse data; it may be encrypted but no key is configured.'
219+
'Failed to parse push notification config data. '
220+
'Data is not valid JSON, or it is encrypted with the wrong key.'
186221
) from e
187222

188223
async def set_info(

src/a2a/server/tasks/database_task_store.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33

44
try:
5-
from sqlalchemy import delete, select
5+
from sqlalchemy import Table, delete, select
66
from sqlalchemy.ext.asyncio import (
77
AsyncEngine,
88
AsyncSession,
99
async_sessionmaker,
1010
)
11+
from sqlalchemy.orm import class_mapper
1112
except ImportError as e:
1213
raise ImportError(
1314
'DatabaseTaskStore requires SQLAlchemy and a database driver. '
@@ -75,8 +76,13 @@ async def initialize(self) -> None:
7576
logger.debug('Initializing database schema...')
7677
if self.create_table:
7778
async with self.engine.begin() as conn:
78-
# This will create the 'tasks' table based on TaskModel's definition
79-
await conn.run_sync(Base.metadata.create_all)
79+
mapper = class_mapper(self.task_model)
80+
tables_to_create = [
81+
table for table in mapper.tables if isinstance(table, Table)
82+
]
83+
await conn.run_sync(
84+
Base.metadata.create_all, tables=tables_to_create
85+
)
8086
self._initialized = True
8187
logger.debug('Database schema initialized.')
8288

tests/server/tasks/test_database_push_notification_config_store.py

Lines changed: 105 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -394,25 +394,47 @@ async def test_custom_table_name(
394394
):
395395
"""Test that the store works correctly with a custom table name."""
396396
table_name = 'my_custom_push_configs'
397+
engine = db_store_parameterized.engine
398+
custom_store = None
399+
try:
400+
# Use a new store with a custom table name
401+
custom_store = DatabasePushNotificationConfigStore(
402+
engine=engine,
403+
create_table=True,
404+
table_name=table_name,
405+
encryption_key=Fernet.generate_key(),
406+
)
397407

398-
task_id = 'custom-table-task'
399-
config = PushNotificationConfig(id='config-1', url='http://custom.url')
408+
task_id = 'custom-table-task'
409+
config = PushNotificationConfig(id='config-1', url='http://custom.url')
400410

401-
# This will create the table on first use
402-
await db_store_parameterized.set_info(task_id, config)
403-
retrieved_configs = await db_store_parameterized.get_info(task_id)
411+
# This will create the table on first use
412+
await custom_store.set_info(task_id, config)
413+
retrieved_configs = await custom_store.get_info(task_id)
404414

405-
assert len(retrieved_configs) == 1
406-
assert retrieved_configs[0] == config
415+
assert len(retrieved_configs) == 1
416+
assert retrieved_configs[0] == config
407417

408-
# Verify the custom table exists and has data
409-
async with db_store_parameterized.engine.connect() as conn:
410-
result = await conn.execute(
411-
select(db_store_parameterized.config_model).where(
412-
db_store_parameterized.config_model.task_id == task_id
418+
# Verify the custom table exists and has data
419+
async with custom_store.engine.connect() as conn:
420+
421+
def has_table_sync(sync_conn):
422+
inspector = inspect(sync_conn)
423+
return inspector.has_table(table_name)
424+
425+
assert await conn.run_sync(has_table_sync)
426+
427+
result = await conn.execute(
428+
select(custom_store.config_model).where(
429+
custom_store.config_model.task_id == task_id
430+
)
413431
)
414-
)
415-
assert result.scalar_one_or_none() is not None
432+
assert result.scalar_one_or_none() is not None
433+
finally:
434+
if custom_store:
435+
# Clean up the dynamically created table from the metadata
436+
# to prevent errors in subsequent parameterized test runs.
437+
Base.metadata.remove(custom_store.config_model.__table__) # type: ignore
416438

417439

418440
@pytest.mark.asyncio
@@ -432,9 +454,9 @@ async def test_set_and_get_info_multiple_configs_no_key(
432454
config1 = PushNotificationConfig(id='config-1', url='http://example.com/1')
433455
config2 = PushNotificationConfig(id='config-2', url='http://example.com/2')
434456

435-
await db_store_parameterized.set_info(task_id, config1)
436-
await db_store_parameterized.set_info(task_id, config2)
437-
retrieved_configs = await db_store_parameterized.get_info(task_id)
457+
await store.set_info(task_id, config1)
458+
await store.set_info(task_id, config2)
459+
retrieved_configs = await store.get_info(task_id)
438460

439461
assert len(retrieved_configs) == 2
440462
assert config1 in retrieved_configs
@@ -472,3 +494,69 @@ async def test_data_is_not_encrypted_in_db_if_no_key_is_set(
472494
db_model = result.scalar_one()
473495

474496
assert db_model.config_data == plain_json.encode('utf-8')
497+
498+
499+
@pytest.mark.asyncio
500+
async def test_decryption_fallback_for_unencrypted_data(
501+
db_store_parameterized: DatabasePushNotificationConfigStore,
502+
):
503+
"""Test reading unencrypted data with an encryption-enabled store."""
504+
# 1. Store unencrypted data using a new store instance without a key
505+
unencrypted_store = DatabasePushNotificationConfigStore(
506+
engine=db_store_parameterized.engine,
507+
create_table=False, # Table already exists from fixture
508+
encryption_key=None,
509+
)
510+
await unencrypted_store.initialize()
511+
512+
task_id = 'mixed-encryption-task'
513+
config = PushNotificationConfig(id='config-1', url='http://plain.url')
514+
await unencrypted_store.set_info(task_id, config)
515+
516+
# 2. Try to read with the encryption-enabled store from the fixture
517+
retrieved_configs = await db_store_parameterized.get_info(task_id)
518+
519+
# Should fall back to parsing as plain JSON and not fail
520+
assert len(retrieved_configs) == 1
521+
assert retrieved_configs[0] == config
522+
523+
524+
@pytest.mark.asyncio
525+
async def test_parsing_error_after_successful_decryption(
526+
db_store_parameterized: DatabasePushNotificationConfigStore,
527+
):
528+
"""Test that a parsing error after successful decryption is handled."""
529+
530+
task_id = 'corrupted-data-task'
531+
config_id = 'config-1'
532+
533+
# 1. Encrypt data that is NOT valid JSON
534+
fernet = Fernet(Fernet.generate_key())
535+
corrupted_payload = b'this is not valid json'
536+
encrypted_data = fernet.encrypt(corrupted_payload)
537+
538+
# 2. Manually insert this corrupted data into the DB
539+
async_session = async_sessionmaker(
540+
db_store_parameterized.engine, expire_on_commit=False
541+
)
542+
async with async_session() as session:
543+
db_model = PushNotificationConfigModel(
544+
task_id=task_id,
545+
config_id=config_id,
546+
config_data=encrypted_data,
547+
)
548+
session.add(db_model)
549+
await session.commit()
550+
551+
# 3. get_info should log an error and return an empty list
552+
retrieved_configs = await db_store_parameterized.get_info(task_id)
553+
assert retrieved_configs == []
554+
555+
# 4. _from_orm should raise a ValueError
556+
async with async_session() as session:
557+
db_model_retrieved = await session.get(
558+
PushNotificationConfigModel, (task_id, config_id)
559+
)
560+
561+
with pytest.raises(ValueError) as exc_info:
562+
db_store_parameterized._from_orm(db_model_retrieved) # type: ignore

0 commit comments

Comments
 (0)