From 9f593527faa9884c65da224b9325c28946f732ee Mon Sep 17 00:00:00 2001 From: H0TB0X420 Date: Tue, 7 Oct 2025 17:14:29 -0400 Subject: [PATCH] Add PyCapsule Protocol support for Arrow inputs - Add Protocol types for Arrow PyCapsule Interface - Update schema parameters to accept any Arrow-compatible library - Move pyarrow to TYPE_CHECKING (optional at runtime) --- pyproject.toml | 5 +++- python/datafusion/context.py | 47 +++++++++++++++++++++++------------- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index edecc4588..bb8fe9e96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,9 +43,12 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Rust", ] -dependencies = ["pyarrow>=11.0.0", "typing-extensions;python_version<'3.13'"] +dependencies = ["typing-extensions;python_version<'3.13'"] dynamic = ["version"] +[project.optional-dependencies] +pyarrow = ["pyarrow>=11.0.0"] + [project.urls] homepage = "https://datafusion.apache.org/python" documentation = "https://datafusion.apache.org/python" diff --git a/python/datafusion/context.py b/python/datafusion/context.py index b6e728b51..47c0cfd83 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -27,8 +27,6 @@ except ImportError: from typing_extensions import deprecated # Python 3.12 -import pyarrow as pa - from datafusion.catalog import Catalog, CatalogProvider, Table from datafusion.dataframe import DataFrame from datafusion.expr import SortKey, sort_list_to_raw_sort_list @@ -47,10 +45,21 @@ import pandas as pd import polars as pl # type: ignore[import] + import pyarrow as pa # Optional: only needed for type hints from datafusion.plan import ExecutionPlan, LogicalPlan +class ArrowSchemaExportable(Protocol): + """Type hint for object exporting Arrow Schema via Arrow PyCapsule Interface. + + https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html + """ + + def __arrow_c_schema__(self) -> object: # noqa: D105 + ... + + class ArrowStreamExportable(Protocol): """Type hint for object exporting Arrow C Stream via Arrow PyCapsule Interface. @@ -59,7 +68,8 @@ class ArrowStreamExportable(Protocol): def __arrow_c_stream__( # noqa: D105 self, requested_schema: object | None = None - ) -> object: ... + ) -> object: + ... class ArrowArrayExportable(Protocol): @@ -70,7 +80,8 @@ class ArrowArrayExportable(Protocol): def __arrow_c_array__( # noqa: D105 self, requested_schema: object | None = None - ) -> tuple[object, object]: ... + ) -> tuple[object, object]: + ... class TableProviderExportable(Protocol): @@ -79,7 +90,8 @@ class TableProviderExportable(Protocol): https://datafusion.apache.org/python/user-guide/io/table_provider.html """ - def __datafusion_table_provider__(self) -> object: ... # noqa: D105 + def __datafusion_table_provider__(self) -> object: # noqa: D105 + ... class CatalogProviderExportable(Protocol): @@ -88,7 +100,8 @@ class CatalogProviderExportable(Protocol): https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html """ - def __datafusion_catalog_provider__(self) -> object: ... # noqa: D105 + def __datafusion_catalog_provider__(self) -> object: # noqa: D105 + ... class SessionConfig: @@ -554,7 +567,7 @@ def register_listing_table( path: str | pathlib.Path, table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_extension: str = ".parquet", - schema: pa.Schema | None = None, + schema: ArrowSchemaExportable | None = None, file_sort_order: Sequence[Sequence[SortKey]] | None = None, ) -> None: """Register multiple files as a single table. @@ -623,7 +636,7 @@ def create_dataframe( self, partitions: list[list[pa.RecordBatch]], name: str | None = None, - schema: pa.Schema | None = None, + schema: ArrowSchemaExportable | None = None, ) -> DataFrame: """Create and return a dataframe using the provided partitions. @@ -806,7 +819,7 @@ def register_parquet( parquet_pruning: bool = True, file_extension: str = ".parquet", skip_metadata: bool = True, - schema: pa.Schema | None = None, + schema: ArrowSchemaExportable | None = None, file_sort_order: Sequence[Sequence[SortKey]] | None = None, ) -> None: """Register a Parquet file as a table. @@ -848,7 +861,7 @@ def register_csv( self, name: str, path: str | pathlib.Path | list[str | pathlib.Path], - schema: pa.Schema | None = None, + schema: ArrowSchemaExportable | None = None, has_header: bool = True, delimiter: str = ",", schema_infer_max_records: int = 1000, @@ -891,7 +904,7 @@ def register_json( self, name: str, path: str | pathlib.Path, - schema: pa.Schema | None = None, + schema: ArrowSchemaExportable | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, @@ -930,7 +943,7 @@ def register_avro( self, name: str, path: str | pathlib.Path, - schema: pa.Schema | None = None, + schema: ArrowSchemaExportable | None = None, file_extension: str = ".avro", table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, ) -> None: @@ -1005,7 +1018,7 @@ def session_id(self) -> str: def read_json( self, path: str | pathlib.Path, - schema: pa.Schema | None = None, + schema: ArrowSchemaExportable | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, @@ -1043,7 +1056,7 @@ def read_json( def read_csv( self, path: str | pathlib.Path | list[str] | list[pathlib.Path], - schema: pa.Schema | None = None, + schema: ArrowSchemaExportable | None = None, has_header: bool = True, delimiter: str = ",", schema_infer_max_records: int = 1000, @@ -1097,7 +1110,7 @@ def read_parquet( parquet_pruning: bool = True, file_extension: str = ".parquet", skip_metadata: bool = True, - schema: pa.Schema | None = None, + schema: ArrowSchemaExportable | None = None, file_sort_order: Sequence[Sequence[SortKey]] | None = None, ) -> DataFrame: """Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`. @@ -1141,7 +1154,7 @@ def read_parquet( def read_avro( self, path: str | pathlib.Path, - schema: pa.Schema | None = None, + schema: ArrowSchemaExportable | None = None, file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_extension: str = ".avro", ) -> DataFrame: @@ -1230,4 +1243,4 @@ def _convert_table_partition_cols( stacklevel=2, ) - return converted_table_partition_cols + return converted_table_partition_cols \ No newline at end of file