Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.12.10"
rev: "v0.12.11"
hooks:
- id: ruff
args: ["--fix"]
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ maintainers = [{ name = "Litestar Developers", email = "[email protected]" }]
name = "sqlspec"
readme = "README.md"
requires-python = ">=3.9, <4.0"
version = "0.21.1"
version = "0.22.0"

[project.urls]
Discord = "https://discord.gg/litestar"
Expand Down Expand Up @@ -83,6 +83,7 @@ doc = [
]
extras = [
"adbc_driver_manager",
"fsspec[s3]",
"pgvector",
"pyarrow",
"polars",
Expand Down Expand Up @@ -341,6 +342,7 @@ module = [
"sqlglot.*",
"pgvector",
"pgvector.*",
"minio",
]

[[tool.mypy.overrides]]
Expand Down
8 changes: 4 additions & 4 deletions sqlspec/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _cleanup_sync_pools(self) -> None:
config.close_pool()
cleaned_count += 1
except Exception as e:
logger.warning("Failed to clean up sync pool for config %s: %s", config_type.__name__, e)
logger.debug("Failed to clean up sync pool for config %s: %s", config_type.__name__, e)

if cleaned_count > 0:
logger.debug("Sync pool cleanup completed. Cleaned %d pools.", cleaned_count)
Expand All @@ -87,14 +87,14 @@ async def close_all_pools(self) -> None:
else:
sync_configs.append((config_type, config))
except Exception as e:
logger.warning("Failed to prepare cleanup for config %s: %s", config_type.__name__, e)
logger.debug("Failed to prepare cleanup for config %s: %s", config_type.__name__, e)

if cleanup_tasks:
try:
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
logger.debug("Async pool cleanup completed. Cleaned %d pools.", len(cleanup_tasks))
except Exception as e:
logger.warning("Failed to complete async pool cleanup: %s", e)
logger.debug("Failed to complete async pool cleanup: %s", e)

for _config_type, config in sync_configs:
config.close_pool()
Expand Down Expand Up @@ -129,7 +129,7 @@ def add_config(self, config: "Union[SyncConfigT, AsyncConfigT]") -> "type[Union[
"""
config_type = type(config)
if config_type in self._configs:
logger.warning("Configuration for %s already exists. Overwriting.", config_type.__name__)
logger.debug("Configuration for %s already exists. Overwriting.", config_type.__name__)
self._configs[config_type] = config
return config_type

Expand Down
133 changes: 65 additions & 68 deletions sqlspec/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,15 @@
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any, Final, Optional, Union
from urllib.parse import unquote, urlparse

from sqlspec.core.cache import CacheKey, get_cache_config, get_default_cache
from sqlspec.core.statement import SQL
from sqlspec.exceptions import (
MissingDependencyError,
SQLFileNotFoundError,
SQLFileParseError,
StorageOperationFailedError,
)
from sqlspec.exceptions import SQLFileNotFoundError, SQLFileParseError, StorageOperationFailedError
from sqlspec.storage.registry import storage_registry as default_storage_registry
from sqlspec.utils.correlation import CorrelationContext
from sqlspec.utils.logging import get_logger
from sqlspec.utils.text import slugify

if TYPE_CHECKING:
from sqlspec.storage.registry import StorageRegistry
Expand Down Expand Up @@ -54,13 +51,25 @@
def _normalize_query_name(name: str) -> str:
"""Normalize query name to be a valid Python identifier.

Convert hyphens to underscores, preserve dots for namespacing,
and remove invalid characters.

Args:
name: Raw query name from SQL file.

Returns:
Normalized query name suitable as Python identifier.
"""
return TRIM_SPECIAL_CHARS.sub("", name).replace("-", "_")
# Handle namespace parts separately to preserve dots
parts = name.split(".")
normalized_parts = []

for part in parts:
# Use slugify with underscore separator and remove any remaining invalid chars
normalized_part = slugify(part, separator="_")
normalized_parts.append(normalized_part)

return ".".join(normalized_parts)


def _normalize_dialect(dialect: str) -> str:
Expand All @@ -76,19 +85,6 @@ def _normalize_dialect(dialect: str) -> str:
return DIALECT_ALIASES.get(normalized, normalized)


def _normalize_dialect_for_sqlglot(dialect: str) -> str:
"""Normalize dialect name for SQLGlot compatibility.

Args:
dialect: Dialect name from SQL file or parameter.

Returns:
SQLGlot-compatible dialect name.
"""
normalized = dialect.lower().strip()
return DIALECT_ALIASES.get(normalized, normalized)


class NamedStatement:
"""Represents a parsed SQL statement with metadata.

Expand Down Expand Up @@ -218,8 +214,7 @@ def _calculate_file_checksum(self, path: Union[str, Path]) -> str:
SQLFileParseError: If file cannot be read.
"""
try:
content = self._read_file_content(path)
return hashlib.md5(content.encode(), usedforsecurity=False).hexdigest()
return hashlib.md5(self._read_file_content(path).encode(), usedforsecurity=False).hexdigest()
except Exception as e:
raise SQLFileParseError(str(path), str(path), e) from e

Expand Down Expand Up @@ -253,19 +248,22 @@ def _read_file_content(self, path: Union[str, Path]) -> str:
SQLFileNotFoundError: If file does not exist.
SQLFileParseError: If file cannot be read or parsed.
"""

path_str = str(path)

try:
backend = self.storage_registry.get(path)
# For file:// URIs, extract just the filename for the backend call
if path_str.startswith("file://"):
parsed = urlparse(path_str)
file_path = unquote(parsed.path)
# Handle Windows paths (file:///C:/path)
if file_path and len(file_path) > 2 and file_path[2] == ":": # noqa: PLR2004
file_path = file_path[1:] # Remove leading slash for Windows
filename = Path(file_path).name
return backend.read_text(filename, encoding=self.encoding)
return backend.read_text(path_str, encoding=self.encoding)
except KeyError as e:
raise SQLFileNotFoundError(path_str) from e
except MissingDependencyError:
try:
return path.read_text(encoding=self.encoding) # type: ignore[union-attr]
except FileNotFoundError as e:
raise SQLFileNotFoundError(path_str) from e
except StorageOperationFailedError as e:
if "not found" in str(e).lower() or "no such file" in str(e).lower():
raise SQLFileNotFoundError(path_str) from e
Expand Down Expand Up @@ -419,8 +417,7 @@ def _load_directory(self, dir_path: Path) -> int:
for file_path in sql_files:
relative_path = file_path.relative_to(dir_path)
namespace_parts = relative_path.parent.parts
namespace = ".".join(namespace_parts) if namespace_parts else None
self._load_single_file(file_path, namespace)
self._load_single_file(file_path, ".".join(namespace_parts) if namespace_parts else None)
return len(sql_files)

def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
Expand Down Expand Up @@ -533,44 +530,6 @@ def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) ->
self._queries[normalized_name] = statement
self._query_to_file[normalized_name] = "<directly added>"

def get_sql(self, name: str) -> "SQL":
"""Get a SQL object by statement name.

Args:
name: Name of the statement (from -- name: in SQL file).
Hyphens in names are converted to underscores.

Returns:
SQL object ready for execution.

Raises:
SQLFileNotFoundError: If statement name not found.
"""
correlation_id = CorrelationContext.get()

safe_name = _normalize_query_name(name)

if safe_name not in self._queries:
available = ", ".join(sorted(self._queries.keys())) if self._queries else "none"
logger.error(
"Statement not found: %s",
name,
extra={
"statement_name": name,
"safe_name": safe_name,
"available_statements": len(self._queries),
"correlation_id": correlation_id,
},
)
raise SQLFileNotFoundError(name, path=f"Statement '{name}' not found. Available statements: {available}")

parsed_statement = self._queries[safe_name]
sqlglot_dialect = None
if parsed_statement.dialect:
sqlglot_dialect = _normalize_dialect_for_sqlglot(parsed_statement.dialect)

return SQL(parsed_statement.sql, dialect=sqlglot_dialect)

def get_file(self, path: Union[str, Path]) -> "Optional[SQLFile]":
"""Get a loaded SQLFile object by path.

Expand Down Expand Up @@ -659,3 +618,41 @@ def get_query_text(self, name: str) -> str:
if safe_name not in self._queries:
raise SQLFileNotFoundError(name)
return self._queries[safe_name].sql

def get_sql(self, name: str) -> "SQL":
"""Get a SQL object by statement name.

Args:
name: Name of the statement (from -- name: in SQL file).
Hyphens in names are converted to underscores.

Returns:
SQL object ready for execution.

Raises:
SQLFileNotFoundError: If statement name not found.
"""
correlation_id = CorrelationContext.get()

safe_name = _normalize_query_name(name)

if safe_name not in self._queries:
available = ", ".join(sorted(self._queries.keys())) if self._queries else "none"
logger.error(
"Statement not found: %s",
name,
extra={
"statement_name": name,
"safe_name": safe_name,
"available_statements": len(self._queries),
"correlation_id": correlation_id,
},
)
raise SQLFileNotFoundError(name, path=f"Statement '{name}' not found. Available statements: {available}")

parsed_statement = self._queries[safe_name]
sqlglot_dialect = None
if parsed_statement.dialect:
sqlglot_dialect = _normalize_dialect(parsed_statement.dialect)

return SQL(parsed_statement.sql, dialect=sqlglot_dialect)
8 changes: 3 additions & 5 deletions sqlspec/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
and runtime isinstance() checks.
"""

from typing import TYPE_CHECKING, Any, ClassVar, Optional, Protocol, Union, runtime_checkable
from typing import TYPE_CHECKING, Any, Optional, Protocol, Union, runtime_checkable

from typing_extensions import Self

Expand All @@ -14,7 +14,6 @@

from sqlglot import exp

from sqlspec.storage.capabilities import StorageCapabilities
from sqlspec.typing import ArrowRecordBatch, ArrowTable

__all__ = (
Expand Down Expand Up @@ -194,9 +193,8 @@ class ObjectStoreItemProtocol(Protocol):
class ObjectStoreProtocol(Protocol):
"""Protocol for object storage operations."""

capabilities: ClassVar["StorageCapabilities"]

protocol: str
backend_type: str

def __init__(self, uri: str, **kwargs: Any) -> None:
return
Expand Down Expand Up @@ -330,7 +328,7 @@ async def write_arrow_async(self, path: "Union[str, Path]", table: "ArrowTable",
msg = "Async arrow writing not implemented"
raise NotImplementedError(msg)

async def stream_arrow_async(self, pattern: str, **kwargs: Any) -> "AsyncIterator[ArrowRecordBatch]":
def stream_arrow_async(self, pattern: str, **kwargs: Any) -> "AsyncIterator[ArrowRecordBatch]":
"""Async stream Arrow record batches from matching objects."""
msg = "Async arrow streaming not implemented"
raise NotImplementedError(msg)
Expand Down
14 changes: 2 additions & 12 deletions sqlspec/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,6 @@
- Capability-based backend selection
"""

from sqlspec.protocols import ObjectStoreProtocol
from sqlspec.storage.capabilities import HasStorageCapabilities, StorageCapabilities
from sqlspec.storage.registry import StorageRegistry
from sqlspec.storage.registry import StorageRegistry, storage_registry

storage_registry = StorageRegistry()

__all__ = (
"HasStorageCapabilities",
"ObjectStoreProtocol",
"StorageCapabilities",
"StorageRegistry",
"storage_registry",
)
__all__ = ("StorageRegistry", "storage_registry")
1 change: 1 addition & 0 deletions sqlspec/storage/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Storage backends."""
Loading