Skip to content

Commit f6f0a75

Browse files
authored
fix: handle aiosqlite 0.22.0 Connection no longer inheriting from Thread (#288)
Fixes test failures in CI caused by aiosqlite 0.22.0 changing its `Connection` class to no longer inherit from `threading.Thread`. Improve thread local pooling of SQLite and DuckDB.
1 parent 4c8ac52 commit f6f0a75

File tree

9 files changed

+177
-100
lines changed

9 files changed

+177
-100
lines changed

sqlspec/adapters/aiosqlite/pool.py

Lines changed: 80 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from sqlspec.utils.logging import get_logger
1313

1414
if TYPE_CHECKING:
15-
import threading
1615
from collections.abc import AsyncGenerator
1716

1817
from sqlspec.adapters.aiosqlite._types import AiosqliteConnection
@@ -38,7 +37,7 @@ class AiosqliteConnectTimeoutError(SQLSpecError):
3837
class AiosqlitePoolConnection:
3938
"""Wrapper for database connections in the pool."""
4039

41-
__slots__ = ("_closed", "connection", "id", "idle_since")
40+
__slots__ = ("_closed", "_healthy", "connection", "id", "idle_since")
4241

4342
def __init__(self, connection: "AiosqliteConnection") -> None:
4443
"""Initialize pool connection wrapper.
@@ -50,6 +49,7 @@ def __init__(self, connection: "AiosqliteConnection") -> None:
5049
self.connection = connection
5150
self.idle_since: float | None = None
5251
self._closed = False
52+
self._healthy = True
5353

5454
@property
5555
def idle_time(self) -> float:
@@ -71,6 +71,15 @@ def is_closed(self) -> bool:
7171
"""
7272
return self._closed
7373

74+
@property
75+
def is_healthy(self) -> bool:
76+
"""Check if connection was healthy on last check.
77+
78+
Returns:
79+
True if connection is presumed healthy
80+
"""
81+
return self._healthy and not self._closed
82+
7483
def mark_as_in_use(self) -> None:
7584
"""Mark connection as in use."""
7685
self.idle_since = None
@@ -79,19 +88,26 @@ def mark_as_idle(self) -> None:
7988
"""Mark connection as idle."""
8089
self.idle_since = time.time()
8190

91+
def mark_unhealthy(self) -> None:
92+
"""Mark connection as unhealthy."""
93+
self._healthy = False
94+
8295
async def is_alive(self) -> bool:
8396
"""Check if connection is alive and functional.
8497
8598
Returns:
8699
True if connection is healthy
87100
"""
88101
if self._closed:
102+
self._healthy = False
89103
return False
90104
try:
91105
await self.connection.execute("SELECT 1")
92106
except Exception:
107+
self._healthy = False
93108
return False
94109
else:
110+
self._healthy = True
95111
return True
96112

97113
async def reset(self) -> None:
@@ -102,11 +118,7 @@ async def reset(self) -> None:
102118
await self.connection.rollback()
103119

104120
async def close(self) -> None:
105-
"""Close the connection.
106-
107-
Since we use daemon threads, the connection will be terminated
108-
when the process exits even if close fails.
109-
"""
121+
"""Close the connection."""
110122
if self._closed:
111123
return
112124
try:
@@ -127,41 +139,49 @@ class AiosqliteConnectionPool:
127139
"_connect_timeout",
128140
"_connection_parameters",
129141
"_connection_registry",
142+
"_health_check_interval",
130143
"_idle_timeout",
131144
"_lock_instance",
145+
"_min_size",
132146
"_operation_timeout",
133147
"_pool_size",
134148
"_queue_instance",
135-
"_tracked_threads",
136149
"_wal_initialized",
150+
"_warmed",
137151
)
138152

139153
def __init__(
140154
self,
141155
connection_parameters: "dict[str, Any]",
142156
pool_size: int = 5,
157+
min_size: int = 0,
143158
connect_timeout: float = 30.0,
144159
idle_timeout: float = 24 * 60 * 60,
145160
operation_timeout: float = 10.0,
161+
health_check_interval: float = 30.0,
146162
) -> None:
147163
"""Initialize connection pool.
148164
149165
Args:
150166
connection_parameters: SQLite connection parameters
151167
pool_size: Maximum number of connections in the pool
168+
min_size: Minimum connections to pre-create (pool warming)
152169
connect_timeout: Maximum time to wait for connection acquisition
153170
idle_timeout: Maximum time a connection can remain idle
154171
operation_timeout: Maximum time for connection operations
172+
health_check_interval: Seconds of idle time before running health check
155173
"""
156174
self._connection_parameters = connection_parameters
157175
self._pool_size = pool_size
176+
self._min_size = min(min_size, pool_size)
158177
self._connect_timeout = connect_timeout
159178
self._idle_timeout = idle_timeout
160179
self._operation_timeout = operation_timeout
180+
self._health_check_interval = health_check_interval
161181

162182
self._connection_registry: dict[str, AiosqlitePoolConnection] = {}
163-
self._tracked_threads: set[threading.Thread | AiosqliteConnection] = set()
164183
self._wal_initialized = False
184+
self._warmed = False
165185

166186
self._queue_instance: asyncio.Queue[AiosqlitePoolConnection] | None = None
167187
self._lock_instance: asyncio.Lock | None = None
@@ -215,23 +235,13 @@ def checked_out(self) -> int:
215235
return len(self._connection_registry)
216236
return len(self._connection_registry) - self._queue.qsize()
217237

218-
def _track_aiosqlite_thread(self, connection: "AiosqliteConnection") -> None:
219-
"""Track the background thread associated with an aiosqlite connection.
220-
221-
Args:
222-
connection: The aiosqlite connection whose thread to track
223-
"""
224-
self._tracked_threads.add(connection)
225-
226238
async def _create_connection(self) -> AiosqlitePoolConnection:
227239
"""Create a new connection.
228240
229241
Returns:
230242
New pool connection instance
231243
"""
232-
connection = aiosqlite.connect(**self._connection_parameters)
233-
connection.daemon = True
234-
connection = await connection
244+
connection = await aiosqlite.connect(**self._connection_parameters)
235245

236246
database_path = str(self._connection_parameters.get("database", ""))
237247
is_shared_cache = "cache=shared" in database_path
@@ -266,7 +276,6 @@ async def _create_connection(self) -> AiosqlitePoolConnection:
266276

267277
pool_connection = AiosqlitePoolConnection(connection)
268278
pool_connection.mark_as_idle()
269-
self._track_aiosqlite_thread(connection)
270279

271280
async with self._lock:
272281
self._connection_registry[pool_connection.id] = pool_connection
@@ -277,6 +286,10 @@ async def _create_connection(self) -> AiosqlitePoolConnection:
277286
async def _claim_if_healthy(self, connection: AiosqlitePoolConnection) -> bool:
278287
"""Check if connection is healthy and claim it.
279288
289+
Uses passive health checks: connections idle less than health_check_interval
290+
are assumed healthy based on their last known state. Active health checks
291+
(SELECT 1) are only performed on long-idle connections.
292+
280293
Args:
281294
connection: Connection to check and claim
282295
@@ -288,15 +301,25 @@ async def _claim_if_healthy(self, connection: AiosqlitePoolConnection) -> bool:
288301
await self._retire_connection(connection)
289302
return False
290303

291-
try:
292-
await asyncio.wait_for(connection.is_alive(), timeout=self._operation_timeout)
293-
except asyncio.TimeoutError:
294-
logger.debug("Connection %s health check timed out, retiring", connection.id)
304+
if not connection.is_healthy:
305+
logger.debug("Connection %s marked unhealthy, retiring", connection.id)
295306
await self._retire_connection(connection)
296307
return False
297-
else:
298-
connection.mark_as_in_use()
299-
return True
308+
309+
if connection.idle_time > self._health_check_interval:
310+
try:
311+
is_alive = await asyncio.wait_for(connection.is_alive(), timeout=self._operation_timeout)
312+
if not is_alive:
313+
logger.debug("Connection %s failed health check, retiring", connection.id)
314+
await self._retire_connection(connection)
315+
return False
316+
except asyncio.TimeoutError:
317+
logger.debug("Connection %s health check timed out, retiring", connection.id)
318+
await self._retire_connection(connection)
319+
return False
320+
321+
connection.mark_as_in_use()
322+
return True
300323

301324
async def _retire_connection(self, connection: AiosqlitePoolConnection) -> None:
302325
"""Retire a connection from the pool.
@@ -363,6 +386,31 @@ async def _wait_for_healthy_connection(self) -> AiosqlitePoolConnection:
363386
with suppress(asyncio.CancelledError):
364387
await task
365388

389+
async def _warm_pool(self) -> None:
390+
"""Pre-create minimum connections for pool warming.
391+
392+
Creates connections up to min_size to avoid cold-start latency
393+
on first requests.
394+
"""
395+
if self._warmed or self._min_size <= 0:
396+
return
397+
398+
self._warmed = True
399+
connections_needed = self._min_size - len(self._connection_registry)
400+
401+
if connections_needed <= 0:
402+
return
403+
404+
logger.debug("Warming pool with %d connections", connections_needed)
405+
tasks = [self._create_connection() for _ in range(connections_needed)]
406+
results = await asyncio.gather(*tasks, return_exceptions=True)
407+
408+
for result in results:
409+
if isinstance(result, AiosqlitePoolConnection):
410+
self._queue.put_nowait(result)
411+
elif isinstance(result, Exception):
412+
logger.warning("Failed to create warm connection: %s", result)
413+
366414
async def _get_connection(self) -> AiosqlitePoolConnection:
367415
"""Run the three-phase connection acquisition cycle.
368416
@@ -376,6 +424,9 @@ async def _get_connection(self) -> AiosqlitePoolConnection:
376424
msg = "Cannot acquire connection from closed pool"
377425
raise AiosqlitePoolClosedError(msg)
378426

427+
if not self._warmed and self._min_size > 0:
428+
await self._warm_pool()
429+
379430
while not self._queue.empty():
380431
connection = self._queue.get_nowait()
381432
if await self._claim_if_healthy(connection):
@@ -387,38 +438,6 @@ async def _get_connection(self) -> AiosqlitePoolConnection:
387438

388439
return await self._wait_for_healthy_connection()
389440

390-
async def _wait_for_threads_to_terminate(self, timeout: float = 1.0) -> None:
391-
"""Wait for all tracked aiosqlite connection threads to terminate.
392-
393-
Args:
394-
timeout: Maximum time to wait for thread termination in seconds
395-
"""
396-
if not self._tracked_threads:
397-
return
398-
399-
logger.debug("Waiting for %d aiosqlite connection threads to terminate...", len(self._tracked_threads))
400-
start_time = time.time()
401-
402-
dead_threads = {t for t in self._tracked_threads if not t.is_alive()}
403-
self._tracked_threads -= dead_threads
404-
405-
if not self._tracked_threads:
406-
logger.debug("All aiosqlite connection threads already terminated")
407-
return
408-
409-
while self._tracked_threads and (time.time() - start_time) < timeout:
410-
await asyncio.sleep(0.05)
411-
dead_threads = {t for t in self._tracked_threads if not t.is_alive()}
412-
self._tracked_threads -= dead_threads
413-
414-
remaining_threads = len(self._tracked_threads)
415-
elapsed = time.time() - start_time
416-
417-
if remaining_threads > 0:
418-
logger.debug("%d aiosqlite threads still running after %.2fs", remaining_threads, elapsed)
419-
else:
420-
logger.debug("All aiosqlite connection threads terminated in %.2fs", elapsed)
421-
422441
async def acquire(self) -> AiosqlitePoolConnection:
423442
"""Acquire a connection from the pool.
424443
@@ -459,6 +478,7 @@ async def release(self, connection: AiosqlitePoolConnection) -> None:
459478
logger.debug("Released connection back to pool: %s", connection.id)
460479
except Exception as e:
461480
logger.warning("Failed to reset connection %s during release: %s", connection.id, e)
481+
connection.mark_unhealthy()
462482
await self._retire_connection(connection)
463483

464484
@asynccontextmanager
@@ -496,5 +516,4 @@ async def close(self) -> None:
496516
if isinstance(result, Exception):
497517
logger.warning("Error closing connection %s: %s", connections[i].id, result)
498518

499-
await self._wait_for_threads_to_terminate(timeout=1.0)
500519
logger.debug("Aiosqlite connection pool closed")

sqlspec/adapters/asyncmy/config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
)
2020
from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs
2121
from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config
22-
from sqlspec.utils.logging import get_logger
2322
from sqlspec.utils.serializers import from_json, to_json
2423

2524
if TYPE_CHECKING:
@@ -34,8 +33,6 @@
3433

3534
__all__ = ("AsyncmyConfig", "AsyncmyConnectionParams", "AsyncmyDriverFeatures", "AsyncmyPoolParams")
3635

37-
logger = get_logger("adapters.asyncmy")
38-
3936

4037
class AsyncmyConnectionParams(TypedDict):
4138
"""Asyncmy connection parameters."""

sqlspec/adapters/asyncpg/config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from sqlspec.exceptions import ImproperConfigurationError, MissingDependencyError
2424
from sqlspec.typing import ALLOYDB_CONNECTOR_INSTALLED, CLOUD_SQL_CONNECTOR_INSTALLED, PGVECTOR_INSTALLED
2525
from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config
26-
from sqlspec.utils.logging import get_logger
2726
from sqlspec.utils.serializers import from_json, to_json
2827

2928
if TYPE_CHECKING:
@@ -36,8 +35,6 @@
3635

3736
__all__ = ("AsyncpgConfig", "AsyncpgConnectionConfig", "AsyncpgDriverFeatures", "AsyncpgPoolConfig")
3837

39-
logger = get_logger("adapters.asyncpg")
40-
4138

4239
class AsyncpgConnectionConfig(TypedDict):
4340
"""TypedDict for AsyncPG connection parameters."""

sqlspec/adapters/bigquery/config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from sqlspec.observability import ObservabilityConfig
1919
from sqlspec.typing import Empty
2020
from sqlspec.utils.config_normalization import normalize_connection_config
21-
from sqlspec.utils.logging import get_logger
2221
from sqlspec.utils.serializers import to_json
2322

2423
if TYPE_CHECKING:
@@ -31,9 +30,6 @@
3130
from sqlspec.core import StatementConfig
3231

3332

34-
logger = get_logger("adapters.bigquery")
35-
36-
3733
class BigQueryConnectionParams(TypedDict):
3834
"""Standard BigQuery connection parameters.
3935

0 commit comments

Comments
 (0)