Skip to content

Commit ff97ab1

Browse files
authored
feat: add Arrow type system foundation for select_to_arrow() (#154)
Add foundational type system and utilities for Apache Arrow integration.
1 parent 4ffbae9 commit ff97ab1

25 files changed

+1082
-89
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ lint = [
104104
"types-protobuf",
105105
"asyncpg-stubs",
106106
"pyarrow-stubs",
107+
"pandas-stubs",
107108

108109
]
109110
test = [

sqlspec/_typing.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,18 +377,101 @@ def slice(self, offset: int = 0, length: "int | None" = None) -> Any:
377377
return None
378378

379379

380+
@runtime_checkable
381+
class ArrowSchemaProtocol(Protocol):
382+
"""Typed shim for pyarrow.Schema."""
383+
384+
def field(self, i: int) -> Any:
385+
"""Get field by index."""
386+
...
387+
388+
@property
389+
def names(self) -> "list[str]":
390+
"""Get list of field names."""
391+
...
392+
393+
def __len__(self) -> int:
394+
"""Get number of fields."""
395+
return 0
396+
397+
398+
@runtime_checkable
399+
class ArrowRecordBatchReaderProtocol(Protocol):
400+
"""Typed shim for pyarrow.RecordBatchReader."""
401+
402+
def read_all(self) -> Any:
403+
"""Read all batches into a table."""
404+
...
405+
406+
def read_next_batch(self) -> Any:
407+
"""Read next batch."""
408+
...
409+
410+
def __iter__(self) -> "Iterable[Any]":
411+
"""Iterate over batches."""
412+
...
413+
414+
380415
try:
381416
from pyarrow import RecordBatch as ArrowRecordBatch
417+
from pyarrow import RecordBatchReader as ArrowRecordBatchReader
418+
from pyarrow import Schema as ArrowSchema
382419
from pyarrow import Table as ArrowTable
383420

384421
PYARROW_INSTALLED = True
385422
except ImportError:
386423
ArrowTable = ArrowTableResult # type: ignore[assignment,misc]
387424
ArrowRecordBatch = ArrowRecordBatchResult # type: ignore[assignment,misc]
425+
ArrowSchema = ArrowSchemaProtocol # type: ignore[assignment,misc]
426+
ArrowRecordBatchReader = ArrowRecordBatchReaderProtocol # type: ignore[assignment,misc]
388427

389428
PYARROW_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
390429

391430

431+
@runtime_checkable
432+
class PandasDataFrameProtocol(Protocol):
433+
"""Typed shim for pandas.DataFrame."""
434+
435+
def __len__(self) -> int:
436+
"""Get number of rows."""
437+
...
438+
439+
def __getitem__(self, key: Any) -> Any:
440+
"""Get column or row."""
441+
...
442+
443+
444+
@runtime_checkable
445+
class PolarsDataFrameProtocol(Protocol):
446+
"""Typed shim for polars.DataFrame."""
447+
448+
def __len__(self) -> int:
449+
"""Get number of rows."""
450+
...
451+
452+
def __getitem__(self, key: Any) -> Any:
453+
"""Get column or row."""
454+
...
455+
456+
457+
try:
458+
from pandas import DataFrame as PandasDataFrame
459+
460+
PANDAS_INSTALLED = True
461+
except ImportError:
462+
PandasDataFrame = PandasDataFrameProtocol # type: ignore[assignment,misc]
463+
PANDAS_INSTALLED = False
464+
465+
466+
try:
467+
from polars import DataFrame as PolarsDataFrame
468+
469+
POLARS_INSTALLED = True
470+
except ImportError:
471+
PolarsDataFrame = PolarsDataFrameProtocol # type: ignore[assignment,misc]
472+
POLARS_INSTALLED = False
473+
474+
392475
@runtime_checkable
393476
class NumpyArrayStub(Protocol):
394477
"""Protocol stub for numpy.ndarray when numpy is not installed.
@@ -639,7 +722,9 @@ async def insert_returning(self, conn: Any, query_name: str, sql: str, parameter
639722
"OBSTORE_INSTALLED",
640723
"OPENTELEMETRY_INSTALLED",
641724
"ORJSON_INSTALLED",
725+
"PANDAS_INSTALLED",
642726
"PGVECTOR_INSTALLED",
727+
"POLARS_INSTALLED",
643728
"PROMETHEUS_INSTALLED",
644729
"PYARROW_INSTALLED",
645730
"PYDANTIC_INSTALLED",
@@ -650,7 +735,9 @@ async def insert_returning(self, conn: Any, query_name: str, sql: str, parameter
650735
"AiosqlSQLOperationType",
651736
"AiosqlSyncProtocol",
652737
"ArrowRecordBatch",
738+
"ArrowRecordBatchReader",
653739
"ArrowRecordBatchResult",
740+
"ArrowSchema",
654741
"ArrowTable",
655742
"ArrowTableResult",
656743
"AttrsInstance",
@@ -670,6 +757,10 @@ async def insert_returning(self, conn: Any, query_name: str, sql: str, parameter
670757
"Histogram",
671758
"NumpyArray",
672759
"NumpyArrayStub",
760+
"PandasDataFrame",
761+
"PandasDataFrameProtocol",
762+
"PolarsDataFrame",
763+
"PolarsDataFrameProtocol",
673764
"Span",
674765
"Status",
675766
"StatusCode",

sqlspec/adapters/adbc/driver.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
DataError,
2424
ForeignKeyViolationError,
2525
IntegrityError,
26-
MissingDependencyError,
2726
NotNullViolationError,
2827
SQLParsingError,
2928
SQLSpecError,
@@ -507,18 +506,6 @@ def __init__(
507506
self.dialect = statement_config.dialect
508507
self._data_dictionary: SyncDataDictionaryBase | None = None
509508

510-
@staticmethod
511-
def _ensure_pyarrow_installed() -> None:
512-
"""Ensure PyArrow is installed.
513-
514-
Raises:
515-
MissingDependencyError: If PyArrow is not installed
516-
"""
517-
from sqlspec.typing import PYARROW_INSTALLED
518-
519-
if not PYARROW_INSTALLED:
520-
raise MissingDependencyError(package="pyarrow", install_package="arrow")
521-
522509
@staticmethod
523510
def _get_dialect(connection: "AdbcConnection") -> str:
524511
"""Detect database dialect from connection information.

sqlspec/core/result.py

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
from typing_extensions import TypeVar
1717

1818
from sqlspec.core.compiler import OperationType
19+
from sqlspec.utils.module_loader import ensure_pandas, ensure_polars, ensure_pyarrow
1920
from sqlspec.utils.schema import to_schema
2021

2122
if TYPE_CHECKING:
2223
from collections.abc import Iterator
2324

2425
from sqlspec.core.statement import SQL
25-
from sqlspec.typing import SchemaT
26+
from sqlspec.typing import ArrowTable, PandasDataFrame, PolarsDataFrame, SchemaT
2627

2728

2829
__all__ = ("ArrowResult", "SQLResult", "StatementResult")
@@ -618,18 +619,27 @@ def is_success(self) -> bool:
618619
"""
619620
return self.data is not None
620621

621-
def get_data(self) -> Any:
622+
def get_data(self) -> "ArrowTable":
622623
"""Get the Apache Arrow Table from the result.
623624
624625
Returns:
625626
The Arrow table containing the result data.
626627
627628
Raises:
628629
ValueError: If no Arrow table is available.
630+
TypeError: If data is not an Arrow Table.
629631
"""
630632
if self.data is None:
631633
msg = "No Arrow table available for this result"
632634
raise ValueError(msg)
635+
636+
ensure_pyarrow()
637+
638+
import pyarrow as pa
639+
640+
if not isinstance(self.data, pa.Table):
641+
msg = f"Expected an Arrow Table, but got {type(self.data).__name__}"
642+
raise TypeError(msg)
633643
return self.data
634644

635645
@property
@@ -680,6 +690,127 @@ def num_columns(self) -> int:
680690

681691
return cast("int", self.data.num_columns)
682692

693+
def to_pandas(self) -> "PandasDataFrame":
694+
"""Convert Arrow data to pandas DataFrame.
695+
696+
Returns:
697+
pandas DataFrame containing the result data.
698+
699+
Raises:
700+
ValueError: If no Arrow table is available.
701+
702+
Examples:
703+
>>> result = session.select_to_arrow("SELECT * FROM users")
704+
>>> df = result.to_pandas()
705+
>>> print(df.head())
706+
"""
707+
if self.data is None:
708+
msg = "No Arrow table available"
709+
raise ValueError(msg)
710+
711+
ensure_pandas()
712+
713+
import pandas as pd
714+
715+
result = self.data.to_pandas()
716+
if not isinstance(result, pd.DataFrame):
717+
msg = f"Expected a pandas DataFrame, but got {type(result).__name__}"
718+
raise TypeError(msg)
719+
return result
720+
721+
def to_polars(self) -> "PolarsDataFrame":
722+
"""Convert Arrow data to Polars DataFrame.
723+
724+
Returns:
725+
Polars DataFrame containing the result data.
726+
727+
Raises:
728+
ValueError: If no Arrow table is available.
729+
730+
Examples:
731+
>>> result = session.select_to_arrow("SELECT * FROM users")
732+
>>> df = result.to_polars()
733+
>>> print(df.head())
734+
"""
735+
if self.data is None:
736+
msg = "No Arrow table available"
737+
raise ValueError(msg)
738+
739+
ensure_polars()
740+
741+
import polars as pl
742+
743+
result = pl.from_arrow(self.data)
744+
if not isinstance(result, pl.DataFrame):
745+
msg = f"Expected a Polars DataFrame, but got {type(result).__name__}"
746+
raise TypeError(msg)
747+
return result
748+
749+
def to_dict(self) -> "list[dict[str, Any]]":
750+
"""Convert Arrow data to list of dictionaries.
751+
752+
Returns:
753+
List of dictionaries, one per row.
754+
755+
Raises:
756+
ValueError: If no Arrow table is available.
757+
758+
Examples:
759+
>>> result = session.select_to_arrow(
760+
... "SELECT id, name FROM users"
761+
... )
762+
>>> rows = result.to_dict()
763+
>>> print(rows[0])
764+
{'id': 1, 'name': 'Alice'}
765+
"""
766+
if self.data is None:
767+
msg = "No Arrow table available"
768+
raise ValueError(msg)
769+
770+
return cast("list[dict[str, Any]]", self.data.to_pylist())
771+
772+
def __len__(self) -> int:
773+
"""Return number of rows in the Arrow table.
774+
775+
Returns:
776+
Number of rows.
777+
778+
Raises:
779+
ValueError: If no Arrow table is available.
780+
781+
Examples:
782+
>>> result = session.select_to_arrow("SELECT * FROM users")
783+
>>> print(len(result))
784+
100
785+
"""
786+
if self.data is None:
787+
msg = "No Arrow table available"
788+
raise ValueError(msg)
789+
790+
return cast("int", self.data.num_rows)
791+
792+
def __iter__(self) -> "Iterator[dict[str, Any]]":
793+
"""Iterate over rows as dictionaries.
794+
795+
Yields:
796+
Dictionary for each row.
797+
798+
Raises:
799+
ValueError: If no Arrow table is available.
800+
801+
Examples:
802+
>>> result = session.select_to_arrow(
803+
... "SELECT id, name FROM users"
804+
... )
805+
>>> for row in result:
806+
... print(row["name"])
807+
"""
808+
if self.data is None:
809+
msg = "No Arrow table available"
810+
raise ValueError(msg)
811+
812+
yield from self.data.to_pylist()
813+
683814

684815
def create_sql_result(
685816
statement: "SQL",

sqlspec/extensions/aiosql/adapter.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,8 @@
1313
from sqlspec.core.result import SQLResult
1414
from sqlspec.core.statement import SQL, StatementConfig
1515
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
16-
from sqlspec.exceptions import MissingDependencyError
17-
from sqlspec.typing import (
18-
AIOSQL_INSTALLED,
19-
AiosqlAsyncProtocol,
20-
AiosqlParamType,
21-
AiosqlSQLOperationType,
22-
AiosqlSyncProtocol,
23-
)
16+
from sqlspec.typing import AiosqlAsyncProtocol, AiosqlParamType, AiosqlSQLOperationType, AiosqlSyncProtocol
17+
from sqlspec.utils.module_loader import ensure_aiosql
2418

2519
logger = logging.getLogger("sqlspec.extensions.aiosql")
2620

@@ -58,12 +52,6 @@ def fetchone(self) -> Any | None:
5852
return rows[0] if rows else None
5953

6054

61-
def _check_aiosql_available() -> None:
62-
if not AIOSQL_INSTALLED:
63-
msg = "aiosql"
64-
raise MissingDependencyError(msg, "aiosql")
65-
66-
6755
def _normalize_dialect(dialect: "str | Any | None") -> str:
6856
"""Normalize dialect name for SQLGlot compatibility.
6957
@@ -105,7 +93,7 @@ def __init__(self, driver: DriverT) -> None:
10593
Args:
10694
driver: SQLSpec driver to use for execution.
10795
"""
108-
_check_aiosql_available()
96+
ensure_aiosql()
10997
self.driver: DriverT = driver
11098

11199
def process_sql(self, query_name: str, op_type: "AiosqlSQLOperationType", sql: str) -> str:

0 commit comments

Comments
 (0)