|
6 | 6 | """ |
7 | 7 |
|
8 | 8 | import logging |
| 9 | +import re |
9 | 10 | from typing import Any |
| 11 | +from urllib.parse import urlparse |
10 | 12 |
|
11 | 13 | from a2a.server.apps import A2AFastAPIApplication |
12 | 14 | from a2a.server.request_handlers import DefaultRequestHandler |
|
28 | 30 | logger = logging.getLogger(__name__) |
29 | 31 |
|
30 | 32 |
|
| 33 | +def _normalize_db_url(url: str) -> str: |
| 34 | + """Ensure a database URL uses an async driver for SQLAlchemy. |
| 35 | +
|
| 36 | + ADK's DatabaseSessionService requires ``create_async_engine``, so any |
| 37 | + synchronous PostgreSQL scheme must be replaced with ``postgresql+asyncpg``. |
| 38 | + """ |
| 39 | + scheme, remainder = url.split("://", 1) |
| 40 | + normalized_scheme = scheme.lower() |
| 41 | + |
| 42 | + sync_postgres_schemes = { |
| 43 | + "postgres", |
| 44 | + "postgresql", |
| 45 | + "postgresql+psycopg", |
| 46 | + "postgresql+psycopg2", |
| 47 | + } |
| 48 | + if normalized_scheme in sync_postgres_schemes: |
| 49 | + return f"postgresql+asyncpg://{remainder}" |
| 50 | + |
| 51 | + return url |
| 52 | + |
| 53 | + |
31 | 54 | def _get_session_service() -> Any: |
32 | | - """Get the appropriate session service based on configuration. |
| 55 | + """Get the appropriate session service based on SESSION_BACKEND setting. |
| 56 | +
|
| 57 | + Uses SESSION_BACKEND to determine the session storage: |
| 58 | + - ``"memory"``: InMemorySessionService (default, no persistence) |
| 59 | + - ``"database"``: DatabaseSessionService (requires SESSION_DATABASE_URL) |
33 | 60 |
|
34 | | - For production, uses DatabaseSessionService which persists sessions to PostgreSQL. |
35 | | - For development, uses InMemorySessionService. |
| 61 | + When SESSION_BACKEND is ``"database"``, failures are raised immediately |
| 62 | + rather than silently falling back to in-memory, so misconfigurations are |
| 63 | + caught at startup. |
36 | 64 |
|
37 | 65 | Security Note: |
38 | | - SESSION_DATABASE_URL must be explicitly set to use database persistence. |
39 | | - This prevents accidental use of the marketplace database (DATABASE_URL) |
40 | | - for session storage, ensuring: |
41 | | - - Agents only have access to session data, not marketplace/auth data |
42 | | - - Compromised agents can't access DCR credentials or order information |
43 | | - - Different retention policies can be applied to sessions vs. marketplace data |
| 66 | + SESSION_DATABASE_URL should point to a separate database from |
| 67 | + DATABASE_URL to ensure agents only access session data, not |
| 68 | + marketplace/auth data. |
44 | 69 |
|
45 | 70 | Returns: |
46 | 71 | Session service instance (DatabaseSessionService or InMemorySessionService). |
47 | 72 | """ |
48 | 73 | settings = get_settings() |
49 | 74 |
|
50 | | - # Only use database session service if SESSION_DATABASE_URL is explicitly set |
51 | | - # Do NOT fall back to DATABASE_URL to avoid mixing session and marketplace data |
52 | | - session_db_url = settings.session_database_url |
| 75 | + if settings.session_backend == "database": |
| 76 | + from google.adk.sessions import DatabaseSessionService |
| 77 | + |
| 78 | + # SESSION_DATABASE_URL is guaranteed non-empty by the model validator |
| 79 | + db_url = _normalize_db_url(settings.session_database_url) |
| 80 | + |
| 81 | + # Log which database is being used (without credentials) |
| 82 | + parsed = urlparse(db_url) |
| 83 | + db_host = parsed.hostname or parsed.query or "local" |
| 84 | + logger.info( |
| 85 | + "Using DatabaseSessionService for session persistence (host=%s)", |
| 86 | + db_host, |
| 87 | + ) |
53 | 88 |
|
54 | | - # Use database session service for production (non-SQLite databases) |
55 | | - if session_db_url and not session_db_url.startswith("sqlite"): |
56 | 89 | try: |
57 | | - from google.adk.sessions import DatabaseSessionService |
58 | | - |
59 | | - # ADK's DatabaseSessionService uses synchronous SQLAlchemy, |
60 | | - # so we need to convert the async URL to sync format |
61 | | - db_url = session_db_url |
62 | | - if "postgresql+asyncpg" in db_url: |
63 | | - # Convert asyncpg URL to sync psycopg2 format |
64 | | - db_url = db_url.replace("postgresql+asyncpg", "postgresql+psycopg2") |
65 | | - elif "postgresql+aiopg" in db_url: |
66 | | - db_url = db_url.replace("postgresql+aiopg", "postgresql+psycopg2") |
67 | | - |
68 | | - # Log which database is being used (without credentials) |
69 | | - db_host = db_url.split("@")[-1].split("/")[0] if "@" in db_url else "local" |
70 | | - logger.info( |
71 | | - "Using DatabaseSessionService for session persistence (host=%s)", |
72 | | - db_host, |
73 | | - ) |
74 | 90 | return DatabaseSessionService(db_url=db_url) |
75 | | - except ImportError as e: |
76 | | - logger.warning( |
77 | | - "DatabaseSessionService not available (%s), falling back to InMemorySessionService", |
78 | | - e, |
79 | | - ) |
80 | 91 | except Exception as e: |
81 | | - logger.warning( |
82 | | - "Failed to initialize DatabaseSessionService (%s), " |
83 | | - "falling back to InMemorySessionService", |
84 | | - e, |
85 | | - ) |
| 92 | + # Sanitize error message to avoid leaking credentials from URLs |
| 93 | + sanitized_msg = re.sub(r"://[^@]+@", "://***@", str(e)) |
| 94 | + raise RuntimeError( |
| 95 | + f"Failed to initialize DatabaseSessionService: {sanitized_msg}" |
| 96 | + ) from None |
86 | 97 |
|
87 | 98 | logger.info("Using InMemorySessionService for session management") |
88 | 99 | return InMemorySessionService() # type: ignore[no-untyped-call] |
|
0 commit comments