Skip to content

Commit 99b6b61

Browse files
committed
fix: update storage backend
1 parent bec4d37 commit 99b6b61

File tree

20 files changed

+3425
-675
lines changed

20 files changed

+3425
-675
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ repos:
1717
- id: mixed-line-ending
1818
- id: trailing-whitespace
1919
- repo: https://github.com/charliermarsh/ruff-pre-commit
20-
rev: "v0.12.10"
20+
rev: "v0.12.11"
2121
hooks:
2222
- id: ruff
2323
args: ["--fix"]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ doc = [
8383
]
8484
extras = [
8585
"adbc_driver_manager",
86+
"fsspec[s3]",
8687
"pgvector",
8788
"pyarrow",
8889
"polars",

sqlspec/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _cleanup_sync_pools(self) -> None:
6464
config.close_pool()
6565
cleaned_count += 1
6666
except Exception as e:
67-
logger.warning("Failed to clean up sync pool for config %s: %s", config_type.__name__, e)
67+
logger.debug("Failed to clean up sync pool for config %s: %s", config_type.__name__, e)
6868

6969
if cleaned_count > 0:
7070
logger.debug("Sync pool cleanup completed. Cleaned %d pools.", cleaned_count)
@@ -87,14 +87,14 @@ async def close_all_pools(self) -> None:
8787
else:
8888
sync_configs.append((config_type, config))
8989
except Exception as e:
90-
logger.warning("Failed to prepare cleanup for config %s: %s", config_type.__name__, e)
90+
logger.debug("Failed to prepare cleanup for config %s: %s", config_type.__name__, e)
9191

9292
if cleanup_tasks:
9393
try:
9494
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
9595
logger.debug("Async pool cleanup completed. Cleaned %d pools.", len(cleanup_tasks))
9696
except Exception as e:
97-
logger.warning("Failed to complete async pool cleanup: %s", e)
97+
logger.debug("Failed to complete async pool cleanup: %s", e)
9898

9999
for _config_type, config in sync_configs:
100100
config.close_pool()
@@ -129,7 +129,7 @@ def add_config(self, config: "Union[SyncConfigT, AsyncConfigT]") -> "type[Union[
129129
"""
130130
config_type = type(config)
131131
if config_type in self._configs:
132-
logger.warning("Configuration for %s already exists. Overwriting.", config_type.__name__)
132+
logger.debug("Configuration for %s already exists. Overwriting.", config_type.__name__)
133133
self._configs[config_type] = config
134134
return config_type
135135

sqlspec/loader.py

Lines changed: 43 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,11 @@
1313

1414
from sqlspec.core.cache import CacheKey, get_cache_config, get_default_cache
1515
from sqlspec.core.statement import SQL
16-
from sqlspec.exceptions import (
17-
MissingDependencyError,
18-
SQLFileNotFoundError,
19-
SQLFileParseError,
20-
StorageOperationFailedError,
21-
)
16+
from sqlspec.exceptions import SQLFileNotFoundError, SQLFileParseError, StorageOperationFailedError
2217
from sqlspec.storage.registry import storage_registry as default_storage_registry
2318
from sqlspec.utils.correlation import CorrelationContext
2419
from sqlspec.utils.logging import get_logger
20+
from sqlspec.utils.text import slugify
2521

2622
if TYPE_CHECKING:
2723
from sqlspec.storage.registry import StorageRegistry
@@ -60,7 +56,7 @@ def _normalize_query_name(name: str) -> str:
6056
Returns:
6157
Normalized query name suitable as Python identifier.
6258
"""
63-
return TRIM_SPECIAL_CHARS.sub("", name).replace("-", "_")
59+
return slugify(name, separator="_")
6460

6561

6662
def _normalize_dialect(dialect: str) -> str:
@@ -76,19 +72,6 @@ def _normalize_dialect(dialect: str) -> str:
7672
return DIALECT_ALIASES.get(normalized, normalized)
7773

7874

79-
def _normalize_dialect_for_sqlglot(dialect: str) -> str:
80-
"""Normalize dialect name for SQLGlot compatibility.
81-
82-
Args:
83-
dialect: Dialect name from SQL file or parameter.
84-
85-
Returns:
86-
SQLGlot-compatible dialect name.
87-
"""
88-
normalized = dialect.lower().strip()
89-
return DIALECT_ALIASES.get(normalized, normalized)
90-
91-
9275
class NamedStatement:
9376
"""Represents a parsed SQL statement with metadata.
9477
@@ -218,8 +201,7 @@ def _calculate_file_checksum(self, path: Union[str, Path]) -> str:
218201
SQLFileParseError: If file cannot be read.
219202
"""
220203
try:
221-
content = self._read_file_content(path)
222-
return hashlib.md5(content.encode(), usedforsecurity=False).hexdigest()
204+
return hashlib.md5(self._read_file_content(path).encode(), usedforsecurity=False).hexdigest()
223205
except Exception as e:
224206
raise SQLFileParseError(str(path), str(path), e) from e
225207

@@ -253,19 +235,13 @@ def _read_file_content(self, path: Union[str, Path]) -> str:
253235
SQLFileNotFoundError: If file does not exist.
254236
SQLFileParseError: If file cannot be read or parsed.
255237
"""
256-
257238
path_str = str(path)
258239

259240
try:
260241
backend = self.storage_registry.get(path)
261242
return backend.read_text(path_str, encoding=self.encoding)
262243
except KeyError as e:
263244
raise SQLFileNotFoundError(path_str) from e
264-
except MissingDependencyError:
265-
try:
266-
return path.read_text(encoding=self.encoding) # type: ignore[union-attr]
267-
except FileNotFoundError as e:
268-
raise SQLFileNotFoundError(path_str) from e
269245
except StorageOperationFailedError as e:
270246
if "not found" in str(e).lower() or "no such file" in str(e).lower():
271247
raise SQLFileNotFoundError(path_str) from e
@@ -419,8 +395,7 @@ def _load_directory(self, dir_path: Path) -> int:
419395
for file_path in sql_files:
420396
relative_path = file_path.relative_to(dir_path)
421397
namespace_parts = relative_path.parent.parts
422-
namespace = ".".join(namespace_parts) if namespace_parts else None
423-
self._load_single_file(file_path, namespace)
398+
self._load_single_file(file_path, ".".join(namespace_parts) if namespace_parts else None)
424399
return len(sql_files)
425400

426401
def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
@@ -533,44 +508,6 @@ def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) ->
533508
self._queries[normalized_name] = statement
534509
self._query_to_file[normalized_name] = "<directly added>"
535510

536-
def get_sql(self, name: str) -> "SQL":
537-
"""Get a SQL object by statement name.
538-
539-
Args:
540-
name: Name of the statement (from -- name: in SQL file).
541-
Hyphens in names are converted to underscores.
542-
543-
Returns:
544-
SQL object ready for execution.
545-
546-
Raises:
547-
SQLFileNotFoundError: If statement name not found.
548-
"""
549-
correlation_id = CorrelationContext.get()
550-
551-
safe_name = _normalize_query_name(name)
552-
553-
if safe_name not in self._queries:
554-
available = ", ".join(sorted(self._queries.keys())) if self._queries else "none"
555-
logger.error(
556-
"Statement not found: %s",
557-
name,
558-
extra={
559-
"statement_name": name,
560-
"safe_name": safe_name,
561-
"available_statements": len(self._queries),
562-
"correlation_id": correlation_id,
563-
},
564-
)
565-
raise SQLFileNotFoundError(name, path=f"Statement '{name}' not found. Available statements: {available}")
566-
567-
parsed_statement = self._queries[safe_name]
568-
sqlglot_dialect = None
569-
if parsed_statement.dialect:
570-
sqlglot_dialect = _normalize_dialect_for_sqlglot(parsed_statement.dialect)
571-
572-
return SQL(parsed_statement.sql, dialect=sqlglot_dialect)
573-
574511
def get_file(self, path: Union[str, Path]) -> "Optional[SQLFile]":
575512
"""Get a loaded SQLFile object by path.
576513
@@ -659,3 +596,41 @@ def get_query_text(self, name: str) -> str:
659596
if safe_name not in self._queries:
660597
raise SQLFileNotFoundError(name)
661598
return self._queries[safe_name].sql
599+
600+
def get_sql(self, name: str) -> "SQL":
601+
"""Get a SQL object by statement name.
602+
603+
Args:
604+
name: Name of the statement (from -- name: in SQL file).
605+
Hyphens in names are converted to underscores.
606+
607+
Returns:
608+
SQL object ready for execution.
609+
610+
Raises:
611+
SQLFileNotFoundError: If statement name not found.
612+
"""
613+
correlation_id = CorrelationContext.get()
614+
615+
safe_name = _normalize_query_name(name)
616+
617+
if safe_name not in self._queries:
618+
available = ", ".join(sorted(self._queries.keys())) if self._queries else "none"
619+
logger.error(
620+
"Statement not found: %s",
621+
name,
622+
extra={
623+
"statement_name": name,
624+
"safe_name": safe_name,
625+
"available_statements": len(self._queries),
626+
"correlation_id": correlation_id,
627+
},
628+
)
629+
raise SQLFileNotFoundError(name, path=f"Statement '{name}' not found. Available statements: {available}")
630+
631+
parsed_statement = self._queries[safe_name]
632+
sqlglot_dialect = None
633+
if parsed_statement.dialect:
634+
sqlglot_dialect = _normalize_dialect(parsed_statement.dialect)
635+
636+
return SQL(parsed_statement.sql, dialect=sqlglot_dialect)

sqlspec/protocols.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
and runtime isinstance() checks.
55
"""
66

7-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Protocol, Union, runtime_checkable
7+
from typing import TYPE_CHECKING, Any, Optional, Protocol, Union, runtime_checkable
88

99
from typing_extensions import Self
1010

@@ -14,7 +14,6 @@
1414

1515
from sqlglot import exp
1616

17-
from sqlspec.storage.capabilities import StorageCapabilities
1817
from sqlspec.typing import ArrowRecordBatch, ArrowTable
1918

2019
__all__ = (
@@ -194,9 +193,8 @@ class ObjectStoreItemProtocol(Protocol):
194193
class ObjectStoreProtocol(Protocol):
195194
"""Protocol for object storage operations."""
196195

197-
capabilities: ClassVar["StorageCapabilities"]
198-
199196
protocol: str
197+
backend_type: str
200198

201199
def __init__(self, uri: str, **kwargs: Any) -> None:
202200
return
@@ -330,7 +328,7 @@ async def write_arrow_async(self, path: "Union[str, Path]", table: "ArrowTable",
330328
msg = "Async arrow writing not implemented"
331329
raise NotImplementedError(msg)
332330

333-
async def stream_arrow_async(self, pattern: str, **kwargs: Any) -> "AsyncIterator[ArrowRecordBatch]":
331+
def stream_arrow_async(self, pattern: str, **kwargs: Any) -> "AsyncIterator[ArrowRecordBatch]":
334332
"""Async stream Arrow record batches from matching objects."""
335333
msg = "Async arrow streaming not implemented"
336334
raise NotImplementedError(msg)

sqlspec/storage/__init__.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,6 @@
88
- Capability-based backend selection
99
"""
1010

11-
from sqlspec.protocols import ObjectStoreProtocol
12-
from sqlspec.storage.capabilities import HasStorageCapabilities, StorageCapabilities
13-
from sqlspec.storage.registry import StorageRegistry
11+
from sqlspec.storage.registry import StorageRegistry, storage_registry
1412

15-
storage_registry = StorageRegistry()
16-
17-
__all__ = (
18-
"HasStorageCapabilities",
19-
"ObjectStoreProtocol",
20-
"StorageCapabilities",
21-
"StorageRegistry",
22-
"storage_registry",
23-
)
13+
__all__ = ("StorageRegistry", "storage_registry")

0 commit comments

Comments
 (0)