Skip to content

Commit 433bb5e

Browse files
committed
feat: postgresql task store
1 parent 5613ce4 commit 433bb5e

File tree

4 files changed

+379
-0
lines changed

4 files changed

+379
-0
lines changed

docker/postgres/docker-compose.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
version: "3.8"
2+
3+
services:
4+
postgres:
5+
image: postgres:15-alpine
6+
ports:
7+
- "5432:5432"
8+
environment:
9+
- POSTGRES_USER=postgres
10+
- POSTGRES_PASSWORD=postgres
11+
- POSTGRES_DB=a2a_test
12+
volumes:
13+
- postgres_data:/var/lib/postgresql/data
14+
- ./docker/postgres/init.sql:/docker-entrypoint-initdb.d/init.sql
15+
networks:
16+
- a2a-network
17+
healthcheck:
18+
test: ["CMD-SHELL", "pg_isready -U postgres"]
19+
interval: 5s
20+
timeout: 5s
21+
retries: 5
22+
23+
volumes:
24+
postgres_data:
25+
26+
networks:
27+
a2a-network:
28+
driver: bridge

docker/postgres/init.sql

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-- Create a dedicated user for the application
2+
CREATE USER a2a WITH PASSWORD 'a2a_password';
3+
4+
-- Create the tasks database
5+
CREATE DATABASE a2a_tasks;
6+
7+
GRANT ALL PRIVILEGES ON DATABASE a2a_test TO a2a;
8+
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import json
2+
import logging
3+
4+
import asyncpg
5+
6+
from a2a.server.tasks.task_store import TaskStore
7+
from a2a.types import Task
8+
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class PostgreSQLTaskStore(TaskStore):
14+
"""PostgreSQL implementation of TaskStore.
15+
16+
Stores task objects in a PostgreSQL database.
17+
"""
18+
19+
def __init__(
20+
self,
21+
url: str,
22+
table_name: str = 'tasks',
23+
create_table: bool = True,
24+
) -> None:
25+
"""Initializes the PostgreSQLTaskStore.
26+
27+
Args:
28+
url: PostgreSQL connection string in the format:
29+
postgresql://username:password@hostname:port/database
30+
table_name: The name of the table to store tasks in
31+
create_table: Whether to create the table if it doesn't exist
32+
"""
33+
logger.debug('Initializing PostgreSQLTaskStore')
34+
self.url = url
35+
self.table_name = table_name
36+
self.create_table = create_table
37+
38+
self.pool: asyncpg.Pool | None = None
39+
40+
async def initialize(self) -> None:
41+
"""Initialize the database connection pool and create the table
42+
if needed.
43+
"""
44+
if self.pool is not None:
45+
return
46+
47+
logger.debug('Creating connection pool')
48+
self.pool = await asyncpg.create_pool(self.url)
49+
50+
if self.create_table:
51+
async with self.pool.acquire() as conn:
52+
logger.debug('Creating tasks table if not exists')
53+
await conn.execute(
54+
f"""
55+
CREATE TABLE IF NOT EXISTS {self.table_name} (
56+
id TEXT PRIMARY KEY,
57+
data JSONB NOT NULL
58+
59+
)
60+
"""
61+
)
62+
63+
async def close(self) -> None:
64+
"""Close the database connection pool."""
65+
if self.pool is not None:
66+
await self.pool.close()
67+
self.pool = None
68+
69+
async def save(self, task: Task) -> None:
70+
"""Saves or updates a task in the PostgreSQL store."""
71+
await self._ensure_initialized()
72+
73+
assert self.pool is not None
74+
async with self.pool.acquire() as conn, conn.transaction():
75+
task_json = task.model_dump()
76+
77+
await conn.execute(
78+
f"""
79+
INSERT INTO {self.table_name} (id, data)
80+
VALUES ($1, $2)
81+
ON CONFLICT (id) DO UPDATE
82+
SET data = $2
83+
""",
84+
task.id,
85+
json.dumps(task_json),
86+
)
87+
88+
logger.info('Task %s saved successfully.', task.id)
89+
90+
async def get(self, task_id: str) -> Task | None:
91+
"""Retrieves a task from the PostgreSQL store by ID."""
92+
await self._ensure_initialized()
93+
94+
assert self.pool is not None
95+
async with self.pool.acquire() as conn, conn.transaction():
96+
logger.debug('Attempting to get task with id: %s', task_id)
97+
98+
row = await conn.fetchrow(
99+
f'SELECT data FROM {self.table_name} WHERE id = $1',
100+
task_id,
101+
)
102+
103+
if row:
104+
task_json = json.loads(row['data'])
105+
task = Task.model_validate(task_json)
106+
logger.debug('Task %s retrieved successfully.', task_id)
107+
return task
108+
109+
logger.debug('Task %s not found in store.', task_id)
110+
return None
111+
112+
async def delete(self, task_id: str) -> None:
113+
"""Deletes a task from the PostgreSQL store by ID."""
114+
await self._ensure_initialized()
115+
116+
assert self.pool is not None
117+
async with self.pool.acquire() as conn, conn.transaction():
118+
logger.debug('Attempting to delete task with id: %s', task_id)
119+
120+
result = await conn.execute(
121+
f'DELETE FROM {self.table_name} WHERE id = $1',
122+
task_id,
123+
)
124+
125+
if result.split()[-1] != '0': # Check if rows were affected
126+
logger.info('Task %s deleted successfully.', task_id)
127+
else:
128+
logger.warning(
129+
'Attempted to delete nonexistent task with id: %s',
130+
task_id,
131+
)
132+
133+
async def _ensure_initialized(self) -> None:
134+
"""Ensure the database connection is initialized."""
135+
if self.pool is None:
136+
await self.initialize()
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import json
2+
import os
3+
from typing import AsyncGenerator
4+
5+
import pytest
6+
import pytest_asyncio
7+
8+
from a2a.server.tasks.postgresql_task_store import PostgreSQLTaskStore
9+
from a2a.types import Task, TaskState, TaskStatus
10+
11+
# Use a proper Task object instead of a dict for the minimal task
12+
task_status = TaskStatus(state=TaskState.submitted)
13+
MINIMAL_TASK_OBJ = Task(
14+
id='task-abc',
15+
contextId='session-xyz',
16+
status=task_status,
17+
kind='task',
18+
)
19+
20+
21+
# Get PostgreSQL connection string from environment or use a default for testing
22+
POSTGRES_TEST_DSN = os.environ.get(
23+
'POSTGRES_TEST_DSN',
24+
'postgresql://postgres:postgres@localhost:5432/a2a_test',
25+
)
26+
27+
28+
@pytest_asyncio.fixture
29+
async def postgres_store() -> AsyncGenerator[PostgreSQLTaskStore, None]:
30+
"""Fixture that provides a PostgreSQLTaskStore connected to a real database.
31+
32+
This fixture requires a running PostgreSQL instance
33+
"""
34+
35+
store = PostgreSQLTaskStore(POSTGRES_TEST_DSN)
36+
await store.initialize()
37+
38+
# Clean up any test data that might be left from previous runs
39+
if store.pool is not None:
40+
async with store.pool.acquire() as conn:
41+
await conn.execute(
42+
f"DELETE FROM {store.table_name} WHERE id LIKE 'test-%'"
43+
)
44+
45+
yield store
46+
await store.close()
47+
48+
49+
@pytest.mark.asyncio
50+
async def test_initialize_creates_table(
51+
postgres_store: PostgreSQLTaskStore,
52+
) -> None:
53+
"""Test that initialize creates the table if it doesn't exist."""
54+
await postgres_store.initialize()
55+
56+
# Verify the pool was created
57+
assert postgres_store.pool is not None
58+
59+
# Verify the table creation query was executed
60+
async with postgres_store.pool.acquire() as conn:
61+
async with conn.transaction():
62+
exists = await conn.fetchval(
63+
f"SELECT EXISTS (SELECT 1 FROM pg_tables WHERE tablename = '{postgres_store.table_name}')"
64+
)
65+
assert exists
66+
67+
68+
@pytest.mark.asyncio
69+
async def test_save_task(postgres_store: PostgreSQLTaskStore) -> None:
70+
"""Test saving a task to the PostgreSQL store."""
71+
# Use the pre-created Task object to avoid serialization issues
72+
task = MINIMAL_TASK_OBJ
73+
await postgres_store.save(task)
74+
assert postgres_store.pool is not None
75+
76+
# Verify the insert query was executed
77+
async with postgres_store.pool.acquire() as conn:
78+
async with conn.transaction():
79+
row = await conn.fetchrow(
80+
f'SELECT data FROM {postgres_store.table_name} WHERE id = $1',
81+
task.id,
82+
)
83+
assert row is not None
84+
# Convert the task to a dictionary with proper enum handling
85+
86+
# Parse the JSON string from the database
87+
db_dict = (
88+
json.loads(row['data'])
89+
if isinstance(row['data'], str)
90+
else row['data']
91+
)
92+
assert db_dict == task.model_dump()
93+
94+
95+
@pytest.mark.asyncio
96+
async def test_get_task(postgres_store: PostgreSQLTaskStore) -> None:
97+
"""Test retrieving a task from the PostgreSQL store."""
98+
retrieved_task = await postgres_store.get(MINIMAL_TASK_OBJ.id)
99+
100+
# Verify the task was correctly reconstructed
101+
assert retrieved_task is not None
102+
assert retrieved_task.id == MINIMAL_TASK_OBJ.id
103+
assert retrieved_task.contextId == MINIMAL_TASK_OBJ.contextId
104+
105+
106+
@pytest.mark.asyncio
107+
async def test_get_nonexistent_task(
108+
postgres_store: PostgreSQLTaskStore,
109+
) -> None:
110+
"""Test retrieving a nonexistent task."""
111+
112+
retrieved_task = await postgres_store.get('nonexistent')
113+
114+
# Verify None was returned
115+
assert retrieved_task is None
116+
117+
118+
@pytest.mark.asyncio
119+
async def test_delete_task(
120+
postgres_store: PostgreSQLTaskStore,
121+
) -> None:
122+
"""Test deleting a task from the PostgreSQL store."""
123+
await postgres_store.initialize()
124+
await postgres_store.delete(MINIMAL_TASK_OBJ.id)
125+
126+
127+
@pytest.mark.asyncio
128+
async def test_delete_nonexistent_task(
129+
postgres_store: PostgreSQLTaskStore,
130+
) -> None:
131+
"""Test deleting a nonexistent task."""
132+
await postgres_store.initialize()
133+
await postgres_store.delete('nonexistent')
134+
135+
136+
@pytest.mark.asyncio
137+
async def test_close_connection_pool(
138+
postgres_store: PostgreSQLTaskStore,
139+
) -> None:
140+
"""Test closing the database connection pool."""
141+
await postgres_store.close()
142+
assert postgres_store.pool is None
143+
144+
145+
@pytest.mark.asyncio
146+
async def test_save_and_get_task(
147+
postgres_store: PostgreSQLTaskStore,
148+
) -> None:
149+
"""Test for saving and retrieving a task from a real PostgreSQL database."""
150+
# Create a unique test task
151+
test_task = Task(
152+
id='test-1',
153+
contextId='test-session-1',
154+
status=TaskStatus(state=TaskState.submitted),
155+
kind='task',
156+
)
157+
158+
# Save the task
159+
await postgres_store.save(test_task)
160+
161+
# Retrieve the task
162+
retrieved_task = await postgres_store.get(test_task.id)
163+
164+
# Verify task was retrieved correctly
165+
assert retrieved_task is not None
166+
assert retrieved_task.id == test_task.id
167+
assert retrieved_task.contextId == test_task.contextId
168+
assert retrieved_task.status.state == test_task.status.state
169+
170+
# Clean up
171+
await postgres_store.delete(test_task.id)
172+
173+
# Verify deletion
174+
deleted_task = await postgres_store.get(test_task.id)
175+
assert deleted_task is None
176+
177+
178+
@pytest.mark.asyncio
179+
async def test_update_task(
180+
postgres_store: PostgreSQLTaskStore,
181+
) -> None:
182+
"""Test for updating a task in a real PostgreSQL database."""
183+
# Create a test task
184+
test_task = Task(
185+
id='test-2',
186+
contextId='test-session-2',
187+
status=TaskStatus(state=TaskState.submitted),
188+
kind='task',
189+
)
190+
191+
# Save the task
192+
await postgres_store.save(test_task)
193+
194+
# Update the task
195+
updated_task = test_task.model_copy(deep=True)
196+
updated_task.status.state = TaskState.completed
197+
await postgres_store.save(updated_task)
198+
199+
# Retrieve the updated task
200+
retrieved_task = await postgres_store.get(test_task.id)
201+
202+
# Verify the update was successful
203+
assert retrieved_task is not None
204+
assert retrieved_task.status.state == TaskState.completed
205+
206+
# Clean up
207+
await postgres_store.delete(test_task.id)

0 commit comments

Comments
 (0)