Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 80 additions & 61 deletions sqlspec/adapters/aiosqlite/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from sqlspec.utils.logging import get_logger

if TYPE_CHECKING:
import threading
from collections.abc import AsyncGenerator

from sqlspec.adapters.aiosqlite._types import AiosqliteConnection
Expand All @@ -38,7 +37,7 @@ class AiosqliteConnectTimeoutError(SQLSpecError):
class AiosqlitePoolConnection:
"""Wrapper for database connections in the pool."""

__slots__ = ("_closed", "connection", "id", "idle_since")
__slots__ = ("_closed", "_healthy", "connection", "id", "idle_since")

def __init__(self, connection: "AiosqliteConnection") -> None:
"""Initialize pool connection wrapper.
Expand All @@ -50,6 +49,7 @@ def __init__(self, connection: "AiosqliteConnection") -> None:
self.connection = connection
self.idle_since: float | None = None
self._closed = False
self._healthy = True

@property
def idle_time(self) -> float:
Expand All @@ -71,6 +71,15 @@ def is_closed(self) -> bool:
"""
return self._closed

@property
def is_healthy(self) -> bool:
"""Check if connection was healthy on last check.

Returns:
True if connection is presumed healthy
"""
return self._healthy and not self._closed

def mark_as_in_use(self) -> None:
"""Mark connection as in use."""
self.idle_since = None
Expand All @@ -79,19 +88,26 @@ def mark_as_idle(self) -> None:
"""Mark connection as idle."""
self.idle_since = time.time()

def mark_unhealthy(self) -> None:
"""Mark connection as unhealthy."""
self._healthy = False

async def is_alive(self) -> bool:
"""Check if connection is alive and functional.

Returns:
True if connection is healthy
"""
if self._closed:
self._healthy = False
return False
try:
await self.connection.execute("SELECT 1")
except Exception:
self._healthy = False
return False
else:
self._healthy = True
return True

async def reset(self) -> None:
Expand All @@ -102,11 +118,7 @@ async def reset(self) -> None:
await self.connection.rollback()

async def close(self) -> None:
"""Close the connection.

Since we use daemon threads, the connection will be terminated
when the process exits even if close fails.
"""
"""Close the connection."""
if self._closed:
return
try:
Expand All @@ -127,41 +139,49 @@ class AiosqliteConnectionPool:
"_connect_timeout",
"_connection_parameters",
"_connection_registry",
"_health_check_interval",
"_idle_timeout",
"_lock_instance",
"_min_size",
"_operation_timeout",
"_pool_size",
"_queue_instance",
"_tracked_threads",
"_wal_initialized",
"_warmed",
)

def __init__(
self,
connection_parameters: "dict[str, Any]",
pool_size: int = 5,
min_size: int = 0,
connect_timeout: float = 30.0,
idle_timeout: float = 24 * 60 * 60,
operation_timeout: float = 10.0,
health_check_interval: float = 30.0,
) -> None:
"""Initialize connection pool.

Args:
connection_parameters: SQLite connection parameters
pool_size: Maximum number of connections in the pool
min_size: Minimum connections to pre-create (pool warming)
connect_timeout: Maximum time to wait for connection acquisition
idle_timeout: Maximum time a connection can remain idle
operation_timeout: Maximum time for connection operations
health_check_interval: Seconds of idle time before running health check
"""
self._connection_parameters = connection_parameters
self._pool_size = pool_size
self._min_size = min(min_size, pool_size)
self._connect_timeout = connect_timeout
self._idle_timeout = idle_timeout
self._operation_timeout = operation_timeout
self._health_check_interval = health_check_interval

self._connection_registry: dict[str, AiosqlitePoolConnection] = {}
self._tracked_threads: set[threading.Thread | AiosqliteConnection] = set()
self._wal_initialized = False
self._warmed = False

self._queue_instance: asyncio.Queue[AiosqlitePoolConnection] | None = None
self._lock_instance: asyncio.Lock | None = None
Expand Down Expand Up @@ -215,23 +235,13 @@ def checked_out(self) -> int:
return len(self._connection_registry)
return len(self._connection_registry) - self._queue.qsize()

def _track_aiosqlite_thread(self, connection: "AiosqliteConnection") -> None:
"""Track the background thread associated with an aiosqlite connection.

Args:
connection: The aiosqlite connection whose thread to track
"""
self._tracked_threads.add(connection)

async def _create_connection(self) -> AiosqlitePoolConnection:
"""Create a new connection.

Returns:
New pool connection instance
"""
connection = aiosqlite.connect(**self._connection_parameters)
connection.daemon = True
connection = await connection
connection = await aiosqlite.connect(**self._connection_parameters)

database_path = str(self._connection_parameters.get("database", ""))
is_shared_cache = "cache=shared" in database_path
Expand Down Expand Up @@ -266,7 +276,6 @@ async def _create_connection(self) -> AiosqlitePoolConnection:

pool_connection = AiosqlitePoolConnection(connection)
pool_connection.mark_as_idle()
self._track_aiosqlite_thread(connection)

async with self._lock:
self._connection_registry[pool_connection.id] = pool_connection
Expand All @@ -277,6 +286,10 @@ async def _create_connection(self) -> AiosqlitePoolConnection:
async def _claim_if_healthy(self, connection: AiosqlitePoolConnection) -> bool:
"""Check if connection is healthy and claim it.

Uses passive health checks: connections idle less than health_check_interval
are assumed healthy based on their last known state. Active health checks
(SELECT 1) are only performed on long-idle connections.

Args:
connection: Connection to check and claim

Expand All @@ -288,15 +301,25 @@ async def _claim_if_healthy(self, connection: AiosqlitePoolConnection) -> bool:
await self._retire_connection(connection)
return False

try:
await asyncio.wait_for(connection.is_alive(), timeout=self._operation_timeout)
except asyncio.TimeoutError:
logger.debug("Connection %s health check timed out, retiring", connection.id)
if not connection.is_healthy:
logger.debug("Connection %s marked unhealthy, retiring", connection.id)
await self._retire_connection(connection)
return False
else:
connection.mark_as_in_use()
return True

if connection.idle_time > self._health_check_interval:
try:
is_alive = await asyncio.wait_for(connection.is_alive(), timeout=self._operation_timeout)
if not is_alive:
logger.debug("Connection %s failed health check, retiring", connection.id)
await self._retire_connection(connection)
return False
except asyncio.TimeoutError:
logger.debug("Connection %s health check timed out, retiring", connection.id)
await self._retire_connection(connection)
return False

connection.mark_as_in_use()
return True

async def _retire_connection(self, connection: AiosqlitePoolConnection) -> None:
"""Retire a connection from the pool.
Expand Down Expand Up @@ -363,6 +386,31 @@ async def _wait_for_healthy_connection(self) -> AiosqlitePoolConnection:
with suppress(asyncio.CancelledError):
await task

async def _warm_pool(self) -> None:
"""Pre-create minimum connections for pool warming.

Creates connections up to min_size to avoid cold-start latency
on first requests.
"""
if self._warmed or self._min_size <= 0:
return

self._warmed = True
connections_needed = self._min_size - len(self._connection_registry)

if connections_needed <= 0:
return

logger.debug("Warming pool with %d connections", connections_needed)
tasks = [self._create_connection() for _ in range(connections_needed)]
results = await asyncio.gather(*tasks, return_exceptions=True)

for result in results:
if isinstance(result, AiosqlitePoolConnection):
self._queue.put_nowait(result)
elif isinstance(result, Exception):
logger.warning("Failed to create warm connection: %s", result)

async def _get_connection(self) -> AiosqlitePoolConnection:
"""Run the three-phase connection acquisition cycle.

Expand All @@ -376,6 +424,9 @@ async def _get_connection(self) -> AiosqlitePoolConnection:
msg = "Cannot acquire connection from closed pool"
raise AiosqlitePoolClosedError(msg)

if not self._warmed and self._min_size > 0:
await self._warm_pool()

while not self._queue.empty():
connection = self._queue.get_nowait()
if await self._claim_if_healthy(connection):
Expand All @@ -387,38 +438,6 @@ async def _get_connection(self) -> AiosqlitePoolConnection:

return await self._wait_for_healthy_connection()

async def _wait_for_threads_to_terminate(self, timeout: float = 1.0) -> None:
"""Wait for all tracked aiosqlite connection threads to terminate.

Args:
timeout: Maximum time to wait for thread termination in seconds
"""
if not self._tracked_threads:
return

logger.debug("Waiting for %d aiosqlite connection threads to terminate...", len(self._tracked_threads))
start_time = time.time()

dead_threads = {t for t in self._tracked_threads if not t.is_alive()}
self._tracked_threads -= dead_threads

if not self._tracked_threads:
logger.debug("All aiosqlite connection threads already terminated")
return

while self._tracked_threads and (time.time() - start_time) < timeout:
await asyncio.sleep(0.05)
dead_threads = {t for t in self._tracked_threads if not t.is_alive()}
self._tracked_threads -= dead_threads

remaining_threads = len(self._tracked_threads)
elapsed = time.time() - start_time

if remaining_threads > 0:
logger.debug("%d aiosqlite threads still running after %.2fs", remaining_threads, elapsed)
else:
logger.debug("All aiosqlite connection threads terminated in %.2fs", elapsed)

async def acquire(self) -> AiosqlitePoolConnection:
"""Acquire a connection from the pool.

Expand Down Expand Up @@ -459,6 +478,7 @@ async def release(self, connection: AiosqlitePoolConnection) -> None:
logger.debug("Released connection back to pool: %s", connection.id)
except Exception as e:
logger.warning("Failed to reset connection %s during release: %s", connection.id, e)
connection.mark_unhealthy()
await self._retire_connection(connection)

@asynccontextmanager
Expand Down Expand Up @@ -496,5 +516,4 @@ async def close(self) -> None:
if isinstance(result, Exception):
logger.warning("Error closing connection %s: %s", connections[i].id, result)

await self._wait_for_threads_to_terminate(timeout=1.0)
logger.debug("Aiosqlite connection pool closed")
3 changes: 0 additions & 3 deletions sqlspec/adapters/asyncmy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs
from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import from_json, to_json

if TYPE_CHECKING:
Expand All @@ -34,8 +33,6 @@

__all__ = ("AsyncmyConfig", "AsyncmyConnectionParams", "AsyncmyDriverFeatures", "AsyncmyPoolParams")

logger = get_logger("adapters.asyncmy")


class AsyncmyConnectionParams(TypedDict):
"""Asyncmy connection parameters."""
Expand Down
3 changes: 0 additions & 3 deletions sqlspec/adapters/asyncpg/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from sqlspec.exceptions import ImproperConfigurationError, MissingDependencyError
from sqlspec.typing import ALLOYDB_CONNECTOR_INSTALLED, CLOUD_SQL_CONNECTOR_INSTALLED, PGVECTOR_INSTALLED
from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import from_json, to_json

if TYPE_CHECKING:
Expand All @@ -36,8 +35,6 @@

__all__ = ("AsyncpgConfig", "AsyncpgConnectionConfig", "AsyncpgDriverFeatures", "AsyncpgPoolConfig")

logger = get_logger("adapters.asyncpg")


class AsyncpgConnectionConfig(TypedDict):
"""TypedDict for AsyncPG connection parameters."""
Expand Down
4 changes: 0 additions & 4 deletions sqlspec/adapters/bigquery/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from sqlspec.observability import ObservabilityConfig
from sqlspec.typing import Empty
from sqlspec.utils.config_normalization import normalize_connection_config
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import to_json

if TYPE_CHECKING:
Expand All @@ -31,9 +30,6 @@
from sqlspec.core import StatementConfig


logger = get_logger("adapters.bigquery")


class BigQueryConnectionParams(TypedDict):
"""Standard BigQuery connection parameters.

Expand Down
Loading