Skip to content

Commit 36f7393

Browse files
authored
fix(migrations): various migration bugfixes and increased test coverage (#57)
Improve migration tests and fix issues related to tracking and execution. Enhance connection pool management by adding a method to close the pool and ensure proper execution in running loops.
1 parent 2ade6a1 commit 36f7393

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+6913
-1461
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ repos:
1717
- id: mixed-line-ending
1818
- id: trailing-whitespace
1919
- repo: https://github.com/charliermarsh/ruff-pre-commit
20-
rev: "v0.12.8"
20+
rev: "v0.12.9"
2121
hooks:
2222
- id: ruff
2323
args: ["--fix"]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ maintainers = [{ name = "Litestar Developers", email = "[email protected]" }]
1313
name = "sqlspec"
1414
readme = "README.md"
1515
requires-python = ">=3.9, <4.0"
16-
version = "0.17.1"
16+
version = "0.18.0"
1717

1818
[project.urls]
1919
Discord = "https://discord.gg/litestar"

sqlspec/_sql.py

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@
4242
)
4343
from sqlspec.builder.mixins._join_operations import JoinBuilder
4444
from sqlspec.builder.mixins._select_operations import Case, SubqueryBuilder, WindowFunctionBuilder
45+
from sqlspec.core.statement import SQL
4546
from sqlspec.exceptions import SQLBuilderError
4647

4748
if TYPE_CHECKING:
4849
from sqlspec.builder._expression_wrappers import ExpressionWrapper
49-
from sqlspec.core.statement import SQL
5050

5151

5252
__all__ = (
@@ -285,9 +285,7 @@ def create_table(self, table_name: str, dialect: DialectType = None) -> "CreateT
285285
Returns:
286286
CreateTable builder instance
287287
"""
288-
builder = CreateTable(table_name)
289-
builder.dialect = dialect or self.dialect
290-
return builder
288+
return CreateTable(table_name, dialect=dialect or self.dialect)
291289

292290
def create_table_as_select(self, dialect: DialectType = None) -> "CreateTableAsSelect":
293291
"""Create a CREATE TABLE AS SELECT builder.
@@ -298,35 +296,31 @@ def create_table_as_select(self, dialect: DialectType = None) -> "CreateTableAsS
298296
Returns:
299297
CreateTableAsSelect builder instance
300298
"""
301-
builder = CreateTableAsSelect()
302-
builder.dialect = dialect or self.dialect
303-
return builder
299+
return CreateTableAsSelect(dialect=dialect or self.dialect)
304300

305-
def create_view(self, dialect: DialectType = None) -> "CreateView":
301+
def create_view(self, view_name: str, dialect: DialectType = None) -> "CreateView":
306302
"""Create a CREATE VIEW builder.
307303
308304
Args:
305+
view_name: Name of the view to create
309306
dialect: Optional SQL dialect
310307
311308
Returns:
312309
CreateView builder instance
313310
"""
314-
builder = CreateView()
315-
builder.dialect = dialect or self.dialect
316-
return builder
311+
return CreateView(view_name, dialect=dialect or self.dialect)
317312

318-
def create_materialized_view(self, dialect: DialectType = None) -> "CreateMaterializedView":
313+
def create_materialized_view(self, view_name: str, dialect: DialectType = None) -> "CreateMaterializedView":
319314
"""Create a CREATE MATERIALIZED VIEW builder.
320315
321316
Args:
317+
view_name: Name of the materialized view to create
322318
dialect: Optional SQL dialect
323319
324320
Returns:
325321
CreateMaterializedView builder instance
326322
"""
327-
builder = CreateMaterializedView()
328-
builder.dialect = dialect or self.dialect
329-
return builder
323+
return CreateMaterializedView(view_name, dialect=dialect or self.dialect)
330324

331325
def create_index(self, index_name: str, dialect: DialectType = None) -> "CreateIndex":
332326
"""Create a CREATE INDEX builder.
@@ -340,18 +334,17 @@ def create_index(self, index_name: str, dialect: DialectType = None) -> "CreateI
340334
"""
341335
return CreateIndex(index_name, dialect=dialect or self.dialect)
342336

343-
def create_schema(self, dialect: DialectType = None) -> "CreateSchema":
337+
def create_schema(self, schema_name: str, dialect: DialectType = None) -> "CreateSchema":
344338
"""Create a CREATE SCHEMA builder.
345339
346340
Args:
341+
schema_name: Name of the schema to create
347342
dialect: Optional SQL dialect
348343
349344
Returns:
350345
CreateSchema builder instance
351346
"""
352-
builder = CreateSchema()
353-
builder.dialect = dialect or self.dialect
354-
return builder
347+
return CreateSchema(schema_name, dialect=dialect or self.dialect)
355348

356349
def drop_table(self, table_name: str, dialect: DialectType = None) -> "DropTable":
357350
"""Create a DROP TABLE builder.
@@ -365,16 +358,17 @@ def drop_table(self, table_name: str, dialect: DialectType = None) -> "DropTable
365358
"""
366359
return DropTable(table_name, dialect=dialect or self.dialect)
367360

368-
def drop_view(self, dialect: DialectType = None) -> "DropView":
361+
def drop_view(self, view_name: str, dialect: DialectType = None) -> "DropView":
369362
"""Create a DROP VIEW builder.
370363
371364
Args:
365+
view_name: Name of the view to drop
372366
dialect: Optional SQL dialect
373367
374368
Returns:
375369
DropView builder instance
376370
"""
377-
return DropView(dialect=dialect or self.dialect)
371+
return DropView(view_name, dialect=dialect or self.dialect)
378372

379373
def drop_index(self, index_name: str, dialect: DialectType = None) -> "DropIndex":
380374
"""Create a DROP INDEX builder.
@@ -388,16 +382,17 @@ def drop_index(self, index_name: str, dialect: DialectType = None) -> "DropIndex
388382
"""
389383
return DropIndex(index_name, dialect=dialect or self.dialect)
390384

391-
def drop_schema(self, dialect: DialectType = None) -> "DropSchema":
385+
def drop_schema(self, schema_name: str, dialect: DialectType = None) -> "DropSchema":
392386
"""Create a DROP SCHEMA builder.
393387
394388
Args:
389+
schema_name: Name of the schema to drop
395390
dialect: Optional SQL dialect
396391
397392
Returns:
398393
DropSchema builder instance
399394
"""
400-
return DropSchema(dialect=dialect or self.dialect)
395+
return DropSchema(schema_name, dialect=dialect or self.dialect)
401396

402397
def alter_table(self, table_name: str, dialect: DialectType = None) -> "AlterTable":
403398
"""Create an ALTER TABLE builder.
@@ -409,22 +404,19 @@ def alter_table(self, table_name: str, dialect: DialectType = None) -> "AlterTab
409404
Returns:
410405
AlterTable builder instance
411406
"""
412-
builder = AlterTable(table_name)
413-
builder.dialect = dialect or self.dialect
414-
return builder
407+
return AlterTable(table_name, dialect=dialect or self.dialect)
415408

416-
def rename_table(self, dialect: DialectType = None) -> "RenameTable":
409+
def rename_table(self, old_name: str, dialect: DialectType = None) -> "RenameTable":
417410
"""Create a RENAME TABLE builder.
418411
419412
Args:
413+
old_name: Current name of the table
420414
dialect: Optional SQL dialect
421415
422416
Returns:
423417
RenameTable builder instance
424418
"""
425-
builder = RenameTable()
426-
builder.dialect = dialect or self.dialect
427-
return builder
419+
return RenameTable(old_name, dialect=dialect or self.dialect)
428420

429421
def comment_on(self, dialect: DialectType = None) -> "CommentOn":
430422
"""Create a COMMENT ON builder.
@@ -435,9 +427,7 @@ def comment_on(self, dialect: DialectType = None) -> "CommentOn":
435427
Returns:
436428
CommentOn builder instance
437429
"""
438-
builder = CommentOn()
439-
builder.dialect = dialect or self.dialect
440-
return builder
430+
return CommentOn(dialect=dialect or self.dialect)
441431

442432
# ===================
443433
# SQL Analysis Helpers
@@ -746,7 +736,6 @@ def raw(sql_fragment: str, **parameters: Any) -> "Union[exp.Expression, SQL]":
746736
raise SQLBuilderError(msg) from e
747737

748738
# New behavior - return SQL statement with parameters
749-
from sqlspec.core.statement import SQL
750739

751740
return SQL(sql_fragment, parameters)
752741

@@ -1331,9 +1320,7 @@ def truncate(self, table_name: str) -> "Truncate":
13311320
)
13321321
```
13331322
"""
1334-
builder = Truncate(dialect=self.dialect)
1335-
builder._table_name = table_name
1336-
return builder
1323+
return Truncate(table_name, dialect=self.dialect)
13371324

13381325
# ===================
13391326
# Case Expressions

sqlspec/adapters/adbc/driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
"postgres": (ParameterStyle.NUMERIC, [ParameterStyle.NUMERIC]),
4949
"postgresql": (ParameterStyle.NUMERIC, [ParameterStyle.NUMERIC]),
5050
"bigquery": (ParameterStyle.NAMED_AT, [ParameterStyle.NAMED_AT]),
51-
"sqlite": (ParameterStyle.QMARK, [ParameterStyle.QMARK, ParameterStyle.NAMED_COLON]),
51+
"sqlite": (ParameterStyle.QMARK, [ParameterStyle.QMARK]),
5252
"duckdb": (ParameterStyle.QMARK, [ParameterStyle.QMARK, ParameterStyle.NUMERIC, ParameterStyle.NAMED_DOLLAR]),
5353
"mysql": (ParameterStyle.POSITIONAL_PYFORMAT, [ParameterStyle.POSITIONAL_PYFORMAT, ParameterStyle.NAMED_PYFORMAT]),
5454
"snowflake": (ParameterStyle.QMARK, [ParameterStyle.QMARK, ParameterStyle.NUMERIC]),

sqlspec/adapters/aiosqlite/pool.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,15 @@ class AiosqliteConnectionPool:
123123
"""Multi-connection pool for aiosqlite with proper shutdown handling."""
124124

125125
__slots__ = (
126-
"_closed_event",
126+
"_closed_event_instance",
127127
"_connect_timeout",
128128
"_connection_parameters",
129129
"_connection_registry",
130130
"_idle_timeout",
131-
"_lock",
131+
"_lock_instance",
132132
"_operation_timeout",
133133
"_pool_size",
134-
"_queue",
134+
"_queue_instance",
135135
"_tracked_threads",
136136
"_wal_initialized",
137137
)
@@ -159,21 +159,44 @@ def __init__(
159159
self._idle_timeout = idle_timeout
160160
self._operation_timeout = operation_timeout
161161

162-
self._queue: asyncio.Queue[AiosqlitePoolConnection] = asyncio.Queue(maxsize=pool_size)
163162
self._connection_registry: dict[str, AiosqlitePoolConnection] = {}
164-
self._lock = asyncio.Lock()
165-
self._closed_event = asyncio.Event()
166163
self._tracked_threads: set[Union[threading.Thread, AiosqliteConnection]] = set()
167164
self._wal_initialized = False
168165

166+
# Lazy initialization for Python 3.9 compatibility (asyncio objects can't be created without event loop)
167+
self._queue_instance: Optional[asyncio.Queue[AiosqlitePoolConnection]] = None
168+
self._lock_instance: Optional[asyncio.Lock] = None
169+
self._closed_event_instance: Optional[asyncio.Event] = None
170+
171+
@property
172+
def _queue(self) -> "asyncio.Queue[AiosqlitePoolConnection]":
173+
"""Lazy initialization of asyncio.Queue for Python 3.9 compatibility."""
174+
if self._queue_instance is None:
175+
self._queue_instance = asyncio.Queue(maxsize=self._pool_size)
176+
return self._queue_instance
177+
178+
@property
179+
def _lock(self) -> asyncio.Lock:
180+
"""Lazy initialization of asyncio.Lock for Python 3.9 compatibility."""
181+
if self._lock_instance is None:
182+
self._lock_instance = asyncio.Lock()
183+
return self._lock_instance
184+
185+
@property
186+
def _closed_event(self) -> asyncio.Event:
187+
"""Lazy initialization of asyncio.Event for Python 3.9 compatibility."""
188+
if self._closed_event_instance is None:
189+
self._closed_event_instance = asyncio.Event()
190+
return self._closed_event_instance
191+
169192
@property
170193
def is_closed(self) -> bool:
171194
"""Check if pool is closed.
172195
173196
Returns:
174197
True if pool is closed
175198
"""
176-
return self._closed_event.is_set()
199+
return self._closed_event_instance is not None and self._closed_event.is_set()
177200

178201
def size(self) -> int:
179202
"""Get total number of connections in pool.
@@ -189,6 +212,8 @@ def checked_out(self) -> int:
189212
Returns:
190213
Number of connections currently in use
191214
"""
215+
if self._queue_instance is None:
216+
return len(self._connection_registry)
192217
return len(self._connection_registry) - self._queue.qsize()
193218

194219
def _track_aiosqlite_thread(self, connection: "AiosqliteConnection") -> None:

sqlspec/adapters/asyncmy/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,11 @@ async def _create_pool(self) -> "Pool": # pyright: ignore
107107
async def _close_pool(self) -> None:
108108
"""Close the actual async connection pool."""
109109
if self.pool_instance:
110-
await self.pool_instance.close()
110+
self.pool_instance.close()
111+
112+
async def close_pool(self) -> None:
113+
"""Close the connection pool."""
114+
await self._close_pool()
111115

112116
async def create_connection(self) -> AsyncmyConnection: # pyright: ignore
113117
"""Create a single async connection (not from pool).

sqlspec/adapters/asyncpg/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ async def _close_pool(self) -> None:
144144
if self.pool_instance:
145145
await self.pool_instance.close()
146146

147+
async def close_pool(self) -> None:
148+
"""Close the connection pool."""
149+
await self._close_pool()
150+
147151
async def create_connection(self) -> "AsyncpgConnection":
148152
"""Create a single async connection from the pool.
149153

sqlspec/adapters/duckdb/driver.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,7 @@
4545
default_parameter_style=ParameterStyle.QMARK,
4646
supported_parameter_styles={ParameterStyle.QMARK, ParameterStyle.NUMERIC, ParameterStyle.NAMED_DOLLAR},
4747
default_execution_parameter_style=ParameterStyle.QMARK,
48-
supported_execution_parameter_styles={
49-
ParameterStyle.QMARK,
50-
ParameterStyle.NUMERIC,
51-
ParameterStyle.NAMED_DOLLAR,
52-
},
48+
supported_execution_parameter_styles={ParameterStyle.QMARK, ParameterStyle.NUMERIC},
5349
type_coercion_map={},
5450
has_native_list_expansion=True,
5551
needs_static_script_compilation=False,

sqlspec/adapters/oracledb/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,10 @@ async def _close_pool(self) -> None:
239239
if self.pool_instance:
240240
await self.pool_instance.close()
241241

242+
async def close_pool(self) -> None:
243+
"""Close the connection pool."""
244+
await self._close_pool()
245+
242246
async def create_connection(self) -> OracleAsyncConnection:
243247
"""Create a single async connection (not from pool).
244248

0 commit comments

Comments
 (0)