Skip to content

Commit a75d190

Browse files
authored
fix: init consistency (#61)
Ensure all configurations have consistent init methods.
1 parent 28ec463 commit a75d190

File tree

10 files changed

+52
-34
lines changed

10 files changed

+52
-34
lines changed

sqlspec/adapters/adbc/config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,17 @@ def __init__(
7979
self,
8080
*,
8181
connection_config: Optional[Union[AdbcConnectionParams, dict[str, Any]]] = None,
82-
statement_config: Optional[StatementConfig] = None,
8382
migration_config: Optional[dict[str, Any]] = None,
83+
statement_config: Optional[StatementConfig] = None,
84+
driver_features: Optional[dict[str, Any]] = None,
8485
) -> None:
8586
"""Initialize ADBC configuration.
8687
8788
Args:
8889
connection_config: Connection configuration parameters
89-
statement_config: Default SQL statement configuration
9090
migration_config: Migration configuration
91+
statement_config: Default SQL statement configuration
92+
driver_features: Driver feature configuration
9193
"""
9294
if connection_config is None:
9395
connection_config = {}
@@ -106,7 +108,7 @@ def __init__(
106108
connection_config=self.connection_config,
107109
migration_config=migration_config,
108110
statement_config=statement_config,
109-
driver_features={},
111+
driver_features=driver_features or {},
110112
)
111113

112114
def _resolve_driver_name(self) -> str:

sqlspec/adapters/aiosqlite/config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(
6161
pool_instance: "Optional[AiosqliteConnectionPool]" = None,
6262
migration_config: "Optional[dict[str, Any]]" = None,
6363
statement_config: "Optional[StatementConfig]" = None,
64-
**kwargs: Any,
64+
driver_features: "Optional[dict[str, Any]]" = None,
6565
) -> None:
6666
"""Initialize AioSQLite configuration.
6767
@@ -70,10 +70,9 @@ def __init__(
7070
pool_instance: Optional pre-configured connection pool instance.
7171
migration_config: Optional migration configuration.
7272
statement_config: Optional statement configuration.
73-
**kwargs: Additional connection parameters that override pool_config.
73+
driver_features: Optional driver feature configuration.
7474
"""
7575
config_dict = dict(pool_config) if pool_config else {}
76-
config_dict.update(kwargs) # Allow kwargs to override pool_config values
7776

7877
# Handle memory database URI conversion - test expectation is different than sqlite pattern
7978
if "database" not in config_dict or config_dict["database"] == ":memory:":
@@ -85,7 +84,7 @@ def __init__(
8584
pool_instance=pool_instance,
8685
migration_config=migration_config,
8786
statement_config=statement_config or aiosqlite_statement_config,
88-
driver_features={},
87+
driver_features=driver_features or {},
8988
)
9089

9190
def _get_pool_config_dict(self) -> "dict[str, Any]":

sqlspec/adapters/asyncmy/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(
7070
pool_instance: "Optional[Pool]" = None,
7171
migration_config: Optional[dict[str, Any]] = None,
7272
statement_config: "Optional[StatementConfig]" = None,
73+
driver_features: "Optional[dict[str, Any]]" = None,
7374
) -> None:
7475
"""Initialize Asyncmy configuration.
7576
@@ -78,6 +79,7 @@ def __init__(
7879
pool_instance: Existing pool instance to use
7980
migration_config: Migration configuration
8081
statement_config: Statement configuration override
82+
driver_features: Optional driver feature configuration
8183
"""
8284
processed_pool_config: dict[str, Any] = dict(pool_config) if pool_config else {}
8385
if "extra" in processed_pool_config:
@@ -97,7 +99,7 @@ def __init__(
9799
pool_instance=pool_instance,
98100
migration_config=migration_config,
99101
statement_config=statement_config,
100-
driver_features={},
102+
driver_features=driver_features or {},
101103
)
102104

103105
async def _create_pool(self) -> "Pool": # pyright: ignore

sqlspec/adapters/bigquery/config.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class BigQueryDriverFeatures(TypedDict, total=False):
7171
Only non-standard BigQuery client parameters that are SQLSpec-specific extensions.
7272
"""
7373

74+
connection_instance: NotRequired["BigQueryConnection"]
7475
on_job_start: NotRequired["Callable[[str], None]"]
7576
on_job_complete: NotRequired["Callable[[str, Any], None]"]
7677
on_connection_create: NotRequired["Callable[[Any], None]"]
@@ -93,7 +94,6 @@ class BigQueryConfig(NoPoolSyncConfig[BigQueryConnection, BigQueryDriver]):
9394
def __init__(
9495
self,
9596
*,
96-
connection_instance: "Optional[BigQueryConnection]" = None,
9797
connection_config: "Optional[Union[BigQueryConnectionParams, dict[str, Any]]]" = None,
9898
migration_config: Optional[dict[str, Any]] = None,
9999
statement_config: "Optional[StatementConfig]" = None,
@@ -103,10 +103,10 @@ def __init__(
103103
104104
Args:
105105
connection_config: Standard connection configuration parameters
106-
connection_instance: Existing connection instance to use
107106
migration_config: Migration configuration
108107
statement_config: Statement configuration override
109-
driver_features: BigQuery-specific driver features and configurations
108+
driver_features: BigQuery-specific driver features and configurations.
109+
Can include 'connection_instance' to reuse an existing BigQuery connection.
110110
111111
Example:
112112
>>> # Basic BigQuery connection
@@ -143,19 +143,16 @@ def __init__(
143143
... )
144144
"""
145145

146-
# Store connection instance
147-
self._connection_instance = connection_instance
148-
149-
# Setup configuration following DuckDB pattern
150146
self.connection_config: dict[str, Any] = dict(connection_config) if connection_config else {}
151147
if "extra" in self.connection_config:
152148
extras = self.connection_config.pop("extra")
153149
self.connection_config.update(extras)
154150

155-
# Setup driver features
156151
self.driver_features: dict[str, Any] = dict(driver_features) if driver_features else {}
157152

158-
# Setup default job config if not provided
153+
# Initialize connection instance cache (for performance optimization)
154+
self._connection_instance: Optional[BigQueryConnection] = self.driver_features.get("connection_instance")
155+
159156
if "default_query_job_config" not in self.connection_config:
160157
self._setup_default_job_config()
161158

@@ -237,8 +234,8 @@ def create_connection(self) -> BigQueryConnection:
237234
if on_connection_create:
238235
on_connection_create(connection)
239236

237+
# Cache the connection for reuse (BigQuery connections are expensive)
240238
self._connection_instance = connection
241-
242239
except Exception as e:
243240
project = self.connection_config.get("project", "Unknown")
244241
msg = f"Could not configure BigQuery connection for project '{project}'. Error: {e}"

sqlspec/adapters/duckdb/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ def __init__(
149149
self,
150150
*,
151151
pool_config: "Optional[Union[DuckDBPoolParams, dict[str, Any]]]" = None,
152-
migration_config: Optional[dict[str, Any]] = None,
153152
pool_instance: "Optional[DuckDBConnectionPool]" = None,
153+
migration_config: Optional[dict[str, Any]] = None,
154154
statement_config: "Optional[StatementConfig]" = None,
155155
driver_features: "Optional[Union[DuckDBDriverFeatures, dict[str, Any]]]" = None,
156156
) -> None:

sqlspec/adapters/oracledb/config.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,20 @@ class OracleSyncConfig(SyncDatabaseConfig[OracleSyncConnection, "ConnectionPool"
8383
def __init__(
8484
self,
8585
*,
86-
pool_instance: "Optional[ConnectionPool]" = None,
8786
pool_config: "Optional[Union[OraclePoolParams, dict[str, Any]]]" = None,
88-
statement_config: "Optional[StatementConfig]" = None,
87+
pool_instance: "Optional[ConnectionPool]" = None,
8988
migration_config: Optional[dict[str, Any]] = None,
89+
statement_config: "Optional[StatementConfig]" = None,
90+
driver_features: "Optional[dict[str, Any]]" = None,
9091
) -> None:
9192
"""Initialize Oracle synchronous configuration.
9293
9394
Args:
9495
pool_config: Pool configuration parameters
9596
pool_instance: Existing pool instance to use
96-
statement_config: Default SQL statement configuration
9797
migration_config: Migration configuration
98+
statement_config: Default SQL statement configuration
99+
driver_features: Optional driver feature configuration
98100
"""
99101
# Store the pool config as a dict and extract/merge extras
100102
processed_pool_config: dict[str, Any] = dict(pool_config) if pool_config else {}
@@ -107,6 +109,7 @@ def __init__(
107109
pool_instance=pool_instance,
108110
migration_config=migration_config,
109111
statement_config=statement_config,
112+
driver_features=driver_features or {},
110113
)
111114

112115
def _create_pool(self) -> "ConnectionPool":
@@ -208,16 +211,18 @@ def __init__(
208211
*,
209212
pool_config: "Optional[Union[OraclePoolParams, dict[str, Any]]]" = None,
210213
pool_instance: "Optional[AsyncConnectionPool]" = None,
211-
statement_config: "Optional[StatementConfig]" = None,
212214
migration_config: Optional[dict[str, Any]] = None,
215+
statement_config: "Optional[StatementConfig]" = None,
216+
driver_features: "Optional[dict[str, Any]]" = None,
213217
) -> None:
214218
"""Initialize Oracle asynchronous configuration.
215219
216220
Args:
217221
pool_config: Pool configuration parameters
218222
pool_instance: Existing pool instance to use
219-
statement_config: Default SQL statement configuration
220223
migration_config: Migration configuration
224+
statement_config: Default SQL statement configuration
225+
driver_features: Optional driver feature configuration
221226
"""
222227
# Store the pool config as a dict and extract/merge extras
223228
processed_pool_config: dict[str, Any] = dict(pool_config) if pool_config else {}
@@ -230,6 +235,7 @@ def __init__(
230235
pool_instance=pool_instance,
231236
migration_config=migration_config,
232237
statement_config=statement_config or oracledb_statement_config,
238+
driver_features=driver_features or {},
233239
)
234240

235241
async def _create_pool(self) -> "AsyncConnectionPool":

sqlspec/adapters/psqlpy/config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,19 @@ def __init__(
8686
self,
8787
*,
8888
pool_config: Optional[Union[PsqlpyPoolParams, dict[str, Any]]] = None,
89-
statement_config: Optional[StatementConfig] = None,
9089
pool_instance: Optional[ConnectionPool] = None,
9190
migration_config: Optional[dict[str, Any]] = None,
91+
statement_config: Optional[StatementConfig] = None,
92+
driver_features: Optional[dict[str, Any]] = None,
9293
) -> None:
9394
"""Initialize Psqlpy asynchronous configuration.
9495
9596
Args:
9697
pool_config: Pool configuration parameters (TypedDict or dict)
9798
pool_instance: Existing connection pool instance to use
98-
statement_config: Default SQL statement configuration
9999
migration_config: Migration configuration
100+
statement_config: Default SQL statement configuration
101+
driver_features: Optional driver feature configuration
100102
"""
101103
processed_pool_config: dict[str, Any] = dict(pool_config) if pool_config else {}
102104
if "extra" in processed_pool_config:
@@ -107,6 +109,7 @@ def __init__(
107109
pool_instance=pool_instance,
108110
migration_config=migration_config,
109111
statement_config=statement_config or psqlpy_statement_config,
112+
driver_features=driver_features or {},
110113
)
111114

112115
def _get_pool_config_dict(self) -> dict[str, Any]:

sqlspec/adapters/psycopg/config.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,18 @@ def __init__(
8585
*,
8686
pool_config: "Optional[Union[PsycopgPoolParams, dict[str, Any]]]" = None,
8787
pool_instance: Optional["ConnectionPool"] = None,
88-
statement_config: "Optional[StatementConfig]" = None,
8988
migration_config: Optional[dict[str, Any]] = None,
89+
statement_config: "Optional[StatementConfig]" = None,
90+
driver_features: "Optional[dict[str, Any]]" = None,
9091
) -> None:
9192
"""Initialize Psycopg synchronous configuration.
9293
9394
Args:
9495
pool_config: Pool configuration parameters (TypedDict or dict)
9596
pool_instance: Existing pool instance to use
96-
statement_config: Default SQL statement configuration
9797
migration_config: Migration configuration
98-
98+
statement_config: Default SQL statement configuration
99+
driver_features: Optional driver feature configuration
99100
"""
100101
processed_pool_config: dict[str, Any] = dict(pool_config) if pool_config else {}
101102
if "extra" in processed_pool_config:
@@ -107,6 +108,7 @@ def __init__(
107108
pool_instance=pool_instance,
108109
migration_config=migration_config,
109110
statement_config=statement_config or psycopg_statement_config,
111+
driver_features=driver_features or {},
110112
)
111113

112114
def _create_pool(self) -> "ConnectionPool":
@@ -268,14 +270,16 @@ def __init__(
268270
pool_instance: "Optional[AsyncConnectionPool]" = None,
269271
migration_config: "Optional[dict[str, Any]]" = None,
270272
statement_config: "Optional[StatementConfig]" = None,
273+
driver_features: "Optional[dict[str, Any]]" = None,
271274
) -> None:
272275
"""Initialize Psycopg asynchronous configuration.
273276
274277
Args:
275278
pool_config: Pool configuration parameters (TypedDict or dict)
276279
pool_instance: Existing pool instance to use
277-
statement_config: Default SQL statement configuration
278280
migration_config: Migration configuration
281+
statement_config: Default SQL statement configuration
282+
driver_features: Optional driver feature configuration
279283
"""
280284
processed_pool_config: dict[str, Any] = dict(pool_config) if pool_config else {}
281285
if "extra" in processed_pool_config:
@@ -287,6 +291,7 @@ def __init__(
287291
pool_instance=pool_instance,
288292
migration_config=migration_config,
289293
statement_config=statement_config or psycopg_statement_config,
294+
driver_features=driver_features or {},
290295
)
291296

292297
async def _create_pool(self) -> "AsyncConnectionPool":

sqlspec/adapters/sqlite/config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,18 @@ def __init__(
4444
*,
4545
pool_config: "Optional[Union[SqliteConnectionParams, dict[str, Any]]]" = None,
4646
pool_instance: "Optional[SqliteConnectionPool]" = None,
47-
statement_config: "Optional[StatementConfig]" = None,
4847
migration_config: "Optional[dict[str, Any]]" = None,
48+
statement_config: "Optional[StatementConfig]" = None,
49+
driver_features: "Optional[dict[str, Any]]" = None,
4950
) -> None:
5051
"""Initialize SQLite configuration.
5152
5253
Args:
5354
pool_config: Configuration parameters including connection settings
5455
pool_instance: Pre-created pool instance
55-
statement_config: Default SQL statement configuration
5656
migration_config: Migration configuration
57+
statement_config: Default SQL statement configuration
58+
driver_features: Optional driver feature configuration
5759
"""
5860
if pool_config is None:
5961
pool_config = {}
@@ -66,7 +68,7 @@ def __init__(
6668
pool_config=cast("dict[str, Any]", pool_config),
6769
migration_config=migration_config,
6870
statement_config=statement_config or sqlite_statement_config,
69-
driver_features={},
71+
driver_features=driver_features or {},
7072
)
7173

7274
def _get_connection_config_dict(self) -> "dict[str, Any]":

tests/integration/test_adapters/test_aiosqlite/test_connection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,9 @@ async def test_config_with_kwargs_override() -> None:
177177
pool_config = {"database": "base.db", "timeout": 5.0}
178178

179179
unique_db = f"file:override_{uuid4().hex}.db?mode=memory&cache=shared"
180-
config = AiosqliteConfig(pool_config=pool_config, database=unique_db, timeout=15.0)
180+
# Override pool_config with specific test values
181+
test_pool_config = {**pool_config, "database": unique_db, "timeout": 15.0}
182+
config = AiosqliteConfig(pool_config=test_pool_config)
181183

182184
try:
183185
connection_config = config._get_connection_config_dict()

0 commit comments

Comments
 (0)