diff --git a/sqlspec/adapters/aiosqlite/pool.py b/sqlspec/adapters/aiosqlite/pool.py index 20eaf28b7..dd6b53e7f 100644 --- a/sqlspec/adapters/aiosqlite/pool.py +++ b/sqlspec/adapters/aiosqlite/pool.py @@ -12,7 +12,6 @@ from sqlspec.utils.logging import get_logger if TYPE_CHECKING: - import threading from collections.abc import AsyncGenerator from sqlspec.adapters.aiosqlite._types import AiosqliteConnection @@ -38,7 +37,7 @@ class AiosqliteConnectTimeoutError(SQLSpecError): class AiosqlitePoolConnection: """Wrapper for database connections in the pool.""" - __slots__ = ("_closed", "connection", "id", "idle_since") + __slots__ = ("_closed", "_healthy", "connection", "id", "idle_since") def __init__(self, connection: "AiosqliteConnection") -> None: """Initialize pool connection wrapper. @@ -50,6 +49,7 @@ def __init__(self, connection: "AiosqliteConnection") -> None: self.connection = connection self.idle_since: float | None = None self._closed = False + self._healthy = True @property def idle_time(self) -> float: @@ -71,6 +71,15 @@ def is_closed(self) -> bool: """ return self._closed + @property + def is_healthy(self) -> bool: + """Check if connection was healthy on last check. + + Returns: + True if connection is presumed healthy + """ + return self._healthy and not self._closed + def mark_as_in_use(self) -> None: """Mark connection as in use.""" self.idle_since = None @@ -79,6 +88,10 @@ def mark_as_idle(self) -> None: """Mark connection as idle.""" self.idle_since = time.time() + def mark_unhealthy(self) -> None: + """Mark connection as unhealthy.""" + self._healthy = False + async def is_alive(self) -> bool: """Check if connection is alive and functional. @@ -86,12 +99,15 @@ async def is_alive(self) -> bool: True if connection is healthy """ if self._closed: + self._healthy = False return False try: await self.connection.execute("SELECT 1") except Exception: + self._healthy = False return False else: + self._healthy = True return True async def reset(self) -> None: @@ -102,11 +118,7 @@ async def reset(self) -> None: await self.connection.rollback() async def close(self) -> None: - """Close the connection. - - Since we use daemon threads, the connection will be terminated - when the process exits even if close fails. - """ + """Close the connection.""" if self._closed: return try: @@ -127,41 +139,49 @@ class AiosqliteConnectionPool: "_connect_timeout", "_connection_parameters", "_connection_registry", + "_health_check_interval", "_idle_timeout", "_lock_instance", + "_min_size", "_operation_timeout", "_pool_size", "_queue_instance", - "_tracked_threads", "_wal_initialized", + "_warmed", ) def __init__( self, connection_parameters: "dict[str, Any]", pool_size: int = 5, + min_size: int = 0, connect_timeout: float = 30.0, idle_timeout: float = 24 * 60 * 60, operation_timeout: float = 10.0, + health_check_interval: float = 30.0, ) -> None: """Initialize connection pool. Args: connection_parameters: SQLite connection parameters pool_size: Maximum number of connections in the pool + min_size: Minimum connections to pre-create (pool warming) connect_timeout: Maximum time to wait for connection acquisition idle_timeout: Maximum time a connection can remain idle operation_timeout: Maximum time for connection operations + health_check_interval: Seconds of idle time before running health check """ self._connection_parameters = connection_parameters self._pool_size = pool_size + self._min_size = min(min_size, pool_size) self._connect_timeout = connect_timeout self._idle_timeout = idle_timeout self._operation_timeout = operation_timeout + self._health_check_interval = health_check_interval self._connection_registry: dict[str, AiosqlitePoolConnection] = {} - self._tracked_threads: set[threading.Thread | AiosqliteConnection] = set() self._wal_initialized = False + self._warmed = False self._queue_instance: asyncio.Queue[AiosqlitePoolConnection] | None = None self._lock_instance: asyncio.Lock | None = None @@ -215,23 +235,13 @@ def checked_out(self) -> int: return len(self._connection_registry) return len(self._connection_registry) - self._queue.qsize() - def _track_aiosqlite_thread(self, connection: "AiosqliteConnection") -> None: - """Track the background thread associated with an aiosqlite connection. - - Args: - connection: The aiosqlite connection whose thread to track - """ - self._tracked_threads.add(connection) - async def _create_connection(self) -> AiosqlitePoolConnection: """Create a new connection. Returns: New pool connection instance """ - connection = aiosqlite.connect(**self._connection_parameters) - connection.daemon = True - connection = await connection + connection = await aiosqlite.connect(**self._connection_parameters) database_path = str(self._connection_parameters.get("database", "")) is_shared_cache = "cache=shared" in database_path @@ -266,7 +276,6 @@ async def _create_connection(self) -> AiosqlitePoolConnection: pool_connection = AiosqlitePoolConnection(connection) pool_connection.mark_as_idle() - self._track_aiosqlite_thread(connection) async with self._lock: self._connection_registry[pool_connection.id] = pool_connection @@ -277,6 +286,10 @@ async def _create_connection(self) -> AiosqlitePoolConnection: async def _claim_if_healthy(self, connection: AiosqlitePoolConnection) -> bool: """Check if connection is healthy and claim it. + Uses passive health checks: connections idle less than health_check_interval + are assumed healthy based on their last known state. Active health checks + (SELECT 1) are only performed on long-idle connections. + Args: connection: Connection to check and claim @@ -288,15 +301,25 @@ async def _claim_if_healthy(self, connection: AiosqlitePoolConnection) -> bool: await self._retire_connection(connection) return False - try: - await asyncio.wait_for(connection.is_alive(), timeout=self._operation_timeout) - except asyncio.TimeoutError: - logger.debug("Connection %s health check timed out, retiring", connection.id) + if not connection.is_healthy: + logger.debug("Connection %s marked unhealthy, retiring", connection.id) await self._retire_connection(connection) return False - else: - connection.mark_as_in_use() - return True + + if connection.idle_time > self._health_check_interval: + try: + is_alive = await asyncio.wait_for(connection.is_alive(), timeout=self._operation_timeout) + if not is_alive: + logger.debug("Connection %s failed health check, retiring", connection.id) + await self._retire_connection(connection) + return False + except asyncio.TimeoutError: + logger.debug("Connection %s health check timed out, retiring", connection.id) + await self._retire_connection(connection) + return False + + connection.mark_as_in_use() + return True async def _retire_connection(self, connection: AiosqlitePoolConnection) -> None: """Retire a connection from the pool. @@ -363,6 +386,31 @@ async def _wait_for_healthy_connection(self) -> AiosqlitePoolConnection: with suppress(asyncio.CancelledError): await task + async def _warm_pool(self) -> None: + """Pre-create minimum connections for pool warming. + + Creates connections up to min_size to avoid cold-start latency + on first requests. + """ + if self._warmed or self._min_size <= 0: + return + + self._warmed = True + connections_needed = self._min_size - len(self._connection_registry) + + if connections_needed <= 0: + return + + logger.debug("Warming pool with %d connections", connections_needed) + tasks = [self._create_connection() for _ in range(connections_needed)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in results: + if isinstance(result, AiosqlitePoolConnection): + self._queue.put_nowait(result) + elif isinstance(result, Exception): + logger.warning("Failed to create warm connection: %s", result) + async def _get_connection(self) -> AiosqlitePoolConnection: """Run the three-phase connection acquisition cycle. @@ -376,6 +424,9 @@ async def _get_connection(self) -> AiosqlitePoolConnection: msg = "Cannot acquire connection from closed pool" raise AiosqlitePoolClosedError(msg) + if not self._warmed and self._min_size > 0: + await self._warm_pool() + while not self._queue.empty(): connection = self._queue.get_nowait() if await self._claim_if_healthy(connection): @@ -387,38 +438,6 @@ async def _get_connection(self) -> AiosqlitePoolConnection: return await self._wait_for_healthy_connection() - async def _wait_for_threads_to_terminate(self, timeout: float = 1.0) -> None: - """Wait for all tracked aiosqlite connection threads to terminate. - - Args: - timeout: Maximum time to wait for thread termination in seconds - """ - if not self._tracked_threads: - return - - logger.debug("Waiting for %d aiosqlite connection threads to terminate...", len(self._tracked_threads)) - start_time = time.time() - - dead_threads = {t for t in self._tracked_threads if not t.is_alive()} - self._tracked_threads -= dead_threads - - if not self._tracked_threads: - logger.debug("All aiosqlite connection threads already terminated") - return - - while self._tracked_threads and (time.time() - start_time) < timeout: - await asyncio.sleep(0.05) - dead_threads = {t for t in self._tracked_threads if not t.is_alive()} - self._tracked_threads -= dead_threads - - remaining_threads = len(self._tracked_threads) - elapsed = time.time() - start_time - - if remaining_threads > 0: - logger.debug("%d aiosqlite threads still running after %.2fs", remaining_threads, elapsed) - else: - logger.debug("All aiosqlite connection threads terminated in %.2fs", elapsed) - async def acquire(self) -> AiosqlitePoolConnection: """Acquire a connection from the pool. @@ -459,6 +478,7 @@ async def release(self, connection: AiosqlitePoolConnection) -> None: logger.debug("Released connection back to pool: %s", connection.id) except Exception as e: logger.warning("Failed to reset connection %s during release: %s", connection.id, e) + connection.mark_unhealthy() await self._retire_connection(connection) @asynccontextmanager @@ -496,5 +516,4 @@ async def close(self) -> None: if isinstance(result, Exception): logger.warning("Error closing connection %s: %s", connections[i].id, result) - await self._wait_for_threads_to_terminate(timeout=1.0) logger.debug("Aiosqlite connection pool closed") diff --git a/sqlspec/adapters/asyncmy/config.py b/sqlspec/adapters/asyncmy/config.py index c53a947eb..531fd06af 100644 --- a/sqlspec/adapters/asyncmy/config.py +++ b/sqlspec/adapters/asyncmy/config.py @@ -19,7 +19,6 @@ ) from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config -from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json if TYPE_CHECKING: @@ -34,8 +33,6 @@ __all__ = ("AsyncmyConfig", "AsyncmyConnectionParams", "AsyncmyDriverFeatures", "AsyncmyPoolParams") -logger = get_logger("adapters.asyncmy") - class AsyncmyConnectionParams(TypedDict): """Asyncmy connection parameters.""" diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index bfbf25bb1..07980a282 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -23,7 +23,6 @@ from sqlspec.exceptions import ImproperConfigurationError, MissingDependencyError from sqlspec.typing import ALLOYDB_CONNECTOR_INSTALLED, CLOUD_SQL_CONNECTOR_INSTALLED, PGVECTOR_INSTALLED from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config -from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json if TYPE_CHECKING: @@ -36,8 +35,6 @@ __all__ = ("AsyncpgConfig", "AsyncpgConnectionConfig", "AsyncpgDriverFeatures", "AsyncpgPoolConfig") -logger = get_logger("adapters.asyncpg") - class AsyncpgConnectionConfig(TypedDict): """TypedDict for AsyncPG connection parameters.""" diff --git a/sqlspec/adapters/bigquery/config.py b/sqlspec/adapters/bigquery/config.py index 383b166b0..fad1aa0ef 100644 --- a/sqlspec/adapters/bigquery/config.py +++ b/sqlspec/adapters/bigquery/config.py @@ -18,7 +18,6 @@ from sqlspec.observability import ObservabilityConfig from sqlspec.typing import Empty from sqlspec.utils.config_normalization import normalize_connection_config -from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import to_json if TYPE_CHECKING: @@ -31,9 +30,6 @@ from sqlspec.core import StatementConfig -logger = get_logger("adapters.bigquery") - - class BigQueryConnectionParams(TypedDict): """Standard BigQuery connection parameters. diff --git a/sqlspec/adapters/duckdb/pool.py b/sqlspec/adapters/duckdb/pool.py index cd9c2e86a..edd5e79ab 100644 --- a/sqlspec/adapters/duckdb/pool.py +++ b/sqlspec/adapters/duckdb/pool.py @@ -20,6 +20,7 @@ DEFAULT_MAX_POOL: Final[int] = 4 POOL_TIMEOUT: Final[float] = 30.0 POOL_RECYCLE: Final[int] = 86400 +HEALTH_CHECK_INTERVAL: Final[float] = 30.0 __all__ = ("DuckDBConnectionPool",) @@ -41,6 +42,7 @@ class DuckDBConnectionPool: "_created_connections", "_extension_flags", "_extensions", + "_health_check_interval", "_lock", "_on_connection_create", "_recycle", @@ -52,6 +54,7 @@ def __init__( self, connection_config: "dict[str, Any]", pool_recycle_seconds: int = POOL_RECYCLE, + health_check_interval: float = HEALTH_CHECK_INTERVAL, extensions: "list[dict[str, Any]] | None" = None, extension_flags: "dict[str, Any] | None" = None, secrets: "list[dict[str, Any]] | None" = None, @@ -63,6 +66,7 @@ def __init__( Args: connection_config: DuckDB connection configuration pool_recycle_seconds: Connection recycle time in seconds + health_check_interval: Seconds of idle time before running health check extensions: List of extensions to install/load extension_flags: Connection-level SET statements applied after creation secrets: List of secrets to create @@ -71,6 +75,7 @@ def __init__( """ self._connection_config = connection_config self._recycle = pool_recycle_seconds + self._health_check_interval = health_check_interval self._extensions = extensions or [] self._extension_flags = extension_flags or {} self._secrets = secrets or [] @@ -191,13 +196,26 @@ def _get_thread_connection(self) -> DuckDBConnection: if not hasattr(self._thread_local, "connection"): self._thread_local.connection = self._create_connection() self._thread_local.created_at = time.time() + self._thread_local.last_used = time.time() + return cast("DuckDBConnection", self._thread_local.connection) if self._recycle > 0 and time.time() - self._thread_local.created_at > self._recycle: with suppress(Exception): self._thread_local.connection.close() self._thread_local.connection = self._create_connection() self._thread_local.created_at = time.time() + self._thread_local.last_used = time.time() + return cast("DuckDBConnection", self._thread_local.connection) + idle_time = time.time() - getattr(self._thread_local, "last_used", 0) + if idle_time > self._health_check_interval and not self._is_connection_alive(self._thread_local.connection): + logger.debug("DuckDB connection failed health check after %.1fs idle, recreating", idle_time) + with suppress(Exception): + self._thread_local.connection.close() + self._thread_local.connection = self._create_connection() + self._thread_local.created_at = time.time() + + self._thread_local.last_used = time.time() return cast("DuckDBConnection", self._thread_local.connection) def _close_thread_connection(self) -> None: @@ -208,6 +226,8 @@ def _close_thread_connection(self) -> None: del self._thread_local.connection if hasattr(self._thread_local, "created_at"): del self._thread_local.created_at + if hasattr(self._thread_local, "last_used"): + del self._thread_local.last_used def _is_connection_alive(self, connection: DuckDBConnection) -> bool: """Check if a connection is still alive and usable. diff --git a/sqlspec/adapters/oracledb/config.py b/sqlspec/adapters/oracledb/config.py index 0741757c3..78b2731ff 100644 --- a/sqlspec/adapters/oracledb/config.py +++ b/sqlspec/adapters/oracledb/config.py @@ -28,7 +28,6 @@ from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs, SyncDatabaseConfig from sqlspec.typing import NUMPY_INSTALLED from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config -from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from collections.abc import AsyncGenerator, Callable, Generator @@ -46,8 +45,6 @@ "OracleSyncConfig", ) -logger = get_logger("adapters.oracledb") - class OracleConnectionParams(TypedDict): """OracleDB connection parameters.""" diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index 8705faec5..166367c38 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -18,16 +18,12 @@ from sqlspec.core import StatementConfig from sqlspec.typing import PGVECTOR_INSTALLED from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config -from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import to_json if TYPE_CHECKING: from collections.abc import Callable -logger = get_logger("adapters.psqlpy") - - class PsqlpyConnectionParams(TypedDict): """Psqlpy connection parameters.""" diff --git a/sqlspec/adapters/psycopg/config.py b/sqlspec/adapters/psycopg/config.py index 0a19fe79a..d9f5c710d 100644 --- a/sqlspec/adapters/psycopg/config.py +++ b/sqlspec/adapters/psycopg/config.py @@ -23,7 +23,6 @@ from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs, SyncDatabaseConfig from sqlspec.typing import PGVECTOR_INSTALLED from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config -from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import to_json if TYPE_CHECKING: @@ -32,9 +31,6 @@ from sqlspec.core import StatementConfig -logger = get_logger("adapters.psycopg") - - class PsycopgConnectionParams(TypedDict): """Psycopg connection parameters.""" diff --git a/sqlspec/adapters/sqlite/pool.py b/sqlspec/adapters/sqlite/pool.py index 571383c81..b7f7f3683 100644 --- a/sqlspec/adapters/sqlite/pool.py +++ b/sqlspec/adapters/sqlite/pool.py @@ -3,12 +3,14 @@ import contextlib import sqlite3 import threading +import time from contextlib import contextmanager from typing import TYPE_CHECKING, Any, TypedDict, cast from typing_extensions import NotRequired from sqlspec.adapters.sqlite._types import SqliteConnection +from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from collections.abc import Generator @@ -29,6 +31,8 @@ class SqliteConnectionParams(TypedDict): __all__ = ("SqliteConnectionPool",) +logger = get_logger(__name__) + class SqliteConnectionPool: """Thread-local connection manager for SQLite. @@ -38,16 +42,29 @@ class SqliteConnectionPool: efficient than a traditional pool for SQLite's constraints. """ - __slots__ = ("_connection_parameters", "_enable_optimizations", "_thread_local") + __slots__ = ( + "_connection_parameters", + "_enable_optimizations", + "_health_check_interval", + "_recycle_seconds", + "_thread_local", + ) def __init__( - self, connection_parameters: "dict[str, Any]", enable_optimizations: bool = True, **kwargs: Any + self, + connection_parameters: "dict[str, Any]", + enable_optimizations: bool = True, + recycle_seconds: int = 86400, + health_check_interval: float = 30.0, + **kwargs: Any, ) -> None: """Initialize the thread-local connection manager. Args: connection_parameters: SQLite connection parameters enable_optimizations: Whether to apply performance PRAGMAs + recycle_seconds: Connection recycle time in seconds (default 24h) + health_check_interval: Seconds of idle time before running health check **kwargs: Ignored pool parameters for compatibility """ if "check_same_thread" not in connection_parameters: @@ -55,6 +72,8 @@ def __init__( self._connection_parameters = connection_parameters self._thread_local = threading.local() self._enable_optimizations = enable_optimizations + self._recycle_seconds = recycle_seconds + self._health_check_interval = health_check_interval def _create_connection(self) -> SqliteConnection: """Create a new SQLite connection with optimizations.""" @@ -62,35 +81,75 @@ def _create_connection(self) -> SqliteConnection: if self._enable_optimizations: database = self._connection_parameters.get("database", ":memory:") - is_memory = database == ":memory:" or "mode=memory" in database - - if not is_memory: - connection.execute("PRAGMA journal_mode = DELETE") + is_memory = database == ":memory:" or "mode=memory" in str(database) + + if is_memory: + connection.execute("PRAGMA journal_mode = MEMORY") + connection.execute("PRAGMA synchronous = OFF") + connection.execute("PRAGMA temp_store = MEMORY") + else: + connection.execute("PRAGMA journal_mode = WAL") + connection.execute("PRAGMA synchronous = NORMAL") connection.execute("PRAGMA busy_timeout = 5000") - connection.execute("PRAGMA optimize") connection.execute("PRAGMA foreign_keys = ON") - connection.execute("PRAGMA synchronous = NORMAL") + logger.debug("Created SQLite connection for thread %s", threading.current_thread().name) return connection # type: ignore[no-any-return] + def _is_connection_alive(self, connection: SqliteConnection) -> bool: + """Check if a connection is still alive and usable. + + Args: + connection: Connection to check + + Returns: + True if connection is alive, False otherwise + """ + try: + connection.execute("SELECT 1") + except Exception: + return False + return True + def _get_thread_connection(self) -> SqliteConnection: """Get or create a connection for the current thread.""" - try: + if not hasattr(self._thread_local, "connection"): + self._thread_local.connection = self._create_connection() + self._thread_local.created_at = time.time() + self._thread_local.last_used = time.time() return cast("SqliteConnection", self._thread_local.connection) - except AttributeError: - connection = self._create_connection() - self._thread_local.connection = connection - return connection + + if self._recycle_seconds > 0 and time.time() - self._thread_local.created_at > self._recycle_seconds: + logger.debug("SQLite connection exceeded recycle time, recreating") + with contextlib.suppress(Exception): + self._thread_local.connection.close() + self._thread_local.connection = self._create_connection() + self._thread_local.created_at = time.time() + self._thread_local.last_used = time.time() + return cast("SqliteConnection", self._thread_local.connection) + + idle_time = time.time() - getattr(self._thread_local, "last_used", 0) + if idle_time > self._health_check_interval and not self._is_connection_alive(self._thread_local.connection): + logger.debug("SQLite connection failed health check after %.1fs idle, recreating", idle_time) + with contextlib.suppress(Exception): + self._thread_local.connection.close() + self._thread_local.connection = self._create_connection() + self._thread_local.created_at = time.time() + + self._thread_local.last_used = time.time() + return cast("SqliteConnection", self._thread_local.connection) def _close_thread_connection(self) -> None: """Close the connection for the current thread.""" - try: - connection = self._thread_local.connection - connection.close() + if hasattr(self._thread_local, "connection"): + with contextlib.suppress(Exception): + self._thread_local.connection.close() del self._thread_local.connection - except AttributeError: - pass + if hasattr(self._thread_local, "created_at"): + del self._thread_local.created_at + if hasattr(self._thread_local, "last_used"): + del self._thread_local.last_used @contextmanager def get_connection(self) -> "Generator[SqliteConnection, None, None]":