|
| 1 | +import json |
| 2 | +import os |
| 3 | +from typing import AsyncGenerator |
| 4 | + |
| 5 | +import pytest |
| 6 | +import pytest_asyncio |
| 7 | + |
| 8 | +from a2a.server.tasks.postgresql_task_store import PostgreSQLTaskStore |
| 9 | +from a2a.types import Task, TaskState, TaskStatus |
| 10 | + |
| 11 | +# Use a proper Task object instead of a dict for the minimal task |
| 12 | +task_status = TaskStatus(state=TaskState.submitted) |
| 13 | +MINIMAL_TASK_OBJ = Task( |
| 14 | + id='task-abc', |
| 15 | + contextId='session-xyz', |
| 16 | + status=task_status, |
| 17 | + kind='task', |
| 18 | +) |
| 19 | + |
| 20 | + |
| 21 | +# Get PostgreSQL connection string from environment or use a default for testing |
| 22 | +POSTGRES_TEST_DSN = os.environ.get( |
| 23 | + 'POSTGRES_TEST_DSN', |
| 24 | + 'postgresql://postgres:postgres@localhost:5432/a2a_test', |
| 25 | +) |
| 26 | + |
| 27 | + |
| 28 | +@pytest_asyncio.fixture |
| 29 | +async def postgres_store() -> AsyncGenerator[PostgreSQLTaskStore, None]: |
| 30 | + """Fixture that provides a PostgreSQLTaskStore connected to a real database. |
| 31 | +
|
| 32 | + This fixture requires a running PostgreSQL instance |
| 33 | + """ |
| 34 | + |
| 35 | + store = PostgreSQLTaskStore(POSTGRES_TEST_DSN) |
| 36 | + await store.initialize() |
| 37 | + |
| 38 | + # Clean up any test data that might be left from previous runs |
| 39 | + if store.pool is not None: |
| 40 | + async with store.pool.acquire() as conn: |
| 41 | + await conn.execute( |
| 42 | + f"DELETE FROM {store.table_name} WHERE id LIKE 'test-%'" |
| 43 | + ) |
| 44 | + |
| 45 | + yield store |
| 46 | + await store.close() |
| 47 | + |
| 48 | + |
| 49 | +@pytest.mark.asyncio |
| 50 | +async def test_initialize_creates_table( |
| 51 | + postgres_store: PostgreSQLTaskStore, |
| 52 | +) -> None: |
| 53 | + """Test that initialize creates the table if it doesn't exist.""" |
| 54 | + await postgres_store.initialize() |
| 55 | + |
| 56 | + # Verify the pool was created |
| 57 | + assert postgres_store.pool is not None |
| 58 | + |
| 59 | + # Verify the table creation query was executed |
| 60 | + async with postgres_store.pool.acquire() as conn: |
| 61 | + async with conn.transaction(): |
| 62 | + exists = await conn.fetchval( |
| 63 | + f"SELECT EXISTS (SELECT 1 FROM pg_tables WHERE tablename = '{postgres_store.table_name}')" |
| 64 | + ) |
| 65 | + assert exists |
| 66 | + |
| 67 | + |
| 68 | +@pytest.mark.asyncio |
| 69 | +async def test_save_task(postgres_store: PostgreSQLTaskStore) -> None: |
| 70 | + """Test saving a task to the PostgreSQL store.""" |
| 71 | + # Use the pre-created Task object to avoid serialization issues |
| 72 | + task = MINIMAL_TASK_OBJ |
| 73 | + await postgres_store.save(task) |
| 74 | + assert postgres_store.pool is not None |
| 75 | + |
| 76 | + # Verify the insert query was executed |
| 77 | + async with postgres_store.pool.acquire() as conn: |
| 78 | + async with conn.transaction(): |
| 79 | + row = await conn.fetchrow( |
| 80 | + f'SELECT data FROM {postgres_store.table_name} WHERE id = $1', |
| 81 | + task.id, |
| 82 | + ) |
| 83 | + assert row is not None |
| 84 | + # Convert the task to a dictionary with proper enum handling |
| 85 | + |
| 86 | + # Parse the JSON string from the database |
| 87 | + db_dict = ( |
| 88 | + json.loads(row['data']) |
| 89 | + if isinstance(row['data'], str) |
| 90 | + else row['data'] |
| 91 | + ) |
| 92 | + assert db_dict == task.model_dump() |
| 93 | + |
| 94 | + |
| 95 | +@pytest.mark.asyncio |
| 96 | +async def test_get_task(postgres_store: PostgreSQLTaskStore) -> None: |
| 97 | + """Test retrieving a task from the PostgreSQL store.""" |
| 98 | + retrieved_task = await postgres_store.get(MINIMAL_TASK_OBJ.id) |
| 99 | + |
| 100 | + # Verify the task was correctly reconstructed |
| 101 | + assert retrieved_task is not None |
| 102 | + assert retrieved_task.id == MINIMAL_TASK_OBJ.id |
| 103 | + assert retrieved_task.contextId == MINIMAL_TASK_OBJ.contextId |
| 104 | + |
| 105 | + |
| 106 | +@pytest.mark.asyncio |
| 107 | +async def test_get_nonexistent_task( |
| 108 | + postgres_store: PostgreSQLTaskStore, |
| 109 | +) -> None: |
| 110 | + """Test retrieving a nonexistent task.""" |
| 111 | + |
| 112 | + retrieved_task = await postgres_store.get('nonexistent') |
| 113 | + |
| 114 | + # Verify None was returned |
| 115 | + assert retrieved_task is None |
| 116 | + |
| 117 | + |
| 118 | +@pytest.mark.asyncio |
| 119 | +async def test_delete_task( |
| 120 | + postgres_store: PostgreSQLTaskStore, |
| 121 | +) -> None: |
| 122 | + """Test deleting a task from the PostgreSQL store.""" |
| 123 | + await postgres_store.initialize() |
| 124 | + await postgres_store.delete(MINIMAL_TASK_OBJ.id) |
| 125 | + |
| 126 | + |
| 127 | +@pytest.mark.asyncio |
| 128 | +async def test_delete_nonexistent_task( |
| 129 | + postgres_store: PostgreSQLTaskStore, |
| 130 | +) -> None: |
| 131 | + """Test deleting a nonexistent task.""" |
| 132 | + await postgres_store.initialize() |
| 133 | + await postgres_store.delete('nonexistent') |
| 134 | + |
| 135 | + |
| 136 | +@pytest.mark.asyncio |
| 137 | +async def test_close_connection_pool( |
| 138 | + postgres_store: PostgreSQLTaskStore, |
| 139 | +) -> None: |
| 140 | + """Test closing the database connection pool.""" |
| 141 | + await postgres_store.close() |
| 142 | + assert postgres_store.pool is None |
| 143 | + |
| 144 | + |
| 145 | +@pytest.mark.asyncio |
| 146 | +async def test_save_and_get_task( |
| 147 | + postgres_store: PostgreSQLTaskStore, |
| 148 | +) -> None: |
| 149 | + """Test for saving and retrieving a task from a real PostgreSQL database.""" |
| 150 | + # Create a unique test task |
| 151 | + test_task = Task( |
| 152 | + id='test-1', |
| 153 | + contextId='test-session-1', |
| 154 | + status=TaskStatus(state=TaskState.submitted), |
| 155 | + kind='task', |
| 156 | + ) |
| 157 | + |
| 158 | + # Save the task |
| 159 | + await postgres_store.save(test_task) |
| 160 | + |
| 161 | + # Retrieve the task |
| 162 | + retrieved_task = await postgres_store.get(test_task.id) |
| 163 | + |
| 164 | + # Verify task was retrieved correctly |
| 165 | + assert retrieved_task is not None |
| 166 | + assert retrieved_task.id == test_task.id |
| 167 | + assert retrieved_task.contextId == test_task.contextId |
| 168 | + assert retrieved_task.status.state == test_task.status.state |
| 169 | + |
| 170 | + # Clean up |
| 171 | + await postgres_store.delete(test_task.id) |
| 172 | + |
| 173 | + # Verify deletion |
| 174 | + deleted_task = await postgres_store.get(test_task.id) |
| 175 | + assert deleted_task is None |
| 176 | + |
| 177 | + |
| 178 | +@pytest.mark.asyncio |
| 179 | +async def test_update_task( |
| 180 | + postgres_store: PostgreSQLTaskStore, |
| 181 | +) -> None: |
| 182 | + """Test for updating a task in a real PostgreSQL database.""" |
| 183 | + # Create a test task |
| 184 | + test_task = Task( |
| 185 | + id='test-2', |
| 186 | + contextId='test-session-2', |
| 187 | + status=TaskStatus(state=TaskState.submitted), |
| 188 | + kind='task', |
| 189 | + ) |
| 190 | + |
| 191 | + # Save the task |
| 192 | + await postgres_store.save(test_task) |
| 193 | + |
| 194 | + # Update the task |
| 195 | + updated_task = test_task.model_copy(deep=True) |
| 196 | + updated_task.status.state = TaskState.completed |
| 197 | + await postgres_store.save(updated_task) |
| 198 | + |
| 199 | + # Retrieve the updated task |
| 200 | + retrieved_task = await postgres_store.get(test_task.id) |
| 201 | + |
| 202 | + # Verify the update was successful |
| 203 | + assert retrieved_task is not None |
| 204 | + assert retrieved_task.status.state == TaskState.completed |
| 205 | + |
| 206 | + # Clean up |
| 207 | + await postgres_store.delete(test_task.id) |
0 commit comments