Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ maintainers = [{ name = "Litestar Developers", email = "[email protected]" }]
name = "sqlspec"
readme = "README.md"
requires-python = ">=3.9, <4.0"
version = "0.16.2"
version = "0.16.3"

[project.urls]
Discord = "https://discord.gg/litestar"
Expand Down
103 changes: 101 additions & 2 deletions sqlspec/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@

if TYPE_CHECKING:
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from pathlib import Path

from sqlspec.core.statement import SQL
from sqlspec.loader import SQLFileLoader
from sqlspec.typing import ConnectionT, PoolT


Expand All @@ -38,13 +41,14 @@
class SQLSpec:
"""Configuration manager and registry for database connections and pools."""

__slots__ = ("_configs", "_instance_cache_config")
__slots__ = ("_configs", "_instance_cache_config", "_sql_loader")

def __init__(self) -> None:
def __init__(self, *, loader: "Optional[SQLFileLoader]" = None) -> None:
self._configs: dict[Any, DatabaseConfigProtocol[Any, Any, Any]] = {}
# Register sync cleanup only for sync resources
atexit.register(self._cleanup_sync_pools)
self._instance_cache_config: Optional[CacheConfig] = None
self._sql_loader: Optional[SQLFileLoader] = loader

@staticmethod
def _get_config_name(obj: Any) -> str:
Expand Down Expand Up @@ -591,3 +595,98 @@ def configure_cache(
else current_config.optimized_cache_enabled,
)
)

# SQL File Loading Integration

def _ensure_sql_loader(self) -> "SQLFileLoader":
"""Ensure SQL loader is initialized lazily."""
if self._sql_loader is None:
# Import here to avoid circular imports
from sqlspec.loader import SQLFileLoader

self._sql_loader = SQLFileLoader()
return self._sql_loader

def load_sql_files(self, *paths: "Union[str, Path]") -> None:
"""Load SQL files from paths or directories.

Args:
*paths: One or more file paths or directory paths to load.
"""
loader = self._ensure_sql_loader()
loader.load_sql(*paths)
logger.debug("Loaded SQL files: %s", paths)

def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) -> None:
"""Add a named SQL query directly.

Args:
name: Name for the SQL query.
sql: Raw SQL content.
dialect: Optional dialect for the SQL statement.
"""
loader = self._ensure_sql_loader()
loader.add_named_sql(name, sql, dialect)
logger.debug("Added named SQL: %s", name)

def get_sql(self, name: str) -> "SQL":
"""Get a SQL object by name.

Args:
name: Name of the statement (from -- name: in SQL file).
Hyphens in names are converted to underscores.

Returns:
SQL object ready for execution.
"""
loader = self._ensure_sql_loader()
return loader.get_sql(name)

def list_sql_queries(self) -> "list[str]":
"""List all available query names.

Returns:
Sorted list of query names.
"""
if self._sql_loader is None:
return []
return self._sql_loader.list_queries()

def has_sql_query(self, name: str) -> bool:
"""Check if a SQL query exists.

Args:
name: Query name to check.

Returns:
True if query exists.
"""
if self._sql_loader is None:
return False
return self._sql_loader.has_query(name)

def clear_sql_cache(self) -> None:
"""Clear the SQL file cache."""
if self._sql_loader is not None:
self._sql_loader.clear_cache()
logger.debug("Cleared SQL cache")

def reload_sql_files(self) -> None:
"""Reload all SQL files.

Note: This clears the cache and requires calling load_sql_files again.
"""
if self._sql_loader is not None:
# Clear cache to force reload
self._sql_loader.clear_cache()
logger.debug("Cleared SQL cache for reload")

def get_sql_files(self) -> "list[str]":
"""Get list of loaded SQL files.

Returns:
Sorted list of file paths.
"""
if self._sql_loader is None:
return []
return self._sql_loader.list_files()
68 changes: 15 additions & 53 deletions sqlspec/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from typing import Any, Optional, Union

from sqlspec.core.cache import CacheKey, get_cache_config, get_default_cache
from sqlspec.core.parameters import ParameterStyleConfig, ParameterValidator
from sqlspec.core.statement import SQL, StatementConfig
from sqlspec.core.statement import SQL
from sqlspec.exceptions import (
MissingDependencyError,
SQLFileNotFoundError,
Expand All @@ -34,7 +33,7 @@
# Matches: -- name: query_name (supports hyphens and special suffixes)
# We capture the name plus any trailing special characters
QUERY_NAME_PATTERN = re.compile(r"^\s*--\s*name\s*:\s*([\w-]+[^\w\s]*)\s*$", re.MULTILINE | re.IGNORECASE)
TRIM_SPECIAL_CHARS = re.compile(r"[^\w-]")
TRIM_SPECIAL_CHARS = re.compile(r"[^\w.-]")

# Matches: -- dialect: dialect_name (optional dialect specification)
DIALECT_PATTERN = re.compile(r"^\s*--\s*dialect\s*:\s*(?P<dialect>[a-zA-Z0-9_]+)\s*$", re.IGNORECASE | re.MULTILINE)
Expand Down Expand Up @@ -581,8 +580,11 @@ def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) ->
Raises:
ValueError: If query name already exists.
"""
if name in self._queries:
existing_source = self._query_to_file.get(name, "<directly added>")
# Normalize the name for consistency with file-loaded queries
normalized_name = _normalize_query_name(name)

if normalized_name in self._queries:
existing_source = self._query_to_file.get(normalized_name, "<directly added>")
msg = f"Query name '{name}' already exists (source: {existing_source})"
raise ValueError(msg)

Expand All @@ -599,21 +601,16 @@ def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) ->
else:
dialect = normalized_dialect

statement = NamedStatement(name=name, sql=sql.strip(), dialect=dialect, start_line=0)
self._queries[name] = statement
self._query_to_file[name] = "<directly added>"
statement = NamedStatement(name=normalized_name, sql=sql.strip(), dialect=dialect, start_line=0)
self._queries[normalized_name] = statement
self._query_to_file[normalized_name] = "<directly added>"

def get_sql(
self, name: str, parameters: "Optional[Any]" = None, dialect: "Optional[str]" = None, **kwargs: "Any"
) -> "SQL":
"""Get a SQL object by statement name with dialect support.
def get_sql(self, name: str) -> "SQL":
"""Get a SQL object by statement name.

Args:
name: Name of the statement (from -- name: in SQL file).
Hyphens in names are converted to underscores.
parameters: Parameters for the SQL statement.
dialect: Optional dialect override.
**kwargs: Additional parameters to pass to the SQL object.

Returns:
SQL object ready for execution.
Expand All @@ -640,46 +637,11 @@ def get_sql(
raise SQLFileNotFoundError(name, path=f"Statement '{name}' not found. Available statements: {available}")

parsed_statement = self._queries[safe_name]

effective_dialect = dialect or parsed_statement.dialect

if dialect is not None:
normalized_dialect = _normalize_dialect(dialect)
if normalized_dialect not in SUPPORTED_DIALECTS:
suggestions = _get_dialect_suggestions(normalized_dialect)
warning_msg = f"Unknown dialect '{dialect}'"
if suggestions:
warning_msg += f". Did you mean: {', '.join(suggestions)}?"
warning_msg += f". Supported dialects: {', '.join(sorted(SUPPORTED_DIALECTS))}. Using dialect as-is."
logger.warning(warning_msg)
effective_dialect = dialect.lower()
else:
effective_dialect = normalized_dialect

sql_kwargs = dict(kwargs)
if parameters is not None:
sql_kwargs["parameters"] = parameters

sqlglot_dialect = None
if effective_dialect:
sqlglot_dialect = _normalize_dialect_for_sqlglot(effective_dialect)

if not effective_dialect and "statement_config" not in sql_kwargs:
validator = ParameterValidator()
param_info = validator.extract_parameters(parsed_statement.sql)
if param_info:
styles = {p.style for p in param_info}
if styles:
detected_style = next(iter(styles))
sql_kwargs["statement_config"] = StatementConfig(
parameter_config=ParameterStyleConfig(
default_parameter_style=detected_style,
supported_parameter_styles=styles,
preserve_parameter_format=True,
)
)
if parsed_statement.dialect:
sqlglot_dialect = _normalize_dialect_for_sqlglot(parsed_statement.dialect)

return SQL(parsed_statement.sql, dialect=sqlglot_dialect, **sql_kwargs)
return SQL(parsed_statement.sql, dialect=sqlglot_dialect)

def get_file(self, path: Union[str, Path]) -> "Optional[SQLFile]":
"""Get a loaded SQLFile object by path.
Expand Down
12 changes: 12 additions & 0 deletions tests/fixtures/asset_maintenance.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-- name: asset_maintenance_alert
-- Get a list of maintenances that are happening between 2 dates and insert the alert to be sent into the database, returns inserted data
with inserted_data as (
insert into alert_users (user_id, asset_maintenance_id, alert_definition_id)
select responsible_id, id, (select id from alert_definition where name = 'maintenances_today') from asset_maintenance
where planned_date_start is not null
and planned_date_start between :date_start and :date_end
and cancelled = False ON CONFLICT ON CONSTRAINT unique_alert DO NOTHING
returning *)
select inserted_data.*, to_jsonb(users.*) as user
from inserted_data
left join users on users.id = inserted_data.user_id;
1 change: 1 addition & 0 deletions tests/fixtures/oracle.ddl.sql
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-- name: oracle-sysdba-ddl
-- Oracle 23AI Database Schema for Coffee Recommendation System
-- This script creates all necessary tables with Oracle 23AI features

Expand Down
Loading