Skip to content

Commit 5206809

Browse files
authored
feat: add native Arrow support for Oracle adapter (#159)
Implements native Apache Arrow support for Oracle using `fetch_df_all()`.
1 parent f819ff0 commit 5206809

File tree

2 files changed

+450
-2
lines changed

2 files changed

+450
-2
lines changed

sqlspec/adapters/oracledb/driver.py

Lines changed: 183 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from sqlspec.adapters.oracledb.type_converter import OracleTypeConverter
1414
from sqlspec.core.cache import get_cache_config
1515
from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig
16-
from sqlspec.core.statement import StatementConfig
16+
from sqlspec.core.result import create_arrow_result
17+
from sqlspec.core.statement import SQL, StatementConfig
1718
from sqlspec.driver import (
1819
AsyncDataDictionaryBase,
1920
AsyncDriverAdapterBase,
@@ -33,14 +34,18 @@
3334
TransactionError,
3435
UniqueViolationError,
3536
)
37+
from sqlspec.utils.module_loader import ensure_pyarrow
3638
from sqlspec.utils.serializers import to_json
3739

3840
if TYPE_CHECKING:
3941
from contextlib import AbstractAsyncContextManager, AbstractContextManager
4042

43+
from sqlspec.builder import QueryBuilder
44+
from sqlspec.core import StatementFilter
4145
from sqlspec.core.result import SQLResult
42-
from sqlspec.core.statement import SQL
46+
from sqlspec.core.statement import Statement
4347
from sqlspec.driver._common import ExecutionResult
48+
from sqlspec.typing import StatementParameters
4449

4550
logger = logging.getLogger(__name__)
4651

@@ -587,6 +592,94 @@ def commit(self) -> None:
587592
msg = f"Failed to commit Oracle transaction: {e}"
588593
raise SQLSpecError(msg) from e
589594

595+
def select_to_arrow(
596+
self,
597+
statement: "Statement | QueryBuilder",
598+
/,
599+
*parameters: "StatementParameters | StatementFilter",
600+
statement_config: "StatementConfig | None" = None,
601+
return_format: str = "table",
602+
native_only: bool = False,
603+
batch_size: int | None = None,
604+
arrow_schema: Any = None,
605+
**kwargs: Any,
606+
) -> "Any":
607+
"""Execute query and return results as Apache Arrow format using Oracle native support.
608+
609+
This implementation uses Oracle's native fetch_df_all() method which returns
610+
an OracleDataFrame with Arrow PyCapsule interface, providing zero-copy data
611+
transfer and 5-10x performance improvement over dict conversion.
612+
613+
Args:
614+
statement: SQL query string, Statement, or QueryBuilder
615+
*parameters: Query parameters (same format as execute()/select())
616+
statement_config: Optional statement configuration override
617+
return_format: "table" for pyarrow.Table (default), "batches" for RecordBatch
618+
native_only: If False, use base conversion path instead of native (default: False uses native)
619+
batch_size: Rows per batch when using "batches" format
620+
arrow_schema: Optional pyarrow.Schema for type casting
621+
**kwargs: Additional keyword arguments
622+
623+
Returns:
624+
ArrowResult containing pyarrow.Table or RecordBatch
625+
626+
Examples:
627+
>>> result = driver.select_to_arrow(
628+
... "SELECT * FROM users WHERE age > :1", (18,)
629+
... )
630+
>>> df = result.to_pandas()
631+
>>> print(df.head())
632+
"""
633+
# Check pyarrow is available
634+
ensure_pyarrow()
635+
636+
# If native_only=False explicitly passed, use base conversion path
637+
if native_only is False:
638+
return super().select_to_arrow(
639+
statement,
640+
*parameters,
641+
statement_config=statement_config,
642+
return_format=return_format,
643+
native_only=native_only,
644+
batch_size=batch_size,
645+
arrow_schema=arrow_schema,
646+
**kwargs,
647+
)
648+
649+
import pyarrow as pa
650+
651+
# Prepare statement with parameters
652+
config = statement_config or self.statement_config
653+
prepared_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs)
654+
sql, prepared_parameters = self._get_compiled_sql(prepared_statement, config)
655+
656+
# Use Oracle's native fetch_df_all() for zero-copy Arrow transfer
657+
oracle_df = self.connection.fetch_df_all(
658+
statement=sql, parameters=prepared_parameters or [], arraysize=batch_size or 1000
659+
)
660+
661+
# Convert OracleDataFrame to PyArrow Table using PyCapsule interface
662+
arrow_table = pa.table(oracle_df)
663+
664+
# Apply schema casting if provided
665+
if arrow_schema is not None:
666+
if not isinstance(arrow_schema, pa.Schema):
667+
msg = f"arrow_schema must be a pyarrow.Schema, got {type(arrow_schema).__name__}"
668+
raise TypeError(msg)
669+
arrow_table = arrow_table.cast(arrow_schema)
670+
671+
# Convert to batches if requested
672+
if return_format == "batches":
673+
batches = arrow_table.to_batches()
674+
arrow_data: Any = batches[0] if batches else pa.RecordBatch.from_pydict({})
675+
else:
676+
arrow_data = arrow_table
677+
678+
# Get row count
679+
rows_affected = len(arrow_table)
680+
681+
return create_arrow_result(statement=prepared_statement, data=arrow_data, rows_affected=rows_affected)
682+
590683
@property
591684
def data_dictionary(self) -> "SyncDataDictionaryBase":
592685
"""Get the data dictionary for this driver.
@@ -783,6 +876,94 @@ async def commit(self) -> None:
783876
msg = f"Failed to commit Oracle transaction: {e}"
784877
raise SQLSpecError(msg) from e
785878

879+
async def select_to_arrow(
880+
self,
881+
statement: "Statement | QueryBuilder",
882+
/,
883+
*parameters: "StatementParameters | StatementFilter",
884+
statement_config: "StatementConfig | None" = None,
885+
return_format: str = "table",
886+
native_only: bool = False,
887+
batch_size: int | None = None,
888+
arrow_schema: Any = None,
889+
**kwargs: Any,
890+
) -> "Any":
891+
"""Execute query and return results as Apache Arrow format using Oracle native support.
892+
893+
This implementation uses Oracle's native fetch_df_all() method which returns
894+
an OracleDataFrame with Arrow PyCapsule interface, providing zero-copy data
895+
transfer and 5-10x performance improvement over dict conversion.
896+
897+
Args:
898+
statement: SQL query string, Statement, or QueryBuilder
899+
*parameters: Query parameters (same format as execute()/select())
900+
statement_config: Optional statement configuration override
901+
return_format: "table" for pyarrow.Table (default), "batches" for RecordBatch
902+
native_only: If False, use base conversion path instead of native (default: False uses native)
903+
batch_size: Rows per batch when using "batches" format
904+
arrow_schema: Optional pyarrow.Schema for type casting
905+
**kwargs: Additional keyword arguments
906+
907+
Returns:
908+
ArrowResult containing pyarrow.Table or RecordBatch
909+
910+
Examples:
911+
>>> result = await driver.select_to_arrow(
912+
... "SELECT * FROM users WHERE age > :1", (18,)
913+
... )
914+
>>> df = result.to_pandas()
915+
>>> print(df.head())
916+
"""
917+
# Check pyarrow is available
918+
ensure_pyarrow()
919+
920+
# If native_only=False explicitly passed, use base conversion path
921+
if native_only is False:
922+
return await super().select_to_arrow(
923+
statement,
924+
*parameters,
925+
statement_config=statement_config,
926+
return_format=return_format,
927+
native_only=native_only,
928+
batch_size=batch_size,
929+
arrow_schema=arrow_schema,
930+
**kwargs,
931+
)
932+
933+
import pyarrow as pa
934+
935+
# Prepare statement with parameters
936+
config = statement_config or self.statement_config
937+
prepared_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs)
938+
sql, prepared_parameters = self._get_compiled_sql(prepared_statement, config)
939+
940+
# Use Oracle's native fetch_df_all() for zero-copy Arrow transfer
941+
oracle_df = await self.connection.fetch_df_all(
942+
statement=sql, parameters=prepared_parameters or [], arraysize=batch_size or 1000
943+
)
944+
945+
# Convert OracleDataFrame to PyArrow Table using PyCapsule interface
946+
arrow_table = pa.table(oracle_df)
947+
948+
# Apply schema casting if provided
949+
if arrow_schema is not None:
950+
if not isinstance(arrow_schema, pa.Schema):
951+
msg = f"arrow_schema must be a pyarrow.Schema, got {type(arrow_schema).__name__}"
952+
raise TypeError(msg)
953+
arrow_table = arrow_table.cast(arrow_schema)
954+
955+
# Convert to batches if requested
956+
if return_format == "batches":
957+
batches = arrow_table.to_batches()
958+
arrow_data: Any = batches[0] if batches else pa.RecordBatch.from_pydict({})
959+
else:
960+
arrow_data = arrow_table
961+
962+
# Get row count
963+
rows_affected = len(arrow_table)
964+
965+
return create_arrow_result(statement=prepared_statement, data=arrow_data, rows_affected=rows_affected)
966+
786967
@property
787968
def data_dictionary(self) -> "AsyncDataDictionaryBase":
788969
"""Get the data dictionary for this driver.

0 commit comments

Comments
 (0)