|
16 | 16 | from typing_extensions import TypeVar |
17 | 17 |
|
18 | 18 | from sqlspec.core.compiler import OperationType |
| 19 | +from sqlspec.utils.module_loader import ensure_pandas, ensure_polars, ensure_pyarrow |
19 | 20 | from sqlspec.utils.schema import to_schema |
20 | 21 |
|
21 | 22 | if TYPE_CHECKING: |
22 | 23 | from collections.abc import Iterator |
23 | 24 |
|
24 | 25 | from sqlspec.core.statement import SQL |
25 | | - from sqlspec.typing import SchemaT |
| 26 | + from sqlspec.typing import ArrowTable, PandasDataFrame, PolarsDataFrame, SchemaT |
26 | 27 |
|
27 | 28 |
|
28 | 29 | __all__ = ("ArrowResult", "SQLResult", "StatementResult") |
@@ -618,18 +619,27 @@ def is_success(self) -> bool: |
618 | 619 | """ |
619 | 620 | return self.data is not None |
620 | 621 |
|
621 | | - def get_data(self) -> Any: |
| 622 | + def get_data(self) -> "ArrowTable": |
622 | 623 | """Get the Apache Arrow Table from the result. |
623 | 624 |
|
624 | 625 | Returns: |
625 | 626 | The Arrow table containing the result data. |
626 | 627 |
|
627 | 628 | Raises: |
628 | 629 | ValueError: If no Arrow table is available. |
| 630 | + TypeError: If data is not an Arrow Table. |
629 | 631 | """ |
630 | 632 | if self.data is None: |
631 | 633 | msg = "No Arrow table available for this result" |
632 | 634 | raise ValueError(msg) |
| 635 | + |
| 636 | + ensure_pyarrow() |
| 637 | + |
| 638 | + import pyarrow as pa |
| 639 | + |
| 640 | + if not isinstance(self.data, pa.Table): |
| 641 | + msg = f"Expected an Arrow Table, but got {type(self.data).__name__}" |
| 642 | + raise TypeError(msg) |
633 | 643 | return self.data |
634 | 644 |
|
635 | 645 | @property |
@@ -680,6 +690,127 @@ def num_columns(self) -> int: |
680 | 690 |
|
681 | 691 | return cast("int", self.data.num_columns) |
682 | 692 |
|
| 693 | + def to_pandas(self) -> "PandasDataFrame": |
| 694 | + """Convert Arrow data to pandas DataFrame. |
| 695 | +
|
| 696 | + Returns: |
| 697 | + pandas DataFrame containing the result data. |
| 698 | +
|
| 699 | + Raises: |
| 700 | + ValueError: If no Arrow table is available. |
| 701 | +
|
| 702 | + Examples: |
| 703 | + >>> result = session.select_to_arrow("SELECT * FROM users") |
| 704 | + >>> df = result.to_pandas() |
| 705 | + >>> print(df.head()) |
| 706 | + """ |
| 707 | + if self.data is None: |
| 708 | + msg = "No Arrow table available" |
| 709 | + raise ValueError(msg) |
| 710 | + |
| 711 | + ensure_pandas() |
| 712 | + |
| 713 | + import pandas as pd |
| 714 | + |
| 715 | + result = self.data.to_pandas() |
| 716 | + if not isinstance(result, pd.DataFrame): |
| 717 | + msg = f"Expected a pandas DataFrame, but got {type(result).__name__}" |
| 718 | + raise TypeError(msg) |
| 719 | + return result |
| 720 | + |
| 721 | + def to_polars(self) -> "PolarsDataFrame": |
| 722 | + """Convert Arrow data to Polars DataFrame. |
| 723 | +
|
| 724 | + Returns: |
| 725 | + Polars DataFrame containing the result data. |
| 726 | +
|
| 727 | + Raises: |
| 728 | + ValueError: If no Arrow table is available. |
| 729 | +
|
| 730 | + Examples: |
| 731 | + >>> result = session.select_to_arrow("SELECT * FROM users") |
| 732 | + >>> df = result.to_polars() |
| 733 | + >>> print(df.head()) |
| 734 | + """ |
| 735 | + if self.data is None: |
| 736 | + msg = "No Arrow table available" |
| 737 | + raise ValueError(msg) |
| 738 | + |
| 739 | + ensure_polars() |
| 740 | + |
| 741 | + import polars as pl |
| 742 | + |
| 743 | + result = pl.from_arrow(self.data) |
| 744 | + if not isinstance(result, pl.DataFrame): |
| 745 | + msg = f"Expected a Polars DataFrame, but got {type(result).__name__}" |
| 746 | + raise TypeError(msg) |
| 747 | + return result |
| 748 | + |
| 749 | + def to_dict(self) -> "list[dict[str, Any]]": |
| 750 | + """Convert Arrow data to list of dictionaries. |
| 751 | +
|
| 752 | + Returns: |
| 753 | + List of dictionaries, one per row. |
| 754 | +
|
| 755 | + Raises: |
| 756 | + ValueError: If no Arrow table is available. |
| 757 | +
|
| 758 | + Examples: |
| 759 | + >>> result = session.select_to_arrow( |
| 760 | + ... "SELECT id, name FROM users" |
| 761 | + ... ) |
| 762 | + >>> rows = result.to_dict() |
| 763 | + >>> print(rows[0]) |
| 764 | + {'id': 1, 'name': 'Alice'} |
| 765 | + """ |
| 766 | + if self.data is None: |
| 767 | + msg = "No Arrow table available" |
| 768 | + raise ValueError(msg) |
| 769 | + |
| 770 | + return cast("list[dict[str, Any]]", self.data.to_pylist()) |
| 771 | + |
| 772 | + def __len__(self) -> int: |
| 773 | + """Return number of rows in the Arrow table. |
| 774 | +
|
| 775 | + Returns: |
| 776 | + Number of rows. |
| 777 | +
|
| 778 | + Raises: |
| 779 | + ValueError: If no Arrow table is available. |
| 780 | +
|
| 781 | + Examples: |
| 782 | + >>> result = session.select_to_arrow("SELECT * FROM users") |
| 783 | + >>> print(len(result)) |
| 784 | + 100 |
| 785 | + """ |
| 786 | + if self.data is None: |
| 787 | + msg = "No Arrow table available" |
| 788 | + raise ValueError(msg) |
| 789 | + |
| 790 | + return cast("int", self.data.num_rows) |
| 791 | + |
| 792 | + def __iter__(self) -> "Iterator[dict[str, Any]]": |
| 793 | + """Iterate over rows as dictionaries. |
| 794 | +
|
| 795 | + Yields: |
| 796 | + Dictionary for each row. |
| 797 | +
|
| 798 | + Raises: |
| 799 | + ValueError: If no Arrow table is available. |
| 800 | +
|
| 801 | + Examples: |
| 802 | + >>> result = session.select_to_arrow( |
| 803 | + ... "SELECT id, name FROM users" |
| 804 | + ... ) |
| 805 | + >>> for row in result: |
| 806 | + ... print(row["name"]) |
| 807 | + """ |
| 808 | + if self.data is None: |
| 809 | + msg = "No Arrow table available" |
| 810 | + raise ValueError(msg) |
| 811 | + |
| 812 | + yield from self.data.to_pylist() |
| 813 | + |
683 | 814 |
|
684 | 815 | def create_sql_result( |
685 | 816 | statement: "SQL", |
|
0 commit comments