|
4 | 4 | from typing import TYPE_CHECKING, Any, Final, TypeVar, overload |
5 | 5 |
|
6 | 6 | from sqlspec.core import SQL |
| 7 | +from sqlspec.core.result import create_arrow_result |
7 | 8 | from sqlspec.driver._common import ( |
8 | 9 | CommonDriverAttributesMixin, |
9 | 10 | DataDictionaryMixin, |
|
12 | 13 | handle_single_row_error, |
13 | 14 | ) |
14 | 15 | from sqlspec.driver.mixins import SQLTranslatorMixin |
| 16 | +from sqlspec.exceptions import ImproperConfigurationError |
| 17 | +from sqlspec.utils.arrow_helpers import convert_dict_to_arrow |
15 | 18 | from sqlspec.utils.logging import get_logger |
| 19 | +from sqlspec.utils.module_loader import ensure_pyarrow |
16 | 20 |
|
17 | 21 | if TYPE_CHECKING: |
18 | 22 | from collections.abc import Sequence |
@@ -341,6 +345,97 @@ def select( |
341 | 345 | result = self.execute(statement, *parameters, statement_config=statement_config, **kwargs) |
342 | 346 | return result.get_data(schema_type=schema_type) |
343 | 347 |
|
| 348 | + def select_to_arrow( |
| 349 | + self, |
| 350 | + statement: "Statement | QueryBuilder", |
| 351 | + /, |
| 352 | + *parameters: "StatementParameters | StatementFilter", |
| 353 | + statement_config: "StatementConfig | None" = None, |
| 354 | + return_format: str = "table", |
| 355 | + native_only: bool = False, |
| 356 | + batch_size: int | None = None, |
| 357 | + arrow_schema: Any = None, |
| 358 | + **kwargs: Any, |
| 359 | + ) -> "Any": |
| 360 | + """Execute query and return results as Apache Arrow format. |
| 361 | +
|
| 362 | + This base implementation uses the conversion path: execute() → dict → Arrow. |
| 363 | + Adapters with native Arrow support (ADBC, DuckDB, BigQuery) override this |
| 364 | + method to use zero-copy native paths for 5-10x performance improvement. |
| 365 | +
|
| 366 | + Args: |
| 367 | + statement: SQL query string, Statement, or QueryBuilder |
| 368 | + *parameters: Query parameters (same format as execute()/select()) |
| 369 | + statement_config: Optional statement configuration override |
| 370 | + return_format: "table" for pyarrow.Table (default), "reader" for RecordBatchReader, |
| 371 | + "batches" for iterator of RecordBatches |
| 372 | + native_only: If True, raise error if native Arrow unavailable (default: False) |
| 373 | + batch_size: Rows per batch for "batches" format (default: None = all rows) |
| 374 | + arrow_schema: Optional pyarrow.Schema for type casting |
| 375 | + **kwargs: Additional keyword arguments |
| 376 | +
|
| 377 | + Returns: |
| 378 | + ArrowResult containing pyarrow.Table, RecordBatchReader, or RecordBatches |
| 379 | +
|
| 380 | + Raises: |
| 381 | + MissingDependencyError: If pyarrow not installed |
| 382 | + ImproperConfigurationError: If native_only=True and adapter doesn't support native Arrow |
| 383 | + SQLExecutionError: If query execution fails |
| 384 | +
|
| 385 | + Examples: |
| 386 | + >>> result = driver.select_to_arrow( |
| 387 | + ... "SELECT * FROM users WHERE age > ?", 18 |
| 388 | + ... ) |
| 389 | + >>> df = result.to_pandas() |
| 390 | + >>> print(df.head()) |
| 391 | +
|
| 392 | + >>> # Force native Arrow path (raises error if unavailable) |
| 393 | + >>> result = driver.select_to_arrow( |
| 394 | + ... "SELECT * FROM users", native_only=True |
| 395 | + ... ) |
| 396 | + """ |
| 397 | + # Check pyarrow is available |
| 398 | + ensure_pyarrow() |
| 399 | + |
| 400 | + # Check if native_only requested but not supported |
| 401 | + if native_only: |
| 402 | + msg = ( |
| 403 | + f"Adapter '{self.__class__.__name__}' does not support native Arrow results. " |
| 404 | + f"Use native_only=False to allow conversion path, or switch to an adapter " |
| 405 | + f"with native Arrow support (ADBC, DuckDB, BigQuery)." |
| 406 | + ) |
| 407 | + raise ImproperConfigurationError(msg) |
| 408 | + |
| 409 | + # Execute query using standard path |
| 410 | + result = self.execute(statement, *parameters, statement_config=statement_config, **kwargs) |
| 411 | + |
| 412 | + # Convert dict results to Arrow |
| 413 | + arrow_data = convert_dict_to_arrow( |
| 414 | + result.data, |
| 415 | + return_format=return_format, # type: ignore[arg-type] |
| 416 | + batch_size=batch_size, |
| 417 | + ) |
| 418 | + |
| 419 | + # Apply schema casting if requested |
| 420 | + if arrow_schema is not None: |
| 421 | + import pyarrow as pa |
| 422 | + |
| 423 | + if not isinstance(arrow_schema, pa.Schema): |
| 424 | + msg = f"arrow_schema must be a pyarrow.Schema, got {type(arrow_schema).__name__}" |
| 425 | + raise TypeError(msg) |
| 426 | + |
| 427 | + arrow_data = arrow_data.cast(arrow_schema) |
| 428 | + |
| 429 | + # Create ArrowResult |
| 430 | + return create_arrow_result( |
| 431 | + statement=result.statement, |
| 432 | + data=arrow_data, |
| 433 | + rows_affected=result.rows_affected, |
| 434 | + last_inserted_id=result.last_inserted_id, |
| 435 | + execution_time=result.execution_time, |
| 436 | + metadata=result.metadata, |
| 437 | + ) |
| 438 | + |
344 | 439 | def select_value( |
345 | 440 | self, |
346 | 441 | statement: "Statement | QueryBuilder", |
|
0 commit comments