|
6 | 6 | import pytest_asyncio |
7 | 7 |
|
8 | 8 | from _pytest.mark.structures import ParameterSet |
9 | | -# from sqlalchemy.ext.asyncio import ( |
10 | | -# AsyncSession, |
11 | | -# async_sessionmaker, |
12 | | -# create_async_engine, |
13 | | -# ) |
14 | | -# from sqlalchemy import select |
| 9 | +from sqlalchemy.ext.asyncio import ( |
| 10 | + async_sessionmaker, |
| 11 | + create_async_engine, |
| 12 | +) |
| 13 | +from sqlalchemy import select |
15 | 14 |
|
16 | 15 | # Skip entire test module if SQLAlchemy is not installed |
17 | 16 | pytest.importorskip('sqlalchemy', reason='Database tests require SQLAlchemy') |
| 17 | +pytest.importorskip( |
| 18 | + 'cryptography', |
| 19 | + reason='Database tests require Cryptography. Install extra encryption', |
| 20 | +) |
18 | 21 |
|
19 | 22 | # Now safe to import SQLAlchemy-dependent modules |
20 | 23 | from sqlalchemy.inspection import inspect |
21 | | -from sqlalchemy.ext.asyncio import create_async_engine |
22 | 24 | from a2a.server.models import ( |
23 | 25 | Base, |
24 | 26 | PushNotificationConfigModel, |
@@ -289,33 +291,184 @@ async def test_delete_info_not_found( |
289 | 291 | await db_store_parameterized.delete_info('task-1', 'non-existent-config') |
290 | 292 |
|
291 | 293 |
|
292 | | -# @pytest.mark.asyncio |
293 | | -# async def test_data_is_encrypted_in_db( |
294 | | -# db_store_parameterized: DatabasePushNotificationConfigStore, |
295 | | -# ): |
296 | | -# """Verify that the data stored in the database is actually encrypted.""" |
297 | | -# task_id = 'encrypted-task' |
298 | | -# config = PushNotificationConfig( |
299 | | -# id='config-1', url='http://secret.url', token='secret-token' |
300 | | -# ) |
301 | | -# plain_json = config.model_dump_json() |
302 | | - |
303 | | -# await db_store_parameterized.set_info(task_id, config) |
304 | | - |
305 | | -# # Directly query the database to inspect the raw data |
306 | | -# async_session = async_sessionmaker( |
307 | | -# db_store_parameterized.engine, expire_on_commit=False |
308 | | -# ) |
309 | | -# async with async_session() as session: |
310 | | -# stmt = select(PushNotificationConfigModel).where( |
311 | | -# PushNotificationConfigModel.task_id == task_id |
312 | | -# ) |
313 | | -# result = await session.execute(stmt) |
314 | | -# db_model = result.scalar_one() |
315 | | - |
316 | | -# assert db_model.config_data != plain_json.encode('utf-8') |
317 | | - |
318 | | -# fernet = db_store_parameterized._fernet |
319 | | - |
320 | | -# decrypted_data = fernet.decrypt(db_model.config_data) # type: ignore |
321 | | -# assert decrypted_data.decode('utf-8') == plain_json |
| 294 | +@pytest.mark.asyncio |
| 295 | +async def test_data_is_encrypted_in_db( |
| 296 | + db_store_parameterized: DatabasePushNotificationConfigStore, |
| 297 | +): |
| 298 | + """Verify that the data stored in the database is actually encrypted.""" |
| 299 | + task_id = 'encrypted-task' |
| 300 | + config = PushNotificationConfig( |
| 301 | + id='config-1', url='http://secret.url', token='secret-token' |
| 302 | + ) |
| 303 | + plain_json = config.model_dump_json() |
| 304 | + |
| 305 | + await db_store_parameterized.set_info(task_id, config) |
| 306 | + |
| 307 | + # Directly query the database to inspect the raw data |
| 308 | + async_session = async_sessionmaker( |
| 309 | + db_store_parameterized.engine, expire_on_commit=False |
| 310 | + ) |
| 311 | + async with async_session() as session: |
| 312 | + stmt = select(PushNotificationConfigModel).where( |
| 313 | + PushNotificationConfigModel.task_id == task_id |
| 314 | + ) |
| 315 | + result = await session.execute(stmt) |
| 316 | + db_model = result.scalar_one() |
| 317 | + |
| 318 | + assert db_model.config_data != plain_json.encode('utf-8') |
| 319 | + |
| 320 | + fernet = db_store_parameterized._fernet |
| 321 | + |
| 322 | + decrypted_data = fernet.decrypt(db_model.config_data) # type: ignore |
| 323 | + assert decrypted_data.decode('utf-8') == plain_json |
| 324 | + |
| 325 | + |
| 326 | +@pytest.mark.asyncio |
| 327 | +async def test_decryption_error_with_wrong_key( |
| 328 | + db_store_parameterized: DatabasePushNotificationConfigStore, |
| 329 | +): |
| 330 | + """Test that using the wrong key to decrypt raises a ValueError.""" |
| 331 | + # 1. Store with one key |
| 332 | + |
| 333 | + task_id = 'wrong-key-task' |
| 334 | + config = PushNotificationConfig(id='config-1', url='http://secret.url') |
| 335 | + await db_store_parameterized.set_info(task_id, config) |
| 336 | + |
| 337 | + # 2. Try to read with a different key |
| 338 | + # Directly query the database to inspect the raw data |
| 339 | + wrong_key = Fernet.generate_key() |
| 340 | + store2 = DatabasePushNotificationConfigStore( |
| 341 | + db_store_parameterized.engine, encryption_key=wrong_key |
| 342 | + ) |
| 343 | + |
| 344 | + retrieved_configs = await store2.get_info(task_id) |
| 345 | + assert retrieved_configs == [] |
| 346 | + |
| 347 | + # _from_orm should raise a ValueError |
| 348 | + async_session = async_sessionmaker( |
| 349 | + db_store_parameterized.engine, expire_on_commit=False |
| 350 | + ) |
| 351 | + async with async_session() as session: |
| 352 | + db_model = await session.get( |
| 353 | + PushNotificationConfigModel, (task_id, 'config-1') |
| 354 | + ) |
| 355 | + |
| 356 | + with pytest.raises(ValueError) as exc_info: |
| 357 | + store2._from_orm(db_model) # type: ignore |
| 358 | + |
| 359 | + |
| 360 | +@pytest.mark.asyncio |
| 361 | +async def test_decryption_error_with_no_key( |
| 362 | + db_store_parameterized: DatabasePushNotificationConfigStore, |
| 363 | +): |
| 364 | + """Test that using the wrong key to decrypt raises a ValueError.""" |
| 365 | + # 1. Store with one key |
| 366 | + |
| 367 | + task_id = 'wrong-key-task' |
| 368 | + config = PushNotificationConfig(id='config-1', url='http://secret.url') |
| 369 | + await db_store_parameterized.set_info(task_id, config) |
| 370 | + |
| 371 | + # 2. Try to read with no key set |
| 372 | + # Directly query the database to inspect the raw data |
| 373 | + store2 = DatabasePushNotificationConfigStore(db_store_parameterized.engine) |
| 374 | + |
| 375 | + retrieved_configs = await store2.get_info(task_id) |
| 376 | + assert retrieved_configs == [] |
| 377 | + |
| 378 | + # _from_orm should raise a ValueError |
| 379 | + async_session = async_sessionmaker( |
| 380 | + db_store_parameterized.engine, expire_on_commit=False |
| 381 | + ) |
| 382 | + async with async_session() as session: |
| 383 | + db_model = await session.get( |
| 384 | + PushNotificationConfigModel, (task_id, 'config-1') |
| 385 | + ) |
| 386 | + |
| 387 | + with pytest.raises(ValueError) as exc_info: |
| 388 | + store2._from_orm(db_model) # type: ignore |
| 389 | + |
| 390 | + |
| 391 | +@pytest.mark.asyncio |
| 392 | +async def test_custom_table_name( |
| 393 | + db_store_parameterized: DatabasePushNotificationConfigStore, |
| 394 | +): |
| 395 | + """Test that the store works correctly with a custom table name.""" |
| 396 | + table_name = 'my_custom_push_configs' |
| 397 | + |
| 398 | + task_id = 'custom-table-task' |
| 399 | + config = PushNotificationConfig(id='config-1', url='http://custom.url') |
| 400 | + |
| 401 | + # This will create the table on first use |
| 402 | + await db_store_parameterized.set_info(task_id, config) |
| 403 | + retrieved_configs = await db_store_parameterized.get_info(task_id) |
| 404 | + |
| 405 | + assert len(retrieved_configs) == 1 |
| 406 | + assert retrieved_configs[0] == config |
| 407 | + |
| 408 | + # Verify the custom table exists and has data |
| 409 | + async with db_store_parameterized.engine.connect() as conn: |
| 410 | + result = await conn.execute( |
| 411 | + select(db_store_parameterized.config_model).where( |
| 412 | + db_store_parameterized.config_model.task_id == task_id |
| 413 | + ) |
| 414 | + ) |
| 415 | + assert result.scalar_one_or_none() is not None |
| 416 | + |
| 417 | + |
| 418 | +@pytest.mark.asyncio |
| 419 | +async def test_set_and_get_info_multiple_configs_no_key( |
| 420 | + db_store_parameterized: DatabasePushNotificationConfigStore, |
| 421 | +): |
| 422 | + """Test setting and retrieving multiple configurations for a single task.""" |
| 423 | + |
| 424 | + store = DatabasePushNotificationConfigStore( |
| 425 | + engine=db_store_parameterized.engine, |
| 426 | + create_table=False, |
| 427 | + encryption_key=None, # No encryption key |
| 428 | + ) |
| 429 | + await store.initialize() |
| 430 | + |
| 431 | + task_id = 'task-1' |
| 432 | + config1 = PushNotificationConfig(id='config-1', url='http://example.com/1') |
| 433 | + config2 = PushNotificationConfig(id='config-2', url='http://example.com/2') |
| 434 | + |
| 435 | + await db_store_parameterized.set_info(task_id, config1) |
| 436 | + await db_store_parameterized.set_info(task_id, config2) |
| 437 | + retrieved_configs = await db_store_parameterized.get_info(task_id) |
| 438 | + |
| 439 | + assert len(retrieved_configs) == 2 |
| 440 | + assert config1 in retrieved_configs |
| 441 | + assert config2 in retrieved_configs |
| 442 | + |
| 443 | + |
| 444 | +@pytest.mark.asyncio |
| 445 | +async def test_data_is_not_encrypted_in_db_if_no_key_is_set( |
| 446 | + db_store_parameterized: DatabasePushNotificationConfigStore, |
| 447 | +): |
| 448 | + """Test data is not encrypted when no encryption key is set.""" |
| 449 | + |
| 450 | + store = DatabasePushNotificationConfigStore( |
| 451 | + engine=db_store_parameterized.engine, |
| 452 | + create_table=False, |
| 453 | + encryption_key=None, # No encryption key |
| 454 | + ) |
| 455 | + await store.initialize() |
| 456 | + |
| 457 | + task_id = 'task-1' |
| 458 | + config = PushNotificationConfig(id='config-1', url='http://example.com/1') |
| 459 | + plain_json = config.model_dump_json() |
| 460 | + |
| 461 | + await store.set_info(task_id, config) |
| 462 | + |
| 463 | + # Directly query the database to inspect the raw data |
| 464 | + async_session = async_sessionmaker( |
| 465 | + db_store_parameterized.engine, expire_on_commit=False |
| 466 | + ) |
| 467 | + async with async_session() as session: |
| 468 | + stmt = select(PushNotificationConfigModel).where( |
| 469 | + PushNotificationConfigModel.task_id == task_id |
| 470 | + ) |
| 471 | + result = await session.execute(stmt) |
| 472 | + db_model = result.scalar_one() |
| 473 | + |
| 474 | + assert db_model.config_data == plain_json.encode('utf-8') |
0 commit comments