Skip to content

Commit d743fc6

Browse files
committed
Unit tests and update lock file
1 parent 3d465d5 commit d743fc6

File tree

4 files changed

+313
-39
lines changed

4 files changed

+313
-39
lines changed

.github/workflows/unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
run: |
4444
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
4545
- name: Install dependencies
46-
run: uv sync --dev --extra sql
46+
run: uv sync --dev --extra sql --extra encryption
4747
- name: Run tests and check coverage
4848
run: uv run pytest --cov=a2a --cov-report=xml --cov-fail-under=89
4949
- name: Show coverage summary in log

src/a2a/server/tasks/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from a2a.server.tasks.base_push_notification_sender import (
44
BasePushNotificationSender,
55
)
6+
from a2a.server.tasks.database_push_notification_config_store import (
7+
DatabasePushNotificationConfigStore,
8+
)
69
from a2a.server.tasks.database_task_store import DatabaseTaskStore
710
from a2a.server.tasks.inmemory_push_notification_config_store import (
811
InMemoryPushNotificationConfigStore,
@@ -20,6 +23,7 @@
2023

2124
__all__ = [
2225
'BasePushNotificationSender',
26+
'DatabasePushNotificationConfigStore',
2327
'DatabaseTaskStore',
2428
'InMemoryPushNotificationConfigStore',
2529
'InMemoryTaskStore',

tests/server/tasks/test_database_push_notification_config_store.py

Lines changed: 190 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,21 @@
66
import pytest_asyncio
77

88
from _pytest.mark.structures import ParameterSet
9-
# from sqlalchemy.ext.asyncio import (
10-
# AsyncSession,
11-
# async_sessionmaker,
12-
# create_async_engine,
13-
# )
14-
# from sqlalchemy import select
9+
from sqlalchemy.ext.asyncio import (
10+
async_sessionmaker,
11+
create_async_engine,
12+
)
13+
from sqlalchemy import select
1514

1615
# Skip entire test module if SQLAlchemy is not installed
1716
pytest.importorskip('sqlalchemy', reason='Database tests require SQLAlchemy')
17+
pytest.importorskip(
18+
'cryptography',
19+
reason='Database tests require Cryptography. Install extra encryption',
20+
)
1821

1922
# Now safe to import SQLAlchemy-dependent modules
2023
from sqlalchemy.inspection import inspect
21-
from sqlalchemy.ext.asyncio import create_async_engine
2224
from a2a.server.models import (
2325
Base,
2426
PushNotificationConfigModel,
@@ -289,33 +291,184 @@ async def test_delete_info_not_found(
289291
await db_store_parameterized.delete_info('task-1', 'non-existent-config')
290292

291293

292-
# @pytest.mark.asyncio
293-
# async def test_data_is_encrypted_in_db(
294-
# db_store_parameterized: DatabasePushNotificationConfigStore,
295-
# ):
296-
# """Verify that the data stored in the database is actually encrypted."""
297-
# task_id = 'encrypted-task'
298-
# config = PushNotificationConfig(
299-
# id='config-1', url='http://secret.url', token='secret-token'
300-
# )
301-
# plain_json = config.model_dump_json()
302-
303-
# await db_store_parameterized.set_info(task_id, config)
304-
305-
# # Directly query the database to inspect the raw data
306-
# async_session = async_sessionmaker(
307-
# db_store_parameterized.engine, expire_on_commit=False
308-
# )
309-
# async with async_session() as session:
310-
# stmt = select(PushNotificationConfigModel).where(
311-
# PushNotificationConfigModel.task_id == task_id
312-
# )
313-
# result = await session.execute(stmt)
314-
# db_model = result.scalar_one()
315-
316-
# assert db_model.config_data != plain_json.encode('utf-8')
317-
318-
# fernet = db_store_parameterized._fernet
319-
320-
# decrypted_data = fernet.decrypt(db_model.config_data) # type: ignore
321-
# assert decrypted_data.decode('utf-8') == plain_json
294+
@pytest.mark.asyncio
295+
async def test_data_is_encrypted_in_db(
296+
db_store_parameterized: DatabasePushNotificationConfigStore,
297+
):
298+
"""Verify that the data stored in the database is actually encrypted."""
299+
task_id = 'encrypted-task'
300+
config = PushNotificationConfig(
301+
id='config-1', url='http://secret.url', token='secret-token'
302+
)
303+
plain_json = config.model_dump_json()
304+
305+
await db_store_parameterized.set_info(task_id, config)
306+
307+
# Directly query the database to inspect the raw data
308+
async_session = async_sessionmaker(
309+
db_store_parameterized.engine, expire_on_commit=False
310+
)
311+
async with async_session() as session:
312+
stmt = select(PushNotificationConfigModel).where(
313+
PushNotificationConfigModel.task_id == task_id
314+
)
315+
result = await session.execute(stmt)
316+
db_model = result.scalar_one()
317+
318+
assert db_model.config_data != plain_json.encode('utf-8')
319+
320+
fernet = db_store_parameterized._fernet
321+
322+
decrypted_data = fernet.decrypt(db_model.config_data) # type: ignore
323+
assert decrypted_data.decode('utf-8') == plain_json
324+
325+
326+
@pytest.mark.asyncio
327+
async def test_decryption_error_with_wrong_key(
328+
db_store_parameterized: DatabasePushNotificationConfigStore,
329+
):
330+
"""Test that using the wrong key to decrypt raises a ValueError."""
331+
# 1. Store with one key
332+
333+
task_id = 'wrong-key-task'
334+
config = PushNotificationConfig(id='config-1', url='http://secret.url')
335+
await db_store_parameterized.set_info(task_id, config)
336+
337+
# 2. Try to read with a different key
338+
# Directly query the database to inspect the raw data
339+
wrong_key = Fernet.generate_key()
340+
store2 = DatabasePushNotificationConfigStore(
341+
db_store_parameterized.engine, encryption_key=wrong_key
342+
)
343+
344+
retrieved_configs = await store2.get_info(task_id)
345+
assert retrieved_configs == []
346+
347+
# _from_orm should raise a ValueError
348+
async_session = async_sessionmaker(
349+
db_store_parameterized.engine, expire_on_commit=False
350+
)
351+
async with async_session() as session:
352+
db_model = await session.get(
353+
PushNotificationConfigModel, (task_id, 'config-1')
354+
)
355+
356+
with pytest.raises(ValueError) as exc_info:
357+
store2._from_orm(db_model) # type: ignore
358+
359+
360+
@pytest.mark.asyncio
361+
async def test_decryption_error_with_no_key(
362+
db_store_parameterized: DatabasePushNotificationConfigStore,
363+
):
364+
"""Test that using the wrong key to decrypt raises a ValueError."""
365+
# 1. Store with one key
366+
367+
task_id = 'wrong-key-task'
368+
config = PushNotificationConfig(id='config-1', url='http://secret.url')
369+
await db_store_parameterized.set_info(task_id, config)
370+
371+
# 2. Try to read with no key set
372+
# Directly query the database to inspect the raw data
373+
store2 = DatabasePushNotificationConfigStore(db_store_parameterized.engine)
374+
375+
retrieved_configs = await store2.get_info(task_id)
376+
assert retrieved_configs == []
377+
378+
# _from_orm should raise a ValueError
379+
async_session = async_sessionmaker(
380+
db_store_parameterized.engine, expire_on_commit=False
381+
)
382+
async with async_session() as session:
383+
db_model = await session.get(
384+
PushNotificationConfigModel, (task_id, 'config-1')
385+
)
386+
387+
with pytest.raises(ValueError) as exc_info:
388+
store2._from_orm(db_model) # type: ignore
389+
390+
391+
@pytest.mark.asyncio
392+
async def test_custom_table_name(
393+
db_store_parameterized: DatabasePushNotificationConfigStore,
394+
):
395+
"""Test that the store works correctly with a custom table name."""
396+
table_name = 'my_custom_push_configs'
397+
398+
task_id = 'custom-table-task'
399+
config = PushNotificationConfig(id='config-1', url='http://custom.url')
400+
401+
# This will create the table on first use
402+
await db_store_parameterized.set_info(task_id, config)
403+
retrieved_configs = await db_store_parameterized.get_info(task_id)
404+
405+
assert len(retrieved_configs) == 1
406+
assert retrieved_configs[0] == config
407+
408+
# Verify the custom table exists and has data
409+
async with db_store_parameterized.engine.connect() as conn:
410+
result = await conn.execute(
411+
select(db_store_parameterized.config_model).where(
412+
db_store_parameterized.config_model.task_id == task_id
413+
)
414+
)
415+
assert result.scalar_one_or_none() is not None
416+
417+
418+
@pytest.mark.asyncio
419+
async def test_set_and_get_info_multiple_configs_no_key(
420+
db_store_parameterized: DatabasePushNotificationConfigStore,
421+
):
422+
"""Test setting and retrieving multiple configurations for a single task."""
423+
424+
store = DatabasePushNotificationConfigStore(
425+
engine=db_store_parameterized.engine,
426+
create_table=False,
427+
encryption_key=None, # No encryption key
428+
)
429+
await store.initialize()
430+
431+
task_id = 'task-1'
432+
config1 = PushNotificationConfig(id='config-1', url='http://example.com/1')
433+
config2 = PushNotificationConfig(id='config-2', url='http://example.com/2')
434+
435+
await db_store_parameterized.set_info(task_id, config1)
436+
await db_store_parameterized.set_info(task_id, config2)
437+
retrieved_configs = await db_store_parameterized.get_info(task_id)
438+
439+
assert len(retrieved_configs) == 2
440+
assert config1 in retrieved_configs
441+
assert config2 in retrieved_configs
442+
443+
444+
@pytest.mark.asyncio
445+
async def test_data_is_not_encrypted_in_db_if_no_key_is_set(
446+
db_store_parameterized: DatabasePushNotificationConfigStore,
447+
):
448+
"""Test data is not encrypted when no encryption key is set."""
449+
450+
store = DatabasePushNotificationConfigStore(
451+
engine=db_store_parameterized.engine,
452+
create_table=False,
453+
encryption_key=None, # No encryption key
454+
)
455+
await store.initialize()
456+
457+
task_id = 'task-1'
458+
config = PushNotificationConfig(id='config-1', url='http://example.com/1')
459+
plain_json = config.model_dump_json()
460+
461+
await store.set_info(task_id, config)
462+
463+
# Directly query the database to inspect the raw data
464+
async_session = async_sessionmaker(
465+
db_store_parameterized.engine, expire_on_commit=False
466+
)
467+
async with async_session() as session:
468+
stmt = select(PushNotificationConfigModel).where(
469+
PushNotificationConfigModel.task_id == task_id
470+
)
471+
result = await session.execute(stmt)
472+
db_model = result.scalar_one()
473+
474+
assert db_model.config_data == plain_json.encode('utf-8')

0 commit comments

Comments
 (0)