Skip to content

Commit c32f8e7

Browse files
google-labs-jules[bot]zboyles
authored andcommitted
Refactor TaskStore to be database-agnostic using SQLAlchemy
This commit replaces the PostgreSQL-specific TaskStore with a generic `DatabaseTaskStore` that leverages SQLAlchemy for database interactions. This change allows your A2A server to support multiple database backends, including SQLite, PostgreSQL, and MySQL. Key changes include: - Definition of a SQLAlchemy model `TaskModel` for storing task data. - Implementation of `DatabaseTaskStore` that uses the `TaskStore` interface and SQLAlchemy for CRUD operations. - Update of example application configurations to use `DatabaseTaskStore` when a `DATABASE_URL` environment variable is provided, defaulting to `InMemoryTaskStore` otherwise. - Creation of parameterized unit tests for `DatabaseTaskStore`, designed to run against SQLite, PostgreSQL, and MySQL to ensure compatibility. - Removal of the old `PostgreSQLTaskStore` and its specific tests. - Addition of necessary dependencies: `sqlalchemy`, `aiosqlite`, `aiomysql`. The new implementation makes the task persistence layer more flexible and extensible, allowing you to choose a database backend that best suits your needs.
1 parent 46d7a78 commit c32f8e7

File tree

10 files changed

+447
-355
lines changed

10 files changed

+447
-355
lines changed

examples/google_adk/birthday_planner/__main__.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111

1212
from a2a.server.apps import A2AStarletteApplication
1313
from a2a.server.request_handlers import DefaultRequestHandler
14-
from a2a.server.tasks import InMemoryTaskStore
14+
from a2a.server.tasks import InMemoryTaskStore, DatabaseTaskStore # MODIFIED
1515
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
16-
16+
# os is already imported
1717

1818
load_dotenv()
1919

@@ -65,8 +65,19 @@ def main(host: str, port: int, calendar_agent: str):
6565
capabilities=AgentCapabilities(streaming=True),
6666
skills=[skill],
6767
)
68+
69+
database_url = os.environ.get("DATABASE_URL")
70+
task_store_instance: InMemoryTaskStore | DatabaseTaskStore
71+
72+
if database_url:
73+
print(f"Using DatabaseTaskStore with URL: {database_url} in {__file__}")
74+
task_store_instance = DatabaseTaskStore(db_url=database_url, create_table=True)
75+
else:
76+
print(f"DATABASE_URL not set in {__file__}, using InMemoryTaskStore.")
77+
task_store_instance = InMemoryTaskStore()
78+
6879
request_handler = DefaultRequestHandler(
69-
agent_executor=agent_executor, task_store=InMemoryTaskStore()
80+
agent_executor=agent_executor, task_store=task_store_instance
7081
)
7182
app = A2AStarletteApplication(agent_card, request_handler)
7283
uvicorn.run(app.build(), host=host, port=port)

examples/google_adk/calendar_agent/__main__.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525

2626
from a2a.server.apps import A2AStarletteApplication
2727
from a2a.server.request_handlers import DefaultRequestHandler
28-
from a2a.server.tasks import InMemoryTaskStore
28+
from a2a.server.tasks import InMemoryTaskStore, DatabaseTaskStore # MODIFIED
2929
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
30-
30+
# os is already imported
3131

3232
load_dotenv()
3333

@@ -86,8 +86,18 @@ async def handle_auth(request: Request) -> PlainTextResponse:
8686
)
8787
return PlainTextResponse('Authentication successful.')
8888

89+
database_url = os.environ.get("DATABASE_URL")
90+
task_store_instance: InMemoryTaskStore | DatabaseTaskStore
91+
92+
if database_url:
93+
print(f"Using DatabaseTaskStore with URL: {database_url} in {__file__}")
94+
task_store_instance = DatabaseTaskStore(db_url=database_url, create_table=True)
95+
else:
96+
print(f"DATABASE_URL not set in {__file__}, using InMemoryTaskStore.")
97+
task_store_instance = InMemoryTaskStore()
98+
8999
request_handler = DefaultRequestHandler(
90-
agent_executor=agent_executor, task_store=InMemoryTaskStore()
100+
agent_executor=agent_executor, task_store=task_store_instance
91101
)
92102

93103
a2a_app = A2AStarletteApplication(

examples/helloworld/__main__.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
HelloWorldAgentExecutor, # type: ignore[import-untyped]
33
)
44

5+
import os
6+
57
from a2a.server.apps import A2AStarletteApplication
68
from a2a.server.request_handlers import DefaultRequestHandler
79
from a2a.server.tasks import InMemoryTaskStore
@@ -55,9 +57,24 @@
5557
}
5658
)
5759

60+
database_url = os.environ.get("DATABASE_URL")
61+
task_store_instance: InMemoryTaskStore | DatabaseTaskStore
62+
63+
if database_url:
64+
print(f"Using DatabaseTaskStore with URL: {database_url}")
65+
# For this example, we assume create_table=True is desired for the DatabaseTaskStore.
66+
# In a production scenario, schema management might be handled separately (e.g., migrations).
67+
task_store_instance = DatabaseTaskStore(db_url=database_url, create_table=True)
68+
# Note: DatabaseTaskStore.initialize() is async and should ideally be called
69+
# during async app startup (e.g., Starlette's on_startup).
70+
# Here, we rely on its internal _ensure_initialized() called by its methods.
71+
else:
72+
print("DATABASE_URL not set, using InMemoryTaskStore.")
73+
task_store_instance = InMemoryTaskStore()
74+
5875
request_handler = DefaultRequestHandler(
5976
agent_executor=HelloWorldAgentExecutor(),
60-
task_store=InMemoryTaskStore(),
77+
task_store=task_store_instance,
6178
)
6279

6380
server = A2AStarletteApplication(agent_card=public_agent_card,

examples/langgraph/__main__.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010

1111
from a2a.server.apps import A2AStarletteApplication
1212
from a2a.server.request_handlers import DefaultRequestHandler
13-
from a2a.server.tasks import InMemoryPushNotifier, InMemoryTaskStore
13+
from a2a.server.tasks import InMemoryPushNotifier, InMemoryTaskStore, DatabaseTaskStore # MODIFIED
1414
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
15-
15+
# os is already imported
1616

1717
load_dotenv()
1818

@@ -25,11 +25,22 @@ def main(host: str, port: int):
2525
print('GOOGLE_API_KEY environment variable not set.')
2626
sys.exit(1)
2727

28-
client = httpx.AsyncClient()
28+
client = httpx.AsyncClient() # This is for the push_notifier
29+
30+
database_url = os.environ.get("DATABASE_URL")
31+
task_store_instance: InMemoryTaskStore | DatabaseTaskStore
32+
33+
if database_url:
34+
print(f"Using DatabaseTaskStore with URL: {database_url} in {__file__}")
35+
task_store_instance = DatabaseTaskStore(db_url=database_url, create_table=True)
36+
else:
37+
print(f"DATABASE_URL not set in {__file__}, using InMemoryTaskStore.")
38+
task_store_instance = InMemoryTaskStore()
39+
2940
request_handler = DefaultRequestHandler(
3041
agent_executor=CurrencyAgentExecutor(),
31-
task_store=InMemoryTaskStore(),
32-
push_notifier=InMemoryPushNotifier(client),
42+
task_store=task_store_instance,
43+
push_notifier=InMemoryPushNotifier(client), # Preserving push_notifier
3344
)
3445

3546
server = A2AStarletteApplication(

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ dependencies = [
1414
"opentelemetry-api>=1.33.0",
1515
"opentelemetry-sdk>=1.33.0",
1616
"pydantic>=2.11.3",
17+
"sqlalchemy>=2.0.0", # Added SQLAlchemy
18+
"aiosqlite>=0.19.0", # Added aiosqlite for SQLite async
19+
"aiomysql>=0.2.0", # Added aiomysql for MySQL async
1720
"sse-starlette>=2.3.3",
1821
"starlette>=0.46.2",
1922
"typing-extensions>=4.13.2",

src/a2a/server/models.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from sqlalchemy import Column, String, JSON
2+
from sqlalchemy.ext.declarative import declarative_base
3+
from sqlalchemy.dialects.postgresql import JSONB # For PostgreSQL specific JSON type, can be generic JSON too
4+
5+
Base = declarative_base()
6+
7+
class TaskModel(Base):
8+
__tablename__ = "tasks"
9+
10+
id = Column(String, primary_key=True, index=True)
11+
contextId = Column(String, nullable=False)
12+
kind = Column(String, nullable=False, default='task')
13+
14+
# Storing Pydantic models as JSONB for flexibility
15+
# SQLAlchemy's JSON type is generally fine, JSONB is a PostgreSQL optimization
16+
# For broader compatibility, we might stick to JSON or use a custom type if needed.
17+
status = Column(JSONB) # Stores TaskStatus as JSON
18+
artifacts = Column(JSONB, nullable=True) # Stores list[Artifact] as JSON
19+
history = Column(JSONB, nullable=True) # Stores list[Message] as JSON
20+
metadata = Column(JSONB, nullable=True) # Stores dict[str, Any] as JSON
21+
22+
def __repr__(self):
23+
return f"<TaskModel(id='{self.id}', contextId='{self.contextId}', status='{self.status}')>"
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import json
2+
import logging
3+
4+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
5+
from sqlalchemy.orm import sessionmaker
6+
from sqlalchemy import select, delete, update
7+
8+
from a2a.server.models import Base, TaskModel # TaskModel is our SQLAlchemy model
9+
from a2a.server.tasks.task_store import TaskStore
10+
from a2a.types import Task, TaskStatus # Task is the Pydantic model
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class DatabaseTaskStore(TaskStore):
16+
"""
17+
SQLAlchemy-based implementation of TaskStore.
18+
Stores task objects in a database supported by SQLAlchemy.
19+
"""
20+
21+
def __init__(
22+
self,
23+
db_url: str,
24+
create_table: bool = True,
25+
) -> None:
26+
"""
27+
Initializes the DatabaseTaskStore.
28+
29+
Args:
30+
db_url: Database connection string.
31+
create_table: If true, create tasks table on initialization.
32+
"""
33+
logger.debug(f"Initializing DatabaseTaskStore with DB URL: {db_url}")
34+
self.engine = create_async_engine(db_url, echo=False) # Set echo=True for SQL logging
35+
self.async_session_maker = sessionmaker(
36+
self.engine, class_=AsyncSession, expire_on_commit=False
37+
)
38+
self.create_table = create_table
39+
self._initialized = False
40+
41+
async def initialize(self) -> None:
42+
"""
43+
Initialize the database and create the table if needed.
44+
"""
45+
if self._initialized:
46+
return
47+
48+
logger.debug("Initializing database schema...")
49+
if self.create_table:
50+
async with self.engine.begin() as conn:
51+
# This will create the 'tasks' table based on TaskModel's definition
52+
await conn.run_sync(Base.metadata.create_all)
53+
self._initialized = True
54+
logger.debug("Database schema initialized.")
55+
56+
async def close(self) -> None:
57+
"""Close the database connection engine."""
58+
if self.engine:
59+
logger.debug("Closing database engine.")
60+
await self.engine.dispose()
61+
self._initialized = False # Reset initialization status
62+
63+
async def _ensure_initialized(self) -> None:
64+
"""Ensure the database connection is initialized."""
65+
if not self._initialized:
66+
await self.initialize()
67+
68+
async def save(self, task: Task) -> None:
69+
"""Saves or updates a task in the database."""
70+
await self._ensure_initialized()
71+
72+
task_data = task.model_dump() # Converts Pydantic Task to dict
73+
74+
async with self.async_session_maker() as session:
75+
async with session.begin():
76+
stmt_select = select(TaskModel).where(TaskModel.id == task.id)
77+
result = await session.execute(stmt_select)
78+
existing_task_model = result.scalar_one_or_none()
79+
80+
if existing_task_model:
81+
logger.debug(f"Updating task {task.id} in the database.")
82+
update_data = {
83+
"contextId": task_data["contextId"],
84+
"kind": task_data["kind"],
85+
"status": task_data["status"], # Already a dict from model_dump
86+
"artifacts": task_data.get("artifacts"), # Already a list of dicts
87+
"history": task_data.get("history"), # Already a list of dicts
88+
"metadata": task_data.get("metadata"), # Already a dict
89+
}
90+
stmt_update = update(TaskModel).where(TaskModel.id == task.id).values(**update_data)
91+
await session.execute(stmt_update)
92+
else:
93+
logger.debug(f"Saving new task {task.id} to the database.")
94+
# Filter task_data to include only columns present in TaskModel
95+
task_model_columns = {col.name for col in TaskModel.__table__.columns}
96+
filtered_task_data = {k: v for k, v in task_data.items() if k in task_model_columns}
97+
new_task_model = TaskModel(**filtered_task_data)
98+
session.add(new_task_model)
99+
100+
await session.commit()
101+
logger.info(f"Task {task.id} saved successfully.")
102+
103+
async def get(self, task_id: str) -> Task | None:
104+
"""Retrieves a task from the database by ID."""
105+
await self._ensure_initialized()
106+
107+
async with self.async_session_maker() as session:
108+
stmt = select(TaskModel).where(TaskModel.id == task_id)
109+
result = await session.execute(stmt)
110+
task_model = result.scalar_one_or_none()
111+
112+
if task_model:
113+
task_data_from_db = {
114+
column.name: getattr(task_model, column.name)
115+
for column in task_model.__table__.columns
116+
}
117+
# Pydantic's model_validate will parse the nested dicts/lists from JSON
118+
task = Task.model_validate(task_data_from_db)
119+
logger.debug(f"Task {task_id} retrieved successfully.")
120+
return task
121+
122+
logger.debug(f"Task {task_id} not found in store.")
123+
return None
124+
125+
async def delete(self, task_id: str) -> None:
126+
"""Deletes a task from the database by ID."""
127+
await self._ensure_initialized()
128+
129+
async with self.async_session_maker() as session:
130+
async with session.begin():
131+
stmt = delete(TaskModel).where(TaskModel.id == task_id)
132+
result = await session.execute(stmt)
133+
await session.commit()
134+
135+
if result.rowcount > 0:
136+
logger.info(f"Task {task_id} deleted successfully.")
137+
else:
138+
logger.warning(
139+
f"Attempted to delete nonexistent task with id: {task_id}"
140+
)

0 commit comments

Comments
 (0)