Skip to content

Commit 28813e2

Browse files
authored
feat: loader enhancements (#54)
Introduce improvements to the SQL loader, including a simplified interface for loading SQL files and executing queries.
1 parent 49cf181 commit 28813e2

File tree

9 files changed

+1200
-99
lines changed

9 files changed

+1200
-99
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ maintainers = [{ name = "Litestar Developers", email = "[email protected]" }]
1313
name = "sqlspec"
1414
readme = "README.md"
1515
requires-python = ">=3.9, <4.0"
16-
version = "0.16.2"
16+
version = "0.16.3"
1717

1818
[project.urls]
1919
Discord = "https://discord.gg/litestar"

sqlspec/base.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626

2727
if TYPE_CHECKING:
2828
from contextlib import AbstractAsyncContextManager, AbstractContextManager
29+
from pathlib import Path
2930

31+
from sqlspec.core.statement import SQL
32+
from sqlspec.loader import SQLFileLoader
3033
from sqlspec.typing import ConnectionT, PoolT
3134

3235

@@ -38,13 +41,14 @@
3841
class SQLSpec:
3942
"""Configuration manager and registry for database connections and pools."""
4043

41-
__slots__ = ("_configs", "_instance_cache_config")
44+
__slots__ = ("_configs", "_instance_cache_config", "_sql_loader")
4245

43-
def __init__(self) -> None:
46+
def __init__(self, *, loader: "Optional[SQLFileLoader]" = None) -> None:
4447
self._configs: dict[Any, DatabaseConfigProtocol[Any, Any, Any]] = {}
4548
# Register sync cleanup only for sync resources
4649
atexit.register(self._cleanup_sync_pools)
4750
self._instance_cache_config: Optional[CacheConfig] = None
51+
self._sql_loader: Optional[SQLFileLoader] = loader
4852

4953
@staticmethod
5054
def _get_config_name(obj: Any) -> str:
@@ -591,3 +595,98 @@ def configure_cache(
591595
else current_config.optimized_cache_enabled,
592596
)
593597
)
598+
599+
# SQL File Loading Integration
600+
601+
def _ensure_sql_loader(self) -> "SQLFileLoader":
602+
"""Ensure SQL loader is initialized lazily."""
603+
if self._sql_loader is None:
604+
# Import here to avoid circular imports
605+
from sqlspec.loader import SQLFileLoader
606+
607+
self._sql_loader = SQLFileLoader()
608+
return self._sql_loader
609+
610+
def load_sql_files(self, *paths: "Union[str, Path]") -> None:
611+
"""Load SQL files from paths or directories.
612+
613+
Args:
614+
*paths: One or more file paths or directory paths to load.
615+
"""
616+
loader = self._ensure_sql_loader()
617+
loader.load_sql(*paths)
618+
logger.debug("Loaded SQL files: %s", paths)
619+
620+
def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) -> None:
621+
"""Add a named SQL query directly.
622+
623+
Args:
624+
name: Name for the SQL query.
625+
sql: Raw SQL content.
626+
dialect: Optional dialect for the SQL statement.
627+
"""
628+
loader = self._ensure_sql_loader()
629+
loader.add_named_sql(name, sql, dialect)
630+
logger.debug("Added named SQL: %s", name)
631+
632+
def get_sql(self, name: str) -> "SQL":
633+
"""Get a SQL object by name.
634+
635+
Args:
636+
name: Name of the statement (from -- name: in SQL file).
637+
Hyphens in names are converted to underscores.
638+
639+
Returns:
640+
SQL object ready for execution.
641+
"""
642+
loader = self._ensure_sql_loader()
643+
return loader.get_sql(name)
644+
645+
def list_sql_queries(self) -> "list[str]":
646+
"""List all available query names.
647+
648+
Returns:
649+
Sorted list of query names.
650+
"""
651+
if self._sql_loader is None:
652+
return []
653+
return self._sql_loader.list_queries()
654+
655+
def has_sql_query(self, name: str) -> bool:
656+
"""Check if a SQL query exists.
657+
658+
Args:
659+
name: Query name to check.
660+
661+
Returns:
662+
True if query exists.
663+
"""
664+
if self._sql_loader is None:
665+
return False
666+
return self._sql_loader.has_query(name)
667+
668+
def clear_sql_cache(self) -> None:
669+
"""Clear the SQL file cache."""
670+
if self._sql_loader is not None:
671+
self._sql_loader.clear_cache()
672+
logger.debug("Cleared SQL cache")
673+
674+
def reload_sql_files(self) -> None:
675+
"""Reload all SQL files.
676+
677+
Note: This clears the cache and requires calling load_sql_files again.
678+
"""
679+
if self._sql_loader is not None:
680+
# Clear cache to force reload
681+
self._sql_loader.clear_cache()
682+
logger.debug("Cleared SQL cache for reload")
683+
684+
def get_sql_files(self) -> "list[str]":
685+
"""Get list of loaded SQL files.
686+
687+
Returns:
688+
Sorted list of file paths.
689+
"""
690+
if self._sql_loader is None:
691+
return []
692+
return self._sql_loader.list_files()

sqlspec/loader.py

Lines changed: 15 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
from typing import Any, Optional, Union
1515

1616
from sqlspec.core.cache import CacheKey, get_cache_config, get_default_cache
17-
from sqlspec.core.parameters import ParameterStyleConfig, ParameterValidator
18-
from sqlspec.core.statement import SQL, StatementConfig
17+
from sqlspec.core.statement import SQL
1918
from sqlspec.exceptions import (
2019
MissingDependencyError,
2120
SQLFileNotFoundError,
@@ -34,7 +33,7 @@
3433
# Matches: -- name: query_name (supports hyphens and special suffixes)
3534
# We capture the name plus any trailing special characters
3635
QUERY_NAME_PATTERN = re.compile(r"^\s*--\s*name\s*:\s*([\w-]+[^\w\s]*)\s*$", re.MULTILINE | re.IGNORECASE)
37-
TRIM_SPECIAL_CHARS = re.compile(r"[^\w-]")
36+
TRIM_SPECIAL_CHARS = re.compile(r"[^\w.-]")
3837

3938
# Matches: -- dialect: dialect_name (optional dialect specification)
4039
DIALECT_PATTERN = re.compile(r"^\s*--\s*dialect\s*:\s*(?P<dialect>[a-zA-Z0-9_]+)\s*$", re.IGNORECASE | re.MULTILINE)
@@ -581,8 +580,11 @@ def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) ->
581580
Raises:
582581
ValueError: If query name already exists.
583582
"""
584-
if name in self._queries:
585-
existing_source = self._query_to_file.get(name, "<directly added>")
583+
# Normalize the name for consistency with file-loaded queries
584+
normalized_name = _normalize_query_name(name)
585+
586+
if normalized_name in self._queries:
587+
existing_source = self._query_to_file.get(normalized_name, "<directly added>")
586588
msg = f"Query name '{name}' already exists (source: {existing_source})"
587589
raise ValueError(msg)
588590

@@ -599,21 +601,16 @@ def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) ->
599601
else:
600602
dialect = normalized_dialect
601603

602-
statement = NamedStatement(name=name, sql=sql.strip(), dialect=dialect, start_line=0)
603-
self._queries[name] = statement
604-
self._query_to_file[name] = "<directly added>"
604+
statement = NamedStatement(name=normalized_name, sql=sql.strip(), dialect=dialect, start_line=0)
605+
self._queries[normalized_name] = statement
606+
self._query_to_file[normalized_name] = "<directly added>"
605607

606-
def get_sql(
607-
self, name: str, parameters: "Optional[Any]" = None, dialect: "Optional[str]" = None, **kwargs: "Any"
608-
) -> "SQL":
609-
"""Get a SQL object by statement name with dialect support.
608+
def get_sql(self, name: str) -> "SQL":
609+
"""Get a SQL object by statement name.
610610
611611
Args:
612612
name: Name of the statement (from -- name: in SQL file).
613613
Hyphens in names are converted to underscores.
614-
parameters: Parameters for the SQL statement.
615-
dialect: Optional dialect override.
616-
**kwargs: Additional parameters to pass to the SQL object.
617614
618615
Returns:
619616
SQL object ready for execution.
@@ -640,46 +637,11 @@ def get_sql(
640637
raise SQLFileNotFoundError(name, path=f"Statement '{name}' not found. Available statements: {available}")
641638

642639
parsed_statement = self._queries[safe_name]
643-
644-
effective_dialect = dialect or parsed_statement.dialect
645-
646-
if dialect is not None:
647-
normalized_dialect = _normalize_dialect(dialect)
648-
if normalized_dialect not in SUPPORTED_DIALECTS:
649-
suggestions = _get_dialect_suggestions(normalized_dialect)
650-
warning_msg = f"Unknown dialect '{dialect}'"
651-
if suggestions:
652-
warning_msg += f". Did you mean: {', '.join(suggestions)}?"
653-
warning_msg += f". Supported dialects: {', '.join(sorted(SUPPORTED_DIALECTS))}. Using dialect as-is."
654-
logger.warning(warning_msg)
655-
effective_dialect = dialect.lower()
656-
else:
657-
effective_dialect = normalized_dialect
658-
659-
sql_kwargs = dict(kwargs)
660-
if parameters is not None:
661-
sql_kwargs["parameters"] = parameters
662-
663640
sqlglot_dialect = None
664-
if effective_dialect:
665-
sqlglot_dialect = _normalize_dialect_for_sqlglot(effective_dialect)
666-
667-
if not effective_dialect and "statement_config" not in sql_kwargs:
668-
validator = ParameterValidator()
669-
param_info = validator.extract_parameters(parsed_statement.sql)
670-
if param_info:
671-
styles = {p.style for p in param_info}
672-
if styles:
673-
detected_style = next(iter(styles))
674-
sql_kwargs["statement_config"] = StatementConfig(
675-
parameter_config=ParameterStyleConfig(
676-
default_parameter_style=detected_style,
677-
supported_parameter_styles=styles,
678-
preserve_parameter_format=True,
679-
)
680-
)
641+
if parsed_statement.dialect:
642+
sqlglot_dialect = _normalize_dialect_for_sqlglot(parsed_statement.dialect)
681643

682-
return SQL(parsed_statement.sql, dialect=sqlglot_dialect, **sql_kwargs)
644+
return SQL(parsed_statement.sql, dialect=sqlglot_dialect)
683645

684646
def get_file(self, path: Union[str, Path]) -> "Optional[SQLFile]":
685647
"""Get a loaded SQLFile object by path.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
-- name: asset_maintenance_alert
2+
-- Get a list of maintenances that are happening between 2 dates and insert the alert to be sent into the database, returns inserted data
3+
with inserted_data as (
4+
insert into alert_users (user_id, asset_maintenance_id, alert_definition_id)
5+
select responsible_id, id, (select id from alert_definition where name = 'maintenances_today') from asset_maintenance
6+
where planned_date_start is not null
7+
and planned_date_start between :date_start and :date_end
8+
and cancelled = False ON CONFLICT ON CONSTRAINT unique_alert DO NOTHING
9+
returning *)
10+
select inserted_data.*, to_jsonb(users.*) as user
11+
from inserted_data
12+
left join users on users.id = inserted_data.user_id;

tests/fixtures/oracle.ddl.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
-- name: oracle-sysdba-ddl
12
-- Oracle 23AI Database Schema for Coffee Recommendation System
23
-- This script creates all necessary tables with Oracle 23AI features
34

0 commit comments

Comments
 (0)