Skip to content

Commit 5c8dac4

Browse files
committed
added sqlalchemy plumbing
1 parent 06d20e7 commit 5c8dac4

File tree

4 files changed

+287
-0
lines changed

4 files changed

+287
-0
lines changed

autogpt_platform/backend/.env.default

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME
1717
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
1818
PRISMA_SCHEMA="postgres/schema.prisma"
1919

20+
# SQLAlchemy Configuration (for gradual migration from Prisma)
21+
SQLALCHEMY_POOL_SIZE=10
22+
SQLALCHEMY_MAX_OVERFLOW=5
23+
SQLALCHEMY_POOL_TIMEOUT=30
24+
SQLALCHEMY_CONNECT_TIMEOUT=10
25+
SQLALCHEMY_ECHO=false
26+
2027
## ===== REQUIRED SERVICE CREDENTIALS ===== ##
2128
# Redis Configuration
2229
REDIS_HOST=localhost
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
"""
2+
SQLAlchemy infrastructure for AutoGPT Platform.
3+
4+
This module provides:
5+
1. Async engine creation with connection pooling
6+
2. Session factory for dependency injection
7+
3. Database lifecycle management
8+
"""
9+
10+
import logging
11+
import re
12+
from typing import AsyncGenerator
13+
14+
from sqlalchemy.ext.asyncio import (
15+
AsyncEngine,
16+
AsyncSession,
17+
async_sessionmaker,
18+
create_async_engine,
19+
)
20+
from sqlalchemy.pool import QueuePool
21+
22+
from backend.util.settings import Config
23+
24+
logger = logging.getLogger(__name__)
25+
26+
# ============================================================================
27+
# CONFIGURATION
28+
# ============================================================================
29+
30+
31+
def get_database_url() -> str:
32+
"""
33+
Extract database URL from environment and convert to async format.
34+
35+
Prisma URL: postgresql://user:pass@host:port/db?schema=platform
36+
Async URL: postgresql+asyncpg://user:pass@host:port/db
37+
38+
Returns the async-compatible URL without schema parameter (handled separately).
39+
"""
40+
prisma_url = Config().database_url
41+
42+
# Replace postgresql:// with postgresql+asyncpg://
43+
async_url = prisma_url.replace("postgresql://", "postgresql+asyncpg://")
44+
45+
# Remove schema parameter (we'll handle via MetaData)
46+
async_url = re.sub(r"\?schema=\w+", "", async_url)
47+
48+
# Remove any remaining query parameters that might conflict
49+
async_url = re.sub(r"&schema=\w+", "", async_url)
50+
51+
return async_url
52+
53+
54+
def get_database_schema() -> str:
55+
"""
56+
Extract schema name from DATABASE_URL query parameter.
57+
58+
Returns 'platform' by default (matches Prisma configuration).
59+
"""
60+
prisma_url = Config().database_url
61+
match = re.search(r"schema=(\w+)", prisma_url)
62+
return match.group(1) if match else "platform"
63+
64+
65+
# ============================================================================
66+
# ENGINE CREATION
67+
# ============================================================================
68+
69+
70+
def create_engine() -> AsyncEngine:
71+
"""
72+
Create async SQLAlchemy engine with connection pooling.
73+
74+
This should be called ONCE per process at startup.
75+
The engine is long-lived and thread-safe.
76+
77+
Connection Pool Configuration:
78+
- pool_size: Number of persistent connections (default: 10)
79+
- max_overflow: Additional connections when pool exhausted (default: 5)
80+
- pool_timeout: Seconds to wait for connection (default: 30)
81+
- pool_pre_ping: Test connections before using (prevents stale connections)
82+
83+
Total max connections = pool_size + max_overflow = 15
84+
"""
85+
url = get_database_url()
86+
config = Config()
87+
88+
engine = create_async_engine(
89+
url,
90+
# Connection pool configuration
91+
poolclass=QueuePool, # Standard connection pool
92+
pool_size=config.sqlalchemy_pool_size, # Persistent connections
93+
max_overflow=config.sqlalchemy_max_overflow, # Burst capacity
94+
pool_timeout=config.sqlalchemy_pool_timeout, # Wait time for connection
95+
pool_pre_ping=True, # Validate connections before use
96+
# Async configuration
97+
echo=config.sqlalchemy_echo, # Log SQL statements (dev/debug only)
98+
future=True, # Use SQLAlchemy 2.0 style
99+
# Connection arguments (passed to asyncpg)
100+
connect_args={
101+
"server_settings": {
102+
"search_path": get_database_schema(), # Use 'platform' schema
103+
},
104+
"timeout": config.sqlalchemy_connect_timeout, # Connection timeout
105+
},
106+
)
107+
108+
logger.info(
109+
f"SQLAlchemy engine created: pool_size={config.sqlalchemy_pool_size}, "
110+
f"max_overflow={config.sqlalchemy_max_overflow}, "
111+
f"schema={get_database_schema()}"
112+
)
113+
114+
return engine
115+
116+
117+
# ============================================================================
118+
# SESSION FACTORY
119+
# ============================================================================
120+
121+
122+
def create_session_factory(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]:
123+
"""
124+
Create session factory for creating AsyncSession instances.
125+
126+
The factory is configured once, then used to create sessions on-demand.
127+
Each session represents a single database transaction.
128+
129+
Args:
130+
engine: The async engine (with connection pool)
131+
132+
Returns:
133+
Session factory that creates properly configured AsyncSession instances
134+
"""
135+
return async_sessionmaker(
136+
bind=engine,
137+
class_=AsyncSession,
138+
expire_on_commit=False, # Don't expire objects after commit
139+
autoflush=False, # Manual control over when to flush
140+
autocommit=False, # Explicit transaction control
141+
)
142+
143+
144+
# ============================================================================
145+
# DEPENDENCY INJECTION FOR FASTAPI
146+
# ============================================================================
147+
148+
# Global references (set during app startup)
149+
_engine: AsyncEngine | None = None
150+
_session_factory: async_sessionmaker[AsyncSession] | None = None
151+
152+
153+
def initialize(engine: AsyncEngine) -> None:
154+
"""
155+
Initialize global engine and session factory.
156+
157+
Called during FastAPI lifespan startup.
158+
159+
Args:
160+
engine: The async engine to use for this process
161+
"""
162+
global _engine, _session_factory
163+
_engine = engine
164+
_session_factory = create_session_factory(engine)
165+
logger.info("SQLAlchemy session factory initialized")
166+
167+
168+
async def get_session() -> AsyncGenerator[AsyncSession, None]:
169+
"""
170+
FastAPI dependency that provides database session.
171+
172+
Usage in routes:
173+
@router.get("/users/{user_id}")
174+
async def get_user(
175+
user_id: int,
176+
session: AsyncSession = Depends(get_session)
177+
):
178+
result = await session.execute(select(User).where(User.id == user_id))
179+
return result.scalar_one_or_none()
180+
181+
Usage in DatabaseManager RPC methods:
182+
@expose
183+
async def get_user(user_id: int):
184+
async with get_session() as session:
185+
result = await session.execute(select(User).where(User.id == user_id))
186+
return result.scalar_one_or_none()
187+
188+
Lifecycle:
189+
1. Request arrives
190+
2. FastAPI calls this function (or used as context manager)
191+
3. Session is created (borrows connection from pool)
192+
4. Session is injected into route handler
193+
5. Route executes (may commit/rollback)
194+
6. Route returns
195+
7. Session is closed (returns connection to pool)
196+
197+
Error handling:
198+
- If exception occurs, session is rolled back
199+
- Connection is always returned to pool (even on error)
200+
"""
201+
if _session_factory is None:
202+
raise RuntimeError(
203+
"SQLAlchemy not initialized. Call initialize() in lifespan context."
204+
)
205+
206+
# Create session (borrows connection from pool)
207+
async with _session_factory() as session:
208+
try:
209+
yield session # Inject into route handler or context manager
210+
# If we get here, route succeeded - commit any pending changes
211+
await session.commit()
212+
except Exception:
213+
# Error occurred - rollback transaction
214+
await session.rollback()
215+
raise
216+
finally:
217+
# Always close session (returns connection to pool)
218+
await session.close()
219+
220+
221+
async def dispose() -> None:
222+
"""
223+
Dispose of engine and close all connections.
224+
225+
Called during FastAPI lifespan shutdown.
226+
Closes all connections in the pool gracefully.
227+
"""
228+
global _engine, _session_factory
229+
230+
if _engine is not None:
231+
logger.info("Disposing SQLAlchemy engine...")
232+
await _engine.dispose()
233+
_engine = None
234+
_session_factory = None
235+
logger.info("SQLAlchemy engine disposed")

autogpt_platform/backend/backend/util/settings.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ def updated_fields(self):
6565
class Config(UpdateTrackingModel["Config"], BaseSettings):
6666
"""Config for the server."""
6767

68+
database_url: str = Field(
69+
default="",
70+
description="PostgreSQL database connection URL. "
71+
"Format: postgresql://user:pass@host:port/db?schema=platform&connect_timeout=60",
72+
)
73+
6874
num_graph_workers: int = Field(
6975
default=10,
7076
ge=1,
@@ -267,6 +273,44 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
267273
description="The pool size for the scheduler database connection pool",
268274
)
269275

276+
# SQLAlchemy Configuration
277+
sqlalchemy_pool_size: int = Field(
278+
default=10,
279+
ge=1,
280+
le=100,
281+
description="Number of persistent connections in the SQLAlchemy pool. "
282+
"Guidelines: REST API (high traffic) 10-20, Background workers 3-5. "
283+
"Total across all services should not exceed PostgreSQL max_connections (default: 100).",
284+
)
285+
286+
sqlalchemy_max_overflow: int = Field(
287+
default=5,
288+
ge=0,
289+
le=50,
290+
description="Additional connections beyond pool_size when pool is exhausted. "
291+
"Total max connections = pool_size + max_overflow.",
292+
)
293+
294+
sqlalchemy_pool_timeout: int = Field(
295+
default=30,
296+
ge=1,
297+
le=300,
298+
description="Seconds to wait for available connection before raising error. "
299+
"If all connections are busy and max_overflow is reached, requests wait this long before failing.",
300+
)
301+
302+
sqlalchemy_connect_timeout: int = Field(
303+
default=10,
304+
ge=1,
305+
le=60,
306+
description="Seconds to wait when establishing new connection to PostgreSQL.",
307+
)
308+
309+
sqlalchemy_echo: bool = Field(
310+
default=False,
311+
description="Whether to log all SQL statements. Useful for debugging but very verbose. Should be False in production.",
312+
)
313+
270314
rabbitmq_host: str = Field(
271315
default="localhost",
272316
description="The host for the RabbitMQ server",

autogpt_platform/backend/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ aiohttp = "^3.10.0"
1414
aiodns = "^3.5.0"
1515
anthropic = "^0.59.0"
1616
apscheduler = "^3.11.1"
17+
asyncpg = "^0.29.0"
1718
autogpt-libs = { path = "../autogpt_libs", develop = true }
1819
bleach = { extras = ["css"], version = "^6.2.0" }
1920
click = "^8.2.0"

0 commit comments

Comments
 (0)