diff --git a/docs/source/conf.py b/docs/source/conf.py index 28db17d35..18d5f1232 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -83,6 +83,9 @@ def autoapi_skip_member_fn(app, what, name, obj, skip, options) -> bool: # noqa # Duplicate modules (skip module-level docs to avoid duplication) ("module", "datafusion.col"), ("module", "datafusion.udf"), + # Private variables causing duplicate documentation + ("data", "datafusion.utils._PYARROW_DATASET_TYPES"), + ("variable", "datafusion.utils._PYARROW_DATASET_TYPES"), # Deprecated ("class", "datafusion.substrait.serde"), ("class", "datafusion.substrait.plan"), @@ -91,9 +94,28 @@ def autoapi_skip_member_fn(app, what, name, obj, skip, options) -> bool: # noqa ("method", "datafusion.context.SessionContext.tables"), ("method", "datafusion.dataframe.DataFrame.unnest_column"), ] + # Explicitly skip certain members listed above. These are either + # re-exports, duplicate module-level documentation, deprecated + # API surfaces, or private variables that would otherwise appear + # in the generated docs and cause confusing duplication. + # Keeping this explicit list avoids surprising entries in the + # AutoAPI output and gives us a single place to opt-out items + # when we intentionally hide them from the docs. if (what, name) in skip_contents: skip = True + # Skip private module-level names (those whose final component + # starts with an underscore) when AutoAPI is rendering data or + # variable entries. Many internal module-level constants are + # implementation details (for example private pyarrow dataset type + # mappings) that would otherwise be emitted as top-level "data" + # or "variable" docs. Filtering them here avoids noisy, + # duplicate, or implementation-specific entries in the public + # documentation while still allowing public members and types to + # be documented normally. + if name.split(".")[-1].startswith("_") and what in ("data", "variable"): + skip = True + return skip diff --git a/docs/source/contributor-guide/ffi.rst b/docs/source/contributor-guide/ffi.rst index e201db71e..e8abde00d 100644 --- a/docs/source/contributor-guide/ffi.rst +++ b/docs/source/contributor-guide/ffi.rst @@ -34,7 +34,7 @@ as performant as possible and to utilize the features of DataFusion, you may dec your source in Rust and then expose it through `PyO3 `_ as a Python library. At first glance, it may appear the best way to do this is to add the ``datafusion-python`` -crate as a dependency, provide a ``PyTable``, and then to register it with the +crate as a dependency, produce a DataFusion table in Rust, and then register it with the ``SessionContext``. Unfortunately, this will not work. When you produce your code as a Python library and it needs to interact with the DataFusion diff --git a/docs/source/user-guide/data-sources.rst b/docs/source/user-guide/data-sources.rst index a9b119b93..bedbabffb 100644 --- a/docs/source/user-guide/data-sources.rst +++ b/docs/source/user-guide/data-sources.rst @@ -152,13 +152,26 @@ as Delta Lake. This will require a recent version of .. code-block:: python from deltalake import DeltaTable + from datafusion import Table delta_table = DeltaTable("path_to_table") - ctx.register_table_provider("my_delta_table", delta_table) + table = Table.from_capsule(delta_table.__datafusion_table_provider__()) + ctx.register_table("my_delta_table", table) df = ctx.table("my_delta_table") df.show() -On older versions of ``deltalake`` (prior to 0.22) you can use the +Objects that implement ``__datafusion_table_provider__`` are supported directly by +:py:meth:`~datafusion.context.SessionContext.register_table`, making it easy to +work with custom table providers from Python libraries such as Delta Lake. + +.. note:: + + :py:meth:`~datafusion.context.SessionContext.register_table_provider` is + deprecated. Use + :py:meth:`~datafusion.context.SessionContext.register_table` with a + :py:class:`~datafusion.Table` instead. + +On older versions of ``deltalake`` (prior to 0.22) you can use the `Arrow DataSet `_ interface to import to DataFusion, but this does not support features such as filter push down which can lead to a significant performance difference. diff --git a/docs/source/user-guide/io/table_provider.rst b/docs/source/user-guide/io/table_provider.rst index bd1d6b80f..0dfc07c3b 100644 --- a/docs/source/user-guide/io/table_provider.rst +++ b/docs/source/user-guide/io/table_provider.rst @@ -39,20 +39,47 @@ A complete example can be found in the `examples folder PyResult> { let name = CString::new("datafusion_table_provider").unwrap(); - let provider = Arc::new(self.clone()) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - let provider = FFI_TableProvider::new(Arc::new(provider), false); + let provider = Arc::new(self.clone()); + let provider = FFI_TableProvider::new(provider, false, None); PyCapsule::new_bound(py, provider, Some(name.clone())) } } -Once you have this library available, in python you can register your table provider -to the ``SessionContext``. +Once you have this library available, you can construct a +:py:class:`~datafusion.Table` in Python and register it with the +``SessionContext``. Tables can be created either from the PyCapsule exposed by your +Rust provider or from an existing :py:class:`~datafusion.dataframe.DataFrame`. +Call the provider's ``__datafusion_table_provider__()`` method to obtain the capsule +before constructing a ``Table``. The ``Table.from_view()`` helper is +deprecated; instead use ``Table.from_dataframe()`` or ``DataFrame.into_view()``. + +.. note:: + + :py:meth:`~datafusion.context.SessionContext.register_table_provider` is + deprecated. Use + :py:meth:`~datafusion.context.SessionContext.register_table` with the + resulting :py:class:`~datafusion.Table` instead. .. code-block:: python + from datafusion import SessionContext, Table + + ctx = SessionContext() provider = MyTableProvider() - ctx.register_table_provider("my_table", provider) - ctx.table("my_table").show() + capsule = provider.__datafusion_table_provider__() + capsule_table = Table.from_capsule(capsule) + + df = ctx.from_pydict({"a": [1]}) + view_table = Table.from_dataframe(df) + # or: view_table = df.into_view() + + ctx.register_table("capsule_table", capsule_table) + ctx.register_table("view_table", view_table) + + ctx.table("capsule_table").show() + ctx.table("view_table").show() + +Both ``Table.from_capsule()`` and ``Table.from_dataframe()`` create +table providers that can be registered with the SessionContext using ``register_table()``. diff --git a/examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py b/examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py index 7ea6b295c..b2b0480db 100644 --- a/examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py +++ b/examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py @@ -45,33 +45,33 @@ def test_ffi_aggregate_register(): result = ctx.sql("select my_custom_sum(a) from test_table group by b").collect() - assert len(result) == 2 + assert result assert result[0].num_columns == 1 - result = [r.column(0) for r in result] - expected = [ - pa.array([3], type=pa.int64()), - pa.array([3], type=pa.int64()), - ] + # Normalizing table registration in _normalize_table_provider feeds the Rust layer + # an actual TableProvider, so collect() emits the grouped rows in a single record batch + # instead of two separate batches. + aggregates = pa.concat_arrays([batch.column(0) for batch in result]) - assert result == expected + assert len(aggregates) == 2 + assert aggregates.to_pylist() == [3, 3] def test_ffi_aggregate_call_directly(): ctx = setup_context_with_table() my_udaf = udaf(MySumUDF()) - + result = ( ctx.table("test_table").aggregate([col("b")], [my_udaf(col("a"))]).collect() ) - assert len(result) == 2 + # Normalizing table registration in _normalize_table_provider feeds the Rust layer + # an actual TableProvider, so collect() emits the grouped rows in a single record batch + # instead of two separate batches. + assert result assert result[0].num_columns == 2 - result = [r.column(1) for r in result] - expected = [ - pa.array([3], type=pa.int64()), - pa.array([3], type=pa.int64()), - ] + aggregates = pa.concat_arrays([batch.column(1) for batch in result]) - assert result == expected + assert len(aggregates) == 2 + assert aggregates.to_pylist() == [3, 3] diff --git a/examples/datafusion-ffi-example/python/tests/_test_table_function.py b/examples/datafusion-ffi-example/python/tests/_test_table_function.py index f3c56a90a..4b8b21454 100644 --- a/examples/datafusion-ffi-example/python/tests/_test_table_function.py +++ b/examples/datafusion-ffi-example/python/tests/_test_table_function.py @@ -53,7 +53,7 @@ def test_ffi_table_function_call_directly(): table_udtf = udtf(table_func, "my_table_func") my_table = table_udtf() - ctx.register_table_provider("t", my_table) + ctx.register_table("t", my_table) result = ctx.table("t").collect() assert len(result) == 2 diff --git a/examples/datafusion-ffi-example/python/tests/_test_table_provider.py b/examples/datafusion-ffi-example/python/tests/_test_table_provider.py index 6b24da06c..01a961e0e 100644 --- a/examples/datafusion-ffi-example/python/tests/_test_table_provider.py +++ b/examples/datafusion-ffi-example/python/tests/_test_table_provider.py @@ -18,14 +18,14 @@ from __future__ import annotations import pyarrow as pa -from datafusion import SessionContext +from datafusion import SessionContext, Table from datafusion_ffi_example import MyTableProvider def test_table_loading(): ctx = SessionContext() table = MyTableProvider(3, 2, 4) - ctx.register_table_provider("t", table) + ctx.register_table("t", Table.from_capsule(table.__datafusion_table_provider__())) result = ctx.table("t").collect() assert len(result) == 4 diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index e9d2dba75..66565a4db 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -28,17 +28,16 @@ try: import importlib.metadata as importlib_metadata except ImportError: - import importlib_metadata + import importlib_metadata # type: ignore[import] +# Public submodules from . import functions, object_store, substrait, unparser # The following imports are okay to remain as opaque to the user. -from ._internal import Config +from ._internal import EXPECTED_PROVIDER_MSG, Config from .catalog import Catalog, Database, Table from .col import col, column -from .common import ( - DFSchema, -) +from .common import DFSchema from .context import ( RuntimeEnvBuilder, SessionConfig, @@ -47,10 +46,7 @@ ) from .dataframe import DataFrame, ParquetColumnOptions, ParquetWriterOptions from .dataframe_formatter import configure_formatter -from .expr import ( - Expr, - WindowFrame, -) +from .expr import Expr, WindowFrame from .io import read_avro, read_csv, read_json, read_parquet from .plan import ExecutionPlan, LogicalPlan from .record_batch import RecordBatch, RecordBatchStream @@ -69,6 +65,7 @@ __version__ = importlib_metadata.version(__name__) __all__ = [ + "EXPECTED_PROVIDER_MSG", "Accumulator", "AggregateUDF", "Catalog", diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 536b3a790..bd3300dab 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -19,14 +19,19 @@ from __future__ import annotations +import warnings from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Protocol +from typing import TYPE_CHECKING, Any, Protocol import datafusion._internal as df_internal +from datafusion._internal import EXPECTED_PROVIDER_MSG +from datafusion.utils import _normalize_table_provider if TYPE_CHECKING: import pyarrow as pa + from datafusion.context import TableProviderExportable + try: from warnings import deprecated # Python 3.13+ except ImportError: @@ -82,7 +87,11 @@ def database(self, name: str = "public") -> Schema: """Returns the database with the given ``name`` from this catalog.""" return self.schema(name) - def register_schema(self, name, schema) -> Schema | None: + def register_schema( + self, + name: str, + schema: Schema | SchemaProvider | SchemaProviderExportable, + ) -> Schema | None: """Register a schema with this catalog.""" if isinstance(schema, Schema): return self.catalog.register_schema(name, schema._raw_schema) @@ -122,11 +131,16 @@ def table(self, name: str) -> Table: """Return the table with the given ``name`` from this schema.""" return Table(self._raw_schema.table(name)) - def register_table(self, name, table) -> None: - """Register a table provider in this schema.""" - if isinstance(table, Table): - return self._raw_schema.register_table(name, table.table) - return self._raw_schema.register_table(name, table) + def register_table( + self, name: str, table: Table | TableProviderExportable | Any + ) -> None: + """Register a table or table provider in this schema. + + Objects implementing ``__datafusion_table_provider__`` are also supported + and treated as table provider instances. + """ + provider = _normalize_table_provider(table) + return self._raw_schema.register_table(name, provider) def deregister_table(self, name: str) -> None: """Deregister a table provider from this schema.""" @@ -138,31 +152,101 @@ class Database(Schema): """See `Schema`.""" +_InternalRawTable = df_internal.catalog.RawTable +_InternalTableProvider = df_internal.TableProvider + +# Keep in sync with ``datafusion._internal.TableProvider.from_view``. +_FROM_VIEW_WARN_STACKLEVEL = 2 + + class Table: - """DataFusion table.""" + """DataFusion table or table provider wrapper.""" - def __init__(self, table: df_internal.catalog.RawTable) -> None: - """This constructor is not typically called by the end user.""" - self.table = table + __slots__ = ("_table",) + + def __init__( + self, + table: _InternalRawTable | _InternalTableProvider | Table, + ) -> None: + """Wrap a low level table or table provider.""" + if isinstance(table, Table): + table = table.table + + if not isinstance(table, (_InternalRawTable, _InternalTableProvider)): + raise TypeError(EXPECTED_PROVIDER_MSG) + + self._table = table + + def __getattribute__(self, name: str) -> Any: + """Restrict provider-specific helpers to compatible tables.""" + if name == "__datafusion_table_provider__": + table = object.__getattribute__(self, "_table") + if not hasattr(table, "__datafusion_table_provider__"): + raise AttributeError(name) + return object.__getattribute__(self, name) def __repr__(self) -> str: """Print a string representation of the table.""" - return self.table.__repr__() + return repr(self._table) - @staticmethod - def from_dataset(dataset: pa.dataset.Dataset) -> Table: - """Turn a pyarrow Dataset into a Table.""" - return Table(df_internal.catalog.RawTable.from_dataset(dataset)) + @property + def table(self) -> _InternalRawTable | _InternalTableProvider: + """Return the wrapped low level table object.""" + return self._table + + @classmethod + def from_dataset(cls, dataset: pa.dataset.Dataset) -> Table: + """Turn a :mod:`pyarrow.dataset` ``Dataset`` into a :class:`Table`.""" + return cls(_InternalRawTable.from_dataset(dataset)) + + @classmethod + def from_capsule(cls, capsule: Any) -> Table: + """Create a :class:`Table` from a PyCapsule exported provider.""" + provider = _InternalTableProvider.from_capsule(capsule) + return cls(provider) + + @classmethod + def from_dataframe(cls, df: Any) -> Table: + """Create a :class:`Table` from tabular data.""" + from datafusion.dataframe import DataFrame as DataFrameWrapper + + dataframe = df if isinstance(df, DataFrameWrapper) else DataFrameWrapper(df) + return dataframe.into_view() + + @classmethod + def from_view(cls, df: Any) -> Table: + """Deprecated helper for constructing tables from views.""" + from datafusion.dataframe import DataFrame as DataFrameWrapper + + if isinstance(df, DataFrameWrapper): + df = df.df + + provider = _InternalTableProvider.from_view(df) + warnings.warn( + "Table.from_view is deprecated; use DataFrame.into_view or " + "Table.from_dataframe instead.", + category=DeprecationWarning, + stacklevel=_FROM_VIEW_WARN_STACKLEVEL, + ) + return cls(provider) @property def schema(self) -> pa.Schema: """Returns the schema associated with this table.""" - return self.table.schema + return self._table.schema @property def kind(self) -> str: """Returns the kind of table.""" - return self.table.kind + return self._table.kind + + def __datafusion_table_provider__(self) -> Any: + """Expose the wrapped provider for FFI integrations.""" + exporter = getattr(self._table, "__datafusion_table_provider__", None) + if exporter is None: + msg = "Underlying object does not export __datafusion_table_provider__()" + raise AttributeError(msg) + return exporter() class CatalogProvider(ABC): @@ -219,14 +303,19 @@ def table(self, name: str) -> Table | None: """Retrieve a specific table from this schema.""" ... - def register_table(self, name: str, table: Table) -> None: # noqa: B027 - """Add a table from this schema. + def register_table( # noqa: B027 + self, name: str, table: Table | TableProviderExportable | Any + ) -> None: + """Add a table to this schema. This method is optional. If your schema provides a fixed list of tables, you do not need to implement this method. + + Objects implementing ``__datafusion_table_provider__`` are also supported + and treated as table provider instances. """ - def deregister_table(self, name, cascade: bool) -> None: # noqa: B027 + def deregister_table(self, name: str, cascade: bool) -> None: # noqa: B027 """Remove a table from this schema. This method is optional. If your schema provides a fixed list of tables, you do diff --git a/python/datafusion/context.py b/python/datafusion/context.py index b6e728b51..4f1c18663 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -29,11 +29,11 @@ import pyarrow as pa -from datafusion.catalog import Catalog, CatalogProvider, Table +from datafusion.catalog import Catalog from datafusion.dataframe import DataFrame -from datafusion.expr import SortKey, sort_list_to_raw_sort_list +from datafusion.expr import sort_list_to_raw_sort_list from datafusion.record_batch import RecordBatchStream -from datafusion.user_defined import AggregateUDF, ScalarUDF, TableFunction, WindowUDF +from datafusion.utils import _normalize_table_provider from ._internal import RuntimeEnvBuilder as RuntimeEnvBuilderInternal from ._internal import SessionConfig as SessionConfigInternal @@ -48,7 +48,15 @@ import pandas as pd import polars as pl # type: ignore[import] + from datafusion.catalog import CatalogProvider, Table + from datafusion.expr import SortKey from datafusion.plan import ExecutionPlan, LogicalPlan + from datafusion.user_defined import ( + AggregateUDF, + ScalarUDF, + TableFunction, + WindowUDF, + ) class ArrowStreamExportable(Protocol): @@ -733,7 +741,7 @@ def from_polars(self, data: pl.DataFrame, name: str | None = None) -> DataFrame: # https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116 # is the discussion on how we arrived at adding register_view def register_view(self, name: str, df: DataFrame) -> None: - """Register a :py:class: `~datafusion.detaframe.DataFrame` as a view. + """Register a :py:class:`~datafusion.dataframe.DataFrame` as a view. Args: name (str): The name to register the view under. @@ -742,16 +750,28 @@ def register_view(self, name: str, df: DataFrame) -> None: view = df.into_view() self.ctx.register_table(name, view) - def register_table(self, name: str, table: Table) -> None: - """Register a :py:class: `~datafusion.catalog.Table` as a table. + def register_table( + self, name: str, table: Table | TableProviderExportable | Any + ) -> None: + """Register a :py:class:`~datafusion.Table` with this context. - The registered table can be referenced from SQL statement executed against. + The registered table can be referenced from SQL statements executed against + this context. + + Plain :py:class:`~datafusion.dataframe.DataFrame` objects are not supported; + convert them first with :meth:`datafusion.dataframe.DataFrame.into_view` or + :meth:`datafusion.Table.from_dataframe`. + + Objects implementing ``__datafusion_table_provider__`` are also supported + and treated as table provider instances. Args: name: Name of the resultant table. - table: DataFusion table to add to the session context. + table: DataFusion :class:`Table` or any object implementing + ``__datafusion_table_provider__`` to add to the session context. """ - self.ctx.register_table(name, table.table) + provider = _normalize_table_provider(table) + self.ctx.register_table(name, provider) def deregister_table(self, name: str) -> None: """Remove a table from the session.""" @@ -771,14 +791,21 @@ def register_catalog_provider( self.ctx.register_catalog_provider(name, provider) def register_table_provider( - self, name: str, provider: TableProviderExportable + self, name: str, provider: Table | TableProviderExportable | Any ) -> None: """Register a table provider. - This table provider must have a method called ``__datafusion_table_provider__`` - which returns a PyCapsule that exposes a ``FFI_TableProvider``. + Deprecated: use :meth:`register_table` instead. + + Objects implementing ``__datafusion_table_provider__`` are also supported + and treated as table provider instances. """ - self.ctx.register_table_provider(name, provider) + warnings.warn( + "register_table_provider is deprecated; use register_table", + DeprecationWarning, + stacklevel=2, + ) + self.register_table(name, provider) def register_udtf(self, func: TableFunction) -> None: """Register a user defined table function.""" diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 68e6fe5a8..7834ceffd 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -60,6 +60,8 @@ import polars as pl import pyarrow as pa + from datafusion.catalog import Table + from enum import Enum @@ -313,9 +315,40 @@ def __init__(self, df: DataFrameInternal) -> None: """ self.df = df - def into_view(self) -> pa.Table: - """Convert DataFrame as a ViewTable which can be used in register_table.""" - return self.df.into_view() + def into_view(self) -> Table: + """Convert ``DataFrame`` into a :class:`~datafusion.Table` for registration. + + This is the preferred way to obtain a view for + :py:meth:`~datafusion.context.SessionContext.register_table` for several + reasons: + + 1. **Direct API**: Most efficient path - directly calls the underlying Rust + ``DataFrame.into_view()`` method without intermediate delegations. + 2. **Clear semantics**: The ``into_`` prefix follows Rust conventions, + indicating conversion from one type to another. + 3. **Canonical method**: Other approaches like ``Table.from_dataframe`` + delegate to this method internally, making this the single source of truth. + 4. **Deprecated alternatives**: The older ``Table.from_view`` helper + is deprecated and issues warnings when used. + + ``datafusion.Table.from_dataframe`` calls this method under the hood, + and the older ``Table.from_view`` helper is deprecated. + + The ``DataFrame`` remains valid after conversion, so it can still be used for + additional queries alongside the returned view. + + Examples: + >>> from datafusion import SessionContext + >>> ctx = SessionContext() + >>> df = ctx.sql("SELECT 1 AS value") + >>> view = df.into_view() + >>> ctx.register_table("values_view", view) + >>> df.collect() # The DataFrame is still usable + >>> ctx.sql("SELECT value FROM values_view").collect() + """ + from datafusion.catalog import Table as _Table + + return _Table(self.df.into_view()) def __getitem__(self, key: str | list[str]) -> DataFrame: """Return a new :py:class`DataFrame` with the specified column or columns. diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 5d1180bd1..82e30a78c 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -25,14 +25,12 @@ import typing as _typing from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional, Sequence -import pyarrow as pa - try: from warnings import deprecated # Python 3.13+ except ImportError: from typing_extensions import deprecated # Python 3.12 -from datafusion.common import NullTreatment +import pyarrow as pa from ._internal import expr as expr_internal from ._internal import functions as functions_internal @@ -40,8 +38,11 @@ if TYPE_CHECKING: from collections.abc import Sequence - # Type-only imports - from datafusion.common import DataTypeMap, RexType + from datafusion.common import ( # type: ignore[import] + DataTypeMap, + NullTreatment, + RexType, + ) from datafusion.plan import LogicalPlan diff --git a/python/datafusion/io.py b/python/datafusion/io.py index 551e20a6f..67dbc730f 100644 --- a/python/datafusion/io.py +++ b/python/datafusion/io.py @@ -22,13 +22,13 @@ from typing import TYPE_CHECKING from datafusion.context import SessionContext -from datafusion.dataframe import DataFrame if TYPE_CHECKING: import pathlib import pyarrow as pa + from datafusion.dataframe import DataFrame from datafusion.expr import Expr diff --git a/python/datafusion/utils.py b/python/datafusion/utils.py new file mode 100644 index 000000000..eb3e3d626 --- /dev/null +++ b/python/datafusion/utils.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Miscellaneous helper utilities for DataFusion's Python bindings.""" + +from __future__ import annotations + +from importlib import import_module, util +from typing import TYPE_CHECKING, Any + +from datafusion._internal import EXPECTED_PROVIDER_MSG + +_PYARROW_DATASET_TYPES: tuple[type[Any], ...] +_dataset_spec = util.find_spec("pyarrow.dataset") +if _dataset_spec is None: # pragma: no cover - optional dependency at runtime + _PYARROW_DATASET_TYPES = () +else: # pragma: no cover - exercised in environments with pyarrow installed + _dataset_module = import_module("pyarrow.dataset") + dataset_base = getattr(_dataset_module, "Dataset", None) + dataset_types: set[type[Any]] = set() + if isinstance(dataset_base, type): + dataset_types.add(dataset_base) + for value in vars(_dataset_module).values(): + if isinstance(value, type) and issubclass(value, dataset_base): + dataset_types.add(value) + _PYARROW_DATASET_TYPES = tuple(dataset_types) + +if TYPE_CHECKING: # pragma: no cover - imported for typing only + from datafusion.catalog import Table + from datafusion.context import TableProviderExportable + + +def _normalize_table_provider( + table: Table | TableProviderExportable | Any, +) -> Any: + """Return the underlying provider for supported table inputs. + + Args: + table: A :class:`~datafusion.Table`, object exporting a DataFusion table + provider via ``__datafusion_table_provider__``, or compatible + :mod:`pyarrow.dataset` implementation. + + Returns: + The object expected by the Rust bindings for table registration. + + Raises: + TypeError: If ``table`` is not a supported table provider input. + """ + from datafusion.catalog import Table as _Table + + if isinstance(table, _Table): + return table.table + + if _PYARROW_DATASET_TYPES and isinstance(table, _PYARROW_DATASET_TYPES): + return table + + provider_factory = getattr(table, "__datafusion_table_provider__", None) + if callable(provider_factory): + return table + + raise TypeError(EXPECTED_PROVIDER_MSG) diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 1f9ecbfc3..fd91d6677 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -20,7 +20,7 @@ import pyarrow as pa import pyarrow.dataset as ds import pytest -from datafusion import SessionContext, Table +from datafusion import EXPECTED_PROVIDER_MSG, SessionContext, Table # Note we take in `database` as a variable even though we don't use @@ -164,6 +164,38 @@ def test_python_table_provider(ctx: SessionContext): assert schema.table_names() == {"table4"} +def test_schema_register_table_with_pyarrow_dataset(ctx: SessionContext): + schema = ctx.catalog().schema() + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + dataset = ds.dataset([batch]) + table_name = "pa_dataset" + + try: + schema.register_table(table_name, dataset) + assert table_name in schema.table_names() + + result = ctx.sql(f"SELECT a, b FROM {table_name}").collect() + + assert len(result) == 1 + assert result[0].column(0) == pa.array([1, 2, 3]) + assert result[0].column(1) == pa.array([4, 5, 6]) + finally: + schema.deregister_table(table_name) + + +def test_schema_register_table_with_dataframe_errors(ctx: SessionContext): + schema = ctx.catalog().schema() + df = ctx.from_pydict({"a": [1]}) + + with pytest.raises(Exception) as exc_info: + schema.register_table("bad", df) + + assert str(exc_info.value) == EXPECTED_PROVIDER_MSG + + def test_in_end_to_end_python_providers(ctx: SessionContext): """Test registering all python providers and running a query against them.""" diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 6dbcc0d5e..243178797 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -22,11 +22,13 @@ import pyarrow.dataset as ds import pytest from datafusion import ( + EXPECTED_PROVIDER_MSG, DataFrame, RuntimeEnvBuilder, SessionConfig, SessionContext, SQLOptions, + Table, column, literal, ) @@ -330,6 +332,73 @@ def test_deregister_table(ctx, database): assert public.names() == {"csv1", "csv2"} +def test_register_table_from_dataframe_into_view(ctx): + df = ctx.from_pydict({"a": [1, 2]}) + table = df.into_view() + assert isinstance(table, Table) + ctx.register_table("view_tbl", table) + result = ctx.sql("SELECT * FROM view_tbl").collect() + assert [b.to_pydict() for b in result] == [{"a": [1, 2]}] + + +def test_table_from_capsule(ctx): + df = ctx.from_pydict({"a": [1, 2]}) + table = df.into_view() + capsule = table.__datafusion_table_provider__() + table2 = Table.from_capsule(capsule) + assert isinstance(table2, Table) + ctx.register_table("capsule_tbl", table2) + result = ctx.sql("SELECT * FROM capsule_tbl").collect() + assert [b.to_pydict() for b in result] == [{"a": [1, 2]}] + + +def test_table_from_dataframe(ctx): + df = ctx.from_pydict({"a": [1, 2]}) + table = Table.from_dataframe(df) + assert isinstance(table, Table) + ctx.register_table("from_dataframe_tbl", table) + result = ctx.sql("SELECT * FROM from_dataframe_tbl").collect() + assert [b.to_pydict() for b in result] == [{"a": [1, 2]}] + + +def test_table_from_dataframe_internal(ctx): + df = ctx.from_pydict({"a": [1, 2]}) + table = Table.from_dataframe(df.df) + assert isinstance(table, Table) + ctx.register_table("from_internal_dataframe_tbl", table) + result = ctx.sql("SELECT * FROM from_internal_dataframe_tbl").collect() + assert [b.to_pydict() for b in result] == [{"a": [1, 2]}] + + +def test_register_table_capsule_direct(ctx): + df = ctx.from_pydict({"a": [1, 2]}) + provider = df.into_view() + + class CapsuleProvider: + def __init__(self, inner): + self._inner = inner + + def __datafusion_table_provider__(self): + return self._inner.__datafusion_table_provider__() + + ctx.register_table("capsule_direct_tbl", CapsuleProvider(provider)) + result = ctx.sql("SELECT * FROM capsule_direct_tbl").collect() + assert [b.to_pydict() for b in result] == [{"a": [1, 2]}] + + +def test_table_from_capsule_invalid(): + with pytest.raises(RuntimeError): + Table.from_capsule(object()) + + +def test_register_table_with_dataframe_errors(ctx): + df = ctx.from_pydict({"a": [1]}) + with pytest.raises(TypeError) as exc_info: + ctx.register_table("bad", df) + + assert str(exc_info.value) == EXPECTED_PROVIDER_MSG + + def test_register_dataset(ctx): # create a RecordBatch and register it as a pyarrow.dataset.Dataset batch = pa.RecordBatch.from_arrays( diff --git a/python/tests/test_wrapper_coverage.py b/python/tests/test_wrapper_coverage.py index f484cb282..5df454d1d 100644 --- a/python/tests/test_wrapper_coverage.py +++ b/python/tests/test_wrapper_coverage.py @@ -21,6 +21,8 @@ import datafusion.substrait import pytest +IGNORED_EXPORTS = {"TableProvider"} + # EnumType introduced in 3.11. 3.10 and prior it was called EnumMeta. try: from enum import EnumType @@ -28,7 +30,29 @@ from enum import EnumMeta as EnumType -def missing_exports(internal_obj, wrapped_obj) -> None: # noqa: C901 +def _check_enum_exports(internal_obj, wrapped_obj) -> None: + """Check that all enum values are present in wrapped object.""" + expected_values = [v for v in dir(internal_obj) if not v.startswith("__")] + for value in expected_values: + assert value in dir(wrapped_obj) + + +def _check_list_attribute(internal_attr, wrapped_attr) -> None: + """Check that list attributes match between internal and wrapped objects.""" + assert isinstance(wrapped_attr, list) + + # We have cases like __all__ that are a list and we want to be certain that + # every value in the list in the internal object is also in the wrapper list + for val in internal_attr: + if isinstance(val, str) and val in IGNORED_EXPORTS: + continue + if isinstance(val, str) and val.startswith("Raw"): + assert val[3:] in wrapped_attr + else: + assert val in wrapped_attr + + +def missing_exports(internal_obj, wrapped_obj) -> None: """ Identify if any of the rust exposted structs or functions do not have wrappers. @@ -36,13 +60,12 @@ def missing_exports(internal_obj, wrapped_obj) -> None: # noqa: C901 - Raw* classes: Internal implementation details that shouldn't be exposed - _global_ctx: Internal implementation detail - __self__, __class__, __repr__: Python special attributes + - TableProvider: Superseded by the public ``Table`` API in Python """ # Special case enums - EnumType overrides a some of the internal functions, # so check all of the values exist and move on if isinstance(wrapped_obj, EnumType): - expected_values = [v for v in dir(internal_obj) if not v.startswith("__")] - for value in expected_values: - assert value in dir(wrapped_obj) + _check_enum_exports(internal_obj, wrapped_obj) return if "__repr__" in internal_obj.__dict__ and "__repr__" not in wrapped_obj.__dict__: @@ -50,6 +73,10 @@ def missing_exports(internal_obj, wrapped_obj) -> None: # noqa: C901 for internal_attr_name in dir(internal_obj): wrapped_attr_name = internal_attr_name.removeprefix("Raw") + + if wrapped_attr_name in IGNORED_EXPORTS: + continue + assert wrapped_attr_name in dir(wrapped_obj) internal_attr = getattr(internal_obj, internal_attr_name) @@ -66,15 +93,7 @@ def missing_exports(internal_obj, wrapped_obj) -> None: # noqa: C901 continue if isinstance(internal_attr, list): - assert isinstance(wrapped_attr, list) - - # We have cases like __all__ that are a list and we want to be certain that - # every value in the list in the internal object is also in the wrapper list - for val in internal_attr: - if isinstance(val, str) and val.startswith("Raw"): - assert val[3:] in wrapped_attr - else: - assert val in wrapped_attr + _check_list_attribute(internal_attr, wrapped_attr) elif hasattr(internal_attr, "__dict__"): # Check all submodules recursively missing_exports(internal_attr, wrapped_attr) diff --git a/src/catalog.rs b/src/catalog.rs index 17d4ec3b8..03e6408cd 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -17,7 +17,10 @@ use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; -use crate::utils::{validate_pycapsule, wait_for_future}; +use crate::table::PyTableProvider; +use crate::utils::{ + coerce_table_provider, table_provider_from_pycapsule, validate_pycapsule, wait_for_future, +}; use async_trait::async_trait; use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider}; use datafusion::common::DataFusionError; @@ -27,7 +30,6 @@ use datafusion::{ datasource::{TableProvider, TableType}, }; use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider}; -use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::exceptions::PyKeyError; use pyo3::prelude::*; use pyo3::types::PyCapsule; @@ -196,26 +198,7 @@ impl PySchema { } fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> { - let provider = if table_provider.hasattr("__datafusion_table_provider__")? { - let capsule = table_provider - .getattr("__datafusion_table_provider__")? - .call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; - - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); - Arc::new(provider) as Arc - } else { - match table_provider.extract::() { - Ok(py_table) => py_table.table, - Err(_) => { - let py = table_provider.py(); - let provider = Dataset::new(&table_provider, py)?; - Arc::new(provider) as Arc - } - } - }; + let provider = coerce_table_provider(&table_provider).map_err(PyErr::from)?; let _ = self .schema @@ -304,15 +287,8 @@ impl RustWrappedPySchemaProvider { return Ok(None); } - if py_table.hasattr("__datafusion_table_provider__")? { - let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; - - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); - - Ok(Some(Arc::new(provider) as Arc)) + if let Some(provider) = table_provider_from_pycapsule(&py_table)? { + Ok(Some(provider)) } else { if let Ok(inner_table) = py_table.getattr("table") { if let Ok(inner_table) = inner_table.extract::() { @@ -320,6 +296,10 @@ impl RustWrappedPySchemaProvider { } } + if let Ok(py_provider) = py_table.extract::() { + return Ok(Some(py_provider.into_inner())); + } + match py_table.extract::() { Ok(py_table) => Ok(Some(py_table.table)), Err(_) => { diff --git a/src/context.rs b/src/context.rs index 36133a33d..d2c9b1c98 100644 --- a/src/context.rs +++ b/src/context.rs @@ -45,7 +45,9 @@ use crate::udaf::PyAggregateUDF; use crate::udf::PyScalarUDF; use crate::udtf::PyTableFunction; use crate::udwf::PyWindowUDF; -use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future}; +use crate::utils::{ + coerce_table_provider, get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future, +}; use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; @@ -71,7 +73,6 @@ use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider}; -use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; use pyo3::IntoPyObjectExt; use tokio::task::JoinHandle; @@ -417,12 +418,7 @@ impl PySessionContext { .with_listing_options(options) .with_schema(resolved_schema); let table = ListingTable::try_new(config)?; - self.register_table( - name, - &PyTable { - table: Arc::new(table), - }, - )?; + self.ctx.register_table(name, Arc::new(table))?; Ok(()) } @@ -607,8 +603,14 @@ impl PySessionContext { Ok(df) } - pub fn register_table(&mut self, name: &str, table: &PyTable) -> PyDataFusionResult<()> { - self.ctx.register_table(name, table.table())?; + pub fn register_table( + &mut self, + name: &str, + table_provider: Bound<'_, PyAny>, + ) -> PyDataFusionResult<()> { + let provider = coerce_table_provider(&table_provider)?; + + self.ctx.register_table(name, provider)?; Ok(()) } @@ -651,23 +653,8 @@ impl PySessionContext { name: &str, provider: Bound<'_, PyAny>, ) -> PyDataFusionResult<()> { - if provider.hasattr("__datafusion_table_provider__")? { - let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; - - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); - - let _ = self.ctx.register_table(name, Arc::new(provider))?; - - Ok(()) - } else { - Err(crate::errors::PyDataFusionError::Common( - "__datafusion_table_provider__ does not exist on Table Provider object." - .to_string(), - )) - } + // Deprecated: use `register_table` instead + self.register_table(name, provider) } pub fn register_record_batches( diff --git a/src/dataframe.rs b/src/dataframe.rs index 5882acf76..17900afda 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -31,12 +31,10 @@ use datafusion::arrow::util::pretty; use datafusion::common::UnnestOptions; use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, TableParquetOptions}; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; -use datafusion::datasource::TableProvider; use datafusion::error::DataFusionError; use datafusion::execution::SendableRecordBatchStream; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; use datafusion::prelude::*; -use datafusion_ffi::table_provider::FFI_TableProvider; use futures::{StreamExt, TryStreamExt}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; @@ -44,12 +42,12 @@ use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods}; use tokio::task::JoinHandle; -use crate::catalog::PyTable; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError}; use crate::expr::sort_expr::to_sort_expressions; use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; use crate::sql::logical::PyLogicalPlan; +pub use crate::table::PyTableProvider; use crate::utils::{ get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future, }; @@ -58,40 +56,6 @@ use crate::{ expr::{sort_expr::PySortExpr, PyExpr}, }; -// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116 -// - we have not decided on the table_provider approach yet -// this is an interim implementation -#[pyclass(name = "TableProvider", module = "datafusion")] -pub struct PyTableProvider { - provider: Arc, -} - -impl PyTableProvider { - pub fn new(provider: Arc) -> Self { - Self { provider } - } - - pub fn as_table(&self) -> PyTable { - let table_provider: Arc = self.provider.clone(); - PyTable::new(table_provider) - } -} - -#[pymethods] -impl PyTableProvider { - fn __datafusion_table_provider__<'py>( - &self, - py: Python<'py>, - ) -> PyResult> { - let name = CString::new("datafusion_table_provider").unwrap(); - - let runtime = get_tokio_runtime().0.handle().clone(); - let provider = FFI_TableProvider::new(Arc::clone(&self.provider), false, Some(runtime)); - - PyCapsule::new(py, provider, Some(name.clone())) - } -} - /// Configuration for DataFrame display formatting #[derive(Debug, Clone)] pub struct FormatterConfig { @@ -302,6 +266,11 @@ impl PyDataFrame { } } + /// Return a clone of the inner Arc for crate-local callers. + pub(crate) fn inner_df(&self) -> Arc { + Arc::clone(&self.df) + } + fn prepare_repr_string(&mut self, py: Python, as_html: bool) -> PyDataFusionResult { // Get the Python formatter and config let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?; @@ -427,22 +396,18 @@ impl PyDataFrame { PyArrowType(self.df.schema().into()) } - /// Convert this DataFrame into a Table that can be used in register_table + /// Convert this DataFrame into a Table Provider that can be used in register_table /// By convention, into_... methods consume self and return the new object. /// Disabling the clippy lint, so we can use &self /// because we're working with Python bindings /// where objects are shared - /// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116 - /// - we have not decided on the table_provider approach yet #[allow(clippy::wrong_self_convention)] - fn into_view(&self) -> PyDataFusionResult { + pub fn into_view(&self) -> PyDataFusionResult { // Call the underlying Rust DataFrame::into_view method. // Note that the Rust method consumes self; here we clone the inner Arc - // so that we don’t invalidate this PyDataFrame. + // so that we don't invalidate this PyDataFrame. let table_provider = self.df.as_ref().clone().into_view(); - let table_provider = PyTableProvider::new(table_provider); - - Ok(table_provider.as_table()) + Ok(PyTableProvider::new(table_provider)) } #[pyo3(signature = (*args))] diff --git a/src/lib.rs b/src/lib.rs index 29d3f41da..d45992db2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,6 +52,7 @@ pub mod pyarrow_util; mod record_batch; pub mod sql; pub mod store; +pub mod table; pub mod unparser; #[cfg(feature = "substrait")] @@ -80,6 +81,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { // Initialize logging pyo3_log::init(); + m.add("EXPECTED_PROVIDER_MSG", crate::utils::EXPECTED_PROVIDER_MSG)?; + // Register the python classes m.add_class::()?; m.add_class::()?; @@ -97,6 +100,7 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; let catalog = PyModule::new(py, "catalog")?; catalog::init_module(&catalog)?; diff --git a/src/table.rs b/src/table.rs new file mode 100644 index 000000000..29476e473 --- /dev/null +++ b/src/table.rs @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ffi::CString; +use std::sync::Arc; + +use datafusion::datasource::TableProvider; +use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; +use pyo3::exceptions::PyDeprecationWarning; +use pyo3::prelude::*; +use pyo3::types::{PyCapsule, PyDict}; + +use crate::catalog::PyTable; +use crate::dataframe::PyDataFrame; +use crate::errors::{py_datafusion_err, PyDataFusionResult}; +use crate::utils::{get_tokio_runtime, validate_pycapsule}; + +/// Represents a table provider that can be registered with DataFusion +#[pyclass(name = "TableProvider", module = "datafusion")] +#[derive(Clone)] +pub struct PyTableProvider { + pub(crate) provider: Arc, +} + +impl PyTableProvider { + pub(crate) fn new(provider: Arc) -> Self { + Self { provider } + } + + /// Return a `PyTable` wrapper around this provider. + /// + /// Historically callers chained `as_table().table()` to access the + /// underlying [`Arc`]. Prefer [`as_arc`] or + /// [`into_inner`] for direct access instead. + pub fn as_table(&self) -> PyTable { + PyTable::new(Arc::clone(&self.provider)) + } + + /// Return a clone of the inner [`TableProvider`]. + pub fn as_arc(&self) -> Arc { + Arc::clone(&self.provider) + } + + /// Consume this wrapper and return the inner [`TableProvider`]. + pub fn into_inner(self) -> Arc { + self.provider + } +} + +#[pymethods] +impl PyTableProvider { + /// Create a `TableProvider` from a PyCapsule containing an FFI pointer + #[staticmethod] + pub fn from_capsule(capsule: Bound<'_, PyAny>) -> PyResult { + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_table_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignTableProvider = provider.into(); + + Ok(Self::new(Arc::new(provider))) + } + + /// Create a `TableProvider` from a `DataFrame`. + /// + /// This method simply delegates to `DataFrame.into_view`. + #[staticmethod] + pub fn from_dataframe(df: &PyDataFrame) -> Self { + // Clone the inner DataFrame and convert it into a view TableProvider. + // `into_view` consumes a DataFrame, so clone the underlying DataFrame + Self::new(df.inner_df().as_ref().clone().into_view()) + } + + /// Create a `TableProvider` from a `DataFrame` by converting it into a view. + /// + /// Deprecated: prefer `DataFrame.into_view` or + /// `Table.from_dataframe` instead. + #[staticmethod] + pub fn from_view(py: Python<'_>, df: &PyDataFrame) -> PyDataFusionResult { + let kwargs = PyDict::new(py); + // Keep stack level consistent with python/datafusion/table_provider.py + kwargs.set_item("stacklevel", 2)?; + py.import("warnings")?.call_method( + "warn", + ( + "PyTableProvider.from_view() is deprecated; use DataFrame.into_view() or Table.from_dataframe() instead.", + py.get_type::(), + ), + Some(&kwargs), + )?; + Ok(Self::from_dataframe(df)) + } + + fn __datafusion_table_provider__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let name = CString::new("datafusion_table_provider").unwrap(); + + let runtime = get_tokio_runtime().0.handle().clone(); + let provider: Arc = self.provider.clone(); + let provider = FFI_TableProvider::new(provider, false, Some(runtime)); + + PyCapsule::new(py, provider, Some(name.clone())) + } +} diff --git a/src/udtf.rs b/src/udtf.rs index db16d6c05..311d67e24 100644 --- a/src/udtf.rs +++ b/src/udtf.rs @@ -18,14 +18,13 @@ use pyo3::prelude::*; use std::sync::Arc; -use crate::dataframe::PyTableProvider; use crate::errors::{py_datafusion_err, to_datafusion_err}; use crate::expr::PyExpr; -use crate::utils::validate_pycapsule; +use crate::table::PyTableProvider; +use crate::utils::{table_provider_from_pycapsule, validate_pycapsule}; use datafusion::catalog::{TableFunctionImpl, TableProvider}; use datafusion::error::Result as DataFusionResult; use datafusion::logical_expr::Expr; -use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use datafusion_ffi::udtf::{FFI_TableFunction, ForeignTableFunction}; use pyo3::exceptions::PyNotImplementedError; use pyo3::types::{PyCapsule, PyTuple}; @@ -99,20 +98,11 @@ fn call_python_table_function( let provider_obj = func.call1(py, py_args)?; let provider = provider_obj.bind(py); - if provider.hasattr("__datafusion_table_provider__")? { - let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; - - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); - - Ok(Arc::new(provider) as Arc) - } else { - Err(PyNotImplementedError::new_err( + table_provider_from_pycapsule(provider)?.ok_or_else(|| { + PyNotImplementedError::new_err( "__datafusion_table_provider__ does not exist on Table Provider object.", - )) - } + ) + }) }) .map_err(to_datafusion_err) } diff --git a/src/utils.rs b/src/utils.rs index 3b30de5de..6d5755f88 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -16,17 +16,29 @@ // under the License. use crate::{ + catalog::PyTable, common::data_type::PyScalarValue, + dataframe::PyDataFrame, + dataset::Dataset, errors::{PyDataFusionError, PyDataFusionResult}, + table::PyTableProvider, TokioRuntime, }; use datafusion::{ - common::ScalarValue, execution::context::SessionContext, logical_expr::Volatility, + common::ScalarValue, datasource::TableProvider, execution::context::SessionContext, + logical_expr::Volatility, }; use pyo3::prelude::*; use pyo3::{exceptions::PyValueError, types::PyCapsule}; -use std::{future::Future, sync::OnceLock, time::Duration}; +use std::{ + future::Future, + sync::{Arc, OnceLock}, + time::Duration, +}; use tokio::{runtime::Runtime, time::sleep}; + +pub(crate) const EXPECTED_PROVIDER_MSG: &str = + "Expected a Table. Convert DataFrames with \"DataFrame.into_view()\" or \"Table.from_dataframe()\"."; /// Utility to get the Tokio Runtime from Python #[inline] pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime { @@ -91,7 +103,7 @@ pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult { "volatile" => Volatility::Volatile, value => { return Err(PyDataFusionError::Common(format!( - "Unsupportad volatility type: `{value}`, supported \ + "Unsupported volatility type: `{value}`, supported \ values are: immutable, stable and volatile." ))) } @@ -101,9 +113,9 @@ pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult { pub(crate) fn validate_pycapsule(capsule: &Bound, name: &str) -> PyResult<()> { let capsule_name = capsule.name()?; if capsule_name.is_none() { - return Err(PyValueError::new_err( - "Expected schema PyCapsule to have name set.", - )); + return Err(PyValueError::new_err(format!( + "Expected {name} PyCapsule to have name set." + ))); } let capsule_name = capsule_name.unwrap().to_str()?; @@ -116,6 +128,40 @@ pub(crate) fn validate_pycapsule(capsule: &Bound, name: &str) -> PyRe Ok(()) } +pub(crate) fn table_provider_from_pycapsule( + obj: &Bound, +) -> PyResult>> { + if obj.hasattr("__datafusion_table_provider__")? { + let capsule = obj.getattr("__datafusion_table_provider__")?.call0()?; + let provider = PyTableProvider::from_capsule(capsule)?; + Ok(Some(provider.into_inner())) + } else { + Ok(None) + } +} + +pub(crate) fn coerce_table_provider( + obj: &Bound, +) -> PyDataFusionResult> { + if let Ok(py_table) = obj.extract::() { + Ok(py_table.table()) + } else if let Ok(py_provider) = obj.extract::() { + Ok(py_provider.into_inner()) + } else if obj.is_instance_of::() + || obj + .getattr("df") + .is_ok_and(|inner| inner.is_instance_of::()) + { + Err(PyDataFusionError::Common(EXPECTED_PROVIDER_MSG.to_string())) + } else if let Some(provider) = table_provider_from_pycapsule(obj)? { + Ok(provider) + } else { + let py = obj.py(); + let provider = Dataset::new(obj, py)?; + Ok(Arc::new(provider) as Arc) + } +} + pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult { // convert Python object to PyScalarValue to ScalarValue