Skip to content

Commit db22434

Browse files
authored
feat: storage & serialization cleanup (#227)
Refactor SQL file loading and storage utilities to improve integration with PyArrow and serializers.
1 parent e3a579c commit db22434

File tree

18 files changed

+930
-165
lines changed

18 files changed

+930
-165
lines changed

sqlspec/builder/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@
4040
MathExpression,
4141
StringExpression,
4242
)
43-
from sqlspec.builder._factory import SQLFactory, sql
43+
from sqlspec.builder._factory import (
44+
SQLFactory,
45+
build_copy_from_statement,
46+
build_copy_statement,
47+
build_copy_to_statement,
48+
sql,
49+
)
4450
from sqlspec.builder._insert import Insert
4551
from sqlspec.builder._join import JoinBuilder
4652
from sqlspec.builder._merge import Merge
@@ -127,6 +133,9 @@
127133
"UpdateTableClauseMixin",
128134
"WhereClauseMixin",
129135
"WindowFunctionBuilder",
136+
"build_copy_from_statement",
137+
"build_copy_statement",
138+
"build_copy_to_statement",
130139
"extract_expression",
131140
"parse_column_expression",
132141
"parse_condition_expression",

sqlspec/builder/_factory.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
"""
55

66
import logging
7-
from typing import TYPE_CHECKING, Any, Union
7+
from collections.abc import Mapping, Sequence
8+
from typing import TYPE_CHECKING, Any, Union, cast
89

910
import sqlglot
1011
from sqlglot import exp
@@ -46,6 +47,8 @@
4647
from sqlspec.exceptions import SQLBuilderError
4748

4849
if TYPE_CHECKING:
50+
from collections.abc import Mapping, Sequence
51+
4952
from sqlspec.builder._expression_wrappers import ExpressionWrapper
5053

5154

@@ -73,6 +76,9 @@
7376
"Truncate",
7477
"Update",
7578
"WindowFunctionBuilder",
79+
"build_copy_from_statement",
80+
"build_copy_statement",
81+
"build_copy_to_statement",
7682
"sql",
7783
)
7884

@@ -108,6 +114,96 @@
108114
}
109115

110116

117+
def _normalize_copy_dialect(dialect: DialectType | None) -> str:
118+
if dialect is None:
119+
return "postgres"
120+
if isinstance(dialect, str):
121+
return dialect
122+
return str(dialect)
123+
124+
125+
def _to_copy_schema(table: str, columns: "Sequence[str] | None") -> exp.Expression:
126+
base = exp.table_(table)
127+
if not columns:
128+
return base
129+
column_nodes = [exp.column(column_name) for column_name in columns]
130+
return exp.Schema(this=base, expressions=column_nodes)
131+
132+
133+
def _build_copy_expression(
134+
*, direction: str, table: str, location: str, columns: "Sequence[str] | None", options: "Mapping[str, Any] | None"
135+
) -> exp.Copy:
136+
copy_args: dict[str, Any] = {"this": _to_copy_schema(table, columns), "files": [exp.Literal.string(location)]}
137+
138+
if direction == "from":
139+
copy_args["kind"] = True
140+
elif direction == "to":
141+
copy_args["kind"] = False
142+
143+
if options:
144+
params: list[exp.CopyParameter] = []
145+
for key, value in options.items():
146+
identifier = exp.Var(this=str(key).upper())
147+
value_expression: exp.Expression
148+
if isinstance(value, bool):
149+
value_expression = exp.Boolean(this=value)
150+
elif value is None:
151+
value_expression = exp.null()
152+
elif isinstance(value, (int, float)):
153+
value_expression = exp.Literal.number(value)
154+
elif isinstance(value, (list, tuple)):
155+
elements = [exp.Literal.string(str(item)) for item in value]
156+
value_expression = exp.Array(expressions=elements)
157+
else:
158+
value_expression = exp.Literal.string(str(value))
159+
params.append(exp.CopyParameter(this=identifier, expression=value_expression))
160+
copy_args["params"] = params
161+
162+
return exp.Copy(**copy_args)
163+
164+
165+
def build_copy_statement(
166+
*,
167+
direction: str,
168+
table: str,
169+
location: str,
170+
columns: "Sequence[str] | None" = None,
171+
options: "Mapping[str, Any] | None" = None,
172+
dialect: DialectType | None = None,
173+
) -> SQL:
174+
expression = _build_copy_expression(
175+
direction=direction, table=table, location=location, columns=columns, options=options
176+
)
177+
rendered = expression.sql(dialect=_normalize_copy_dialect(dialect))
178+
return SQL(rendered)
179+
180+
181+
def build_copy_from_statement(
182+
table: str,
183+
source: str,
184+
*,
185+
columns: "Sequence[str] | None" = None,
186+
options: "Mapping[str, Any] | None" = None,
187+
dialect: DialectType | None = None,
188+
) -> SQL:
189+
return build_copy_statement(
190+
direction="from", table=table, location=source, columns=columns, options=options, dialect=dialect
191+
)
192+
193+
194+
def build_copy_to_statement(
195+
table: str,
196+
target: str,
197+
*,
198+
columns: "Sequence[str] | None" = None,
199+
options: "Mapping[str, Any] | None" = None,
200+
dialect: DialectType | None = None,
201+
) -> SQL:
202+
return build_copy_statement(
203+
direction="to", table=table, location=target, columns=columns, options=options, dialect=dialect
204+
)
205+
206+
111207
class SQLFactory:
112208
"""Factory for creating SQL builders and column expressions."""
113209

@@ -479,6 +575,56 @@ def comment_on(self, dialect: DialectType = None) -> "CommentOn":
479575
"""
480576
return CommentOn(dialect=dialect or self.dialect)
481577

578+
def copy_from(
579+
self,
580+
table: str,
581+
source: str,
582+
*,
583+
columns: "Sequence[str] | None" = None,
584+
options: "Mapping[str, Any] | None" = None,
585+
dialect: DialectType | None = None,
586+
) -> SQL:
587+
"""Build a COPY ... FROM statement."""
588+
589+
effective_dialect = dialect or self.dialect
590+
return build_copy_from_statement(table, source, columns=columns, options=options, dialect=effective_dialect)
591+
592+
def copy_to(
593+
self,
594+
table: str,
595+
target: str,
596+
*,
597+
columns: "Sequence[str] | None" = None,
598+
options: "Mapping[str, Any] | None" = None,
599+
dialect: DialectType | None = None,
600+
) -> SQL:
601+
"""Build a COPY ... TO statement."""
602+
603+
effective_dialect = dialect or self.dialect
604+
return build_copy_to_statement(table, target, columns=columns, options=options, dialect=effective_dialect)
605+
606+
def copy(
607+
self,
608+
table: str,
609+
*,
610+
source: str | None = None,
611+
target: str | None = None,
612+
columns: "Sequence[str] | None" = None,
613+
options: "Mapping[str, Any] | None" = None,
614+
dialect: DialectType | None = None,
615+
) -> SQL:
616+
"""Build a COPY statement, inferring direction from provided arguments."""
617+
618+
if (source is None and target is None) or (source is not None and target is not None):
619+
msg = "Provide either 'source' or 'target' (but not both) to sql.copy()."
620+
raise SQLBuilderError(msg)
621+
622+
if source is not None:
623+
return self.copy_from(table, source, columns=columns, options=options, dialect=dialect)
624+
625+
target_value = cast("str", target)
626+
return self.copy_to(table, target_value, columns=columns, options=options, dialect=dialect)
627+
482628
@staticmethod
483629
def _looks_like_sql(candidate: str, expected_type: str | None = None) -> bool:
484630
"""Determine if a string looks like SQL.

sqlspec/loader.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from urllib.parse import unquote, urlparse
1414

1515
from sqlspec.core import SQL, get_cache, get_cache_config
16-
from sqlspec.exceptions import SQLFileNotFoundError, SQLFileParseError, StorageOperationFailedError
16+
from sqlspec.exceptions import (
17+
FileNotFoundInStorageError,
18+
SQLFileNotFoundError,
19+
SQLFileParseError,
20+
StorageOperationFailedError,
21+
)
1722
from sqlspec.storage.registry import storage_registry as default_storage_registry
1823
from sqlspec.utils.correlation import CorrelationContext
1924
from sqlspec.utils.logging import get_logger
@@ -259,9 +264,11 @@ def _read_file_content(self, path: str | Path) -> str:
259264
return backend.read_text(path_str, encoding=self.encoding)
260265
except KeyError as e:
261266
raise SQLFileNotFoundError(path_str) from e
267+
except FileNotFoundInStorageError as e:
268+
raise SQLFileNotFoundError(path_str) from e
269+
except FileNotFoundError as e:
270+
raise SQLFileNotFoundError(path_str) from e
262271
except StorageOperationFailedError as e:
263-
if "not found" in str(e).lower() or "no such file" in str(e).lower():
264-
raise SQLFileNotFoundError(path_str) from e
265272
raise SQLFileParseError(path_str, path_str, e) from e
266273
except Exception as e:
267274
raise SQLFileParseError(path_str, path_str, e) from e

sqlspec/storage/_utils.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,44 @@
11
"""Shared utilities for storage backends."""
22

33
from pathlib import Path
4+
from typing import Any, Final
45

5-
__all__ = ("resolve_storage_path",)
6+
from sqlspec.utils.module_loader import ensure_pyarrow
7+
8+
FILE_PROTOCOL: Final[str] = "file"
9+
FILE_SCHEME_PREFIX: Final[str] = "file://"
10+
11+
__all__ = ("FILE_PROTOCOL", "FILE_SCHEME_PREFIX", "import_pyarrow", "import_pyarrow_parquet", "resolve_storage_path")
12+
13+
14+
def import_pyarrow() -> "Any":
15+
"""Import PyArrow with optional dependency guard.
16+
17+
Returns:
18+
PyArrow module.
19+
"""
20+
21+
ensure_pyarrow()
22+
import pyarrow as pa
23+
24+
return pa
25+
26+
27+
def import_pyarrow_parquet() -> "Any":
28+
"""Import PyArrow parquet module with optional dependency guard.
29+
30+
Returns:
31+
PyArrow parquet module.
32+
"""
33+
34+
ensure_pyarrow()
35+
import pyarrow.parquet as pq
36+
37+
return pq
638

739

840
def resolve_storage_path(
9-
path: "str | Path", base_path: str = "", protocol: str = "file", strip_file_scheme: bool = True
41+
path: "str | Path", base_path: str = "", protocol: str = FILE_PROTOCOL, strip_file_scheme: bool = True
1042
) -> str:
1143
"""Resolve path relative to base_path with protocol-specific handling.
1244
@@ -43,10 +75,10 @@ def resolve_storage_path(
4375

4476
path_str = str(path)
4577

46-
if strip_file_scheme and path_str.startswith("file://"):
47-
path_str = path_str.removeprefix("file://")
78+
if strip_file_scheme and path_str.startswith(FILE_SCHEME_PREFIX):
79+
path_str = path_str.removeprefix(FILE_SCHEME_PREFIX)
4880

49-
if protocol == "file":
81+
if protocol == FILE_PROTOCOL:
5082
path_obj = Path(path_str)
5183

5284
if path_obj.is_absolute():

0 commit comments

Comments
 (0)