|
1 | | -from typing import Annotated, Generator |
| 1 | +from typing import Annotated, Generator, AsyncGenerator |
2 | 2 |
|
3 | 3 | from fastapi import Depends |
4 | 4 | from loguru import logger |
5 | 5 | from sqlalchemy.engine import Engine |
| 6 | +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine |
| 7 | +from sqlmodel.ext.asyncio.session import AsyncSession |
| 8 | +from sqlalchemy.orm import sessionmaker |
6 | 9 | from sqlalchemy.pool import StaticPool |
7 | 10 | from sqlmodel import Session, SQLModel, create_engine, MetaData |
8 | 11 | from sqlalchemy import text, schema |
9 | 12 |
|
10 | 13 | from config import Settings, get_settings |
11 | 14 |
|
12 | 15 | _engine: Engine | None = None |
| 16 | +_async_engine: AsyncEngine | None = None |
13 | 17 | _metadata: MetaData | None = None |
| 18 | +_async_sessionmaker: sessionmaker[AsyncSession] | None = None |
14 | 19 |
|
15 | 20 |
|
16 | 21 | def get_metadata( |
@@ -62,48 +67,134 @@ def get_engine(_settings: Settings | None = None) -> Engine: |
62 | 67 | if _settings is None: |
63 | 68 | _settings = get_settings() |
64 | 69 |
|
65 | | - # Configure engine based on database type |
66 | | - engine_args = {"echo": True} |
| 70 | + # Get the base URL template and credentials |
| 71 | + url = _settings.database_url_template |
67 | 72 |
|
68 | | - # Add SQLite-specific settings for in-memory database |
69 | | - if _settings.database_url.startswith("sqlite"): |
| 73 | + try: |
| 74 | + from sqlalchemy.engine import make_url |
| 75 | + |
| 76 | + # Just for validation purposes |
| 77 | + make_url(url) |
| 78 | + except ImportError: |
| 79 | + logger.warning("Could not import sqlalchemy.engine.make_url") |
| 80 | + |
| 81 | + engine_args: dict = {"echo": True} |
| 82 | + if url.startswith("sqlite"): |
70 | 83 | engine_args.update( |
71 | 84 | {"connect_args": {"check_same_thread": False}, "poolclass": StaticPool} |
72 | 85 | ) |
73 | 86 |
|
74 | | - _engine = create_engine(_settings.database_url, **engine_args) |
| 87 | + _engine = create_engine(url, **engine_args) |
75 | 88 |
|
76 | | - # Enable WAL mode if configured |
77 | | - if _settings.sqlite_wal_mode: |
| 89 | + if _settings.sqlite_wal_mode and url.startswith("sqlite"): |
78 | 90 | with _engine.connect() as conn: |
79 | | - # https://www.sqlite.org/pragma.html |
80 | 91 | conn.execute(text("PRAGMA journal_mode=WAL")) |
81 | | - # conn.execute(text("PRAGMA synchronous=OFF")) |
82 | 92 | logger.info("SQLite WAL mode enabled") |
83 | 93 |
|
84 | | - # Initialize schema if using SQLModel, with a schema if not sqlite |
85 | | - if _settings.database_type == "sqlmodel": |
86 | | - SQLModel.metadata.schema = _settings.get_table_schema |
87 | | - with _engine.connect() as conn: |
88 | | - """ conn.execution_options = { |
89 | | - "schema_translate_map": {None: _settings.get_table_schema} |
90 | | - } """ |
91 | | - if not _settings.database_url.startswith( |
92 | | - "sqlite" |
93 | | - ) and not conn.dialect.has_schema(conn, _settings.get_table_schema): |
94 | | - logger.warning( |
95 | | - f"Schema '{_settings.get_table_schema}' not found in database. Creating..." |
96 | | - ) |
97 | | - conn.execute(schema.CreateSchema(_settings.get_table_schema)) |
98 | | - conn.commit() |
99 | | - |
100 | | - if not _settings.migrate_database: |
101 | | - SQLModel.metadata.create_all(conn) |
102 | | - conn.commit() |
103 | | - logger.info("Database tables created successfully") |
104 | | - else: |
105 | | - logger.info( |
106 | | - "Database tables already exist or migration is configured, skipping creation" |
107 | | - ) |
| 94 | + # Set schema for SQLModel metadata |
| 95 | + SQLModel.metadata.schema = _settings.get_table_schema |
| 96 | + with _engine.connect() as conn: |
| 97 | + if not url.startswith("sqlite") and not conn.dialect.has_schema( |
| 98 | + conn, |
| 99 | + _settings.get_table_schema, # type: ignore[arg-type] |
| 100 | + ): |
| 101 | + logger.warning( |
| 102 | + f"Schema '{_settings.get_table_schema}' not found in database. Creating..." |
| 103 | + ) |
| 104 | + conn.execute(schema.CreateSchema(_settings.get_table_schema)) |
| 105 | + conn.commit() |
| 106 | + |
| 107 | + if _settings.create_tables: |
| 108 | + SQLModel.metadata.create_all(conn) |
| 109 | + conn.commit() |
| 110 | + logger.info("Database tables created successfully") |
| 111 | + else: |
| 112 | + logger.info( |
| 113 | + "Database tables already exist or migration is configured, skipping creation" |
| 114 | + ) |
108 | 115 |
|
109 | 116 | return _engine |
| 117 | + |
| 118 | + |
| 119 | +def get_async_engine(_settings: Settings | None = None) -> AsyncEngine: |
| 120 | + """Get or create async SQLModel engine instance.""" |
| 121 | + global _async_engine, _async_sessionmaker |
| 122 | + |
| 123 | + if _async_engine is not None: |
| 124 | + return _async_engine |
| 125 | + |
| 126 | + if _settings is None: |
| 127 | + _settings = get_settings() |
| 128 | + |
| 129 | + url = _settings.database_url |
| 130 | + engine_args: dict = {"echo": True} |
| 131 | + |
| 132 | + # SQLite: use the aiosqlite driver for async support |
| 133 | + if url.startswith("sqlite") and "+aiosqlite" not in url: |
| 134 | + url = url.replace("sqlite://", "sqlite+aiosqlite://", 1) |
| 135 | + engine_args.update( |
| 136 | + {"connect_args": {"check_same_thread": False}, "poolclass": StaticPool} |
| 137 | + ) |
| 138 | + else: |
| 139 | + # PostgreSQL: ensure an asyncpg driver if scheme is postgres/postgresql without a +driver suffix |
| 140 | + try: |
| 141 | + scheme, rest = url.split("://", 1) |
| 142 | + except ValueError: |
| 143 | + scheme = url |
| 144 | + rest = "" |
| 145 | + |
| 146 | + # Handle SSL mode for asyncpg |
| 147 | + connect_args = {} |
| 148 | + if "?" in rest: |
| 149 | + base_url, query_string = rest.split("?", 1) |
| 150 | + params = {} |
| 151 | + for param in query_string.split("&"): |
| 152 | + if "=" in param: |
| 153 | + key, value = param.split("=", 1) |
| 154 | + params[key] = value |
| 155 | + |
| 156 | + # Remove sslmode from URL and add as connect_args for asyncpg |
| 157 | + if "sslmode" in params: |
| 158 | + sslmode = params.pop("sslmode") |
| 159 | + if sslmode == "disable": |
| 160 | + connect_args["ssl"] = False |
| 161 | + elif sslmode in ("require", "verify-ca", "verify-full"): |
| 162 | + connect_args["ssl"] = True |
| 163 | + |
| 164 | + # Rebuild the URL without sslmode |
| 165 | + new_query = "&".join([f"{k}={v}" for k, v in params.items()]) |
| 166 | + if new_query: |
| 167 | + rest = f"{base_url}?{new_query}" |
| 168 | + else: |
| 169 | + rest = base_url |
| 170 | + |
| 171 | + if connect_args: |
| 172 | + engine_args["connect_args"] = connect_args |
| 173 | + |
| 174 | + if scheme in ("postgres", "postgresql") and "+" not in scheme: |
| 175 | + url = f"postgresql+asyncpg://{rest}" |
| 176 | + |
| 177 | + engine_args: dict = {"echo": True} |
| 178 | + if url.startswith("sqlite"): |
| 179 | + engine_args.update( |
| 180 | + {"connect_args": {"check_same_thread": False}, "poolclass": StaticPool} |
| 181 | + ) |
| 182 | + |
| 183 | + _async_engine = create_async_engine(url, **engine_args) |
| 184 | + _async_sessionmaker = sessionmaker( |
| 185 | + _async_engine, class_=AsyncSession, expire_on_commit=False |
| 186 | + ) |
| 187 | + return _async_engine |
| 188 | + |
| 189 | + |
| 190 | +async def get_async_session( |
| 191 | + settings: Annotated[Settings, Depends(get_settings)], |
| 192 | +) -> AsyncGenerator[AsyncSession, None]: |
| 193 | + # Initialize the async engine if it doesn't exist yet |
| 194 | + get_async_engine(settings) |
| 195 | + assert _async_sessionmaker is not None |
| 196 | + async with _async_sessionmaker() as session: |
| 197 | + yield session |
| 198 | + |
| 199 | + |
| 200 | +AsyncSessionDep = Annotated[AsyncSession, Depends(get_async_session)] |
0 commit comments