Skip to content

Commit 9f59352

Browse files
committed
Add PyCapsule Protocol support for Arrow inputs
- Add Protocol types for Arrow PyCapsule Interface - Update schema parameters to accept any Arrow-compatible library - Move pyarrow to TYPE_CHECKING (optional at runtime)
1 parent f08d5b0 commit 9f59352

File tree

2 files changed

+34
-18
lines changed

2 files changed

+34
-18
lines changed

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,12 @@ classifiers = [
4343
"Programming Language :: Python",
4444
"Programming Language :: Rust",
4545
]
46-
dependencies = ["pyarrow>=11.0.0", "typing-extensions;python_version<'3.13'"]
46+
dependencies = ["typing-extensions;python_version<'3.13'"]
4747
dynamic = ["version"]
4848

49+
[project.optional-dependencies]
50+
pyarrow = ["pyarrow>=11.0.0"]
51+
4952
[project.urls]
5053
homepage = "https://datafusion.apache.org/python"
5154
documentation = "https://datafusion.apache.org/python"

python/datafusion/context.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
except ImportError:
2828
from typing_extensions import deprecated # Python 3.12
2929

30-
import pyarrow as pa
31-
3230
from datafusion.catalog import Catalog, CatalogProvider, Table
3331
from datafusion.dataframe import DataFrame
3432
from datafusion.expr import SortKey, sort_list_to_raw_sort_list
@@ -47,10 +45,21 @@
4745

4846
import pandas as pd
4947
import polars as pl # type: ignore[import]
48+
import pyarrow as pa # Optional: only needed for type hints
5049

5150
from datafusion.plan import ExecutionPlan, LogicalPlan
5251

5352

53+
class ArrowSchemaExportable(Protocol):
54+
"""Type hint for object exporting Arrow Schema via Arrow PyCapsule Interface.
55+
56+
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
57+
"""
58+
59+
def __arrow_c_schema__(self) -> object: # noqa: D105
60+
...
61+
62+
5463
class ArrowStreamExportable(Protocol):
5564
"""Type hint for object exporting Arrow C Stream via Arrow PyCapsule Interface.
5665
@@ -59,7 +68,8 @@ class ArrowStreamExportable(Protocol):
5968

6069
def __arrow_c_stream__( # noqa: D105
6170
self, requested_schema: object | None = None
62-
) -> object: ...
71+
) -> object:
72+
...
6373

6474

6575
class ArrowArrayExportable(Protocol):
@@ -70,7 +80,8 @@ class ArrowArrayExportable(Protocol):
7080

7181
def __arrow_c_array__( # noqa: D105
7282
self, requested_schema: object | None = None
73-
) -> tuple[object, object]: ...
83+
) -> tuple[object, object]:
84+
...
7485

7586

7687
class TableProviderExportable(Protocol):
@@ -79,7 +90,8 @@ class TableProviderExportable(Protocol):
7990
https://datafusion.apache.org/python/user-guide/io/table_provider.html
8091
"""
8192

82-
def __datafusion_table_provider__(self) -> object: ... # noqa: D105
93+
def __datafusion_table_provider__(self) -> object: # noqa: D105
94+
...
8395

8496

8597
class CatalogProviderExportable(Protocol):
@@ -88,7 +100,8 @@ class CatalogProviderExportable(Protocol):
88100
https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
89101
"""
90102

91-
def __datafusion_catalog_provider__(self) -> object: ... # noqa: D105
103+
def __datafusion_catalog_provider__(self) -> object: # noqa: D105
104+
...
92105

93106

94107
class SessionConfig:
@@ -554,7 +567,7 @@ def register_listing_table(
554567
path: str | pathlib.Path,
555568
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
556569
file_extension: str = ".parquet",
557-
schema: pa.Schema | None = None,
570+
schema: ArrowSchemaExportable | None = None,
558571
file_sort_order: Sequence[Sequence[SortKey]] | None = None,
559572
) -> None:
560573
"""Register multiple files as a single table.
@@ -623,7 +636,7 @@ def create_dataframe(
623636
self,
624637
partitions: list[list[pa.RecordBatch]],
625638
name: str | None = None,
626-
schema: pa.Schema | None = None,
639+
schema: ArrowSchemaExportable | None = None,
627640
) -> DataFrame:
628641
"""Create and return a dataframe using the provided partitions.
629642
@@ -806,7 +819,7 @@ def register_parquet(
806819
parquet_pruning: bool = True,
807820
file_extension: str = ".parquet",
808821
skip_metadata: bool = True,
809-
schema: pa.Schema | None = None,
822+
schema: ArrowSchemaExportable | None = None,
810823
file_sort_order: Sequence[Sequence[SortKey]] | None = None,
811824
) -> None:
812825
"""Register a Parquet file as a table.
@@ -848,7 +861,7 @@ def register_csv(
848861
self,
849862
name: str,
850863
path: str | pathlib.Path | list[str | pathlib.Path],
851-
schema: pa.Schema | None = None,
864+
schema: ArrowSchemaExportable | None = None,
852865
has_header: bool = True,
853866
delimiter: str = ",",
854867
schema_infer_max_records: int = 1000,
@@ -891,7 +904,7 @@ def register_json(
891904
self,
892905
name: str,
893906
path: str | pathlib.Path,
894-
schema: pa.Schema | None = None,
907+
schema: ArrowSchemaExportable | None = None,
895908
schema_infer_max_records: int = 1000,
896909
file_extension: str = ".json",
897910
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
@@ -930,7 +943,7 @@ def register_avro(
930943
self,
931944
name: str,
932945
path: str | pathlib.Path,
933-
schema: pa.Schema | None = None,
946+
schema: ArrowSchemaExportable | None = None,
934947
file_extension: str = ".avro",
935948
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
936949
) -> None:
@@ -1005,7 +1018,7 @@ def session_id(self) -> str:
10051018
def read_json(
10061019
self,
10071020
path: str | pathlib.Path,
1008-
schema: pa.Schema | None = None,
1021+
schema: ArrowSchemaExportable | None = None,
10091022
schema_infer_max_records: int = 1000,
10101023
file_extension: str = ".json",
10111024
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
@@ -1043,7 +1056,7 @@ def read_json(
10431056
def read_csv(
10441057
self,
10451058
path: str | pathlib.Path | list[str] | list[pathlib.Path],
1046-
schema: pa.Schema | None = None,
1059+
schema: ArrowSchemaExportable | None = None,
10471060
has_header: bool = True,
10481061
delimiter: str = ",",
10491062
schema_infer_max_records: int = 1000,
@@ -1097,7 +1110,7 @@ def read_parquet(
10971110
parquet_pruning: bool = True,
10981111
file_extension: str = ".parquet",
10991112
skip_metadata: bool = True,
1100-
schema: pa.Schema | None = None,
1113+
schema: ArrowSchemaExportable | None = None,
11011114
file_sort_order: Sequence[Sequence[SortKey]] | None = None,
11021115
) -> DataFrame:
11031116
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
@@ -1141,7 +1154,7 @@ def read_parquet(
11411154
def read_avro(
11421155
self,
11431156
path: str | pathlib.Path,
1144-
schema: pa.Schema | None = None,
1157+
schema: ArrowSchemaExportable | None = None,
11451158
file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
11461159
file_extension: str = ".avro",
11471160
) -> DataFrame:
@@ -1230,4 +1243,4 @@ def _convert_table_partition_cols(
12301243
stacklevel=2,
12311244
)
12321245

1233-
return converted_table_partition_cols
1246+
return converted_table_partition_cols

0 commit comments

Comments
 (0)