diff --git a/docs/source/user-guide/io/table_provider.rst b/docs/source/user-guide/io/table_provider.rst index bd1d6b80f..419715126 100644 --- a/docs/source/user-guide/io/table_provider.rst +++ b/docs/source/user-guide/io/table_provider.rst @@ -56,3 +56,13 @@ to the ``SessionContext``. ctx.register_table_provider("my_table", provider) ctx.table("my_table").show() + +If you already have a provider instance you can also use +``SessionContext.read_table`` to obtain a :class:`~datafusion.DataFrame` +directly without registering it first: + +.. code-block:: python + + provider = MyTableProvider() + df = ctx.read_table(provider) + df.show() diff --git a/examples/datafusion-ffi-example/Cargo.toml b/examples/datafusion-ffi-example/Cargo.toml index 647f6c51e..fc6e68c45 100644 --- a/examples/datafusion-ffi-example/Cargo.toml +++ b/examples/datafusion-ffi-example/Cargo.toml @@ -23,7 +23,12 @@ edition = "2021" [dependencies] datafusion = { version = "49.0.2" } datafusion-ffi = { version = "49.0.2" } -pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"] } +datafusion-python = { path = "../../" } +pyo3 = { version = "0.25", features = [ + "extension-module", + "abi3", + "abi3-py39", +] } arrow = { version = "55.0.0" } arrow-array = { version = "55.0.0" } arrow-schema = { version = "55.0.0" } diff --git a/examples/datafusion-ffi-example/src/table_provider.rs b/examples/datafusion-ffi-example/src/table_provider.rs index e884585b5..b3c3b47a6 100644 --- a/examples/datafusion-ffi-example/src/table_provider.rs +++ b/examples/datafusion-ffi-example/src/table_provider.rs @@ -20,6 +20,7 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion::catalog::MemTable; use datafusion::error::{DataFusionError, Result as DataFusionResult}; use datafusion_ffi::table_provider::FFI_TableProvider; +use datafusion_python::utils::table_provider_capsule_name; use pyo3::exceptions::PyRuntimeError; use pyo3::types::PyCapsule; use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; @@ -91,13 +92,11 @@ impl MyTableProvider { &self, py: Python<'py>, ) -> PyResult> { - let name = cr"datafusion_table_provider".into(); - let provider = self .create_table() .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; let provider = FFI_TableProvider::new(Arc::new(provider), false, None); - PyCapsule::new(py, provider, Some(name)) + PyCapsule::new(py, provider, Some(table_provider_capsule_name().to_owned())) } } diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 536b3a790..15ec7ca80 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -149,6 +149,14 @@ def __repr__(self) -> str: """Print a string representation of the table.""" return self.table.__repr__() + def __datafusion_table_provider__(self) -> object: + """Expose the internal DataFusion table provider PyCapsule. + + This forwards the call to the underlying Rust-backed RawTable so the + object can be used as a TableProviderExportable by the FFI layer. + """ + return self.table.__datafusion_table_provider__() + @staticmethod def from_dataset(dataset: pa.dataset.Dataset) -> Table: """Turn a pyarrow Dataset into a Table.""" diff --git a/python/datafusion/context.py b/python/datafusion/context.py index b6e728b51..6cd0c7ce8 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -20,7 +20,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol, Union try: from warnings import deprecated # Python 3.13+ @@ -82,6 +82,13 @@ class TableProviderExportable(Protocol): def __datafusion_table_provider__(self) -> object: ... # noqa: D105 +# Type alias for objects accepted by read_table +# Use typing.Union here (instead of PEP 604 `|`) because this alias is +# evaluated at import time and must work on Python 3.9 where PEP 604 +# syntax is not supported for runtime expressions. +TableLike = Union[Table, TableProviderExportable] + + class CatalogProviderExportable(Protocol): """Type hint for object that has __datafusion_catalog_provider__ PyCapsule. @@ -1163,14 +1170,21 @@ def read_avro( self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) ) - def read_table(self, table: Table) -> DataFrame: + def read_table(self, table: TableLike) -> DataFrame: """Creates a :py:class:`~datafusion.dataframe.DataFrame` from a table. - For a :py:class:`~datafusion.catalog.Table` such as a - :py:class:`~datafusion.catalog.ListingTable`, create a - :py:class:`~datafusion.dataframe.DataFrame`. + Args: + table: Either a :py:class:`~datafusion.catalog.Table` (such as a + :py:class:`~datafusion.catalog.ListingTable`) or an object that + implements ``__datafusion_table_provider__`` and returns a + PyCapsule describing a custom table provider. + + Returns: + A :py:class:`~datafusion.dataframe.DataFrame` backed by the + provided table provider. """ - return DataFrame(self.ctx.read_table(table.table)) + provider = table.table if isinstance(table, Table) else table + return DataFrame(self.ctx.read_table(provider)) def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream: """Execute the ``plan`` and return the results.""" diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 1f9ecbfc3..2c3f7012f 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -164,6 +164,29 @@ def test_python_table_provider(ctx: SessionContext): assert schema.table_names() == {"table4"} +def test_register_raw_table_without_capsule(ctx: SessionContext, database, monkeypatch): + schema = ctx.catalog().schema("public") + raw_table = schema.table("csv").table + + def fail(*args, **kwargs): + msg = "RawTable capsule path should not be invoked" + raise AssertionError(msg) + + monkeypatch.setattr(type(raw_table), "__datafusion_table_provider__", fail) + + schema.register_table("csv_copy", raw_table) + + # Restore the original implementation to avoid interfering with later assertions + monkeypatch.undo() + + batches = ctx.sql("select count(*) from csv_copy").collect() + + assert len(batches) == 1 + assert batches[0].column(0) == pa.array([4]) + + schema.deregister_table("csv_copy") + + def test_in_end_to_end_python_providers(ctx: SessionContext): """Test registering all python providers and running a query against them.""" diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 6dbcc0d5e..a14ecd795 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -17,6 +17,7 @@ import datetime as dt import gzip import pathlib +from uuid import uuid4 import pyarrow as pa import pyarrow.dataset as ds @@ -113,6 +114,26 @@ def test_register_record_batches(ctx): assert result[0].column(1) == pa.array([-3, -3, -3]) +def test_read_table_accepts_table_provider(ctx): + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array(["x", "y"])], + names=["value", "label"], + ) + + ctx.register_record_batches("capsule_provider", [[batch]]) + + table = ctx.catalog().schema().table("capsule_provider") + provider = table.table + + expected = pa.Table.from_batches([batch]) + + provider_result = pa.Table.from_batches(ctx.read_table(provider).collect()) + assert provider_result.equals(expected) + + table_result = pa.Table.from_batches(ctx.read_table(table).collect()) + assert table_result.equals(expected) + + def test_create_dataframe_registers_unique_table_name(ctx): # create a RecordBatch and register it as memtable batch = pa.RecordBatch.from_arrays( @@ -484,8 +505,6 @@ def test_table_exist(ctx): def test_table_not_found(ctx): - from uuid import uuid4 - with pytest.raises(KeyError): ctx.table(f"not-found-{uuid4()}") diff --git a/src/catalog.rs b/src/catalog.rs index 17d4ec3b8..6764083ee 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -17,7 +17,10 @@ use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; -use crate::utils::{validate_pycapsule, wait_for_future}; +use crate::utils::{ + get_tokio_runtime, table_provider_capsule_name, try_table_provider_from_object, + validate_pycapsule, wait_for_future, +}; use async_trait::async_trait; use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider}; use datafusion::common::DataFusionError; @@ -27,7 +30,7 @@ use datafusion::{ datasource::{TableProvider, TableType}, }; use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider}; -use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; +use datafusion_ffi::table_provider::FFI_TableProvider; use pyo3::exceptions::PyKeyError; use pyo3::prelude::*; use pyo3::types::PyCapsule; @@ -196,25 +199,14 @@ impl PySchema { } fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> { - let provider = if table_provider.hasattr("__datafusion_table_provider__")? { - let capsule = table_provider - .getattr("__datafusion_table_provider__")? - .call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; - - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); - Arc::new(provider) as Arc + let provider = if let Ok(py_table) = table_provider.extract::() { + py_table.table + } else if let Some(provider) = try_table_provider_from_object(&table_provider)? { + provider } else { - match table_provider.extract::() { - Ok(py_table) => py_table.table, - Err(_) => { - let py = table_provider.py(); - let provider = Dataset::new(&table_provider, py)?; - Arc::new(provider) as Arc - } - } + let py = table_provider.py(); + let provider = Dataset::new(&table_provider, py)?; + Arc::new(provider) as Arc }; let _ = self @@ -261,6 +253,19 @@ impl PyTable { } } + fn __datafusion_table_provider__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let runtime = get_tokio_runtime().0.handle().clone(); + + let provider = Arc::clone(&self.table); + let provider: Arc = provider; + let provider = FFI_TableProvider::new(provider, false, Some(runtime)); + + PyCapsule::new(py, provider, Some(table_provider_capsule_name().to_owned())) + } + fn __repr__(&self) -> PyResult { let kind = self.kind(); Ok(format!("Table(kind={kind})")) @@ -304,15 +309,12 @@ impl RustWrappedPySchemaProvider { return Ok(None); } - if py_table.hasattr("__datafusion_table_provider__")? { - let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; - - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); + if let Ok(inner_table) = py_table.extract::() { + return Ok(Some(inner_table.table)); + } - Ok(Some(Arc::new(provider) as Arc)) + if let Some(provider) = try_table_provider_from_object(&py_table)? { + Ok(Some(provider)) } else { if let Ok(inner_table) = py_table.getattr("table") { if let Ok(inner_table) = inner_table.extract::() { @@ -320,13 +322,8 @@ impl RustWrappedPySchemaProvider { } } - match py_table.extract::() { - Ok(py_table) => Ok(Some(py_table.table)), - Err(_) => { - let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?; - Ok(Some(Arc::new(ds) as Arc)) - } - } + let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?; + Ok(Some(Arc::new(ds) as Arc)) } }) } diff --git a/src/context.rs b/src/context.rs index 36133a33d..039066398 100644 --- a/src/context.rs +++ b/src/context.rs @@ -45,7 +45,10 @@ use crate::udaf::PyAggregateUDF; use crate::udf::PyScalarUDF; use crate::udtf::PyTableFunction; use crate::udwf::PyWindowUDF; -use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future}; +use crate::utils::{ + get_global_ctx, get_tokio_runtime, try_table_provider_from_object, validate_pycapsule, + wait_for_future, +}; use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; @@ -71,7 +74,6 @@ use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider}; -use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; use pyo3::IntoPyObjectExt; use tokio::task::JoinHandle; @@ -651,15 +653,8 @@ impl PySessionContext { name: &str, provider: Bound<'_, PyAny>, ) -> PyDataFusionResult<()> { - if provider.hasattr("__datafusion_table_provider__")? { - let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; - - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); - - let _ = self.ctx.register_table(name, Arc::new(provider))?; + if let Some(provider) = try_table_provider_from_object(&provider)? { + let _ = self.ctx.register_table(name, provider)?; Ok(()) } else { @@ -1102,9 +1097,23 @@ impl PySessionContext { Ok(PyDataFrame::new(df)) } - pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult { - let df = self.ctx.read_table(table.table())?; - Ok(PyDataFrame::new(df)) + pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult { + if let Ok(py_table) = table.extract::() { + // RawTable values returned from DataFusion (e.g. ctx.catalog().schema().table(...).table) + // should keep using this native path to avoid an unnecessary FFI round-trip. + let df = self.ctx.read_table(py_table.table())?; + return Ok(PyDataFrame::new(df)); + } + + if let Some(provider) = try_table_provider_from_object(&table)? { + let df = self.ctx.read_table(provider)?; + Ok(PyDataFrame::new(df)) + } else { + Err(crate::errors::PyDataFusionError::Common( + "Object must be a datafusion.Table or expose __datafusion_table_provider__()." + .to_string(), + )) + } } fn __repr__(&self) -> PyResult { diff --git a/src/dataframe.rs b/src/dataframe.rs index 5882acf76..0ba34a4fd 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -51,7 +51,8 @@ use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; use crate::sql::logical::PyLogicalPlan; use crate::utils::{ - get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future, + get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, table_provider_capsule_name, + validate_pycapsule, wait_for_future, }; use crate::{ errors::PyDataFusionResult, @@ -83,12 +84,10 @@ impl PyTableProvider { &self, py: Python<'py>, ) -> PyResult> { - let name = CString::new("datafusion_table_provider").unwrap(); - let runtime = get_tokio_runtime().0.handle().clone(); let provider = FFI_TableProvider::new(Arc::clone(&self.provider), false, Some(runtime)); - PyCapsule::new(py, provider, Some(name.clone())) + PyCapsule::new(py, provider, Some(table_provider_capsule_name().to_owned())) } } diff --git a/src/udtf.rs b/src/udtf.rs index db16d6c05..f07185e0a 100644 --- a/src/udtf.rs +++ b/src/udtf.rs @@ -21,11 +21,10 @@ use std::sync::Arc; use crate::dataframe::PyTableProvider; use crate::errors::{py_datafusion_err, to_datafusion_err}; use crate::expr::PyExpr; -use crate::utils::validate_pycapsule; +use crate::utils::{try_table_provider_from_object, validate_pycapsule}; use datafusion::catalog::{TableFunctionImpl, TableProvider}; use datafusion::error::Result as DataFusionResult; use datafusion::logical_expr::Expr; -use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use datafusion_ffi::udtf::{FFI_TableFunction, ForeignTableFunction}; use pyo3::exceptions::PyNotImplementedError; use pyo3::types::{PyCapsule, PyTuple}; @@ -99,15 +98,8 @@ fn call_python_table_function( let provider_obj = func.call1(py, py_args)?; let provider = provider_obj.bind(py); - if provider.hasattr("__datafusion_table_provider__")? { - let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; - - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); - - Ok(Arc::new(provider) as Arc) + if let Some(provider) = try_table_provider_from_object(provider)? { + Ok(provider) } else { Err(PyNotImplementedError::new_err( "__datafusion_table_provider__ does not exist on Table Provider object.", diff --git a/src/utils.rs b/src/utils.rs index 3b30de5de..03f189a4c 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -17,16 +17,36 @@ use crate::{ common::data_type::PyScalarValue, - errors::{PyDataFusionError, PyDataFusionResult}, + errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult}, TokioRuntime, }; use datafusion::{ - common::ScalarValue, execution::context::SessionContext, logical_expr::Volatility, + common::ScalarValue, datasource::TableProvider, execution::context::SessionContext, + logical_expr::Volatility, }; +use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::prelude::*; use pyo3::{exceptions::PyValueError, types::PyCapsule}; -use std::{future::Future, sync::OnceLock, time::Duration}; +use std::ffi::CString; +use std::{ + ffi::CStr, + future::Future, + sync::{Arc, OnceLock}, + time::Duration, +}; use tokio::{runtime::Runtime, time::sleep}; + +pub const TABLE_PROVIDER_CAPSULE_NAME_STR: &str = "datafusion_table_provider"; +/// Return a static CStr for the PyCapsule name. +/// +/// We create this lazily from a `CString` to avoid unsafe const +/// initialization from a byte literal and to satisfy compiler lints. +pub fn table_provider_capsule_name() -> &'static CStr { + static NAME: OnceLock = OnceLock::new(); + NAME.get_or_init(|| CString::new(TABLE_PROVIDER_CAPSULE_NAME_STR).unwrap()) + .as_c_str() +} + /// Utility to get the Tokio Runtime from Python #[inline] pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime { @@ -116,6 +136,28 @@ pub(crate) fn validate_pycapsule(capsule: &Bound, name: &str) -> PyRe Ok(()) } +pub(crate) fn foreign_table_provider_from_capsule( + capsule: &Bound, +) -> PyResult { + validate_pycapsule(capsule, TABLE_PROVIDER_CAPSULE_NAME_STR)?; + Ok(unsafe { capsule.reference::() }.into()) +} + +pub(crate) fn try_table_provider_from_object( + provider: &Bound<'_, PyAny>, +) -> PyResult>> { + if !provider.hasattr("__datafusion_table_provider__")? { + return Ok(None); + } + + let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + let provider = foreign_table_provider_from_capsule(capsule)?; + let provider: Arc = Arc::new(provider); + + Ok(Some(provider)) +} + pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult { // convert Python object to PyScalarValue to ScalarValue