Skip to content

Commit 3dd79b9

Browse files
authored
feat: add select_to_arrow() to base driver classes (#155)
Add `select_to_arrow()` method to base driver classes for Apache Arrow result support.
1 parent ff97ab1 commit 3dd79b9

File tree

2 files changed

+184
-0
lines changed

2 files changed

+184
-0
lines changed

sqlspec/driver/_async.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import TYPE_CHECKING, Any, Final, TypeVar, overload
55

66
from sqlspec.core import SQL, Statement
7+
from sqlspec.core.result import create_arrow_result
78
from sqlspec.driver._common import (
89
CommonDriverAttributesMixin,
910
DataDictionaryMixin,
@@ -12,7 +13,10 @@
1213
handle_single_row_error,
1314
)
1415
from sqlspec.driver.mixins import SQLTranslatorMixin
16+
from sqlspec.exceptions import ImproperConfigurationError
17+
from sqlspec.utils.arrow_helpers import convert_dict_to_arrow
1518
from sqlspec.utils.logging import get_logger
19+
from sqlspec.utils.module_loader import ensure_pyarrow
1620

1721
if TYPE_CHECKING:
1822
from collections.abc import Sequence
@@ -341,6 +345,91 @@ async def select(
341345
result = await self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
342346
return result.get_data(schema_type=schema_type)
343347

348+
async def select_to_arrow(
349+
self,
350+
statement: "Statement | QueryBuilder",
351+
/,
352+
*parameters: "StatementParameters | StatementFilter",
353+
statement_config: "StatementConfig | None" = None,
354+
return_format: str = "table",
355+
native_only: bool = False,
356+
batch_size: int | None = None,
357+
arrow_schema: Any = None,
358+
**kwargs: Any,
359+
) -> "Any":
360+
"""Execute query and return results as Apache Arrow format (async).
361+
362+
This base implementation uses the conversion path: execute() → dict → Arrow.
363+
Adapters with native Arrow support (ADBC, DuckDB, BigQuery) override this
364+
method to use zero-copy native paths for 5-10x performance improvement.
365+
366+
Args:
367+
statement: SQL query string, Statement, or QueryBuilder
368+
*parameters: Query parameters (same format as execute()/select())
369+
statement_config: Optional statement configuration override
370+
return_format: "table" for pyarrow.Table (default), "reader" for RecordBatchReader,
371+
"batches" for iterator of RecordBatches
372+
native_only: If True, raise error if native Arrow unavailable (default: False)
373+
batch_size: Rows per batch for "batches" format (default: None = all rows)
374+
arrow_schema: Optional pyarrow.Schema for type casting
375+
**kwargs: Additional keyword arguments
376+
377+
Returns:
378+
ArrowResult containing pyarrow.Table, RecordBatchReader, or RecordBatches
379+
380+
Raises:
381+
ImproperConfigurationError: If native_only=True and adapter doesn't support native Arrow
382+
383+
Examples:
384+
>>> result = await driver.select_to_arrow(
385+
... "SELECT * FROM users WHERE age > ?", 18
386+
... )
387+
>>> df = result.to_pandas()
388+
>>> print(df.head())
389+
390+
>>> # Force native Arrow path (raises error if unavailable)
391+
>>> result = await driver.select_to_arrow(
392+
... "SELECT * FROM users", native_only=True
393+
... )
394+
"""
395+
# Check pyarrow is available
396+
ensure_pyarrow()
397+
398+
# Check if native_only requested but not supported
399+
if native_only:
400+
msg = (
401+
f"Adapter '{self.__class__.__name__}' does not support native Arrow results. "
402+
f"Use native_only=False to allow conversion path, or switch to an adapter "
403+
f"with native Arrow support (ADBC, DuckDB, BigQuery)."
404+
)
405+
raise ImproperConfigurationError(msg)
406+
407+
# Execute query using standard path
408+
result = await self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
409+
410+
# Convert dict results to Arrow
411+
arrow_data = convert_dict_to_arrow(
412+
result.data,
413+
return_format=return_format, # type: ignore[arg-type]
414+
batch_size=batch_size,
415+
)
416+
if arrow_schema is not None:
417+
import pyarrow as pa
418+
419+
if not isinstance(arrow_schema, pa.Schema):
420+
msg = f"arrow_schema must be a pyarrow.Schema, got {type(arrow_schema).__name__}"
421+
raise TypeError(msg)
422+
423+
arrow_data = arrow_data.cast(arrow_schema)
424+
return create_arrow_result(
425+
statement=result.statement,
426+
data=arrow_data,
427+
rows_affected=result.rows_affected,
428+
last_inserted_id=result.last_inserted_id,
429+
execution_time=result.execution_time,
430+
metadata=result.metadata,
431+
)
432+
344433
async def select_value(
345434
self,
346435
statement: "Statement | QueryBuilder",

sqlspec/driver/_sync.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import TYPE_CHECKING, Any, Final, TypeVar, overload
55

66
from sqlspec.core import SQL
7+
from sqlspec.core.result import create_arrow_result
78
from sqlspec.driver._common import (
89
CommonDriverAttributesMixin,
910
DataDictionaryMixin,
@@ -12,7 +13,10 @@
1213
handle_single_row_error,
1314
)
1415
from sqlspec.driver.mixins import SQLTranslatorMixin
16+
from sqlspec.exceptions import ImproperConfigurationError
17+
from sqlspec.utils.arrow_helpers import convert_dict_to_arrow
1518
from sqlspec.utils.logging import get_logger
19+
from sqlspec.utils.module_loader import ensure_pyarrow
1620

1721
if TYPE_CHECKING:
1822
from collections.abc import Sequence
@@ -341,6 +345,97 @@ def select(
341345
result = self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
342346
return result.get_data(schema_type=schema_type)
343347

348+
def select_to_arrow(
349+
self,
350+
statement: "Statement | QueryBuilder",
351+
/,
352+
*parameters: "StatementParameters | StatementFilter",
353+
statement_config: "StatementConfig | None" = None,
354+
return_format: str = "table",
355+
native_only: bool = False,
356+
batch_size: int | None = None,
357+
arrow_schema: Any = None,
358+
**kwargs: Any,
359+
) -> "Any":
360+
"""Execute query and return results as Apache Arrow format.
361+
362+
This base implementation uses the conversion path: execute() → dict → Arrow.
363+
Adapters with native Arrow support (ADBC, DuckDB, BigQuery) override this
364+
method to use zero-copy native paths for 5-10x performance improvement.
365+
366+
Args:
367+
statement: SQL query string, Statement, or QueryBuilder
368+
*parameters: Query parameters (same format as execute()/select())
369+
statement_config: Optional statement configuration override
370+
return_format: "table" for pyarrow.Table (default), "reader" for RecordBatchReader,
371+
"batches" for iterator of RecordBatches
372+
native_only: If True, raise error if native Arrow unavailable (default: False)
373+
batch_size: Rows per batch for "batches" format (default: None = all rows)
374+
arrow_schema: Optional pyarrow.Schema for type casting
375+
**kwargs: Additional keyword arguments
376+
377+
Returns:
378+
ArrowResult containing pyarrow.Table, RecordBatchReader, or RecordBatches
379+
380+
Raises:
381+
MissingDependencyError: If pyarrow not installed
382+
ImproperConfigurationError: If native_only=True and adapter doesn't support native Arrow
383+
SQLExecutionError: If query execution fails
384+
385+
Examples:
386+
>>> result = driver.select_to_arrow(
387+
... "SELECT * FROM users WHERE age > ?", 18
388+
... )
389+
>>> df = result.to_pandas()
390+
>>> print(df.head())
391+
392+
>>> # Force native Arrow path (raises error if unavailable)
393+
>>> result = driver.select_to_arrow(
394+
... "SELECT * FROM users", native_only=True
395+
... )
396+
"""
397+
# Check pyarrow is available
398+
ensure_pyarrow()
399+
400+
# Check if native_only requested but not supported
401+
if native_only:
402+
msg = (
403+
f"Adapter '{self.__class__.__name__}' does not support native Arrow results. "
404+
f"Use native_only=False to allow conversion path, or switch to an adapter "
405+
f"with native Arrow support (ADBC, DuckDB, BigQuery)."
406+
)
407+
raise ImproperConfigurationError(msg)
408+
409+
# Execute query using standard path
410+
result = self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
411+
412+
# Convert dict results to Arrow
413+
arrow_data = convert_dict_to_arrow(
414+
result.data,
415+
return_format=return_format, # type: ignore[arg-type]
416+
batch_size=batch_size,
417+
)
418+
419+
# Apply schema casting if requested
420+
if arrow_schema is not None:
421+
import pyarrow as pa
422+
423+
if not isinstance(arrow_schema, pa.Schema):
424+
msg = f"arrow_schema must be a pyarrow.Schema, got {type(arrow_schema).__name__}"
425+
raise TypeError(msg)
426+
427+
arrow_data = arrow_data.cast(arrow_schema)
428+
429+
# Create ArrowResult
430+
return create_arrow_result(
431+
statement=result.statement,
432+
data=arrow_data,
433+
rows_affected=result.rows_affected,
434+
last_inserted_id=result.last_inserted_id,
435+
execution_time=result.execution_time,
436+
metadata=result.metadata,
437+
)
438+
344439
def select_value(
345440
self,
346441
statement: "Statement | QueryBuilder",

0 commit comments

Comments
 (0)