Skip to content

Commit 758e80e

Browse files
committed
Refactor project to support multiple database backends
This commit introduces optional dependencies for PostgreSQL, MySQL, and SQLite support in the A2A SDK. Key changes include: - Addition of optional dependencies in `pyproject.toml` for database drivers. - Updates to example applications to demonstrate installation and usage with different databases. - Refactoring of task management to utilize a generic `DatabaseTaskStore` for improved flexibility. - Enhancements to the README for clearer instructions on database support. These changes enhance the SDK's versatility, allowing users to choose their preferred database backend.
1 parent 7aeef85 commit 758e80e

File tree

11 files changed

+1101
-744
lines changed

11 files changed

+1101
-744
lines changed

README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@ When you're working within a uv project or a virtual environment managed by uv,
3232
uv add a2a-sdk
3333
```
3434

35+
To install with database support:
36+
```bash
37+
# PostgreSQL support
38+
uv add "a2a-sdk[postgresql]"
39+
40+
# MySQL support
41+
uv add "a2a-sdk[mysql]"
42+
43+
# SQLite support
44+
uv add "a2a-sdk[sqlite]"
45+
46+
# All database drivers
47+
uv add "a2a-sdk[sql]"
48+
```
49+
3550
### Using `pip`
3651

3752
If you prefer to use pip, the standard Python package installer, you can install `a2a-sdk` as follows
@@ -40,6 +55,21 @@ If you prefer to use pip, the standard Python package installer, you can install
4055
pip install a2a-sdk
4156
```
4257

58+
To install with database support:
59+
```bash
60+
# PostgreSQL support
61+
pip install "a2a-sdk[postgresql]"
62+
63+
# MySQL support
64+
pip install "a2a-sdk[mysql]"
65+
66+
# SQLite support
67+
pip install "a2a-sdk[sqlite]"
68+
69+
# All database drivers
70+
pip install "a2a-sdk[sql]"
71+
```
72+
4373
## Examples
4474

4575
### [Helloworld Example](https://github.com/google/a2a-python/tree/main/examples/helloworld)

examples/google_adk/birthday_planner/__main__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111

1212
from a2a.server.apps import A2AStarletteApplication
1313
from a2a.server.request_handlers import DefaultRequestHandler
14-
from a2a.server.tasks import InMemoryTaskStore, DatabaseTaskStore # MODIFIED
14+
from a2a.server.tasks import DatabaseTaskStore, InMemoryTaskStore
1515
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
16+
17+
1618
# os is already imported
1719

1820
load_dotenv()
@@ -66,14 +68,16 @@ def main(host: str, port: int, calendar_agent: str):
6668
skills=[skill],
6769
)
6870

69-
database_url = os.environ.get("DATABASE_URL")
71+
database_url = os.environ.get('DATABASE_URL')
7072
task_store_instance: InMemoryTaskStore | DatabaseTaskStore
7173

7274
if database_url:
73-
print(f"Using DatabaseTaskStore with URL: {database_url} in {__file__}")
74-
task_store_instance = DatabaseTaskStore(db_url=database_url, create_table=True)
75+
print(f'Using DatabaseTaskStore with URL: {database_url} in {__file__}')
76+
task_store_instance = DatabaseTaskStore(
77+
db_url=database_url, create_table=True
78+
)
7579
else:
76-
print(f"DATABASE_URL not set in {__file__}, using InMemoryTaskStore.")
80+
print(f'DATABASE_URL not set in {__file__}, using InMemoryTaskStore.')
7781
task_store_instance = InMemoryTaskStore()
7882

7983
request_handler = DefaultRequestHandler(

examples/google_adk/calendar_agent/__main__.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525

2626
from a2a.server.apps import A2AStarletteApplication
2727
from a2a.server.request_handlers import DefaultRequestHandler
28-
from a2a.server.tasks import InMemoryTaskStore, DatabaseTaskStore # MODIFIED
28+
from a2a.server.tasks import DatabaseTaskStore, InMemoryTaskStore
2929
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
30+
31+
3032
# os is already imported
3133

3234
load_dotenv()
@@ -67,10 +69,12 @@ def main(host: str, port: int):
6769
skills=[skill],
6870
)
6971

70-
adk_agent = asyncio.run(create_agent(
71-
client_id=os.getenv('GOOGLE_CLIENT_ID'),
72-
client_secret=os.getenv('GOOGLE_CLIENT_SECRET'),
73-
))
72+
adk_agent = asyncio.run(
73+
create_agent(
74+
client_id=os.getenv('GOOGLE_CLIENT_ID'),
75+
client_secret=os.getenv('GOOGLE_CLIENT_SECRET'),
76+
)
77+
)
7478
runner = Runner(
7579
app_name=agent_card.name,
7680
agent=adk_agent,
@@ -86,14 +90,16 @@ async def handle_auth(request: Request) -> PlainTextResponse:
8690
)
8791
return PlainTextResponse('Authentication successful.')
8892

89-
database_url = os.environ.get("DATABASE_URL")
93+
database_url = os.environ.get('DATABASE_URL')
9094
task_store_instance: InMemoryTaskStore | DatabaseTaskStore
9195

9296
if database_url:
93-
print(f"Using DatabaseTaskStore with URL: {database_url} in {__file__}")
94-
task_store_instance = DatabaseTaskStore(db_url=database_url, create_table=True)
97+
print(f'Using DatabaseTaskStore with URL: {database_url} in {__file__}')
98+
task_store_instance = DatabaseTaskStore(
99+
db_url=database_url, create_table=True
100+
)
95101
else:
96-
print(f"DATABASE_URL not set in {__file__}, using InMemoryTaskStore.")
102+
print(f'DATABASE_URL not set in {__file__}, using InMemoryTaskStore.')
97103
task_store_instance = InMemoryTaskStore()
98104

99105
request_handler = DefaultRequestHandler(

examples/helloworld/__main__.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,15 @@
4646
# It includes the additional 'extended_skill'
4747
specific_extended_agent_card = public_agent_card.model_copy(
4848
update={
49-
'name': 'Hello World Agent - Extended Edition', # Different name for clarity
49+
'name': 'Hello World Agent - Extended Edition', # Different name for clarity
5050
'description': 'The full-featured hello world agent for authenticated users.',
51-
'version': '1.0.1', # Could even be a different version
51+
'version': '1.0.1', # Could even be a different version
5252
# Capabilities and other fields like url, defaultInputModes, defaultOutputModes,
5353
# supportsAuthenticatedExtendedCard are inherited from public_agent_card unless specified here.
54-
'skills': [skill, extended_skill], # Both skills for the extended card
54+
'skills': [
55+
skill,
56+
extended_skill,
57+
], # Both skills for the extended card
5558
}
5659
)
5760

@@ -60,9 +63,11 @@
6063
task_store=InMemoryTaskStore(),
6164
)
6265

63-
server = A2AStarletteApplication(agent_card=public_agent_card,
64-
http_handler=request_handler,
65-
extended_agent_card=specific_extended_agent_card)
66+
server = A2AStarletteApplication(
67+
agent_card=public_agent_card,
68+
http_handler=request_handler,
69+
extended_agent_card=specific_extended_agent_card,
70+
)
6671
import uvicorn
6772

68-
uvicorn.run(server.build(), host='0.0.0.0', port=9999)
73+
uvicorn.run(server.build(), host='0.0.0.0', port=9999)

examples/langgraph/__main__.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,14 @@
1010

1111
from a2a.server.apps import A2AStarletteApplication
1212
from a2a.server.request_handlers import DefaultRequestHandler
13-
from a2a.server.tasks import InMemoryPushNotifier, InMemoryTaskStore, DatabaseTaskStore # MODIFIED
13+
from a2a.server.tasks import (
14+
DatabaseTaskStore,
15+
InMemoryPushNotifier,
16+
InMemoryTaskStore,
17+
)
1418
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
19+
20+
1521
# os is already imported
1622

1723
load_dotenv()
@@ -25,22 +31,24 @@ def main(host: str, port: int):
2531
print('GOOGLE_API_KEY environment variable not set.')
2632
sys.exit(1)
2733

28-
client = httpx.AsyncClient() # This is for the push_notifier
34+
client = httpx.AsyncClient() # This is for the push_notifier
2935

30-
database_url = os.environ.get("DATABASE_URL")
36+
database_url = os.environ.get('DATABASE_URL')
3137
task_store_instance: InMemoryTaskStore | DatabaseTaskStore
3238

3339
if database_url:
34-
print(f"Using DatabaseTaskStore with URL: {database_url} in {__file__}")
35-
task_store_instance = DatabaseTaskStore(db_url=database_url, create_table=True)
40+
print(f'Using DatabaseTaskStore with URL: {database_url} in {__file__}')
41+
task_store_instance = DatabaseTaskStore(
42+
db_url=database_url, create_table=True
43+
)
3644
else:
37-
print(f"DATABASE_URL not set in {__file__}, using InMemoryTaskStore.")
45+
print(f'DATABASE_URL not set in {__file__}, using InMemoryTaskStore.')
3846
task_store_instance = InMemoryTaskStore()
3947

4048
request_handler = DefaultRequestHandler(
4149
agent_executor=CurrencyAgentExecutor(),
4250
task_store=task_store_instance,
43-
push_notifier=InMemoryPushNotifier(client), # Preserving push_notifier
51+
push_notifier=InMemoryPushNotifier(client), # Preserving push_notifier
4452
)
4553

4654
server = A2AStarletteApplication(

pyproject.toml

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,11 @@ authors = [{ name = "Google LLC", email = "[email protected]" }]
88
requires-python = ">=3.10"
99
keywords = ["A2A", "A2A SDK", "A2A Protocol", "Agent2Agent"]
1010
dependencies = [
11-
"asyncpg>=0.30.0",
1211
"httpx>=0.28.1",
1312
"httpx-sse>=0.4.0",
1413
"opentelemetry-api>=1.33.0",
1514
"opentelemetry-sdk>=1.33.0",
1615
"pydantic>=2.11.3",
17-
"sqlalchemy>=2.0.0", # Added SQLAlchemy
18-
"aiosqlite>=0.19.0", # Added aiosqlite for SQLite async
19-
"aiomysql>=0.2.0", # Added aiomysql for MySQL async
2016
"sse-starlette>=2.3.3",
2117
"starlette>=0.46.2",
2218
"typing-extensions>=4.13.2",
@@ -35,6 +31,26 @@ classifiers = [
3531
"License :: OSI Approved :: Apache Software License",
3632
]
3733

34+
[project.optional-dependencies]
35+
postgresql = [
36+
"sqlalchemy>=2.0.0",
37+
"asyncpg>=0.30.0",
38+
]
39+
mysql = [
40+
"sqlalchemy>=2.0.0",
41+
"aiomysql>=0.2.0",
42+
]
43+
sqlite = [
44+
"sqlalchemy>=2.0.0",
45+
"aiosqlite>=0.19.0",
46+
]
47+
sql = [
48+
"sqlalchemy>=2.0.0",
49+
"asyncpg>=0.30.0",
50+
"aiomysql>=0.2.0",
51+
"aiosqlite>=0.19.0",
52+
]
53+
3854
[project.urls]
3955
homepage = "https://google.github.io/A2A/"
4056
repository = "https://github.com/google/a2a-python"
@@ -79,7 +95,9 @@ members = [
7995
dev = [
8096
"asyncpg-stubs>=0.30.1",
8197
"datamodel-code-generator>=0.30.0",
98+
"greenlet>=3.2.2",
8299
"mypy>=1.15.0",
100+
"nox>=2025.5.1",
83101
"pytest>=8.3.5",
84102
"pytest-asyncio>=0.26.0",
85103
"pytest-cov>=6.1.1",

src/a2a/server/models.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,35 @@
1-
from sqlalchemy import Column, String, JSON
2-
from sqlalchemy.ext.declarative import declarative_base
3-
from sqlalchemy.dialects.postgresql import JSONB # For PostgreSQL specific JSON type, can be generic JSON too
1+
try:
2+
from sqlalchemy import JSON, Column, String
3+
from sqlalchemy.orm import declarative_base
4+
except ImportError as e:
5+
raise ImportError(
6+
'Database models require SQLAlchemy. '
7+
'Install with one of: '
8+
"'pip install a2a-sdk[postgresql]', "
9+
"'pip install a2a-sdk[mysql]', "
10+
"'pip install a2a-sdk[sqlite]', "
11+
"or 'pip install a2a-sdk[sql]'"
12+
) from e
13+
414

515
Base = declarative_base()
616

17+
718
class TaskModel(Base):
8-
__tablename__ = "tasks"
19+
__tablename__ = 'tasks'
920

1021
id = Column(String, primary_key=True, index=True)
1122
contextId = Column(String, nullable=False)
1223
kind = Column(String, nullable=False, default='task')
13-
14-
# Storing Pydantic models as JSONB for flexibility
15-
# SQLAlchemy's JSON type is generally fine, JSONB is a PostgreSQL optimization
16-
# For broader compatibility, we might stick to JSON or use a custom type if needed.
17-
status = Column(JSONB) # Stores TaskStatus as JSON
18-
artifacts = Column(JSONB, nullable=True) # Stores list[Artifact] as JSON
19-
history = Column(JSONB, nullable=True) # Stores list[Message] as JSON
20-
metadata = Column(JSONB, nullable=True) # Stores dict[str, Any] as JSON
24+
25+
# Using generic JSON type for database-agnostic storage
26+
# This works with PostgreSQL, MySQL, SQLite, and other databases
27+
status = Column(JSON) # Stores TaskStatus as JSON
28+
artifacts = Column(JSON, nullable=True) # Stores list[Artifact] as JSON
29+
history = Column(JSON, nullable=True) # Stores list[Message] as JSON
30+
task_metadata = Column(
31+
JSON, nullable=True, name='metadata'
32+
) # Stores dict[str, Any] as JSON
2133

2234
def __repr__(self):
2335
return f"<TaskModel(id='{self.id}', contextId='{self.contextId}', status='{self.status}')>"

src/a2a/server/tasks/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Components for managing tasks within the A2A server."""
22

3+
from a2a.server.tasks.database_task_store import DatabaseTaskStore
34
from a2a.server.tasks.inmemory_push_notifier import InMemoryPushNotifier
45
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
5-
from a2a.server.tasks.postgresql_task_store import PostgreSQLTaskStore
66
from a2a.server.tasks.push_notifier import PushNotifier
77
from a2a.server.tasks.result_aggregator import ResultAggregator
88
from a2a.server.tasks.task_manager import TaskManager
@@ -11,9 +11,9 @@
1111

1212

1313
__all__ = [
14+
'DatabaseTaskStore',
1415
'InMemoryPushNotifier',
1516
'InMemoryTaskStore',
16-
'PostgreSQLTaskStore',
1717
'PushNotifier',
1818
'ResultAggregator',
1919
'TaskManager',

0 commit comments

Comments
 (0)