diff --git a/AGENTS.md b/AGENTS.md index f86847da7..8d8ffc9fc 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -510,6 +510,53 @@ class AdapterConfig(AsyncDatabaseConfig): # or SyncDatabaseConfig - v0.33.0+: `pool_config` → `connection_config`, `pool_instance` → `connection_instance` +### Parameter Deprecation Pattern + +For backwards-compatible parameter renames in configuration classes: + +```python +def __init__( + self, + *, + new_param: dict[str, Any] | None = None, + **kwargs: Any, # Capture old parameter names +) -> None: + from sqlspec.utils.deprecation import warn_deprecation + + if "old_param" in kwargs: + warn_deprecation( + version="0.33.0", + deprecated_name="old_param", + kind="parameter", + removal_in="0.34.0", + alternative="new_param", + info="Parameter renamed for consistency across pooled and non-pooled adapters", + ) + if new_param is None: + new_param = kwargs.pop("old_param") + else: + kwargs.pop("old_param") # Discard if new param provided + + # Continue with initialization using new_param +``` + +**Use this pattern when:** + +- Renaming configuration parameters for consistency +- Need backwards compatibility during migration period +- Want clear deprecation warnings for users + +**Key principles:** + +- Use `**kwargs` to capture old parameter names without changing signature +- Import `warn_deprecation` inside function to avoid circular imports +- New parameter takes precedence when both old and new provided +- Use `kwargs.pop()` to remove handled parameters and avoid `**kwargs` passing issues +- Provide clear migration path (version, alternative, removal timeline) +- Set removal timeline (typically next minor or major version) + +**Reference implementation:** `sqlspec/config.py` (lines 920-1517, all 4 base config classes) + ### Error Handling - Custom exceptions inherit from `SQLSpecError` in `sqlspec/exceptions.py` diff --git a/docs/migration-guides/v0.33.0.md b/docs/migration-guides/v0.33.0.md new file mode 100644 index 000000000..193bb5794 --- /dev/null +++ b/docs/migration-guides/v0.33.0.md @@ -0,0 +1,151 @@ +# Migration Guide: v0.33.0 + +## Configuration Parameter Renaming + +SQLSpec v0.33.0 renames configuration parameters for consistency across all database adapters. + +### Changed Parameters + +| Old Name | New Name | Affected Configs | +|----------|----------|------------------| +| `pool_config` | `connection_config` | All adapters | +| `pool_instance` | `connection_instance` | All adapters | + +### Migration Steps + +Update your configuration instantiation to use the new parameter names: + +```python +# Before (deprecated, will be removed in v0.34.0) +from sqlspec.adapters.asyncpg import AsyncpgConfig + +config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/mydb"}, + pool_instance=existing_pool, +) +``` + +```python +# After (recommended) +from sqlspec.adapters.asyncpg import AsyncpgConfig + +config = AsyncpgConfig( + connection_config={"dsn": "postgresql://localhost/mydb"}, + connection_instance=existing_pool, +) +``` + +### Rationale + +The new parameter names accurately reflect usage across both: + +- **Pooled adapters** (AsyncPG, DuckDB, Psycopg, etc.) - where the parameters configure connection pools +- **Non-pooled adapters** (BigQuery, ADBC, Spanner) - where the parameters configure individual connections + +This eliminates conceptual confusion and provides consistent API across all adapters. + +### Backwards Compatibility + +**Deprecation Period**: v0.33.0 - v0.33.x + +Old parameter names continue to work with deprecation warnings. You will see warnings like: + +``` +DeprecationWarning: Use of deprecated parameter 'pool_config'. +Deprecated in SQLSpec 0.33.0. +This parameter will be removed in 0.34.0. +Use 'connection_config' instead. +Parameter renamed for consistency across pooled and non-pooled adapters. +``` + +**Breaking Change**: v0.34.0 + +Old parameter names will be completely removed in v0.34.0. Update your code during the deprecation period to avoid breakage. + +### Affected Adapters + +All database adapter configurations are affected: + +- **Async Pooled**: AsyncPG, Asyncmy, Aiosqlite, Psqlpy, Psycopg (async), Spanner (async), Oracle (async) +- **Sync Pooled**: DuckDB, SQLite, Psycopg (sync), Spanner (sync), Oracle (sync) +- **Non-pooled**: BigQuery, ADBC + +### Type Checking + +Type checkers (mypy, pyright) will not autocomplete or recognize the old parameter names. This is intentional to encourage migration to the new names. + +### Examples + +#### AsyncPG (Pooled) + +```python +from sqlspec.adapters.asyncpg import AsyncpgConfig + +# Old +config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/db", "min_size": 5, "max_size": 10} +) + +# New +config = AsyncpgConfig( + connection_config={"dsn": "postgresql://localhost/db", "min_size": 5, "max_size": 10} +) +``` + +#### BigQuery (Non-pooled) + +```python +from sqlspec.adapters.bigquery import BigQueryConfig + +# Old +config = BigQueryConfig( + pool_config={"project": "my-project"} +) + +# New +config = BigQueryConfig( + connection_config={"project": "my-project"} +) +``` + +#### Pre-created Pool Instance + +```python +import asyncpg +from sqlspec.adapters.asyncpg import AsyncpgConfig + +pool = await asyncpg.create_pool(dsn="postgresql://localhost/db") + +# Old +config = AsyncpgConfig(pool_instance=pool) + +# New +config = AsyncpgConfig(connection_instance=pool) +``` + +### Search and Replace + +For quick migration across your codebase: + +```bash +# Find all occurrences +grep -r "pool_config" . --include="*.py" +grep -r "pool_instance" . --include="*.py" + +# Replace (GNU sed) +find . -name "*.py" -exec sed -i 's/pool_config=/connection_config=/g' {} + +find . -name "*.py" -exec sed -i 's/pool_instance=/connection_instance=/g' {} + +``` + +Review changes carefully after automated replacement to ensure correctness. + +## Questions? + +If you encounter issues during migration: + +1. Check that you're using SQLSpec v0.33.0 or later +2. Verify deprecation warnings appear (ensures old names are recognized) +3. Update to new parameter names when you see warnings +4. Test your application thoroughly after migration + +Report migration issues at: https://github.com/litestar-org/sqlspec/issues diff --git a/sqlspec/adapters/adbc/config.py b/sqlspec/adapters/adbc/config.py index 5838f366b..23a368ab8 100644 --- a/sqlspec/adapters/adbc/config.py +++ b/sqlspec/adapters/adbc/config.py @@ -1,9 +1,8 @@ """ADBC database configuration.""" -import logging from collections.abc import Callable from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, ClassVar, TypedDict +from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast from typing_extensions import NotRequired @@ -12,6 +11,8 @@ from sqlspec.config import ExtensionConfigs, NoPoolSyncConfig from sqlspec.core import StatementConfig from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.utils.config_normalization import normalize_connection_config +from sqlspec.utils.logging import get_logger from sqlspec.utils.module_loader import import_string from sqlspec.utils.serializers import to_json @@ -23,7 +24,52 @@ from sqlspec.observability import ObservabilityConfig -logger = logging.getLogger("sqlspec.adapters.adbc") +logger = get_logger("adapters.adbc") + +_DRIVER_ALIASES: dict[str, str] = { + "sqlite": "adbc_driver_sqlite.dbapi.connect", + "sqlite3": "adbc_driver_sqlite.dbapi.connect", + "duckdb": "adbc_driver_duckdb.dbapi.connect", + "postgres": "adbc_driver_postgresql.dbapi.connect", + "postgresql": "adbc_driver_postgresql.dbapi.connect", + "pg": "adbc_driver_postgresql.dbapi.connect", + "snowflake": "adbc_driver_snowflake.dbapi.connect", + "sf": "adbc_driver_snowflake.dbapi.connect", + "bigquery": "adbc_driver_bigquery.dbapi.connect", + "bq": "adbc_driver_bigquery.dbapi.connect", + "flightsql": "adbc_driver_flightsql.dbapi.connect", + "grpc": "adbc_driver_flightsql.dbapi.connect", +} + +_URI_PREFIX_DRIVER: tuple[tuple[str, str], ...] = ( + ("postgresql://", "adbc_driver_postgresql.dbapi.connect"), + ("postgres://", "adbc_driver_postgresql.dbapi.connect"), + ("sqlite://", "adbc_driver_sqlite.dbapi.connect"), + ("duckdb://", "adbc_driver_duckdb.dbapi.connect"), + ("grpc://", "adbc_driver_flightsql.dbapi.connect"), + ("snowflake://", "adbc_driver_snowflake.dbapi.connect"), + ("bigquery://", "adbc_driver_bigquery.dbapi.connect"), +) + +_DRIVER_PATH_KEYWORDS_TO_DIALECT: tuple[tuple[str, str], ...] = ( + ("postgresql", "postgres"), + ("sqlite", "sqlite"), + ("duckdb", "duckdb"), + ("bigquery", "bigquery"), + ("snowflake", "snowflake"), + ("flightsql", "sqlite"), + ("grpc", "sqlite"), +) + +_PARAMETER_STYLES_BY_KEYWORD: tuple[tuple[str, tuple[tuple[str, ...], str]], ...] = ( + ("postgresql", (("numeric",), "numeric")), + ("sqlite", (("qmark", "named_colon"), "qmark")), + ("duckdb", (("qmark", "numeric"), "qmark")), + ("bigquery", (("named_at",), "named_at")), + ("snowflake", (("qmark", "numeric"), "qmark")), +) + +_BIGQUERY_DB_KWARGS_FIELDS: tuple[str, ...] = ("project_id", "dataset_id", "token") class AdbcConnectionParams(TypedDict): @@ -121,6 +167,7 @@ def __init__( bind_key: str | None = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, + **kwargs: Any, ) -> None: """Initialize configuration. @@ -133,15 +180,9 @@ def __init__( bind_key: Optional unique identifier for this configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) observability_config: Adapter-level observability overrides for lifecycle hooks and observers + **kwargs: Additional keyword arguments passed to the base configuration. """ - if connection_config is None: - connection_config = {} - extras = connection_config.pop("extra", {}) - if not isinstance(extras, dict): - msg = "The 'extra' field in connection_config must be a dictionary." - raise ImproperConfigurationError(msg) - self.connection_config: dict[str, Any] = dict(connection_config) - self.connection_config.update(extras) + self.connection_config = normalize_connection_config(connection_config) if statement_config is None: detected_dialect = str(self._get_dialect() or "sqlite") @@ -154,18 +195,7 @@ def __init__( processed_driver_features.setdefault("arrow_extension_types", True) if json_serializer is not None: - parameter_config = statement_config.parameter_config - previous_list_converter = parameter_config.type_coercion_map.get(list) - previous_tuple_converter = parameter_config.type_coercion_map.get(tuple) - updated_parameter_config = parameter_config.with_json_serializers(json_serializer) - updated_map = dict(updated_parameter_config.type_coercion_map) - if previous_list_converter is not None: - updated_map[list] = previous_list_converter - if previous_tuple_converter is not None: - updated_map[tuple] = previous_tuple_converter - statement_config = statement_config.replace( - parameter_config=updated_parameter_config.replace(type_coercion_map=updated_map) - ) + statement_config = _apply_json_serializer_to_statement_config(statement_config, json_serializer) super().__init__( connection_config=self.connection_config, @@ -176,6 +206,7 @@ def __init__( bind_key=bind_key, extension_config=extension_config, observability_config=observability_config, + **kwargs, ) def _resolve_driver_name(self) -> str: @@ -188,47 +219,16 @@ def _resolve_driver_name(self) -> str: uri = self.connection_config.get("uri") if isinstance(driver_name, str): - driver_aliases = { - "sqlite": "adbc_driver_sqlite.dbapi.connect", - "sqlite3": "adbc_driver_sqlite.dbapi.connect", - "adbc_driver_sqlite": "adbc_driver_sqlite.dbapi.connect", - "duckdb": "adbc_driver_duckdb.dbapi.connect", - "adbc_driver_duckdb": "adbc_driver_duckdb.dbapi.connect", - "postgres": "adbc_driver_postgresql.dbapi.connect", - "postgresql": "adbc_driver_postgresql.dbapi.connect", - "pg": "adbc_driver_postgresql.dbapi.connect", - "adbc_driver_postgresql": "adbc_driver_postgresql.dbapi.connect", - "snowflake": "adbc_driver_snowflake.dbapi.connect", - "sf": "adbc_driver_snowflake.dbapi.connect", - "adbc_driver_snowflake": "adbc_driver_snowflake.dbapi.connect", - "bigquery": "adbc_driver_bigquery.dbapi.connect", - "bq": "adbc_driver_bigquery.dbapi.connect", - "adbc_driver_bigquery": "adbc_driver_bigquery.dbapi.connect", - "flightsql": "adbc_driver_flightsql.dbapi.connect", - "adbc_driver_flightsql": "adbc_driver_flightsql.dbapi.connect", - "grpc": "adbc_driver_flightsql.dbapi.connect", - } - - resolved_driver = driver_aliases.get(driver_name, driver_name) - - if not resolved_driver.endswith(".dbapi.connect"): - resolved_driver = f"{resolved_driver}.dbapi.connect" - - return resolved_driver + lowered_driver = driver_name.lower() + alias = _DRIVER_ALIASES.get(lowered_driver) + if alias is not None: + return alias + return _normalize_driver_path(driver_name) if isinstance(uri, str): - if uri.startswith(("postgresql://", "postgres://")): - return "adbc_driver_postgresql.dbapi.connect" - if uri.startswith("sqlite://"): - return "adbc_driver_sqlite.dbapi.connect" - if uri.startswith("duckdb://"): - return "adbc_driver_duckdb.dbapi.connect" - if uri.startswith("grpc://"): - return "adbc_driver_flightsql.dbapi.connect" - if uri.startswith("snowflake://"): - return "adbc_driver_snowflake.dbapi.connect" - if uri.startswith("bigquery://"): - return "adbc_driver_bigquery.dbapi.connect" + resolved = _driver_from_uri(uri) + if resolved is not None: + return resolved return "adbc_driver_sqlite.dbapi.connect" @@ -246,26 +246,14 @@ def _get_connect_func(self) -> Callable[..., AdbcConnection]: try: connect_func = import_string(driver_path) except ImportError as e: - # Only add .dbapi.connect if it's not already there - if not driver_path.endswith(".dbapi.connect"): - driver_path_with_suffix = f"{driver_path}.dbapi.connect" - else: - driver_path_with_suffix = driver_path - try: - connect_func = import_string(driver_path_with_suffix) - except ImportError as e2: - msg = ( - f"Failed to import connect function from '{driver_path}' or " - f"'{driver_path_with_suffix}'. Is the driver installed? " - f"Original errors: {e} / {e2}" - ) - raise ImproperConfigurationError(msg) from e2 + msg = f"Failed to import connect function from '{driver_path}'. Is the driver installed? Error: {e}" + raise ImproperConfigurationError(msg) from e if not callable(connect_func): msg = f"The path '{driver_path}' did not resolve to a callable function." raise ImproperConfigurationError(msg) - return connect_func # type: ignore[no-any-return] + return cast("Callable[..., AdbcConnection]", connect_func) def _get_dialect(self) -> "DialectType": """Get the SQL dialect type based on the driver. @@ -273,21 +261,8 @@ def _get_dialect(self) -> "DialectType": Returns: The SQL dialect type for the driver. """ - try: - driver_path = self._resolve_driver_name() - except ImproperConfigurationError: - return None - - dialect_map = { - "postgres": "postgres", - "sqlite": "sqlite", - "duckdb": "duckdb", - "bigquery": "bigquery", - "snowflake": "snowflake", - "flightsql": "sqlite", - "grpc": "sqlite", - } - for keyword, dialect in dialect_map.items(): + driver_path = self._resolve_driver_name() + for keyword, dialect in _DRIVER_PATH_KEYWORDS_TO_DIALECT: if keyword in driver_path: return dialect return None @@ -300,19 +275,12 @@ def _get_parameter_styles(self) -> tuple[tuple[str, ...], str]: """ try: driver_path = self._resolve_driver_name() - if "postgresql" in driver_path: - return (("numeric",), "numeric") - if "sqlite" in driver_path: - return (("qmark", "named_colon"), "qmark") - if "duckdb" in driver_path: - return (("qmark", "numeric"), "qmark") - if "bigquery" in driver_path: - return (("named_at",), "named_at") - if "snowflake" in driver_path: - return (("qmark", "numeric"), "qmark") - - except Exception: - logger.debug("Error resolving parameter styles, using defaults") + for keyword, styles in _PARAMETER_STYLES_BY_KEYWORD: + if keyword in driver_path: + return styles + + except Exception: # pylint: disable=broad-exception-caught + logger.debug("Error resolving parameter styles, using defaults", exc_info=True) return (("qmark",), "qmark") def create_connection(self) -> AdbcConnection: @@ -389,36 +357,34 @@ def _get_connection_config_dict(self) -> dict[str, Any]: """ config = dict(self.connection_config) - if "driver_name" in config: - driver_name = config["driver_name"] - - if "uri" in config: - uri = config["uri"] - - if driver_name in {"sqlite", "sqlite3", "adbc_driver_sqlite"} and uri.startswith("sqlite://"): # pyright: ignore - config["uri"] = uri[9:] # pyright: ignore - - elif driver_name in {"duckdb", "adbc_driver_duckdb"} and uri.startswith("duckdb://"): # pyright: ignore - config["path"] = uri[9:] # pyright: ignore - config.pop("uri", None) - - if driver_name in {"bigquery", "bq", "adbc_driver_bigquery"}: - bigquery_parameters = ["project_id", "dataset_id", "token"] - db_kwargs = config.get("db_kwargs", {}) - - for param in bigquery_parameters: - if param in config and param != "db_kwargs": - db_kwargs[param] = config.pop(param) # pyright: ignore - - if db_kwargs: - config["db_kwargs"] = db_kwargs - - elif "db_kwargs" in config and driver_name not in {"bigquery", "bq", "adbc_driver_bigquery"}: - db_kwargs = config.pop("db_kwargs") - if isinstance(db_kwargs, dict): - config.update(db_kwargs) - - config.pop("driver_name", None) + driver_name = config.get("driver_name") + uri = config.get("uri") + driver_kind: str | None = None + if isinstance(driver_name, str): + driver_kind = _driver_kind_from_driver_name(driver_name) + if driver_kind is None and isinstance(uri, str): + driver_kind = _driver_kind_from_uri(uri) + + if isinstance(uri, str) and driver_kind == "sqlite" and uri.startswith("sqlite://"): + config["uri"] = uri[9:] + if isinstance(uri, str) and driver_kind == "duckdb" and uri.startswith("duckdb://"): + config["path"] = uri[9:] + config.pop("uri", None) + + if isinstance(driver_name, str) and driver_kind == "bigquery": + db_kwargs = config.get("db_kwargs") + db_kwargs_dict: dict[str, Any] = dict(db_kwargs) if isinstance(db_kwargs, dict) else {} + for param in _BIGQUERY_DB_KWARGS_FIELDS: + if param in config: + db_kwargs_dict[param] = config.pop(param) + if db_kwargs_dict: + config["db_kwargs"] = db_kwargs_dict + elif isinstance(driver_name, str) and "db_kwargs" in config and driver_kind != "bigquery": + db_kwargs = config.pop("db_kwargs") + if isinstance(db_kwargs, dict): + config.update(db_kwargs) + + config.pop("driver_name", None) return config @@ -437,3 +403,96 @@ def get_signature_namespace(self) -> "dict[str, Any]": "AdbcExceptionHandler": AdbcExceptionHandler, }) return namespace + + +def _apply_json_serializer_to_statement_config( + statement_config: "StatementConfig", json_serializer: "Callable[[Any], str]" +) -> "StatementConfig": + """Apply a JSON serializer to statement config while preserving list/tuple converters. + + Args: + statement_config: Base statement configuration to update. + json_serializer: JSON serializer function. + + Returns: + Updated statement configuration. + """ + parameter_config = statement_config.parameter_config + previous_list_converter = parameter_config.type_coercion_map.get(list) + previous_tuple_converter = parameter_config.type_coercion_map.get(tuple) + + updated_parameter_config = parameter_config.with_json_serializers(json_serializer) + updated_map = dict(updated_parameter_config.type_coercion_map) + + if previous_list_converter is not None: + updated_map[list] = previous_list_converter + if previous_tuple_converter is not None: + updated_map[tuple] = previous_tuple_converter + + return statement_config.replace(parameter_config=updated_parameter_config.replace(type_coercion_map=updated_map)) + + +def _normalize_driver_path(driver_name: str) -> str: + """Normalize a driver name to an importable connect function path. + + Args: + driver_name: Driver name or dotted import path. + + Returns: + A dotted path to a driver connect function. + """ + stripped = driver_name.strip() + if stripped.endswith(".dbapi.connect"): + return stripped + if stripped.endswith(".dbapi"): + return f"{stripped}.connect" + if "." in stripped: + return stripped + return f"{stripped}.dbapi.connect" + + +def _driver_from_uri(uri: str) -> str | None: + """Resolve a default driver connect path from a URI. + + Args: + uri: Connection URI. + + Returns: + Dotted connect function path if a scheme matches, otherwise None. + """ + for prefix, driver_path in _URI_PREFIX_DRIVER: + if uri.startswith(prefix): + return driver_path + return None + + +def _driver_kind_from_driver_name(driver_name: str) -> str | None: + """Return a canonical driver kind based on driver name content. + + Args: + driver_name: Driver name or dotted path. + + Returns: + Canonical driver kind string or None. + """ + resolved = _DRIVER_ALIASES.get(driver_name.lower(), driver_name) + lowered = resolved.lower() + for keyword, _dialect in _DRIVER_PATH_KEYWORDS_TO_DIALECT: + if keyword in lowered: + return keyword + return None + + +def _driver_kind_from_uri(uri: str) -> str | None: + """Return a canonical driver kind based on URI scheme. + + Args: + uri: Connection URI. + + Returns: + Canonical driver kind string or None. + """ + for prefix, driver_path in _URI_PREFIX_DRIVER: + if uri.startswith(prefix): + return _driver_kind_from_driver_name(driver_path) + return None diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index 8dd2e1de3..3ddc4d573 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -568,7 +568,6 @@ def _execute_script(self, cursor: "Cursor", statement: "SQL") -> "ExecutionResul last_rowcount = cursor.rowcount except Exception: self._handle_postgres_rollback(cursor) - logger.exception("Script execution failed") raise return self.create_execution_result( diff --git a/sqlspec/adapters/aiosqlite/config.py b/sqlspec/adapters/aiosqlite/config.py index 5ec7a6bc0..3f699ca8a 100644 --- a/sqlspec/adapters/aiosqlite/config.py +++ b/sqlspec/adapters/aiosqlite/config.py @@ -1,6 +1,5 @@ """Aiosqlite database configuration.""" -import logging from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any, ClassVar, TypedDict @@ -21,6 +20,8 @@ ) from sqlspec.adapters.sqlite._type_handlers import register_type_handlers from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs +from sqlspec.utils.config_normalization import normalize_connection_config +from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json if TYPE_CHECKING: @@ -31,7 +32,7 @@ __all__ = ("AiosqliteConfig", "AiosqliteConnectionParams", "AiosqliteDriverFeatures", "AiosqlitePoolParams") -logger = logging.getLogger(__name__) +logger = get_logger("adapters.aiosqlite") class AiosqliteConnectionParams(TypedDict): @@ -97,6 +98,7 @@ def __init__( bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, + **kwargs: Any, ) -> None: """Initialize AioSQLite configuration. @@ -109,8 +111,9 @@ def __init__( bind_key: Optional unique identifier for this configuration. extension_config: Extension-specific configuration (e.g., Litestar plugin settings) observability_config: Adapter-level observability overrides for lifecycle hooks and observers + **kwargs: Additional keyword arguments passed to the base configuration. """ - config_dict = dict(connection_config) if connection_config else {} + config_dict: dict[str, Any] = dict(connection_config) if connection_config else {} if "database" not in config_dict or config_dict["database"] == ":memory:": config_dict["database"] = "file::memory:?cache=shared" @@ -125,6 +128,8 @@ def __init__( ) config_dict["uri"] = True + config_dict = normalize_connection_config(config_dict) + processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {} processed_driver_features.setdefault("enable_custom_adapters", True) json_serializer = processed_driver_features.setdefault("json_serializer", to_json) @@ -146,6 +151,7 @@ def __init__( bind_key=bind_key, extension_config=extension_config, observability_config=observability_config, + **kwargs, ) def _get_pool_config_dict(self) -> "dict[str, Any]": @@ -154,10 +160,7 @@ def _get_pool_config_dict(self) -> "dict[str, Any]": Returns: Dictionary with pool parameters, filtering out None values. """ - config: dict[str, Any] = dict(self.connection_config) - extras = config.pop("extra", {}) - config.update(extras) - return {k: v for k, v in config.items() if v is not None} + return {k: v for k, v in self.connection_config.items() if v is not None} def _get_connection_config_dict(self) -> "dict[str, Any]": """Get connection configuration as plain dict for pool creation. diff --git a/sqlspec/adapters/aiosqlite/pool.py b/sqlspec/adapters/aiosqlite/pool.py index 35874ca5c..20eaf28b7 100644 --- a/sqlspec/adapters/aiosqlite/pool.py +++ b/sqlspec/adapters/aiosqlite/pool.py @@ -1,7 +1,6 @@ """Multi-connection pool for aiosqlite.""" import asyncio -import logging import time import uuid from contextlib import asynccontextmanager, suppress @@ -10,6 +9,7 @@ import aiosqlite from sqlspec.exceptions import SQLSpecError +from sqlspec.utils.logging import get_logger if TYPE_CHECKING: import threading @@ -24,7 +24,7 @@ "AiosqlitePoolConnection", ) -logger = logging.getLogger(__name__) +logger = get_logger(__name__) class AiosqlitePoolClosedError(SQLSpecError): diff --git a/sqlspec/adapters/asyncmy/config.py b/sqlspec/adapters/asyncmy/config.py index d5f94b0a8..c53a947eb 100644 --- a/sqlspec/adapters/asyncmy/config.py +++ b/sqlspec/adapters/asyncmy/config.py @@ -1,6 +1,5 @@ """Asyncmy database configuration.""" -import logging from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any, ClassVar, TypedDict @@ -19,6 +18,8 @@ build_asyncmy_statement_config, ) from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs +from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config +from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json if TYPE_CHECKING: @@ -33,7 +34,7 @@ __all__ = ("AsyncmyConfig", "AsyncmyConnectionParams", "AsyncmyDriverFeatures", "AsyncmyPoolParams") -logger = logging.getLogger(__name__) +logger = get_logger("adapters.asyncmy") class AsyncmyConnectionParams(TypedDict): @@ -107,6 +108,7 @@ def __init__( bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, + **kwargs: Any, ) -> None: """Initialize Asyncmy configuration. @@ -119,11 +121,13 @@ def __init__( bind_key: Optional unique identifier for this configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) observability_config: Adapter-level observability overrides for lifecycle hooks and observers + **kwargs: Additional keyword arguments (handles deprecated pool_config/pool_instance) """ - processed_connection_config: dict[str, Any] = dict(connection_config) if connection_config else {} - if "extra" in processed_connection_config: - extras = processed_connection_config.pop("extra") - processed_connection_config.update(extras) + connection_config, connection_instance = apply_pool_deprecations( + kwargs=kwargs, connection_config=connection_config, connection_instance=connection_instance + ) + + processed_connection_config = normalize_connection_config(connection_config) processed_connection_config.setdefault("host", "localhost") processed_connection_config.setdefault("port", 3306) @@ -145,6 +149,7 @@ def __init__( bind_key=bind_key, extension_config=extension_config, observability_config=observability_config, + **kwargs, ) async def _create_pool(self) -> "AsyncmyPool": # pyright: ignore diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index d3f229255..f759c0262 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -4,7 +4,6 @@ type coercion, error handling, and transaction management. """ -import logging from typing import TYPE_CHECKING, Any, Final, cast import asyncmy.errors # pyright: ignore @@ -32,6 +31,7 @@ TransactionError, UniqueViolationError, ) +from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json if TYPE_CHECKING: @@ -57,7 +57,7 @@ "build_asyncmy_statement_config", ) -logger = logging.getLogger(__name__) +logger = get_logger(__name__) json_type_value = ( ASYNC_MY_FIELD_TYPE.JSON if ASYNC_MY_FIELD_TYPE is not None and hasattr(ASYNC_MY_FIELD_TYPE, "JSON") else None diff --git a/sqlspec/adapters/asyncpg/_type_handlers.py b/sqlspec/adapters/asyncpg/_type_handlers.py index 9e3c05f6a..a54dabe0f 100644 --- a/sqlspec/adapters/asyncpg/_type_handlers.py +++ b/sqlspec/adapters/asyncpg/_type_handlers.py @@ -5,10 +5,10 @@ and optional vector type support. """ -import logging from typing import TYPE_CHECKING, Any from sqlspec.typing import PGVECTOR_INSTALLED +from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from collections.abc import Callable @@ -17,7 +17,7 @@ __all__ = ("register_json_codecs", "register_pgvector_support") -logger = logging.getLogger(__name__) +logger = get_logger(__name__) def _is_missing_vector_error(error: Exception) -> bool: diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index a93d0a67f..bfbf25bb1 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -1,6 +1,5 @@ """AsyncPG database configuration with direct field-based configuration.""" -import logging from collections.abc import Callable from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any, ClassVar, TypedDict @@ -21,8 +20,10 @@ build_asyncpg_statement_config, ) from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs -from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.exceptions import ImproperConfigurationError, MissingDependencyError from sqlspec.typing import ALLOYDB_CONNECTOR_INSTALLED, CLOUD_SQL_CONNECTOR_INSTALLED, PGVECTOR_INSTALLED +from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config +from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json if TYPE_CHECKING: @@ -35,7 +36,7 @@ __all__ = ("AsyncpgConfig", "AsyncpgConnectionConfig", "AsyncpgDriverFeatures", "AsyncpgPoolConfig") -logger = logging.getLogger("sqlspec") +logger = get_logger("adapters.asyncpg") class AsyncpgConnectionConfig(TypedDict): @@ -157,6 +158,7 @@ def __init__( bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, + **kwargs: Any, ) -> None: """Initialize AsyncPG configuration. @@ -169,7 +171,12 @@ def __init__( bind_key: Optional unique identifier for this configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) observability_config: Adapter-level observability overrides for lifecycle hooks and observers + **kwargs: Additional keyword arguments (handles deprecated pool_config/pool_instance) """ + connection_config, connection_instance = apply_pool_deprecations( + kwargs=kwargs, connection_config=connection_config, connection_instance=connection_instance + ) + features_dict: dict[str, Any] = dict(driver_features) if driver_features else {} serializer = features_dict.setdefault("json_serializer", to_json) @@ -184,7 +191,7 @@ def __init__( ) super().__init__( - connection_config=dict(connection_config) if connection_config else {}, + connection_config=normalize_connection_config(connection_config), connection_instance=connection_instance, migration_config=migration_config, statement_config=base_statement_config, @@ -192,6 +199,7 @@ def __init__( bind_key=bind_key, extension_config=extension_config, observability_config=observability_config, + **kwargs, ) self._cloud_sql_connector: Any | None = None @@ -204,42 +212,51 @@ def _validate_connector_config(self) -> None: Raises: ImproperConfigurationError: If configuration is invalid. + MissingDependencyError: If required connector packages are not installed. """ enable_cloud_sql = self.driver_features.get("enable_cloud_sql", False) enable_alloydb = self.driver_features.get("enable_alloydb", False) - if enable_cloud_sql and enable_alloydb: - msg = "Cannot enable both Cloud SQL and AlloyDB connectors simultaneously. Use separate configs for each database." - raise ImproperConfigurationError(msg) - - if enable_cloud_sql: - if not CLOUD_SQL_CONNECTOR_INSTALLED: - msg = "cloud-sql-python-connector package not installed. Install with: pip install cloud-sql-python-connector" - raise ImproperConfigurationError(msg) - - instance = self.driver_features.get("cloud_sql_instance") - if not instance: - msg = "cloud_sql_instance required when enable_cloud_sql is True. Format: 'project:region:instance'" - raise ImproperConfigurationError(msg) - - cloud_sql_instance_parts_expected = 2 - if instance.count(":") != cloud_sql_instance_parts_expected: - msg = f"Invalid Cloud SQL instance format: {instance}. Expected format: 'project:region:instance'" - raise ImproperConfigurationError(msg) - - elif enable_alloydb: - if not ALLOYDB_CONNECTOR_INSTALLED: - msg = "cloud-alloydb-python-connector package not installed. Install with: pip install cloud-alloydb-python-connector" - raise ImproperConfigurationError(msg) - - instance_uri = self.driver_features.get("alloydb_instance_uri") - if not instance_uri: - msg = "alloydb_instance_uri required when enable_alloydb is True. Format: 'projects/PROJECT/locations/REGION/clusters/CLUSTER/instances/INSTANCE'" - raise ImproperConfigurationError(msg) - - if not instance_uri.startswith("projects/"): - msg = f"Invalid AlloyDB instance URI format: {instance_uri}. Expected format: 'projects/PROJECT/locations/REGION/clusters/CLUSTER/instances/INSTANCE'" + match (enable_cloud_sql, enable_alloydb): + case (True, True): + msg = ( + "Cannot enable both Cloud SQL and AlloyDB connectors simultaneously. " + "Use separate configs for each database." + ) raise ImproperConfigurationError(msg) + case (False, False): + return + case (True, False): + if not CLOUD_SQL_CONNECTOR_INSTALLED: + raise MissingDependencyError(package="cloud-sql-python-connector", install_package="cloud-sql") + + instance = self.driver_features.get("cloud_sql_instance") + if not instance: + msg = "cloud_sql_instance required when enable_cloud_sql is True. Format: 'project:region:instance'" + raise ImproperConfigurationError(msg) + + cloud_sql_instance_parts_expected = 2 + if instance.count(":") != cloud_sql_instance_parts_expected: + msg = f"Invalid Cloud SQL instance format: {instance}. Expected format: 'project:region:instance'" + raise ImproperConfigurationError(msg) + case (False, True): + if not ALLOYDB_CONNECTOR_INSTALLED: + raise MissingDependencyError(package="google-cloud-alloydb-connector", install_package="alloydb") + + instance_uri = self.driver_features.get("alloydb_instance_uri") + if not instance_uri: + msg = ( + "alloydb_instance_uri required when enable_alloydb is True. " + "Format: 'projects/PROJECT/locations/REGION/clusters/CLUSTER/instances/INSTANCE'" + ) + raise ImproperConfigurationError(msg) + + if not instance_uri.startswith("projects/"): + msg = ( + f"Invalid AlloyDB instance URI format: {instance_uri}. Expected format: " + "'projects/PROJECT/locations/REGION/clusters/CLUSTER/instances/INSTANCE'" + ) + raise ImproperConfigurationError(msg) def _get_pool_config_dict(self) -> "dict[str, Any]": """Get pool configuration as plain dict for external library. @@ -247,10 +264,7 @@ def _get_pool_config_dict(self) -> "dict[str, Any]": Returns: Dictionary with pool parameters, filtering out None values. """ - config: dict[str, Any] = dict(self.connection_config) - extras = config.pop("extra", {}) - config.update(extras) - return {k: v for k, v in config.items() if v is not None} + return {k: v for k, v in self.connection_config.items() if v is not None} def _setup_cloud_sql_connector(self, config: "dict[str, Any]") -> None: """Setup Cloud SQL connector and configure pool for connection factory pattern. diff --git a/sqlspec/adapters/bigquery/config.py b/sqlspec/adapters/bigquery/config.py index 3ac5cc766..383b166b0 100644 --- a/sqlspec/adapters/bigquery/config.py +++ b/sqlspec/adapters/bigquery/config.py @@ -1,7 +1,6 @@ """BigQuery database configuration.""" import contextlib -import logging from typing import TYPE_CHECKING, Any, ClassVar, TypedDict from google.cloud.bigquery import LoadJobConfig, QueryJobConfig @@ -18,6 +17,8 @@ from sqlspec.exceptions import ImproperConfigurationError from sqlspec.observability import ObservabilityConfig from sqlspec.typing import Empty +from sqlspec.utils.config_normalization import normalize_connection_config +from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import to_json if TYPE_CHECKING: @@ -30,7 +31,7 @@ from sqlspec.core import StatementConfig -logger = logging.getLogger(__name__) +logger = get_logger("adapters.bigquery") class BigQueryConnectionParams(TypedDict): @@ -124,6 +125,7 @@ def __init__( bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, + **kwargs: Any, ) -> None: """Initialize BigQuery configuration. @@ -136,12 +138,10 @@ def __init__( bind_key: Optional unique identifier for this configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) observability_config: Adapter-level observability overrides for lifecycle hooks and observers + **kwargs: Additional keyword arguments passed to the base configuration. """ - self.connection_config: dict[str, Any] = dict(connection_config) if connection_config else {} - if "extra" in self.connection_config: - extras = self.connection_config.pop("extra") - self.connection_config.update(extras) + self.connection_config = normalize_connection_config(connection_config) processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {} user_connection_hook = processed_driver_features.pop("on_connection_create", None) @@ -176,6 +176,7 @@ def _wrap_hook(context: dict[str, Any]) -> None: bind_key=bind_key, extension_config=extension_config, observability_config=local_observability, + **kwargs, ) self.driver_features = processed_driver_features diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index 3cf831254..4af5d4ea0 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -7,7 +7,6 @@ import datetime import io -import logging import os from collections.abc import Callable from decimal import Decimal @@ -42,6 +41,7 @@ StorageCapabilityError, UniqueViolationError, ) +from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import to_json if TYPE_CHECKING: @@ -60,7 +60,7 @@ ) from sqlspec.typing import ArrowReturnFormat, StatementParameters -logger = logging.getLogger(__name__) +logger = get_logger(__name__) __all__ = ( "BigQueryCursor", diff --git a/sqlspec/adapters/duckdb/config.py b/sqlspec/adapters/duckdb/config.py index 102ad7c33..6161c534e 100644 --- a/sqlspec/adapters/duckdb/config.py +++ b/sqlspec/adapters/duckdb/config.py @@ -16,6 +16,7 @@ from sqlspec.adapters.duckdb.pool import DuckDBConnectionPool from sqlspec.config import ExtensionConfigs, SyncDatabaseConfig from sqlspec.observability import ObservabilityConfig +from sqlspec.utils.config_normalization import normalize_connection_config from sqlspec.utils.serializers import to_json if TYPE_CHECKING: @@ -215,6 +216,7 @@ def __init__( bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, + **kwargs: Any, ) -> None: """Initialize DuckDB configuration. @@ -228,18 +230,18 @@ def __init__( bind_key: Optional unique identifier for this configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) observability_config: Adapter-level observability overrides for lifecycle hooks and observers + **kwargs: Additional keyword arguments passed to the base configuration. """ - if connection_config is None: - connection_config = {} - connection_config.setdefault("database", ":memory:shared_db") + processed_connection_config = normalize_connection_config(connection_config) + processed_connection_config.setdefault("database", ":memory:shared_db") - if connection_config.get("database") in {":memory:", ""}: - connection_config["database"] = ":memory:shared_db" + if processed_connection_config.get("database") in {":memory:", ""}: + processed_connection_config["database"] = ":memory:shared_db" extension_flags: dict[str, Any] = {} - for key in tuple(connection_config.keys()): + for key in tuple(processed_connection_config.keys()): if key in EXTENSION_FLAG_KEYS: - extension_flags[key] = connection_config.pop(key) # type: ignore[misc] + extension_flags[key] = processed_connection_config.pop(key) processed_features: dict[str, Any] = dict(driver_features) if driver_features else {} user_connection_hook = cast( @@ -271,13 +273,14 @@ def _wrap_lifecycle_hook(context: dict[str, Any]) -> None: super().__init__( bind_key=bind_key, - connection_config=dict(connection_config), + connection_config=processed_connection_config, connection_instance=connection_instance, migration_config=migration_config, statement_config=base_statement_config, driver_features=processed_features, extension_config=extension_config, observability_config=local_observability, + **kwargs, ) def _get_connection_config_dict(self) -> "dict[str, Any]": diff --git a/sqlspec/adapters/duckdb/pool.py b/sqlspec/adapters/duckdb/pool.py index 5a40c874e..cd9c2e86a 100644 --- a/sqlspec/adapters/duckdb/pool.py +++ b/sqlspec/adapters/duckdb/pool.py @@ -1,6 +1,5 @@ """DuckDB connection pool with thread-local connections.""" -import logging import threading import time from contextlib import contextmanager, suppress @@ -9,12 +8,13 @@ import duckdb from sqlspec.adapters.duckdb._types import DuckDBConnection +from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from collections.abc import Callable, Generator -logger = logging.getLogger(__name__) +logger = get_logger(__name__) DEFAULT_MIN_POOL: Final[int] = 1 DEFAULT_MAX_POOL: Final[int] = 4 diff --git a/sqlspec/adapters/oracledb/_numpy_handlers.py b/sqlspec/adapters/oracledb/_numpy_handlers.py index 164b509b5..7c8456747 100644 --- a/sqlspec/adapters/oracledb/_numpy_handlers.py +++ b/sqlspec/adapters/oracledb/_numpy_handlers.py @@ -5,10 +5,10 @@ """ import array -import logging from typing import TYPE_CHECKING, Any from sqlspec.typing import NUMPY_INSTALLED +from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from oracledb import AsyncConnection, AsyncCursor, Connection, Cursor @@ -22,7 +22,7 @@ ) -logger = logging.getLogger(__name__) +logger = get_logger(__name__) DTYPE_TO_ARRAY_CODE: dict[str, str] = {"float64": "d", "float32": "f", "uint8": "B", "int8": "b"} diff --git a/sqlspec/adapters/oracledb/_uuid_handlers.py b/sqlspec/adapters/oracledb/_uuid_handlers.py index d2e789bad..567d64a85 100644 --- a/sqlspec/adapters/oracledb/_uuid_handlers.py +++ b/sqlspec/adapters/oracledb/_uuid_handlers.py @@ -4,17 +4,18 @@ via connection type handlers. Uses stdlib uuid (no external dependencies). """ -import logging import uuid from typing import TYPE_CHECKING, Any +from sqlspec.utils.logging import get_logger + if TYPE_CHECKING: from oracledb import AsyncConnection, AsyncCursor, Connection, Cursor __all__ = ("register_uuid_handlers", "uuid_converter_in", "uuid_converter_out") -logger = logging.getLogger(__name__) +logger = get_logger(__name__) UUID_BINARY_SIZE = 16 diff --git a/sqlspec/adapters/oracledb/config.py b/sqlspec/adapters/oracledb/config.py index 703d51877..0741757c3 100644 --- a/sqlspec/adapters/oracledb/config.py +++ b/sqlspec/adapters/oracledb/config.py @@ -1,7 +1,6 @@ """OracleDB database configuration with direct field-based configuration.""" import contextlib -import logging from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast @@ -28,6 +27,8 @@ from sqlspec.adapters.oracledb.migrations import OracleAsyncMigrationTracker, OracleSyncMigrationTracker from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs, SyncDatabaseConfig from sqlspec.typing import NUMPY_INSTALLED +from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config +from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from collections.abc import AsyncGenerator, Callable, Generator @@ -45,7 +46,7 @@ "OracleSyncConfig", ) -logger = logging.getLogger(__name__) +logger = get_logger("adapters.oracledb") class OracleConnectionParams(TypedDict): @@ -136,6 +137,7 @@ def __init__( driver_features: "OracleDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, + **kwargs: Any, ) -> None: """Initialize Oracle synchronous configuration. @@ -147,12 +149,13 @@ def __init__( driver_features: Optional driver feature configuration (TypedDict or dict). bind_key: Optional unique identifier for this configuration. extension_config: Extension-specific configuration (e.g., Litestar plugin settings). + **kwargs: Additional keyword arguments (handles deprecated pool_config/pool_instance). """ + connection_config, connection_instance = apply_pool_deprecations( + kwargs=kwargs, connection_config=connection_config, connection_instance=connection_instance + ) - processed_connection_config: dict[str, Any] = dict(connection_config) if connection_config else {} - if "extra" in processed_connection_config: - extras = processed_connection_config.pop("extra") - processed_connection_config.update(extras) + processed_connection_config = normalize_connection_config(connection_config) statement_config = statement_config or oracledb_statement_config processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {} @@ -168,6 +171,7 @@ def __init__( driver_features=processed_driver_features, bind_key=bind_key, extension_config=extension_config, + **kwargs, ) def _create_pool(self) -> "OracleSyncConnectionPool": @@ -313,6 +317,7 @@ def __init__( driver_features: "OracleDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, + **kwargs: Any, ) -> None: """Initialize Oracle asynchronous configuration. @@ -324,12 +329,13 @@ def __init__( driver_features: Optional driver feature configuration (TypedDict or dict). bind_key: Optional unique identifier for this configuration. extension_config: Extension-specific configuration (e.g., Litestar plugin settings). + **kwargs: Additional keyword arguments (handles deprecated pool_config/pool_instance). """ + connection_config, connection_instance = apply_pool_deprecations( + kwargs=kwargs, connection_config=connection_config, connection_instance=connection_instance + ) - processed_connection_config: dict[str, Any] = dict(connection_config) if connection_config else {} - if "extra" in processed_connection_config: - extras = processed_connection_config.pop("extra") - processed_connection_config.update(extras) + processed_connection_config = normalize_connection_config(connection_config) processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {} processed_driver_features.setdefault("enable_numpy_vectors", NUMPY_INSTALLED) @@ -344,6 +350,7 @@ def __init__( driver_features=processed_driver_features, bind_key=bind_key, extension_config=extension_config, + **kwargs, ) async def _create_pool(self) -> "OracleAsyncConnectionPool": diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index caef2e18f..8523cb998 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -45,7 +45,7 @@ TransactionError, UniqueViolationError, ) -from sqlspec.utils.logging import log_with_context +from sqlspec.utils.logging import get_logger, log_with_context from sqlspec.utils.module_loader import ensure_pyarrow from sqlspec.utils.serializers import to_json @@ -84,7 +84,7 @@ def prepare_statement( def _get_compiled_sql(self, statement: SQL, statement_config: StatementConfig) -> "tuple[str, Any]": ... -logger = logging.getLogger(__name__) +logger = get_logger(__name__) # Oracle-specific constants LARGE_STRING_THRESHOLD = 4000 # Threshold for large string parameters to avoid ORA-01704 diff --git a/sqlspec/adapters/psqlpy/_type_handlers.py b/sqlspec/adapters/psqlpy/_type_handlers.py index 0329a855e..410d0f7e4 100644 --- a/sqlspec/adapters/psqlpy/_type_handlers.py +++ b/sqlspec/adapters/psqlpy/_type_handlers.py @@ -10,10 +10,10 @@ custom type handlers on pool initialization. """ -import logging from typing import TYPE_CHECKING from sqlspec.typing import PGVECTOR_INSTALLED +from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from psqlpy import Connection @@ -21,7 +21,7 @@ __all__ = ("register_pgvector",) -logger = logging.getLogger(__name__) +logger = get_logger(__name__) def register_pgvector(connection: "Connection") -> None: diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index be7217af7..8705faec5 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -1,6 +1,5 @@ """Psqlpy database configuration.""" -import logging from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast @@ -18,13 +17,15 @@ from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs from sqlspec.core import StatementConfig from sqlspec.typing import PGVECTOR_INSTALLED +from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config +from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import to_json if TYPE_CHECKING: from collections.abc import Callable -logger = logging.getLogger("sqlspec.adapters.psqlpy") +logger = get_logger("adapters.psqlpy") class PsqlpyConnectionParams(TypedDict): @@ -120,6 +121,7 @@ def __init__( driver_features: "PsqlpyDriverFeatures | dict[str, Any] | None" = None, bind_key: str | None = None, extension_config: "ExtensionConfigs | None" = None, + **kwargs: Any, ) -> None: """Initialize Psqlpy configuration. @@ -131,11 +133,13 @@ def __init__( driver_features: Driver feature configuration (TypedDict or dict). bind_key: Optional unique identifier for this configuration. extension_config: Extension-specific configuration (e.g., Litestar plugin settings). + **kwargs: Additional keyword arguments (handles deprecated pool_config/pool_instance). """ - processed_connection_config: dict[str, Any] = dict(connection_config) if connection_config else {} - if "extra" in processed_connection_config: - extras = processed_connection_config.pop("extra") - processed_connection_config.update(extras) + connection_config, connection_instance = apply_pool_deprecations( + kwargs=kwargs, connection_config=connection_config, connection_instance=connection_instance + ) + + processed_connection_config = normalize_connection_config(connection_config) processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {} serializer = processed_driver_features.get("json_serializer") @@ -151,6 +155,7 @@ def __init__( driver_features=processed_driver_features, bind_key=bind_key, extension_config=extension_config, + **kwargs, ) def _get_pool_config_dict(self) -> dict[str, Any]: @@ -163,31 +168,15 @@ def _get_pool_config_dict(self) -> dict[str, Any]: async def _create_pool(self) -> "ConnectionPool": """Create the actual async connection pool.""" - logger.info("Creating psqlpy connection pool", extra={"adapter": "psqlpy"}) - - try: - config = self._get_pool_config_dict() - - pool = ConnectionPool(**config) - logger.info("Psqlpy connection pool created successfully", extra={"adapter": "psqlpy"}) - except Exception as e: - logger.exception("Failed to create psqlpy connection pool", extra={"adapter": "psqlpy", "error": str(e)}) - raise - return pool + config = self._get_pool_config_dict() + return ConnectionPool(**config) async def _close_pool(self) -> None: """Close the actual async connection pool.""" if not self.connection_instance: return - logger.info("Closing psqlpy connection pool", extra={"adapter": "psqlpy"}) - - try: - self.connection_instance.close() - logger.info("Psqlpy connection pool closed successfully", extra={"adapter": "psqlpy"}) - except Exception as e: - logger.exception("Failed to close psqlpy connection pool", extra={"adapter": "psqlpy", "error": str(e)}) - raise + self.connection_instance.close() async def close_pool(self) -> None: """Close the connection pool.""" diff --git a/sqlspec/adapters/psycopg/_type_handlers.py b/sqlspec/adapters/psycopg/_type_handlers.py index c25ae199d..77fb22c9f 100644 --- a/sqlspec/adapters/psycopg/_type_handlers.py +++ b/sqlspec/adapters/psycopg/_type_handlers.py @@ -4,12 +4,12 @@ via pgvector-python library. Supports both sync and async connections. """ -import logging from typing import TYPE_CHECKING, Any from psycopg import ProgrammingError, errors from sqlspec.typing import NUMPY_INSTALLED, PGVECTOR_INSTALLED +from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from psycopg import AsyncConnection, Connection @@ -17,7 +17,7 @@ __all__ = ("register_pgvector_async", "register_pgvector_sync") -logger = logging.getLogger(__name__) +logger = get_logger(__name__) def _is_missing_vector_error(error: Exception) -> bool: diff --git a/sqlspec/adapters/psycopg/config.py b/sqlspec/adapters/psycopg/config.py index 7924ae417..0a19fe79a 100644 --- a/sqlspec/adapters/psycopg/config.py +++ b/sqlspec/adapters/psycopg/config.py @@ -1,7 +1,6 @@ """Psycopg database configuration with direct field-based configuration.""" import contextlib -import logging from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast @@ -23,6 +22,8 @@ ) from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs, SyncDatabaseConfig from sqlspec.typing import PGVECTOR_INSTALLED +from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config +from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import to_json if TYPE_CHECKING: @@ -31,7 +32,7 @@ from sqlspec.core import StatementConfig -logger = logging.getLogger("sqlspec.adapters.psycopg") +logger = get_logger("adapters.psycopg") class PsycopgConnectionParams(TypedDict): @@ -120,6 +121,7 @@ def __init__( driver_features: "dict[str, Any] | None" = None, bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, + **kwargs: Any, ) -> None: """Initialize Psycopg synchronous configuration. @@ -131,11 +133,13 @@ def __init__( driver_features: Optional driver feature configuration bind_key: Optional unique identifier for this configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) + **kwargs: Additional keyword arguments (handles deprecated pool_config/pool_instance) """ - processed_connection_config: dict[str, Any] = dict(connection_config) if connection_config else {} - if "extra" in processed_connection_config: - extras = processed_connection_config.pop("extra") - processed_connection_config.update(extras) + connection_config, connection_instance = apply_pool_deprecations( + kwargs=kwargs, connection_config=connection_config, connection_instance=connection_instance + ) + + processed_connection_config = normalize_connection_config(connection_config) processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {} serializer = cast("Callable[[Any], str]", processed_driver_features.get("json_serializer", to_json)) @@ -150,68 +154,54 @@ def __init__( driver_features=processed_driver_features, bind_key=bind_key, extension_config=extension_config, + **kwargs, ) def _create_pool(self) -> "ConnectionPool": """Create the actual connection pool.""" - logger.info("Creating Psycopg connection pool", extra={"adapter": "psycopg"}) + all_config = dict(self.connection_config) - try: - all_config = dict(self.connection_config) - - pool_parameters = { - "min_size": all_config.pop("min_size", 4), - "max_size": all_config.pop("max_size", None), - "name": all_config.pop("name", None), - "timeout": all_config.pop("timeout", 30.0), - "max_waiting": all_config.pop("max_waiting", 0), - "max_lifetime": all_config.pop("max_lifetime", 3600.0), - "max_idle": all_config.pop("max_idle", 600.0), - "reconnect_timeout": all_config.pop("reconnect_timeout", 300.0), - "num_workers": all_config.pop("num_workers", 3), - } - - autocommit_setting = all_config.get("autocommit") - - def configure_connection(conn: "PsycopgSyncConnection") -> None: - conn.row_factory = dict_row - if autocommit_setting is not None: - conn.autocommit = autocommit_setting + pool_parameters = { + "min_size": all_config.pop("min_size", 4), + "max_size": all_config.pop("max_size", None), + "name": all_config.pop("name", None), + "timeout": all_config.pop("timeout", 30.0), + "max_waiting": all_config.pop("max_waiting", 0), + "max_lifetime": all_config.pop("max_lifetime", 3600.0), + "max_idle": all_config.pop("max_idle", 600.0), + "reconnect_timeout": all_config.pop("reconnect_timeout", 300.0), + "num_workers": all_config.pop("num_workers", 3), + } - if self.driver_features.get("enable_pgvector", False): - register_pgvector_sync(conn) + autocommit_setting = all_config.get("autocommit") - pool_parameters["configure"] = all_config.pop("configure", configure_connection) + def configure_connection(conn: "PsycopgSyncConnection") -> None: + conn.row_factory = dict_row + if autocommit_setting is not None: + conn.autocommit = autocommit_setting - pool_parameters = {k: v for k, v in pool_parameters.items() if v is not None} + if self.driver_features.get("enable_pgvector", False): + register_pgvector_sync(conn) - conninfo = all_config.pop("conninfo", None) - if conninfo: - pool = ConnectionPool(conninfo, open=True, **pool_parameters) - else: - kwargs = all_config.pop("kwargs", {}) - all_config.update(kwargs) - pool = ConnectionPool("", kwargs=all_config, open=True, **pool_parameters) + pool_parameters["configure"] = all_config.pop("configure", configure_connection) - logger.info("Psycopg connection pool created successfully", extra={"adapter": "psycopg"}) - except Exception as e: - logger.exception("Failed to create Psycopg connection pool", extra={"adapter": "psycopg", "error": str(e)}) - raise - return pool + pool_parameters = {k: v for k, v in pool_parameters.items() if v is not None} + + conninfo = all_config.pop("conninfo", None) + if conninfo: + return ConnectionPool(conninfo, open=True, **pool_parameters) + + kwargs = all_config.pop("kwargs", {}) + all_config.update(kwargs) + return ConnectionPool("", kwargs=all_config, open=True, **pool_parameters) def _close_pool(self) -> None: """Close the actual connection pool.""" if not self.connection_instance: return - logger.info("Closing Psycopg connection pool", extra={"adapter": "psycopg"}) - try: self.connection_instance.close() - logger.info("Psycopg connection pool closed successfully", extra={"adapter": "psycopg"}) - except Exception as e: - logger.exception("Failed to close Psycopg connection pool", extra={"adapter": "psycopg", "error": str(e)}) - raise finally: self.connection_instance = None @@ -319,6 +309,7 @@ def __init__( driver_features: "dict[str, Any] | None" = None, bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, + **kwargs: Any, ) -> None: """Initialize Psycopg asynchronous configuration. @@ -330,11 +321,13 @@ def __init__( driver_features: Optional driver feature configuration bind_key: Optional unique identifier for this configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) + **kwargs: Additional keyword arguments (handles deprecated pool_config/pool_instance) """ - processed_connection_config: dict[str, Any] = dict(connection_config) if connection_config else {} - if "extra" in processed_connection_config: - extras = processed_connection_config.pop("extra") - processed_connection_config.update(extras) + connection_config, connection_instance = apply_pool_deprecations( + kwargs=kwargs, connection_config=connection_config, connection_instance=connection_instance + ) + + processed_connection_config = normalize_connection_config(connection_config) processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {} serializer = cast("Callable[[Any], str]", processed_driver_features.get("json_serializer", to_json)) @@ -349,6 +342,7 @@ def __init__( driver_features=processed_driver_features, bind_key=bind_key, extension_config=extension_config, + **kwargs, ) async def _create_pool(self) -> "AsyncConnectionPool": diff --git a/sqlspec/adapters/spanner/config.py b/sqlspec/adapters/spanner/config.py index d0192b0be..0e17c34be 100644 --- a/sqlspec/adapters/spanner/config.py +++ b/sqlspec/adapters/spanner/config.py @@ -12,6 +12,7 @@ from sqlspec.adapters.spanner.driver import SpannerSyncDriver, spanner_statement_config from sqlspec.config import SyncDatabaseConfig from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config from sqlspec.utils.serializers import from_json, to_json if TYPE_CHECKING: @@ -78,8 +79,13 @@ def __init__( bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, + **kwargs: Any, ) -> None: - self.connection_config = dict(connection_config) if connection_config else {} + connection_config, connection_instance = apply_pool_deprecations( + kwargs=kwargs, connection_config=connection_config, connection_instance=connection_instance + ) + + self.connection_config = normalize_connection_config(connection_config) self.connection_config.setdefault("min_sessions", 1) self.connection_config.setdefault("max_sessions", 10) @@ -101,6 +107,7 @@ def __init__( bind_key=bind_key, extension_config=extension_config, observability_config=observability_config, + **kwargs, ) self._client: Client | None = None diff --git a/sqlspec/adapters/sqlite/_type_handlers.py b/sqlspec/adapters/sqlite/_type_handlers.py index 55b2c58a1..be7dddb70 100644 --- a/sqlspec/adapters/sqlite/_type_handlers.py +++ b/sqlspec/adapters/sqlite/_type_handlers.py @@ -5,16 +5,17 @@ via SqliteDriverFeatures configuration. """ -import logging import sqlite3 from typing import TYPE_CHECKING, Any +from sqlspec.utils.logging import get_logger + if TYPE_CHECKING: from collections.abc import Callable __all__ = ("json_adapter", "json_converter", "register_type_handlers", "unregister_type_handlers") -logger = logging.getLogger(__name__) +logger = get_logger(__name__) DEFAULT_JSON_TYPE = "JSON" @@ -65,16 +66,12 @@ def register_type_handlers( json_serializer: Optional custom JSON serializer (e.g., orjson.dumps). json_deserializer: Optional custom JSON deserializer (e.g., orjson.loads). """ - try: - sqlite3.register_adapter(dict, lambda v: json_adapter(v, json_serializer)) - sqlite3.register_adapter(list, lambda v: json_adapter(v, json_serializer)) + sqlite3.register_adapter(dict, lambda v: json_adapter(v, json_serializer)) + sqlite3.register_adapter(list, lambda v: json_adapter(v, json_serializer)) - sqlite3.register_converter(DEFAULT_JSON_TYPE, lambda v: json_converter(v, json_deserializer)) + sqlite3.register_converter(DEFAULT_JSON_TYPE, lambda v: json_converter(v, json_deserializer)) - logger.debug("Registered SQLite custom type handlers (JSON dict/list adapters)") - except Exception: - logger.exception("Failed to register SQLite type handlers") - raise + logger.debug("Registered SQLite custom type handlers (JSON dict/list adapters)") def unregister_type_handlers() -> None: diff --git a/sqlspec/adapters/sqlite/config.py b/sqlspec/adapters/sqlite/config.py index d07fbe7fa..6d87e72fd 100644 --- a/sqlspec/adapters/sqlite/config.py +++ b/sqlspec/adapters/sqlite/config.py @@ -1,9 +1,8 @@ """SQLite database configuration with thread-local connections.""" -import logging import uuid from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast +from typing import TYPE_CHECKING, Any, ClassVar, TypedDict from typing_extensions import NotRequired @@ -12,9 +11,10 @@ from sqlspec.adapters.sqlite.driver import SqliteCursor, SqliteDriver, SqliteExceptionHandler, sqlite_statement_config from sqlspec.adapters.sqlite.pool import SqliteConnectionPool from sqlspec.config import ExtensionConfigs, SyncDatabaseConfig +from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json -logger = logging.getLogger(__name__) +logger = get_logger("adapters.sqlite") if TYPE_CHECKING: from collections.abc import Callable, Generator @@ -80,6 +80,7 @@ def __init__( bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, + **kwargs: Any, ) -> None: """Initialize SQLite configuration. @@ -92,21 +93,21 @@ def __init__( bind_key: Optional bind key for the configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) observability_config: Adapter-level observability overrides for lifecycle hooks and observers + **kwargs: Additional keyword arguments passed to the base configuration. """ - if connection_config is None: - connection_config = {} - if "database" not in connection_config or connection_config["database"] == ":memory:": - connection_config["database"] = f"file:memory_{uuid.uuid4().hex}?mode=memory&cache=private" - connection_config["uri"] = True - elif "database" in connection_config: - database_path = str(connection_config["database"]) - if database_path.startswith("file:") and not connection_config.get("uri"): + config_dict: dict[str, Any] = dict(connection_config) if connection_config else {} + if "database" not in config_dict or config_dict["database"] == ":memory:": + config_dict["database"] = f"file:memory_{uuid.uuid4().hex}?mode=memory&cache=private" + config_dict["uri"] = True + elif "database" in config_dict: + database_path = str(config_dict["database"]) + if database_path.startswith("file:") and not config_dict.get("uri"): logger.debug( "Database URI detected (%s) but uri=True not set. " "Auto-enabling URI mode to prevent physical file creation.", database_path, ) - connection_config["uri"] = True + config_dict["uri"] = True processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {} processed_driver_features.setdefault("enable_custom_adapters", True) @@ -123,12 +124,13 @@ def __init__( super().__init__( bind_key=bind_key, connection_instance=connection_instance, - connection_config=cast("dict[str, Any]", connection_config), + connection_config=config_dict, migration_config=migration_config, statement_config=base_statement_config, driver_features=processed_driver_features, extension_config=extension_config, observability_config=observability_config, + **kwargs, ) def _get_connection_config_dict(self) -> "dict[str, Any]": diff --git a/sqlspec/builder/_base.py b/sqlspec/builder/_base.py index 3613f5954..7eef60620 100644 --- a/sqlspec/builder/_base.py +++ b/sqlspec/builder/_base.py @@ -554,7 +554,6 @@ def build(self, dialect: DialectType = None) -> "SafeQuery": sql_string = str(final_expression) except Exception as e: err_msg = f"Error generating SQL from expression: {e!s}" - logger.exception("SQL generation failed") self._raise_sql_builder_error(err_msg, e) return SafeQuery(sql=sql_string, parameters=self._parameters.copy(), dialect=dialect or self.dialect) diff --git a/sqlspec/builder/_factory.py b/sqlspec/builder/_factory.py index 92934b8b0..88f88d9e9 100644 --- a/sqlspec/builder/_factory.py +++ b/sqlspec/builder/_factory.py @@ -3,6 +3,7 @@ Provides statement builders (select, insert, update, etc.) and column expressions. """ +import hashlib import logging from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Union, cast @@ -45,6 +46,7 @@ from sqlspec.builder._update import Update from sqlspec.core import SQL from sqlspec.exceptions import SQLBuilderError +from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from collections.abc import Mapping, Sequence @@ -82,7 +84,7 @@ "sql", ) -logger = logging.getLogger("sqlspec") +logger = get_logger("builder.factory") MIN_SQL_LIKE_STRING_LENGTH = 6 MIN_DECODE_ARGS = 2 @@ -114,6 +116,11 @@ } +def _fingerprint_sql(sql: str) -> str: + digest = hashlib.sha256(sql.encode("utf-8", errors="replace")).hexdigest() + return digest[:12] + + def _normalize_copy_dialect(dialect: DialectType | None) -> str: if dialect is None: return "postgres" @@ -219,9 +226,18 @@ def detect_sql_type(cls, sql: str, dialect: DialectType = None) -> str: return str(parsed_expr.this).upper() return command_type except SQLGlotParseError: - logger.debug("Failed to parse SQL for type detection: %s", sql[:100]) - except (ValueError, TypeError, AttributeError) as e: - logger.warning("Unexpected error during SQL type detection for '%s...': %s", sql[:50], e) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Failed to parse SQL for type detection", + extra={"sql_length": len(sql), "sql_hash": _fingerprint_sql(sql)}, + ) + except (ValueError, TypeError, AttributeError): + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Unexpected error during SQL type detection", + exc_info=True, + extra={"sql_length": len(sql), "sql_hash": _fingerprint_sql(sql)}, + ) return "UNKNOWN" def __init__(self, dialect: DialectType = None) -> None: @@ -659,13 +675,23 @@ def _populate_insert_from_sql(self, builder: "Insert", sql_string: str) -> "Inse return builder if isinstance(parsed_expr, exp.Select): - logger.info("Detected SELECT statement for INSERT - may need target table specification") + logger.debug( + "Detected SELECT statement for INSERT; builder requires explicit target table", + extra={"builder": "insert"}, + ) return builder - logger.warning("Cannot create INSERT from %s statement", type(parsed_expr).__name__) + logger.debug( + "Cannot create INSERT from parsed statement type", + extra={"builder": "insert", "parsed_type": type(parsed_expr).__name__}, + ) - except Exception as e: - logger.warning("Failed to parse INSERT SQL, falling back to traditional mode: %s", e) + except Exception: + logger.debug( + "Failed to parse INSERT SQL; falling back to traditional mode", + exc_info=True, + extra={"builder": "insert"}, + ) return builder def _populate_select_from_sql(self, builder: "Select", sql_string: str) -> "Select": @@ -677,10 +703,17 @@ def _populate_select_from_sql(self, builder: "Select", sql_string: str) -> "Sele builder.set_expression(parsed_expr) return builder - logger.warning("Cannot create SELECT from %s statement", type(parsed_expr).__name__) + logger.debug( + "Cannot create SELECT from parsed statement type", + extra={"builder": "select", "parsed_type": type(parsed_expr).__name__}, + ) - except Exception as e: - logger.warning("Failed to parse SELECT SQL, falling back to traditional mode: %s", e) + except Exception: + logger.debug( + "Failed to parse SELECT SQL; falling back to traditional mode", + exc_info=True, + extra={"builder": "select"}, + ) return builder def _populate_update_from_sql(self, builder: "Update", sql_string: str) -> "Update": @@ -692,10 +725,17 @@ def _populate_update_from_sql(self, builder: "Update", sql_string: str) -> "Upda builder.set_expression(parsed_expr) return builder - logger.warning("Cannot create UPDATE from %s statement", type(parsed_expr).__name__) + logger.debug( + "Cannot create UPDATE from parsed statement type", + extra={"builder": "update", "parsed_type": type(parsed_expr).__name__}, + ) - except Exception as e: - logger.warning("Failed to parse UPDATE SQL, falling back to traditional mode: %s", e) + except Exception: + logger.debug( + "Failed to parse UPDATE SQL; falling back to traditional mode", + exc_info=True, + extra={"builder": "update"}, + ) return builder def _populate_delete_from_sql(self, builder: "Delete", sql_string: str) -> "Delete": @@ -707,10 +747,17 @@ def _populate_delete_from_sql(self, builder: "Delete", sql_string: str) -> "Dele builder.set_expression(parsed_expr) return builder - logger.warning("Cannot create DELETE from %s statement", type(parsed_expr).__name__) + logger.debug( + "Cannot create DELETE from parsed statement type", + extra={"builder": "delete", "parsed_type": type(parsed_expr).__name__}, + ) - except Exception as e: - logger.warning("Failed to parse DELETE SQL, falling back to traditional mode: %s", e) + except Exception: + logger.debug( + "Failed to parse DELETE SQL; falling back to traditional mode", + exc_info=True, + extra={"builder": "delete"}, + ) return builder def _populate_merge_from_sql(self, builder: "Merge", sql_string: str) -> "Merge": @@ -722,10 +769,15 @@ def _populate_merge_from_sql(self, builder: "Merge", sql_string: str) -> "Merge" builder.set_expression(parsed_expr) return builder - logger.warning("Cannot create MERGE from %s statement", type(parsed_expr).__name__) + logger.debug( + "Cannot create MERGE from parsed statement type", + extra={"builder": "merge", "parsed_type": type(parsed_expr).__name__}, + ) - except Exception as e: - logger.warning("Failed to parse MERGE SQL, falling back to traditional mode: %s", e) + except Exception: + logger.debug( + "Failed to parse MERGE SQL; falling back to traditional mode", exc_info=True, extra={"builder": "merge"} + ) return builder def column(self, name: str, table: str | None = None) -> Column: diff --git a/sqlspec/config.py b/sqlspec/config.py index ae50b7a00..e9a1941c0 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -1218,6 +1218,7 @@ def __init__( bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, + **kwargs: Any, ) -> None: self.bind_key = bind_key self.connection_instance = connection_instance @@ -1391,6 +1392,7 @@ def __init__( bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, + **kwargs: Any, ) -> None: self.bind_key = bind_key self.connection_instance = connection_instance diff --git a/sqlspec/core/compiler.py b/sqlspec/core/compiler.py index adbb62b42..c244cf5d4 100644 --- a/sqlspec/core/compiler.py +++ b/sqlspec/core/compiler.py @@ -395,7 +395,7 @@ def _compile_uncached(self, sql: str, parameters: Any, is_many: bool = False) -> except sqlspec.exceptions.SQLSpecError: raise except Exception as e: - logger.warning("Compilation failed, using fallback: %s", e) + logger.debug("Compilation failed, using fallback: %s", e) return CompiledSQL( compiled_sql=sql, execution_parameters=parameters, diff --git a/sqlspec/core/splitter.py b/sqlspec/core/splitter.py index 7512b9cfa..af6280388 100644 --- a/sqlspec/core/splitter.py +++ b/sqlspec/core/splitter.py @@ -13,6 +13,7 @@ MySQL, SQLite, DuckDB, and BigQuery. """ +import logging import re import threading from abc import ABC, abstractmethod @@ -39,6 +40,9 @@ logger = get_logger("sqlspec.core.splitter") +_TOKENIZE_DEBUG_SAMPLE_LIMIT: Final[int] = 3 +_TOKENIZE_SNIPPET_LENGTH: Final[int] = 20 + DEFAULT_PATTERN_CACHE_SIZE: Final = 1000 DEFAULT_RESULT_CACHE_SIZE: Final = 5000 DEFAULT_CACHE_TTL: Final = 3600 @@ -703,6 +707,9 @@ def _tokenize(self, sql: str) -> Generator[Token, None, None]: pos = 0 line = 1 line_start = 0 + unmatched_count = 0 + first_unmatched_pos: int | None = None + first_unmatched_snippet: str | None = None while pos < len(sql): matched = False @@ -740,9 +747,24 @@ def _tokenize(self, sql: str) -> Generator[Token, None, None]: break if not matched: - logger.error("Failed to tokenize at position %d: %s", pos, sql[pos : pos + 20]) + if unmatched_count == 0: + first_unmatched_pos = pos + first_unmatched_snippet = sql[pos : pos + _TOKENIZE_SNIPPET_LENGTH] + unmatched_count += 1 + if unmatched_count <= _TOKENIZE_DEBUG_SAMPLE_LIMIT and logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Failed to tokenize at position %d: %s", pos, sql[pos : pos + _TOKENIZE_SNIPPET_LENGTH] + ) pos += 1 + if unmatched_count and logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Tokenization skipped %d unmatched characters (first at %s: %s)", + unmatched_count, + first_unmatched_pos, + first_unmatched_snippet, + ) + def split(self, sql: str) -> list[str]: """Split SQL script into individual statements. diff --git a/sqlspec/core/statement.py b/sqlspec/core/statement.py index 47f532735..7cc2fc3b1 100644 --- a/sqlspec/core/statement.py +++ b/sqlspec/core/statement.py @@ -528,7 +528,7 @@ def copy( return new_sql def _handle_compile_failure(self, error: Exception) -> ProcessedState: - logger.warning("Processing failed, using fallback: %s", error) + logger.debug("Processing failed, using fallback: %s", error) return ProcessedState( compiled_sql=self._raw_sql, execution_parameters=self._named_parameters or self._positional_parameters, diff --git a/sqlspec/extensions/aiosql/adapter.py b/sqlspec/extensions/aiosql/adapter.py index 37bb1508f..c6f30a7af 100644 --- a/sqlspec/extensions/aiosql/adapter.py +++ b/sqlspec/extensions/aiosql/adapter.py @@ -5,7 +5,6 @@ from files using aiosql while using SQLSpec's features for execution and type mapping. """ -import logging from collections.abc import AsyncGenerator, Generator from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager from typing import Any, ClassVar, Generic, TypeVar @@ -13,9 +12,10 @@ from sqlspec.core import SQL, SQLResult, StatementConfig from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase from sqlspec.typing import AiosqlAsyncProtocol, AiosqlParamType, AiosqlSQLOperationType, AiosqlSyncProtocol +from sqlspec.utils.logging import get_logger from sqlspec.utils.module_loader import ensure_aiosql -logger = logging.getLogger("sqlspec.extensions.aiosql") +logger = get_logger("extensions.aiosql") __all__ = ("AiosqlAsyncAdapter", "AiosqlSyncAdapter") diff --git a/sqlspec/extensions/litestar/plugin.py b/sqlspec/extensions/litestar/plugin.py index dfc720420..b862bd107 100644 --- a/sqlspec/extensions/litestar/plugin.py +++ b/sqlspec/extensions/litestar/plugin.py @@ -147,6 +147,7 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No if not header_value: header_value = CorrelationContext.generate() + previous_correlation_id = CorrelationContext.get() CorrelationContext.set(header_value) set_sqlspec_scope_state(scope, CORRELATION_STATE_KEY, header_value) try: @@ -154,7 +155,7 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No finally: with suppress(KeyError): delete_sqlspec_scope_state(scope, CORRELATION_STATE_KEY) - CorrelationContext.clear() + CorrelationContext.set(previous_correlation_id) @dataclass diff --git a/sqlspec/loader.py b/sqlspec/loader.py index 07dbcac42..1f483dc01 100644 --- a/sqlspec/loader.py +++ b/sqlspec/loader.py @@ -383,6 +383,7 @@ def load_sql(self, *paths: str | Path) -> None: error: Exception | None = None start_time = time.perf_counter() path_count = len(paths) + previous_correlation_id = CorrelationContext.get() if runtime is not None: runtime.increment_metric("loader.load.invocations") runtime.increment_metric("loader.paths.requested", path_count) @@ -391,8 +392,6 @@ def load_sql(self, *paths: str | Path) -> None: attributes={"sqlspec.loader.path_count": path_count, "sqlspec.loader.encoding": self.encoding}, ) - correlation_id = CorrelationContext.get() - try: for path in paths: path_str = str(path) @@ -409,16 +408,6 @@ def load_sql(self, *paths: str | Path) -> None: except Exception as exc: error = exc - duration = time.perf_counter() - start_time - logger.exception( - "Failed to load SQL files after %.3fms", - duration * 1000, - extra={ - "error_type": type(exc).__name__, - "duration_ms": duration * 1000, - "correlation_id": correlation_id, - }, - ) if runtime is not None: runtime.increment_metric("loader.load.errors") raise @@ -428,7 +417,7 @@ def load_sql(self, *paths: str | Path) -> None: runtime.record_metric("loader.last_load_ms", duration_ms) runtime.increment_metric("loader.load.duration_ms", duration_ms) runtime.end_span(span, error=error) - CorrelationContext.clear() + CorrelationContext.set(previous_correlation_id) def _load_directory(self, dir_path: Path) -> None: """Load all SQL files from a directory. @@ -690,22 +679,10 @@ def get_sql(self, name: str) -> "SQL": Raises: SQLFileNotFoundError: If statement name not found. """ - correlation_id = CorrelationContext.get() - safe_name = _normalize_query_name(name) if safe_name not in self._queries: available = ", ".join(sorted(self._queries.keys())) if self._queries else "none" - logger.error( - "Statement not found: %s", - name, - extra={ - "statement_name": name, - "safe_name": safe_name, - "available_statements": len(self._queries), - "correlation_id": correlation_id, - }, - ) raise SQLFileNotFoundError(name, path=f"Statement '{name}' not found. Available statements: {available}") parsed_statement = self._queries[safe_name] @@ -713,4 +690,9 @@ def get_sql(self, name: str) -> "SQL": if parsed_statement.dialect: sqlglot_dialect = _normalize_dialect(parsed_statement.dialect) - return SQL(parsed_statement.sql, dialect=sqlglot_dialect) + sql = SQL(parsed_statement.sql, dialect=sqlglot_dialect) + try: + sql.compile() + except Exception as exc: + raise SQLFileParseError(name=name, path="", original_error=exc) from exc + return sql diff --git a/sqlspec/migrations/commands.py b/sqlspec/migrations/commands.py index 8d21f54aa..e0d9c9671 100644 --- a/sqlspec/migrations/commands.py +++ b/sqlspec/migrations/commands.py @@ -630,7 +630,6 @@ def fix(self, dry_run: bool = False, update_database: bool = True, yes: bool = F console.print("[green]✓ Conversion complete![/]") except Exception as e: - logger.exception("Fix command failed") console.print(f"[red]✗ Error: {e}[/]") fixer.rollback() console.print("[yellow]Restored files from backup[/]") @@ -1117,7 +1116,6 @@ async def fix(self, dry_run: bool = False, update_database: bool = True, yes: bo console.print("[green]✓ Conversion complete![/]") except Exception as e: - logger.exception("Fix command failed") console.print(f"[red]✗ Error: {e}[/]") fixer.rollback() console.print("[yellow]Restored files from backup[/]") diff --git a/sqlspec/migrations/fix.py b/sqlspec/migrations/fix.py index 2dc72f85c..8b4358f4d 100644 --- a/sqlspec/migrations/fix.py +++ b/sqlspec/migrations/fix.py @@ -5,16 +5,17 @@ uses timestamps and production uses sequential numbers. """ -import logging import re import shutil from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path +from sqlspec.utils.logging import get_logger + __all__ = ("MigrationFixer", "MigrationRename") -logger = logging.getLogger(__name__) +logger = get_logger(__name__) @dataclass diff --git a/sqlspec/migrations/utils.py b/sqlspec/migrations/utils.py index 681b91ca7..32e223c6c 100644 --- a/sqlspec/migrations/utils.py +++ b/sqlspec/migrations/utils.py @@ -2,7 +2,6 @@ import importlib import inspect -import logging import os import subprocess from datetime import datetime, timezone @@ -10,6 +9,7 @@ from typing import TYPE_CHECKING, Any, cast from sqlspec.migrations.templates import MigrationTemplateSettings, TemplateValidationError, build_template_settings +from sqlspec.utils.logging import get_logger from sqlspec.utils.text import slugify if TYPE_CHECKING: @@ -20,7 +20,7 @@ __all__ = ("create_migration_file", "drop_all", "get_author") -logger = logging.getLogger(__name__) +logger = get_logger(__name__) def create_migration_file( diff --git a/sqlspec/observability/_observer.py b/sqlspec/observability/_observer.py index 08b9f5eb8..6aa4476f9 100644 --- a/sqlspec/observability/_observer.py +++ b/sqlspec/observability/_observer.py @@ -1,5 +1,6 @@ """Statement observer primitives for SQL execution events.""" +import logging from collections.abc import Callable from time import time from typing import Any @@ -11,6 +12,9 @@ logger = get_logger("sqlspec.observability") +_LOG_SQL_MAX_CHARS = 2000 +_LOG_PARAMETERS_MAX_ITEMS = 100 + StatementObserver = Callable[["StatementEvent"], None] @@ -140,7 +144,83 @@ def format_statement_event(event: StatementEvent) -> str: def default_statement_observer(event: StatementEvent) -> None: """Log statement execution payload when no custom observer is supplied.""" - logger.info(format_statement_event(event), extra={"correlation_id": event.correlation_id}) + sql_preview, sql_truncated, sql_length = _truncate_text(event.sql, max_chars=_LOG_SQL_MAX_CHARS) + sql_preview = sql_preview.replace("\n", " ").strip() + + extra: dict[str, Any] = { + "driver": event.driver, + "adapter": event.adapter, + "bind_key": event.bind_key, + "operation": event.operation, + "execution_mode": event.execution_mode, + "is_many": event.is_many, + "is_script": event.is_script, + "rows_affected": event.rows_affected, + "duration_s": event.duration_s, + "started_at": event.started_at, + "correlation_id": event.correlation_id, + "storage_backend": event.storage_backend, + "sql": sql_preview, + "sql_length": sql_length, + "sql_truncated": sql_truncated, + } + + params_summary = _summarize_parameters(event.parameters) + if params_summary: + extra.update(params_summary) + + if logger.isEnabledFor(logging.DEBUG): + params, params_truncated = _maybe_truncate_parameters(event.parameters, max_items=_LOG_PARAMETERS_MAX_ITEMS) + if params_truncated: + extra["parameters_truncated"] = True + extra["parameters"] = params + + rows_label = event.rows_affected if event.rows_affected is not None else "unknown" + logger.info( + "[%s] %s duration=%.3fms rows=%s sql=%s", + event.driver, + event.operation, + event.duration_s * 1000, + rows_label, + sql_preview, + extra=extra, + ) + + +def _truncate_text(value: str, *, max_chars: int) -> tuple[str, bool, int]: + length = len(value) + if length <= max_chars: + return value, False, length + return value[:max_chars], True, length + + +def _summarize_parameters(parameters: Any) -> dict[str, Any]: + if parameters is None: + return {"parameters_type": None, "parameters_size": None} + if isinstance(parameters, dict): + return {"parameters_type": "dict", "parameters_size": len(parameters)} + if isinstance(parameters, list): + return {"parameters_type": "list", "parameters_size": len(parameters)} + if isinstance(parameters, tuple): + return {"parameters_type": "tuple", "parameters_size": len(parameters)} + return {"parameters_type": type(parameters).__name__, "parameters_size": None} + + +def _maybe_truncate_parameters(parameters: Any, *, max_items: int) -> tuple[Any, bool]: + if isinstance(parameters, dict): + if len(parameters) <= max_items: + return parameters, False + truncated = dict(list(parameters.items())[:max_items]) + return truncated, True + if isinstance(parameters, list): + if len(parameters) <= max_items: + return parameters, False + return parameters[:max_items], True + if isinstance(parameters, tuple): + if len(parameters) <= max_items: + return parameters, False + return parameters[:max_items], True + return parameters, False def create_event( diff --git a/sqlspec/observability/_runtime.py b/sqlspec/observability/_runtime.py index e14bfd569..a34203a91 100644 --- a/sqlspec/observability/_runtime.py +++ b/sqlspec/observability/_runtime.py @@ -1,5 +1,6 @@ """Runtime helpers that bundle lifecycle, observer, and span orchestration.""" +import hashlib import re from typing import TYPE_CHECKING, Any, cast @@ -243,13 +244,23 @@ def emit_statement_event( def start_query_span(self, sql: str, operation: str, driver: str) -> Any: """Start a query span with runtime metadata.""" + sql_hash = _hash_sql(sql) + connection_info = {"sqlspec.statement.hash": sql_hash, "sqlspec.statement.length": len(sql)} + sql_payload = "" + if self.config.print_sql: + sql_payload = self._redact_sql(sql) + sql_payload, truncated = _truncate_text(sql_payload, max_chars=4096) + if truncated: + connection_info["sqlspec.statement.truncated"] = True + correlation_id = CorrelationContext.get() return self.span_manager.start_query_span( driver=driver, adapter=self.config_name, bind_key=self.bind_key, - sql=sql, + sql=sql_payload, operation=operation, + connection_info=connection_info, correlation_id=correlation_id, ) @@ -378,4 +389,14 @@ def _mask_parameters(value: Any, allow_list: set[str]) -> Any: return "***" +def _truncate_text(value: str, *, max_chars: int) -> tuple[str, bool]: + if len(value) <= max_chars: + return value, False + return value[:max_chars], True + + +def _hash_sql(sql: str) -> str: + return hashlib.sha256(sql.encode("utf-8")).hexdigest()[:16] + + __all__ = ("ObservabilityRuntime",) diff --git a/sqlspec/observability/_spans.py b/sqlspec/observability/_spans.py index cf137e1c1..898a8ae10 100644 --- a/sqlspec/observability/_spans.py +++ b/sqlspec/observability/_spans.py @@ -10,6 +10,32 @@ logger = get_logger("sqlspec.observability.spans") +_DB_SYSTEM_MAP: tuple[tuple[str, str], ...] = ( + ("asyncpg", "postgresql"), + ("psycopg", "postgresql"), + ("psqlpy", "postgresql"), + ("postgres", "postgresql"), + ("asyncmy", "mysql"), + ("mysql", "mysql"), + ("mariadb", "mysql"), + ("aiosqlite", "sqlite"), + ("sqlite", "sqlite"), + ("duckdb", "duckdb"), + ("bigquery", "bigquery"), + ("spanner", "spanner"), + ("oracle", "oracle"), + ("oracledb", "oracle"), + ("adbc", "adbc"), +) + + +def _resolve_db_system(adapter: str) -> str: + normalized = adapter.lower() + for needle, system in _DB_SYSTEM_MAP: + if needle in normalized: + return system + return "other_sql" + class SpanManager: """Lazy OpenTelemetry span manager with graceful degradation.""" @@ -61,11 +87,12 @@ def start_query_span( if not self._enabled: return None attributes: dict[str, Any] = { - "db.system": adapter.lower(), + "db.system": _resolve_db_system(adapter), "db.operation": operation, - "db.statement": sql, "sqlspec.driver": driver, } + if sql: + attributes["db.statement"] = sql if bind_key: attributes["sqlspec.bind_key"] = bind_key if storage_backend: diff --git a/sqlspec/protocols.py b/sqlspec/protocols.py index 30d029ce5..95418226f 100644 --- a/sqlspec/protocols.py +++ b/sqlspec/protocols.py @@ -5,7 +5,7 @@ """ from collections.abc import Callable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Protocol, overload, runtime_checkable from typing_extensions import Self @@ -335,6 +335,68 @@ def stream_arrow_async(self, pattern: str, **kwargs: Any) -> "AsyncIterator[Arro msg = "Async arrow streaming not implemented" raise NotImplementedError(msg) + @property + def supports_signing(self) -> bool: + """Whether this backend supports URL signing. + + Returns: + True if the backend supports generating signed URLs, False otherwise. + Only S3, GCS, and Azure backends via obstore support signing. + """ + return False + + @overload + def sign_sync(self, paths: str, expires_in: int = 3600, for_upload: bool = False) -> str: ... + + @overload + def sign_sync(self, paths: list[str], expires_in: int = 3600, for_upload: bool = False) -> list[str]: ... + + def sign_sync( + self, paths: "str | list[str]", expires_in: int = 3600, for_upload: bool = False + ) -> "str | list[str]": + """Generate signed URL(s) for object(s). + + Args: + paths: Single object path or list of paths to sign. + expires_in: URL expiration time in seconds (default: 3600, max: 604800 = 7 days). + for_upload: Whether the URL is for upload (PUT) vs download (GET). + + Returns: + Single signed URL string if paths is a string, or list of signed URLs + if paths is a list. Preserves input type for convenience. + + Raises: + NotImplementedError: If the backend does not support URL signing. + """ + msg = "URL signing not supported by this backend" + raise NotImplementedError(msg) + + @overload + async def sign_async(self, paths: str, expires_in: int = 3600, for_upload: bool = False) -> str: ... + + @overload + async def sign_async(self, paths: list[str], expires_in: int = 3600, for_upload: bool = False) -> list[str]: ... + + async def sign_async( + self, paths: "str | list[str]", expires_in: int = 3600, for_upload: bool = False + ) -> "str | list[str]": + """Generate signed URL(s) asynchronously. + + Args: + paths: Single object path or list of paths to sign. + expires_in: URL expiration time in seconds (default: 3600, max: 604800 = 7 days). + for_upload: Whether the URL is for upload (PUT) vs download (GET). + + Returns: + Single signed URL string if paths is a string, or list of signed URLs + if paths is a list. Preserves input type for convenience. + + Raises: + NotImplementedError: If the backend does not support URL signing. + """ + msg = "URL signing not supported by this backend" + raise NotImplementedError(msg) + @runtime_checkable class HasSQLGlotExpressionProtocol(Protocol): diff --git a/sqlspec/storage/backends/fsspec.py b/sqlspec/storage/backends/fsspec.py index 3f0b0faed..fe2977660 100644 --- a/sqlspec/storage/backends/fsspec.py +++ b/sqlspec/storage/backends/fsspec.py @@ -1,12 +1,12 @@ # pyright: reportPrivateUsage=false -import logging from pathlib import Path -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, cast, overload from mypy_extensions import mypyc_attr from sqlspec.storage._utils import import_pyarrow_parquet, resolve_storage_path from sqlspec.storage.errors import execute_sync_storage_operation +from sqlspec.utils.logging import get_logger from sqlspec.utils.module_loader import ensure_fsspec from sqlspec.utils.sync_tools import async_ @@ -17,7 +17,7 @@ __all__ = ("FSSpecBackend",) -logger = logging.getLogger(__name__) +logger = get_logger(__name__) class _ArrowStreamer: @@ -307,10 +307,38 @@ def get_metadata(self, path: str | Path, **kwargs: Any) -> dict[str, Any]: "type": info.type, } - def sign(self, path: str, expires_in: int = 3600, for_upload: bool = False) -> str: - """Generate a signed URL for the file.""" - resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=False) - return f"{self._fs_uri}{resolved_path}" + @property + def supports_signing(self) -> bool: + """Whether this backend supports URL signing. + + FSSpec backends do not support URL signing. Use ObStoreBackend + for S3, GCS, or Azure if you need signed URLs. + + Returns: + Always False for fsspec backends. + """ + return False + + @overload + def sign_sync(self, paths: str, expires_in: int = 3600, for_upload: bool = False) -> str: ... + + @overload + def sign_sync(self, paths: list[str], expires_in: int = 3600, for_upload: bool = False) -> list[str]: ... + + def sign_sync( + self, paths: "str | list[str]", expires_in: int = 3600, for_upload: bool = False + ) -> "str | list[str]": + """Generate signed URL(s). + + Raises: + NotImplementedError: fsspec backends do not support URL signing. + Use obstore backend for S3, GCS, or Azure if you need signed URLs. + """ + msg = ( + f"URL signing is not supported for fsspec backend (protocol: {self.protocol}). " + "For S3, GCS, or Azure signed URLs, use ObStoreBackend instead." + ) + raise NotImplementedError(msg) def _stream_file_batches(self, obj_path: str | Path) -> "Iterator[ArrowRecordBatch]": pq = import_pyarrow_parquet() @@ -385,9 +413,21 @@ async def get_metadata_async(self, path: str | Path, **kwargs: Any) -> dict[str, """Get object metadata from storage asynchronously.""" return await async_(self.get_metadata)(path, **kwargs) - async def sign_async(self, path: str, expires_in: int = 3600, for_upload: bool = False) -> str: - """Generate a signed URL asynchronously.""" - return await async_(self.sign)(path, expires_in, for_upload) + @overload + async def sign_async(self, paths: str, expires_in: int = 3600, for_upload: bool = False) -> str: ... + + @overload + async def sign_async(self, paths: list[str], expires_in: int = 3600, for_upload: bool = False) -> list[str]: ... + + async def sign_async( + self, paths: "str | list[str]", expires_in: int = 3600, for_upload: bool = False + ) -> "str | list[str]": + """Generate signed URL(s) asynchronously. + + Raises: + NotImplementedError: fsspec backends do not support URL signing. + """ + return await async_(self.sign_sync)(paths, expires_in, for_upload) # type: ignore[arg-type] async def read_arrow_async(self, path: str | Path, **kwargs: Any) -> "ArrowTable": """Read Arrow table from storage asynchronously.""" diff --git a/sqlspec/storage/backends/local.py b/sqlspec/storage/backends/local.py index 88567e403..4ee8a1146 100644 --- a/sqlspec/storage/backends/local.py +++ b/sqlspec/storage/backends/local.py @@ -8,7 +8,7 @@ from collections.abc import AsyncIterator, Iterator from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, cast, overload from urllib.parse import unquote, urlparse from mypy_extensions import mypyc_attr @@ -304,9 +304,35 @@ def stream_arrow(self, pattern: str, **kwargs: Any) -> Iterator["ArrowRecordBatc ) yield from parquet_file.iter_batches() # pyright: ignore - def sign(self, path: "str | Path", expires_in: int = 3600, for_upload: bool = False) -> str: - """Generate a signed URL (returns file:// URI for local files).""" - return self._resolve_path(path).as_uri() + @property + def supports_signing(self) -> bool: + """Whether this backend supports URL signing. + + Local file storage does not support URL signing. + Local files are accessed directly via file:// URIs. + + Returns: + Always False for local storage. + """ + return False + + @overload + def sign_sync(self, paths: str, expires_in: int = 3600, for_upload: bool = False) -> str: ... + + @overload + def sign_sync(self, paths: list[str], expires_in: int = 3600, for_upload: bool = False) -> list[str]: ... + + def sign_sync( + self, paths: "str | list[str]", expires_in: int = 3600, for_upload: bool = False + ) -> "str | list[str]": + """Generate signed URL(s). + + Raises: + NotImplementedError: Local file storage does not require URL signing. + Local files are accessed directly via file:// URIs. + """ + msg = "URL signing is not applicable to local file storage. Use file:// URIs directly." + raise NotImplementedError(msg) # Async methods using sync_tools.async_ async def read_bytes_async(self, path: "str | Path", **kwargs: Any) -> bytes: @@ -372,6 +398,18 @@ def stream_arrow_async(self, pattern: str, **kwargs: Any) -> AsyncIterator["Arro """ return _LocalArrowIterator(self.stream_arrow(pattern, **kwargs)) - async def sign_async(self, path: "str | Path", expires_in: int = 3600, for_upload: bool = False) -> str: - """Generate a signed URL asynchronously (returns file:// URI for local files).""" - return await async_(self.sign)(path, expires_in, for_upload) + @overload + async def sign_async(self, paths: str, expires_in: int = 3600, for_upload: bool = False) -> str: ... + + @overload + async def sign_async(self, paths: list[str], expires_in: int = 3600, for_upload: bool = False) -> list[str]: ... + + async def sign_async( + self, paths: "str | list[str]", expires_in: int = 3600, for_upload: bool = False + ) -> "str | list[str]": + """Generate signed URL(s) asynchronously. + + Raises: + NotImplementedError: Local file storage does not require URL signing. + """ + return await async_(self.sign_sync)(paths, expires_in, for_upload) # type: ignore[arg-type] diff --git a/sqlspec/storage/backends/obstore.py b/sqlspec/storage/backends/obstore.py index 0e825a79c..6f4fdb48e 100644 --- a/sqlspec/storage/backends/obstore.py +++ b/sqlspec/storage/backends/obstore.py @@ -6,12 +6,11 @@ import fnmatch import io -import logging import re from collections.abc import AsyncIterator, Iterator from functools import partial from pathlib import Path, PurePosixPath -from typing import Any, Final, cast +from typing import Any, Final, cast, overload from urllib.parse import urlparse from mypy_extensions import mypyc_attr @@ -20,12 +19,13 @@ from sqlspec.storage._utils import import_pyarrow, import_pyarrow_parquet, resolve_storage_path from sqlspec.storage.errors import execute_sync_storage_operation from sqlspec.typing import ArrowRecordBatch, ArrowTable +from sqlspec.utils.logging import get_logger from sqlspec.utils.module_loader import ensure_obstore from sqlspec.utils.sync_tools import async_ __all__ = ("ObStoreBackend",) -logger = logging.getLogger(__name__) +logger = get_logger(__name__) class _AsyncArrowIterator: @@ -431,10 +431,80 @@ def stream_arrow(self, pattern: str, **kwargs: Any) -> Iterator[ArrowRecordBatch ) yield from parquet_file.iter_batches() - def sign(self, path: str, expires_in: int = 3600, for_upload: bool = False) -> str: - """Generate a signed URL for the object.""" - resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) - return f"{self.store_uri}/{resolved_path}" + @property + def supports_signing(self) -> bool: + """Whether this backend supports URL signing. + + Only S3, GCS, and Azure backends support pre-signed URLs. + Local file storage does not support URL signing. + + Returns: + True if the protocol supports signing, False otherwise. + """ + signable_protocols = {"s3", "gs", "gcs", "az", "azure"} + return self.protocol in signable_protocols + + @overload + def sign_sync(self, paths: str, expires_in: int = 3600, for_upload: bool = False) -> str: ... + + @overload + def sign_sync(self, paths: list[str], expires_in: int = 3600, for_upload: bool = False) -> list[str]: ... + + def sign_sync( + self, paths: "str | list[str]", expires_in: int = 3600, for_upload: bool = False + ) -> "str | list[str]": + """Generate signed URL(s) for the object(s). + + Args: + paths: Single object path or list of paths to sign. + expires_in: URL expiration time in seconds (default: 3600, max: 604800 = 7 days). + for_upload: Whether the URL is for upload (PUT) vs download (GET). + + Returns: + Single signed URL string if paths is a string, or list of signed URLs + if paths is a list. Preserves input type for convenience. + + Raises: + NotImplementedError: If the backend protocol does not support signing. + ValueError: If expires_in exceeds maximum (604800 seconds). + """ + import obstore as obs + + signable_protocols = {"s3", "gs", "gcs", "az", "azure"} + if self.protocol not in signable_protocols: + msg = ( + f"URL signing is not supported for protocol '{self.protocol}'. " + f"Only S3, GCS, and Azure backends support pre-signed URLs." + ) + raise NotImplementedError(msg) + + max_expires = 604800 # 7 days max per obstore/object_store limits + if expires_in > max_expires: + msg = f"expires_in cannot exceed {max_expires} seconds (7 days), got {expires_in}" + raise ValueError(msg) + + from datetime import timedelta + + method = "PUT" if for_upload else "GET" + expires_delta = timedelta(seconds=expires_in) + + if isinstance(paths, str): + path_list = [paths] + is_single = True + else: + path_list = list(paths) + is_single = False + + resolved_paths = [ + resolve_storage_path(p, self.base_path, self.protocol, strip_file_scheme=True) for p in path_list + ] + + try: + signed_urls: list[str] = obs.sign(self.store, method, resolved_paths, expires_delta) # type: ignore[call-overload] + return signed_urls[0] if is_single else signed_urls + except Exception as exc: + msg = f"Failed to generate signed URL(s) for {resolved_paths}" + raise StorageOperationFailedError(msg) from exc async def read_bytes_async(self, path: "str | Path", **kwargs: Any) -> bytes: # pyright: ignore[reportUnusedParameter] """Read bytes from storage asynchronously.""" @@ -574,7 +644,64 @@ def stream_arrow_async(self, pattern: str, **kwargs: Any) -> AsyncIterator[Arrow resolved_pattern = resolve_storage_path(pattern, self.base_path, self.protocol, strip_file_scheme=True) return _AsyncArrowIterator(self, resolved_pattern, **kwargs) - async def sign_async(self, path: str, expires_in: int = 3600, for_upload: bool = False) -> str: - """Generate a signed URL asynchronously.""" - resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) - return f"{self.store_uri}/{resolved_path}" + @overload + async def sign_async(self, paths: str, expires_in: int = 3600, for_upload: bool = False) -> str: ... + + @overload + async def sign_async(self, paths: list[str], expires_in: int = 3600, for_upload: bool = False) -> list[str]: ... + + async def sign_async( + self, paths: "str | list[str]", expires_in: int = 3600, for_upload: bool = False + ) -> "str | list[str]": + """Generate signed URL(s) asynchronously. + + Args: + paths: Single object path or list of paths to sign. + expires_in: URL expiration time in seconds (default: 3600, max: 604800 = 7 days). + for_upload: Whether the URL is for upload (PUT) vs download (GET). + + Returns: + Single signed URL string if paths is a string, or list of signed URLs + if paths is a list. Preserves input type for convenience. + + Raises: + NotImplementedError: If the backend protocol does not support signing. + ValueError: If expires_in exceeds maximum (604800 seconds). + """ + import obstore as obs + + signable_protocols = {"s3", "gs", "gcs", "az", "azure"} + if self.protocol not in signable_protocols: + msg = ( + f"URL signing is not supported for protocol '{self.protocol}'. " + f"Only S3, GCS, and Azure backends support pre-signed URLs." + ) + raise NotImplementedError(msg) + + max_expires = 604800 # 7 days max per obstore/object_store limits + if expires_in > max_expires: + msg = f"expires_in cannot exceed {max_expires} seconds (7 days), got {expires_in}" + raise ValueError(msg) + + from datetime import timedelta + + method = "PUT" if for_upload else "GET" + expires_delta = timedelta(seconds=expires_in) + + if isinstance(paths, str): + path_list = [paths] + is_single = True + else: + path_list = list(paths) + is_single = False + + resolved_paths = [ + resolve_storage_path(p, self.base_path, self.protocol, strip_file_scheme=True) for p in path_list + ] + + try: + signed_urls: list[str] = await obs.sign_async(self.store, method, resolved_paths, expires_delta) # type: ignore[call-overload] + return signed_urls[0] if is_single else signed_urls + except Exception as exc: + msg = f"Failed to generate signed URL(s) for {resolved_paths}" + raise StorageOperationFailedError(msg) from exc diff --git a/sqlspec/storage/errors.py b/sqlspec/storage/errors.py index c1b81bd20..8e06b7b44 100644 --- a/sqlspec/storage/errors.py +++ b/sqlspec/storage/errors.py @@ -1,10 +1,10 @@ """Storage error normalization helpers.""" import errno -import logging from typing import TYPE_CHECKING, Any, TypeVar from sqlspec.exceptions import FileNotFoundInStorageError, StorageOperationFailedError +from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Mapping @@ -12,7 +12,7 @@ __all__ = ("StorageError", "execute_async_storage_operation", "execute_sync_storage_operation", "raise_storage_error") -logger = logging.getLogger(__name__) +logger = get_logger(__name__) T = TypeVar("T") diff --git a/sqlspec/storage/registry.py b/sqlspec/storage/registry.py index bc594b5bc..303d6194a 100644 --- a/sqlspec/storage/registry.py +++ b/sqlspec/storage/registry.py @@ -5,7 +5,6 @@ scheme-based routing, and named aliases for common configurations. """ -import logging import re from pathlib import Path from typing import Any, Final, cast @@ -15,11 +14,12 @@ from sqlspec.exceptions import ImproperConfigurationError, MissingDependencyError from sqlspec.protocols import ObjectStoreProtocol from sqlspec.typing import FSSPEC_INSTALLED, OBSTORE_INSTALLED +from sqlspec.utils.logging import get_logger from sqlspec.utils.type_guards import is_local_path __all__ = ("StorageRegistry", "storage_registry") -logger = logging.getLogger(__name__) +logger = get_logger(__name__) SCHEME_REGEX: Final = re.compile(r"([a-zA-Z0-9+.-]+)://") diff --git a/sqlspec/utils/config_normalization.py b/sqlspec/utils/config_normalization.py new file mode 100644 index 000000000..a04ca20e8 --- /dev/null +++ b/sqlspec/utils/config_normalization.py @@ -0,0 +1,105 @@ +"""Configuration normalization helpers. + +These utilities are used by adapter config modules to keep connection configuration handling +consistent across pooled and non-pooled adapters. +""" + +from typing import TYPE_CHECKING, Any + +from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.utils.deprecation import warn_deprecation + +if TYPE_CHECKING: + from collections.abc import Mapping + +__all__ = ("apply_pool_deprecations", "normalize_connection_config") + +_POOL_DEPRECATION_INFO = "Parameter renamed for consistency across pooled and non-pooled adapters" + + +def apply_pool_deprecations( + *, + kwargs: dict[str, Any], + connection_config: "Any | None", + connection_instance: "Any | None", + version: str = "0.33.0", + removal_in: str = "0.34.0", +) -> tuple["Any | None", "Any | None"]: + """Apply deprecated pool_config/pool_instance arguments. + + Several adapters historically accepted ``pool_config`` and ``pool_instance``. SQLSpec standardized + these to ``connection_config`` and ``connection_instance``. This helper preserves the prior + behavior without repeating the same deprecation handling blocks in every adapter config. + + Args: + kwargs: Keyword arguments passed to the adapter config constructor (mutated in-place). + connection_config: Current connection_config value. + connection_instance: Current connection_instance value. + version: Version the parameters were deprecated in. + removal_in: Version the parameters are scheduled for removal. + + Returns: + Updated (connection_config, connection_instance). + """ + if "pool_config" in kwargs: + warn_deprecation( + version=version, + deprecated_name="pool_config", + kind="parameter", + removal_in=removal_in, + alternative="connection_config", + info=_POOL_DEPRECATION_INFO, + stacklevel=3, + ) + if connection_config is None: + connection_config = kwargs.pop("pool_config") + else: + kwargs.pop("pool_config") + + if "pool_instance" in kwargs: + warn_deprecation( + version=version, + deprecated_name="pool_instance", + kind="parameter", + removal_in=removal_in, + alternative="connection_instance", + info=_POOL_DEPRECATION_INFO, + stacklevel=3, + ) + if connection_instance is None: + connection_instance = kwargs.pop("pool_instance") + else: + kwargs.pop("pool_instance") + + return connection_config, connection_instance + + +def normalize_connection_config( + connection_config: "Mapping[str, Any] | None", *, extra_key: str = "extra" +) -> dict[str, Any]: + """Normalize an adapter connection_config dictionary. + + This function: + - Copies the provided mapping into a new dict. + - Merges any nested dict stored under ``extra_key`` into the top-level config. + - Ensures the extra mapping is a dictionary (or None). + + Args: + connection_config: Raw connection configuration mapping. + extra_key: Key holding additional keyword arguments to merge. + + Returns: + Normalized connection configuration. + + Raises: + ImproperConfigurationError: If ``extra_key`` exists but is not a dictionary. + """ + normalized: dict[str, Any] = dict(connection_config) if connection_config else {} + extras = normalized.pop(extra_key, {}) + if extras is None: + return normalized + if not isinstance(extras, dict): + msg = f"The '{extra_key}' field in connection_config must be a dictionary." + raise ImproperConfigurationError(msg) + normalized.update(extras) + return normalized diff --git a/sqlspec/utils/correlation.py b/sqlspec/utils/correlation.py index cbdfc182c..e19608dd8 100644 --- a/sqlspec/utils/correlation.py +++ b/sqlspec/utils/correlation.py @@ -13,6 +13,8 @@ __all__ = ("CorrelationContext", "correlation_context", "get_correlation_adapter") +correlation_id_var: "ContextVar[str | None]" = ContextVar("sqlspec_correlation_id", default=None) + class CorrelationContext: """Context manager for correlation ID tracking. @@ -21,7 +23,7 @@ class CorrelationContext: across async and sync operations. """ - _correlation_id: ClassVar[ContextVar[str | None]] = ContextVar("sqlspec_correlation_id", default=None) + _correlation_id: ClassVar["ContextVar[str | None]"] = correlation_id_var @classmethod def get(cls) -> str | None: diff --git a/sqlspec/utils/deprecation.py b/sqlspec/utils/deprecation.py index eb3e11a87..46f2c0b74 100644 --- a/sqlspec/utils/deprecation.py +++ b/sqlspec/utils/deprecation.py @@ -29,6 +29,7 @@ def warn_deprecation( alternative: str | None = None, info: str | None = None, pending: bool = False, + stacklevel: int = 2, ) -> None: """Warn about a call to a deprecated function. @@ -40,6 +41,7 @@ def warn_deprecation( info: Additional information pending: Use :class:`warnings.PendingDeprecationWarning` instead of :class:`warnings.DeprecationWarning` kind: Type of the deprecated thing + stacklevel: Warning stacklevel to report the correct caller site. """ parts = [] @@ -67,7 +69,7 @@ def warn_deprecation( text = ". ".join(parts) # pyright: ignore[reportUnknownArgumentType] warning_class = PendingDeprecationWarning if pending else DeprecationWarning - warn(text, warning_class, stacklevel=2) + warn(text, warning_class, stacklevel=stacklevel) def deprecated( diff --git a/sqlspec/utils/logging.py b/sqlspec/utils/logging.py index 1297cd284..7cb11bcfa 100644 --- a/sqlspec/utils/logging.py +++ b/sqlspec/utils/logging.py @@ -6,11 +6,15 @@ """ import logging -from contextvars import ContextVar from logging import LogRecord -from typing import Any +from typing import TYPE_CHECKING, Any from sqlspec._serialization import encode_json +from sqlspec.utils.correlation import CorrelationContext +from sqlspec.utils.correlation import correlation_id_var as _correlation_id_var + +if TYPE_CHECKING: + from contextvars import ContextVar __all__ = ( "SqlglotCommandFallbackFilter", @@ -22,7 +26,14 @@ "suppress_erroneous_sqlglot_log_messages", ) -correlation_id_var: "ContextVar[str | None]" = ContextVar("correlation_id", default=None) +_BASE_RECORD_KEYS = set( + logging.LogRecord( + name="sqlspec", level=logging.INFO, pathname="(unknown file)", lineno=0, msg="", args=(), exc_info=None + ).__dict__.keys() +) +_BASE_RECORD_KEYS.update({"message", "asctime"}) + +correlation_id_var: "ContextVar[str | None]" = _correlation_id_var def set_correlation_id(correlation_id: "str | None") -> None: @@ -31,7 +42,7 @@ def set_correlation_id(correlation_id: "str | None") -> None: Args: correlation_id: The correlation ID to set, or None to clear """ - correlation_id_var.set(correlation_id) + CorrelationContext.set(correlation_id) def get_correlation_id() -> "str | None": @@ -40,7 +51,7 @@ def get_correlation_id() -> "str | None": Returns: The current correlation ID or None if not set """ - return correlation_id_var.get() + return CorrelationContext.get() class StructuredFormatter(logging.Formatter): @@ -65,12 +76,21 @@ def format(self, record: LogRecord) -> str: "line": record.lineno, } - if correlation_id := get_correlation_id(): + correlation_id = getattr(record, "correlation_id", None) or get_correlation_id() + if correlation_id: log_entry["correlation_id"] = correlation_id if hasattr(record, "extra_fields"): log_entry.update(record.extra_fields) # pyright: ignore + extras = { + key: value + for key, value in record.__dict__.items() + if key not in _BASE_RECORD_KEYS and key not in {"extra_fields", "correlation_id"} + } + if extras: + log_entry.update(extras) + if record.exc_info: log_entry["exception"] = self.formatException(record.exc_info) @@ -154,9 +174,7 @@ def log_with_context(logger: logging.Logger, level: int, message: str, **extra_f message: Log message **extra_fields: Additional fields to include in structured logs """ - record = logger.makeRecord(logger.name, level, "(unknown file)", 0, message, (), None) - record.extra_fields = extra_fields - logger.handle(record) + logger.log(level, message, extra={"extra_fields": extra_fields}, stacklevel=2) def suppress_erroneous_sqlglot_log_messages() -> None: diff --git a/sqlspec/utils/schema.py b/sqlspec/utils/schema.py index 0fa1d4f86..6c8ebc37a 100644 --- a/sqlspec/utils/schema.py +++ b/sqlspec/utils/schema.py @@ -1,7 +1,6 @@ """Schema transformation utilities for converting data to various schema types.""" import datetime -import logging from collections.abc import Callable, Sequence from enum import Enum from functools import lru_cache, partial @@ -23,6 +22,7 @@ get_type_adapter, ) from sqlspec.utils.data_transformation import transform_dict_keys +from sqlspec.utils.logging import get_logger from sqlspec.utils.text import camelize, kebabize, pascalize from sqlspec.utils.type_guards import ( get_msgspec_rename_config, @@ -45,7 +45,7 @@ DataT = TypeVar("DataT", default=dict[str, Any]) -logger = logging.getLogger(__name__) +logger = get_logger(__name__) _DATETIME_TYPES: Final[set[type]] = {datetime.datetime, datetime.date, datetime.time} diff --git a/sqlspec/utils/version.py b/sqlspec/utils/version.py index 8b065e4f9..adecda714 100644 --- a/sqlspec/utils/version.py +++ b/sqlspec/utils/version.py @@ -4,13 +4,14 @@ (0001) and timestamp-based (20251011120000) formats with type-safe comparison. """ -import logging import re from dataclasses import dataclass from datetime import datetime, timezone from enum import Enum from typing import Any +from sqlspec.utils.logging import get_logger + __all__ = ( "MigrationVersion", "VersionType", @@ -23,7 +24,7 @@ "parse_version", ) -logger = logging.getLogger(__name__) +logger = get_logger(__name__) SEQUENTIAL_PATTERN = re.compile(r"^(?!\d{14}$)\d+$") TIMESTAMP_PATTERN = re.compile(r"^(\d{14})$") diff --git a/tests/integration/test_extensions/test_litestar/test_correlation_middleware.py b/tests/integration/test_extensions/test_litestar/test_correlation_middleware.py index 02a8c3ffe..380256cc9 100644 --- a/tests/integration/test_extensions/test_litestar/test_correlation_middleware.py +++ b/tests/integration/test_extensions/test_litestar/test_correlation_middleware.py @@ -10,6 +10,16 @@ from sqlspec.utils.correlation import CorrelationContext +def setup_function() -> None: + """Clear correlation context before each test to prevent pollution.""" + CorrelationContext.clear() + + +def teardown_function() -> None: + """Clear correlation context after each test to prevent pollution.""" + CorrelationContext.clear() + + @get("/correlation") async def correlation_handler() -> dict[str, str | None]: return {"correlation_id": CorrelationContext.get()} diff --git a/tests/integration/test_storage/test_storage_integration.py b/tests/integration/test_storage/test_storage_integration.py index 274f50b9d..f2fcba5f5 100644 --- a/tests/integration/test_storage/test_storage_integration.py +++ b/tests/integration/test_storage/test_storage_integration.py @@ -135,16 +135,17 @@ def test_local_store_listing_operations(local_test_setup: Path) -> None: @pytest.mark.xdist_group("storage") -def test_local_store_url_signing(local_test_setup: Path) -> None: - """Test LocalStore URL signing functionality.""" +def test_local_store_url_signing_not_supported(local_test_setup: Path) -> None: + """Test LocalStore URL signing raises NotImplementedError.""" from sqlspec.storage.backends.local import LocalStore store = LocalStore(str(local_test_setup)) - # Test sign method - signed_url = store.sign("test.txt", expires_in=3600) - assert signed_url.startswith("file://") - assert "test.txt" in signed_url + # Local storage does not support URL signing + assert store.supports_signing is False + + with pytest.raises(NotImplementedError, match="URL signing is not applicable"): + store.sign_sync("test.txt", expires_in=3600) @pytest.mark.xdist_group("storage") @@ -533,10 +534,14 @@ def test_backend_consistency(request: pytest.FixtureRequest, backend_name: str) # Test exists consistency assert backend.exists(test_path) - # Test URL signing consistency (all should return some form of URL) - signed_url = backend.sign(test_path, expires_in=3600) - assert isinstance(signed_url, str) - assert len(signed_url) > 0 + # Test URL signing consistency (only for backends that support signing) + if backend.supports_signing: + signed_url = backend.sign_sync(test_path, expires_in=3600) + assert isinstance(signed_url, str) + assert len(signed_url) > 0 + else: + with pytest.raises(NotImplementedError): + backend.sign_sync(test_path, expires_in=3600) @pytest.mark.xdist_group("storage") diff --git a/tests/unit/test_adapters/test_adbc/test_adbc_config_normalization.py b/tests/unit/test_adapters/test_adbc/test_adbc_config_normalization.py new file mode 100644 index 000000000..81718a114 --- /dev/null +++ b/tests/unit/test_adapters/test_adbc/test_adbc_config_normalization.py @@ -0,0 +1,104 @@ +"""Unit tests for ADBC config normalization helpers.""" + +from typing import Any, cast + +from sqlspec.adapters.adbc import AdbcConfig + + +def _resolve_driver_name(config: AdbcConfig) -> str: + """Call the internal driver-name resolver without triggering pyright private usage.""" + return cast("str", cast("Any", config)._resolve_driver_name()) + + +def _get_connection_config_dict(config: AdbcConfig) -> dict[str, Any]: + """Call the internal connection-config builder without triggering pyright private usage.""" + return cast("dict[str, Any]", cast("Any", config)._get_connection_config_dict()) + + +def test_resolve_driver_name_alias_to_connect_path() -> None: + """Resolve short driver aliases to concrete connect paths.""" + config = AdbcConfig(connection_config={"driver_name": "sqlite"}) + assert _resolve_driver_name(config) == "adbc_driver_sqlite.dbapi.connect" + + +def test_resolve_driver_name_module_name_appends_suffix() -> None: + """Append .dbapi.connect for bare driver module names.""" + config = AdbcConfig(connection_config={"driver_name": "adbc_driver_sqlite"}) + assert _resolve_driver_name(config) == "adbc_driver_sqlite.dbapi.connect" + + +def test_resolve_driver_name_dbapi_suffix_appends_connect() -> None: + """Append .connect when driver_name ends in .dbapi.""" + config = AdbcConfig(connection_config={"driver_name": "adbc_driver_sqlite.dbapi"}) + assert _resolve_driver_name(config) == "adbc_driver_sqlite.dbapi.connect" + + +def test_resolve_driver_name_custom_dotted_path_is_left_unchanged() -> None: + """Treat dotted driver_name values as full import paths.""" + config = AdbcConfig(connection_config={"driver_name": "my.custom.connect"}) + assert _resolve_driver_name(config) == "my.custom.connect" + + +def test_resolve_driver_name_custom_bare_name_appends_suffix() -> None: + """Preserve historical behavior for bare custom driver names.""" + config = AdbcConfig(connection_config={"driver_name": "my_custom_driver"}) + assert _resolve_driver_name(config) == "my_custom_driver.dbapi.connect" + + +def test_resolve_driver_name_from_uri() -> None: + """Detect driver from URI scheme when driver_name is absent.""" + config = AdbcConfig(connection_config={"uri": "postgresql://example.invalid/db"}) + assert _resolve_driver_name(config) == "adbc_driver_postgresql.dbapi.connect" + + +def test_connection_config_dict_strips_sqlite_scheme() -> None: + """Strip sqlite:// from URI when using the sqlite driver.""" + config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": "sqlite:///tmp.db"}) + resolved = _get_connection_config_dict(config) + assert resolved.get("uri") == "/tmp.db" + assert "driver_name" not in resolved + + +def test_connection_config_dict_converts_duckdb_uri_to_path() -> None: + """Convert duckdb:// URI to a path parameter for DuckDB.""" + config = AdbcConfig(connection_config={"driver_name": "duckdb", "uri": "duckdb:///tmp.db"}) + resolved = _get_connection_config_dict(config) + assert resolved.get("path") == "/tmp.db" + assert "uri" not in resolved + assert "driver_name" not in resolved + + +def test_connection_config_dict_moves_bigquery_fields_into_db_kwargs() -> None: + """Move BigQuery configuration fields into db_kwargs.""" + config = AdbcConfig( + connection_config={ + "driver_name": "bigquery", + "project_id": "test-project", + "dataset_id": "test-dataset", + "token": "token", + } + ) + resolved = _get_connection_config_dict(config) + assert "driver_name" not in resolved + assert "project_id" not in resolved + assert "dataset_id" not in resolved + assert "token" not in resolved + assert resolved["db_kwargs"]["project_id"] == "test-project" + assert resolved["db_kwargs"]["dataset_id"] == "test-dataset" + assert resolved["db_kwargs"]["token"] == "token" + + +def test_connection_config_dict_moves_bigquery_fields_for_bq_alias() -> None: + """Move BigQuery fields into db_kwargs when using the bq alias.""" + config = AdbcConfig(connection_config={"driver_name": "bq", "project_id": "p", "dataset_id": "d"}) + resolved = _get_connection_config_dict(config) + assert resolved["db_kwargs"]["project_id"] == "p" + assert resolved["db_kwargs"]["dataset_id"] == "d" + + +def test_connection_config_dict_flattens_db_kwargs_for_non_bigquery() -> None: + """Flatten db_kwargs into top-level for non-BigQuery drivers.""" + config = AdbcConfig(connection_config={"driver_name": "postgres", "db_kwargs": {"foo": "bar"}}) + resolved = _get_connection_config_dict(config) + assert "db_kwargs" not in resolved + assert resolved["foo"] == "bar" diff --git a/tests/unit/test_adapters/test_asyncpg/test_cloud_connectors.py b/tests/unit/test_adapters/test_asyncpg/test_cloud_connectors.py index b9109ae4f..5da994c56 100644 --- a/tests/unit/test_adapters/test_asyncpg/test_cloud_connectors.py +++ b/tests/unit/test_adapters/test_asyncpg/test_cloud_connectors.py @@ -8,7 +8,7 @@ import pytest from sqlspec.adapters.asyncpg.config import AsyncpgConfig -from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.exceptions import ImproperConfigurationError, MissingDependencyError @pytest.fixture(autouse=True) @@ -81,7 +81,7 @@ def test_mutual_exclusion_both_enabled_raises_error() -> None: def test_cloud_sql_missing_package_raises_error() -> None: """Enabling Cloud SQL without package installed should raise error.""" - with pytest.raises(ImproperConfigurationError, match="cloud-sql-python-connector package not installed"): + with pytest.raises(MissingDependencyError, match="cloud-sql-python-connector"): AsyncpgConfig( connection_config={"dsn": "postgresql://localhost/test"}, driver_features={"enable_cloud_sql": True, "cloud_sql_instance": "project:region:instance"}, @@ -91,7 +91,7 @@ def test_cloud_sql_missing_package_raises_error() -> None: def test_alloydb_missing_package_raises_error() -> None: """Enabling AlloyDB without package installed should raise error.""" with patch("sqlspec.adapters.asyncpg.config.ALLOYDB_CONNECTOR_INSTALLED", False): - with pytest.raises(ImproperConfigurationError, match="cloud-alloydb-python-connector package not installed"): + with pytest.raises(MissingDependencyError, match="google-cloud-alloydb-connector"): AsyncpgConfig( connection_config={"dsn": "postgresql://localhost/test"}, driver_features={ diff --git a/tests/unit/test_config_deprecation.py b/tests/unit/test_config_deprecation.py new file mode 100644 index 000000000..d9664086e --- /dev/null +++ b/tests/unit/test_config_deprecation.py @@ -0,0 +1,243 @@ +"""Tests for config parameter deprecation (pool_config → connection_config, pool_instance → connection_instance). + +Only adapters that previously supported pooling are tested here: +- asyncpg (async pooled) +- psycopg (sync/async pooled) +- psqlpy (async pooled) +- asyncmy (async pooled) +- oracledb (sync/async pooled) +- spanner (sync pooled) + +Non-pooled adapters (sqlite, duckdb, aiosqlite, adbc, bigquery) never had pool_config/pool_instance. +""" + +import warnings +from typing import Any + +import pytest + +from sqlspec.adapters.asyncmy import AsyncmyConfig +from sqlspec.adapters.asyncpg import AsyncpgConfig +from sqlspec.adapters.oracledb import OracleAsyncConfig, OracleSyncConfig +from sqlspec.adapters.psqlpy import PsqlpyConfig +from sqlspec.adapters.psycopg import PsycopgAsyncConfig, PsycopgSyncConfig +from sqlspec.adapters.spanner import SpannerSyncConfig + + +def test_pool_config_deprecated_psycopg_sync() -> None: + """Test pool_config parameter triggers deprecation warning (sync pooled adapter).""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + config = PsycopgSyncConfig(pool_config={"conninfo": "postgresql://localhost/test"}) + + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "pool_config" in str(w[0].message) + assert "connection_config" in str(w[0].message) + assert "0.34.0" in str(w[0].message) + assert config.connection_config["conninfo"] == "postgresql://localhost/test" + + +def test_pool_config_deprecated_asyncpg() -> None: + """Test pool_config parameter triggers deprecation warning (async pooled adapter).""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + config = AsyncpgConfig(pool_config={"dsn": "postgresql://localhost/test"}) + + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "pool_config" in str(w[0].message) + assert "connection_config" in str(w[0].message) + assert config.connection_config["dsn"] == "postgresql://localhost/test" + + +def test_pool_config_deprecated_oracledb() -> None: + """Test pool_config parameter triggers deprecation warning (Oracle pooled adapter).""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + config = OracleSyncConfig(pool_config={"user": "test", "password": "test"}) + + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "pool_config" in str(w[0].message) + assert config.connection_config["user"] == "test" + + +def test_pool_instance_deprecated() -> None: + """Test pool_instance parameter triggers deprecation warning.""" + mock_pool: Any = object() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + config = PsycopgSyncConfig(pool_instance=mock_pool) + + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "pool_instance" in str(w[0].message) + assert "connection_instance" in str(w[0].message) + assert config.connection_instance is mock_pool + + +def test_new_parameter_takes_precedence() -> None: + """Test new parameter wins when both old and new provided.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + config = PsycopgSyncConfig( + pool_config={"conninfo": "postgresql://localhost/old"}, + connection_config={"conninfo": "postgresql://localhost/new"}, + ) + + # Should get warning for pool_config but still use connection_config + assert len(w) == 1 + assert "pool_config" in str(w[0].message) + assert config.connection_config["conninfo"] == "postgresql://localhost/new" + + +def test_no_warning_when_using_new_params() -> None: + """Test no warning when only new parameters used.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + PsycopgSyncConfig(connection_config={"conninfo": "postgresql://localhost/test"}) + + assert len(w) == 0 + + +def test_both_deprecated_params() -> None: + """Test both deprecated parameters together.""" + mock_pool: Any = object() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + config = PsycopgSyncConfig(pool_config={"conninfo": "postgresql://localhost/test"}, pool_instance=mock_pool) + + assert len(w) == 2 + warning_messages = [str(warning.message) for warning in w] + assert any("pool_config" in msg for msg in warning_messages) + assert any("pool_instance" in msg for msg in warning_messages) + assert config.connection_instance is mock_pool + + +@pytest.mark.parametrize( + "adapter_class", + [ + AsyncpgConfig, + AsyncmyConfig, + PsycopgSyncConfig, + PsycopgAsyncConfig, + OracleSyncConfig, + OracleAsyncConfig, + PsqlpyConfig, + SpannerSyncConfig, + ], +) +def test_all_pooled_adapters_handle_deprecated_params(adapter_class: type) -> None: + """Parametrized test ensuring all pooled adapters support deprecated parameter names. + + Only adapters that previously supported pool_config/pool_instance are tested: + - asyncpg, asyncmy, psycopg (sync/async), oracledb (sync/async), psqlpy, spanner + """ + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + adapter_class(pool_config={}) + + assert len(w) >= 1 + assert any("pool_config" in str(warning.message) for warning in w) + assert all(issubclass(warning.category, DeprecationWarning) for warning in w) + + +def test_deprecation_message_format() -> None: + """Test deprecation warning has correct format with all required information.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + PsycopgSyncConfig(pool_config={"conninfo": "postgresql://localhost/test"}) + + assert len(w) == 1 + message = str(w[0].message) + + # Check all required elements are present + assert "pool_config" in message # Old parameter name + assert "connection_config" in message # New parameter name + assert "0.33.0" in message # Deprecated in version + assert "0.34.0" in message # Removal version + assert "consistency" in message.lower() # Rationale info + + +def test_connection_instance_precedence() -> None: + """Test connection_instance takes precedence over pool_instance.""" + old_instance: Any = object() + new_instance: Any = object() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + config = PsycopgSyncConfig(pool_instance=old_instance, connection_instance=new_instance) + + # Should get warning for pool_instance but use connection_instance + assert len(w) == 1 + assert "pool_instance" in str(w[0].message) + assert config.connection_instance is new_instance + assert config.connection_instance is not old_instance + + +def test_empty_pool_config_deprecated() -> None: + """Test empty pool_config dict triggers deprecation warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + PsycopgSyncConfig(pool_config={}) + + assert len(w) == 1 + assert "pool_config" in str(w[0].message) + + +def test_none_pool_instance_deprecated() -> None: + """Test explicitly passing None for pool_instance triggers warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + config = PsycopgSyncConfig(pool_instance=None) + + assert len(w) == 1 + assert "pool_instance" in str(w[0].message) + assert config.connection_instance is None + + +def test_mixed_old_new_both_params() -> None: + """Test when one old param and one new param provided together.""" + mock_pool: Any = object() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + config = PsycopgSyncConfig( + pool_config={"conninfo": "postgresql://localhost/test"}, connection_instance=mock_pool + ) + + # Should only warn about pool_config + assert len(w) == 1 + assert "pool_config" in str(w[0].message) + assert config.connection_instance is mock_pool + + +def test_warning_stack_level() -> None: + """Test deprecation warnings are DeprecationWarning category.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # Create config with deprecated param + PsycopgSyncConfig(pool_config={"conninfo": "postgresql://localhost/test"}) + + assert len(w) == 1 + # Warning should be DeprecationWarning type + assert issubclass(w[0].category, DeprecationWarning) + # Warning is raised from config.py (stacklevel=2 in warn_deprecation) + assert "config.py" in w[0].filename diff --git a/tests/unit/test_extensions/test_litestar/test_correlation_middleware.py b/tests/unit/test_extensions/test_litestar/test_correlation_middleware.py new file mode 100644 index 000000000..9b610a50e --- /dev/null +++ b/tests/unit/test_extensions/test_litestar/test_correlation_middleware.py @@ -0,0 +1,30 @@ +"""Tests for Litestar correlation middleware behavior.""" + +from typing import Any, cast + +from sqlspec.extensions.litestar.plugin import CorrelationMiddleware +from sqlspec.utils.correlation import CorrelationContext + + +async def test_litestar_correlation_middleware_restores_previous_correlation_id() -> None: + CorrelationContext.set("outer") + seen: dict[str, Any] = {} + + async def app(_scope: Any, _receive: Any, _send: Any) -> None: + seen["cid"] = CorrelationContext.get() + + middleware = CorrelationMiddleware(app, headers=("x-request-id",)) + scope = {"type": "http", "headers": [(b"x-request-id", b"inner")]} + + async def receive() -> Any: + return {"type": "http.request"} + + async def send(_message: Any) -> None: + return None + + try: + await middleware(cast("Any", scope), cast("Any", receive), cast("Any", send)) + assert seen["cid"] == "inner" + assert CorrelationContext.get() == "outer" + finally: + CorrelationContext.clear() diff --git a/tests/unit/test_loader/test_correlation_preserved.py b/tests/unit/test_loader/test_correlation_preserved.py new file mode 100644 index 000000000..6a1dffefe --- /dev/null +++ b/tests/unit/test_loader/test_correlation_preserved.py @@ -0,0 +1,19 @@ +"""Tests for correlation context handling in SQL loader.""" + +from pathlib import Path + +from sqlspec.loader import SQLFileLoader +from sqlspec.utils.correlation import CorrelationContext + + +def test_loader_does_not_clear_correlation_context(tmp_path: Path) -> None: + path = tmp_path / "queries.sql" + path.write_text("-- name: ping\nSELECT 1;\n", encoding="utf-8") + + CorrelationContext.set("outer") + try: + loader = SQLFileLoader() + loader.load_sql(path) + assert CorrelationContext.get() == "outer" + finally: + CorrelationContext.clear() diff --git a/tests/unit/test_loader/test_sql_file_loader.py b/tests/unit/test_loader/test_sql_file_loader.py index 0c9dc0b33..4ea5d8cc1 100644 --- a/tests/unit/test_loader/test_sql_file_loader.py +++ b/tests/unit/test_loader/test_sql_file_loader.py @@ -205,8 +205,12 @@ def test_parse_normalize_query_names() -> None: assert "update_user_email" in statements -def test_get_sql_parses_expression_when_missing() -> None: - """SQL objects from get_sql should carry parsed expressions for count queries.""" +def test_get_sql_eagerly_compiles_expression() -> None: + """SQL objects from get_sql should have expressions eagerly compiled. + + This ensures that SQL objects from get_sql() can be used with pagination + and count queries without additional compile() calls (fixes issue #283). + """ loader = SQLFileLoader() content = """ @@ -219,11 +223,9 @@ def test_get_sql_parses_expression_when_missing() -> None: sql_obj = loader.get_sql("list_users") - assert sql_obj.expression is None - - sql_obj.compile() - + # get_sql() now eagerly compiles the SQL, so expression is populated assert sql_obj.expression is not None + assert sql_obj.expression.key == "select" def test_parse_skips_files_without_named_statements() -> None: diff --git a/tests/unit/test_logging_utils.py b/tests/unit/test_logging_utils.py new file mode 100644 index 000000000..0f23beb55 --- /dev/null +++ b/tests/unit/test_logging_utils.py @@ -0,0 +1,51 @@ +"""Unit tests for sqlspec logging utilities.""" + +import io +import json +import logging + +from sqlspec.utils.correlation import CorrelationContext +from sqlspec.utils.logging import StructuredFormatter, get_logger, log_with_context + + +def test_structured_formatter_includes_logging_extra_fields() -> None: + stream = io.StringIO() + handler = logging.StreamHandler(stream) + handler.setFormatter(StructuredFormatter()) + + logger = get_logger("tests.logging.extra") + logger.setLevel(logging.INFO) + logger.propagate = False + logger.addHandler(handler) + + try: + with CorrelationContext.context("cid-123"): + logger.info("hello", extra={"foo": "bar"}) + finally: + logger.removeHandler(handler) + + payload = json.loads(stream.getvalue().strip()) + assert payload["message"] == "hello" + assert payload["foo"] == "bar" + assert payload["correlation_id"] == "cid-123" + + +def test_log_with_context_preserves_source_location_and_fields() -> None: + stream = io.StringIO() + handler = logging.StreamHandler(stream) + handler.setFormatter(StructuredFormatter()) + + logger = get_logger("tests.logging.context") + logger.setLevel(logging.INFO) + logger.propagate = False + logger.addHandler(handler) + + try: + log_with_context(logger, logging.INFO, "event.test", driver="Dummy") + finally: + logger.removeHandler(handler) + + payload = json.loads(stream.getvalue().strip()) + assert payload["message"] == "event.test" + assert payload["driver"] == "Dummy" + assert payload["line"] != 0 diff --git a/tests/unit/test_observability.py b/tests/unit/test_observability.py index 25cb44dc8..b3d15f8f4 100644 --- a/tests/unit/test_observability.py +++ b/tests/unit/test_observability.py @@ -342,6 +342,28 @@ def test_driver_dispatch_records_query_span() -> None: assert span_manager.finished[0].closed is True +def test_runtime_query_span_omits_sql_unless_print_sql_enabled() -> None: + """Query spans should only include SQL when print_sql is enabled.""" + + span_manager = _FakeSpanManager() + runtime = ObservabilityRuntime(ObservabilityConfig(print_sql=False), config_name="DummyAdapter") + runtime.span_manager = cast(Any, span_manager) + + runtime.start_query_span("SELECT 1", "SELECT", "DummyDriver") + + assert span_manager.started[0].attributes["sql"] == "" + assert span_manager.started[0].attributes["connection_info"]["sqlspec.statement.hash"] + assert span_manager.started[0].attributes["connection_info"]["sqlspec.statement.length"] == len("SELECT 1") + + span_manager_enabled = _FakeSpanManager() + runtime_enabled = ObservabilityRuntime(ObservabilityConfig(print_sql=True), config_name="DummyAdapter") + runtime_enabled.span_manager = cast(Any, span_manager_enabled) + + runtime_enabled.start_query_span("SELECT 1", "SELECT", "DummyDriver") + + assert span_manager_enabled.started[0].attributes["sql"] == "SELECT 1" + + def test_storage_span_records_telemetry_attributes() -> None: """Storage spans should capture telemetry attributes when ending.""" diff --git a/tests/unit/test_observability_statement_logging.py b/tests/unit/test_observability_statement_logging.py new file mode 100644 index 000000000..120966cbb --- /dev/null +++ b/tests/unit/test_observability_statement_logging.py @@ -0,0 +1,67 @@ +"""Unit tests for default statement logging.""" + +import logging + +from sqlspec.observability._observer import create_event, default_statement_observer + + +def test_default_statement_observer_info_excludes_parameters(caplog) -> None: + caplog.set_level(logging.INFO, logger="sqlspec.observability") + + event = create_event( + sql="SELECT 1", + parameters={"a": 1}, + driver="DummyDriver", + adapter="DummyAdapter", + bind_key=None, + operation="SELECT", + execution_mode=None, + is_many=False, + is_script=False, + rows_affected=1, + duration_s=0.001, + correlation_id="cid-1", + storage_backend=None, + started_at=0.0, + ) + + default_statement_observer(event) + + record = caplog.records[-1] + assert record.sql == "SELECT 1" + assert record.sql_truncated is False + assert record.parameters_type == "dict" + assert record.parameters_size == 1 + assert not hasattr(record, "parameters") + + +def test_default_statement_observer_debug_includes_parameters_and_truncates(caplog) -> None: + caplog.set_level(logging.DEBUG, logger="sqlspec.observability") + + long_sql = "SELECT " + ("x" * 5000) + parameters = list(range(101)) + event = create_event( + sql=long_sql, + parameters=parameters, + driver="DummyDriver", + adapter="DummyAdapter", + bind_key=None, + operation="SELECT", + execution_mode=None, + is_many=False, + is_script=False, + rows_affected=1, + duration_s=0.001, + correlation_id="cid-2", + storage_backend=None, + started_at=0.0, + ) + + default_statement_observer(event) + + record = caplog.records[-1] + assert record.sql_truncated is True + assert len(record.sql) == 2000 + assert record.parameters_truncated is True + assert isinstance(record.parameters, list) + assert len(record.parameters) == 100 diff --git a/tests/unit/test_storage/test_fsspec_backend.py b/tests/unit/test_storage/test_fsspec_backend.py index 2121b004e..d41e60c33 100644 --- a/tests/unit/test_storage/test_fsspec_backend.py +++ b/tests/unit/test_storage/test_fsspec_backend.py @@ -242,15 +242,18 @@ def test_stream_arrow(tmp_path: Path) -> None: @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") def test_sign_returns_uri(tmp_path: Path) -> None: - """Test sign returns URI for files.""" + """Test sign_sync raises NotImplementedError for fsspec backends.""" from sqlspec.storage.backends.fsspec import FSSpecBackend store = FSSpecBackend("file", base_path=str(tmp_path)) store.write_text("test.txt", "content") - signed_url = store.sign("test.txt") - assert "test.txt" in signed_url + # FSSpec backends do not support URL signing + assert store.supports_signing is False + + with pytest.raises(NotImplementedError, match="URL signing is not supported for fsspec backend"): + store.sign_sync("test.txt") def test_fsspec_not_installed() -> None: @@ -437,16 +440,16 @@ async def test_async_stream_arrow(tmp_path: Path) -> None: @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -async def test_async_sign(tmp_path: Path) -> None: - """Test async sign returns URI for local files.""" +async def test_async_sign_raises_not_implemented(tmp_path: Path) -> None: + """Test async sign_async raises NotImplementedError for fsspec backends.""" from sqlspec.storage.backends.fsspec import FSSpecBackend store = FSSpecBackend("file", base_path=str(tmp_path)) await store.write_text_async("async_test.txt", "content") - signed_url = await store.sign_async("async_test.txt") - assert "async_test.txt" in signed_url + with pytest.raises(NotImplementedError, match="URL signing is not supported for fsspec backend"): + await store.sign_async("async_test.txt") def test_fsspec_operations_without_fsspec() -> None: diff --git a/tests/unit/test_storage/test_local_store.py b/tests/unit/test_storage/test_local_store.py index 84f99b45d..540c0ba0c 100644 --- a/tests/unit/test_storage/test_local_store.py +++ b/tests/unit/test_storage/test_local_store.py @@ -220,26 +220,17 @@ def test_stream_arrow(tmp_path: Path) -> None: assert reconstructed.equals(table) -def test_sign_returns_file_uri(tmp_path: Path) -> None: - """Test sign returns file:// URI for local files.""" +def test_sign_sync_raises_not_implemented(tmp_path: Path) -> None: + """Test sign_sync raises NotImplementedError for local files.""" store = LocalStore(str(tmp_path)) store.write_text("test.txt", "content") - signed_url = store.sign("test.txt") - assert signed_url.startswith("file://") - assert "test.txt" in signed_url + # Local storage does not support URL signing + assert store.supports_signing is False - -def test_sign_with_options(tmp_path: Path) -> None: - """Test sign with expires_in and for_upload options.""" - store = LocalStore(str(tmp_path)) - - store.write_text("test.txt", "content") - - # Options are ignored for local files but should not error - signed_url = store.sign("test.txt", expires_in=7200, for_upload=True) - assert signed_url.startswith("file://") + with pytest.raises(NotImplementedError, match="URL signing is not applicable"): + store.sign_sync("test.txt") def test_resolve_path_absolute(tmp_path: Path) -> None: @@ -428,15 +419,14 @@ async def test_async_stream_arrow(tmp_path: Path) -> None: assert reconstructed.equals(table) -async def test_async_sign(tmp_path: Path) -> None: - """Test async sign returns file:// URI for local files.""" +async def test_async_sign_raises_not_implemented(tmp_path: Path) -> None: + """Test sign_async raises NotImplementedError for local files.""" store = LocalStore(str(tmp_path)) await store.write_text_async("async_test.txt", "content") - signed_url = await store.sign_async("async_test.txt") - assert signed_url.startswith("file://") - assert "async_test.txt" in signed_url + with pytest.raises(NotImplementedError, match="URL signing is not applicable"): + await store.sign_async("async_test.txt") def test_arrow_operations_without_pyarrow(tmp_path: Path) -> None: diff --git a/tests/unit/test_storage/test_obstore_backend.py b/tests/unit/test_storage/test_obstore_backend.py index ce61febc8..77018d124 100644 --- a/tests/unit/test_storage/test_obstore_backend.py +++ b/tests/unit/test_storage/test_obstore_backend.py @@ -232,16 +232,19 @@ def test_stream_arrow(tmp_path: Path) -> None: @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_sign_returns_uri(tmp_path: Path) -> None: - """Test sign returns URI for files.""" +def test_sign_raises_not_implemented_for_local_files(tmp_path: Path) -> None: + """Test sign_sync raises NotImplementedError for local file protocol.""" from sqlspec.storage.backends.obstore import ObStoreBackend store = ObStoreBackend(f"file://{tmp_path}") store.write_text("test.txt", "content") - signed_url = store.sign("test.txt") - assert "test.txt" in signed_url + # Local file protocol does not support URL signing + assert store.supports_signing is False + + with pytest.raises(NotImplementedError, match="URL signing is not supported for protocol 'file'"): + store.sign_sync("test.txt") def test_obstore_not_installed() -> None: @@ -425,16 +428,17 @@ async def test_async_stream_arrow(tmp_path: Path) -> None: @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -async def test_async_sign(tmp_path: Path) -> None: - """Test async sign returns URI for files.""" +async def test_async_sign_raises_not_implemented_for_local_files(tmp_path: Path) -> None: + """Test sign_async raises NotImplementedError for local file protocol.""" from sqlspec.storage.backends.obstore import ObStoreBackend store = ObStoreBackend(f"file://{tmp_path}") await store.write_text_async("async_test.txt", "content") - signed_url = await store.sign_async("async_test.txt") - assert "async_test.txt" in signed_url + # Local file protocol does not support URL signing + with pytest.raises(NotImplementedError, match="URL signing is not supported for protocol 'file'"): + await store.sign_async("async_test.txt") def test_obstore_operations_without_obstore() -> None: diff --git a/tests/unit/test_utils/test_logging.py b/tests/unit/test_utils/test_logging.py index a55bd7fd3..f03a5b194 100644 --- a/tests/unit/test_utils/test_logging.py +++ b/tests/unit/test_utils/test_logging.py @@ -33,6 +33,11 @@ def setup_function() -> None: correlation_id_var.set(None) +def teardown_function() -> None: + """Clear correlation ID after each test to prevent pollution.""" + correlation_id_var.set(None) + + def test_correlation_id_initial_state() -> None: """Test that initial correlation ID is None.""" assert get_correlation_id() is None