From 15d7d791c5f423e2300f40b2d06494f89ac49460 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 14:23:13 +0800 Subject: [PATCH 01/12] feat: enhance read_table to accept custom table providers and update documentation --- docs/source/user-guide/io/table_provider.rst | 10 ++++++++ python/datafusion/context.py | 19 ++++++++++---- python/tests/test_context.py | 25 +++++++++++++++++-- src/catalog.rs | 20 ++++++++++++++- src/context.rs | 26 +++++++++++++++++--- 5 files changed, 89 insertions(+), 11 deletions(-) diff --git a/docs/source/user-guide/io/table_provider.rst b/docs/source/user-guide/io/table_provider.rst index bd1d6b80f..419715126 100644 --- a/docs/source/user-guide/io/table_provider.rst +++ b/docs/source/user-guide/io/table_provider.rst @@ -56,3 +56,13 @@ to the ``SessionContext``. ctx.register_table_provider("my_table", provider) ctx.table("my_table").show() + +If you already have a provider instance you can also use +``SessionContext.read_table`` to obtain a :class:`~datafusion.DataFrame` +directly without registering it first: + +.. code-block:: python + + provider = MyTableProvider() + df = ctx.read_table(provider) + df.show() diff --git a/python/datafusion/context.py b/python/datafusion/context.py index b6e728b51..f88f41f8d 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -1163,14 +1163,23 @@ def read_avro( self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) ) - def read_table(self, table: Table) -> DataFrame: + def read_table( + self, table: Table | TableProviderExportable + ) -> DataFrame: """Creates a :py:class:`~datafusion.dataframe.DataFrame` from a table. - For a :py:class:`~datafusion.catalog.Table` such as a - :py:class:`~datafusion.catalog.ListingTable`, create a - :py:class:`~datafusion.dataframe.DataFrame`. + Args: + table: Either a :py:class:`~datafusion.catalog.Table` (such as a + :py:class:`~datafusion.catalog.ListingTable`) or an object that + implements ``__datafusion_table_provider__`` and returns a + PyCapsule describing a custom table provider. + + Returns: + A :py:class:`~datafusion.dataframe.DataFrame` backed by the + provided table provider. """ - return DataFrame(self.ctx.read_table(table.table)) + provider = table.table if isinstance(table, Table) else table + return DataFrame(self.ctx.read_table(provider)) def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream: """Execute the ``plan`` and return the results.""" diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 6dbcc0d5e..6b1d96b69 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -17,6 +17,7 @@ import datetime as dt import gzip import pathlib +from uuid import uuid4 import pyarrow as pa import pyarrow.dataset as ds @@ -113,6 +114,28 @@ def test_register_record_batches(ctx): assert result[0].column(1) == pa.array([-3, -3, -3]) +def test_read_table_accepts_table_provider(ctx): + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array(["x", "y"])], + names=["value", "label"], + ) + + ctx.register_record_batches("capsule_provider", [[batch]]) + + table = ctx.catalog().schema().table("capsule_provider") + provider = table.table + + expected = pa.Table.from_batches([batch]) + + provider_result = pa.Table.from_batches( + ctx.read_table(provider).collect() + ) + assert provider_result.equals(expected) + + table_result = pa.Table.from_batches(ctx.read_table(table).collect()) + assert table_result.equals(expected) + + def test_create_dataframe_registers_unique_table_name(ctx): # create a RecordBatch and register it as memtable batch = pa.RecordBatch.from_arrays( @@ -484,8 +507,6 @@ def test_table_exist(ctx): def test_table_not_found(ctx): - from uuid import uuid4 - with pytest.raises(KeyError): ctx.table(f"not-found-{uuid4()}") diff --git a/src/catalog.rs b/src/catalog.rs index 17d4ec3b8..643d640c0 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -17,7 +17,7 @@ use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; -use crate::utils::{validate_pycapsule, wait_for_future}; +use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future}; use async_trait::async_trait; use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider}; use datafusion::common::DataFusionError; @@ -34,6 +34,7 @@ use pyo3::types::PyCapsule; use pyo3::IntoPyObjectExt; use std::any::Any; use std::collections::HashSet; +use std::ffi::CString; use std::sync::Arc; #[pyclass(name = "RawCatalog", module = "datafusion.catalog", subclass)] @@ -261,6 +262,23 @@ impl PyTable { } } + fn __datafusion_table_provider__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let name = CString::new("datafusion_table_provider").unwrap(); + let runtime = get_tokio_runtime().0.handle().clone(); + + let provider = Arc::clone(&self.table); + let provider_ptr = Arc::into_raw(provider); + let provider: Arc = + unsafe { Arc::from_raw(provider_ptr as *const (dyn TableProvider + Send)) }; + + let provider = FFI_TableProvider::new(provider, false, Some(runtime)); + + PyCapsule::new(py, provider, Some(name.clone())) + } + fn __repr__(&self) -> PyResult { let kind = self.kind(); Ok(format!("Table(kind={kind})")) diff --git a/src/context.rs b/src/context.rs index 36133a33d..f15d1238c 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1102,9 +1102,29 @@ impl PySessionContext { Ok(PyDataFrame::new(df)) } - pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult { - let df = self.ctx.read_table(table.table())?; - Ok(PyDataFrame::new(df)) + pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult { + if table.hasattr("__datafusion_table_provider__")? { + let capsule = table.getattr("__datafusion_table_provider__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_table_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignTableProvider = provider.into(); + + let df = self.ctx.read_table(Arc::new(provider))?; + Ok(PyDataFrame::new(df)) + } else { + match table.extract::() { + Ok(py_table) => { + let df = self.ctx.read_table(py_table.table())?; + Ok(PyDataFrame::new(df)) + } + Err(_) => Err(crate::errors::PyDataFusionError::Common( + "Object must be a datafusion.Table or expose __datafusion_table_provider__()." + .to_string(), + )), + } + } } fn __repr__(&self) -> PyResult { From f9bac4e353f76455bf6adb3507c836fe6045eea0 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 14:52:06 +0800 Subject: [PATCH 02/12] feat: add __datafusion_table_provider__ method to Table class for FFI compatibility --- python/datafusion/catalog.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 536b3a790..1995de73c 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -149,6 +149,14 @@ def __repr__(self) -> str: """Print a string representation of the table.""" return self.table.__repr__() + def __datafusion_table_provider__(self) -> object: # noqa: D105 + """Expose the internal DataFusion table provider PyCapsule. + + This forwards the call to the underlying Rust-backed RawTable so the + object can be used as a TableProviderExportable by the FFI layer. + """ + return self.table.__datafusion_table_provider__() + @staticmethod def from_dataset(dataset: pa.dataset.Dataset) -> Table: """Turn a pyarrow Dataset into a Table.""" From 1c6c7a5079c0140c965566104757d80037855c73 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 15:37:59 +0800 Subject: [PATCH 03/12] fix: correct attribute access for __datafusion_table_provider__ in RustWrappedPySchemaProvider --- src/catalog.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/catalog.rs b/src/catalog.rs index 643d640c0..f80f9701c 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -323,7 +323,7 @@ impl RustWrappedPySchemaProvider { } if py_table.hasattr("__datafusion_table_provider__")? { - let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; + let capsule = py_table.getattr("__datafusion_table_provider__")?.call0()?; let capsule = capsule.downcast::().map_err(py_datafusion_err)?; validate_pycapsule(capsule, "datafusion_table_provider")?; From 3013f299c9a847e0409ef33b6019976a3ae55c65 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 15:58:34 +0800 Subject: [PATCH 04/12] fix: streamline read_table method to prioritize native table access and improve error handling --- src/context.rs | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/context.rs b/src/context.rs index f15d1238c..9801db01e 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1103,6 +1103,13 @@ impl PySessionContext { } pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult { + if let Ok(py_table) = table.extract::() { + // RawTable values returned from DataFusion (e.g. ctx.catalog().schema().table(...).table) + // should keep using this native path to avoid an unnecessary FFI round-trip. + let df = self.ctx.read_table(py_table.table())?; + return Ok(PyDataFrame::new(df)); + } + if table.hasattr("__datafusion_table_provider__")? { let capsule = table.getattr("__datafusion_table_provider__")?.call0()?; let capsule = capsule.downcast::().map_err(py_datafusion_err)?; @@ -1114,16 +1121,10 @@ impl PySessionContext { let df = self.ctx.read_table(Arc::new(provider))?; Ok(PyDataFrame::new(df)) } else { - match table.extract::() { - Ok(py_table) => { - let df = self.ctx.read_table(py_table.table())?; - Ok(PyDataFrame::new(df)) - } - Err(_) => Err(crate::errors::PyDataFusionError::Common( - "Object must be a datafusion.Table or expose __datafusion_table_provider__()." - .to_string(), - )), - } + Err(crate::errors::PyDataFusionError::Common( + "Object must be a datafusion.Table or expose __datafusion_table_provider__()." + .to_string(), + )) } } From 69e7e606fd6e4bde48f9e75b470185401ff8ce5b Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 17:02:14 +0800 Subject: [PATCH 05/12] fix: improve table registration logic to handle raw tables and enhance error handling --- python/tests/test_catalog.py | 22 ++++++++++++++++++++++ src/catalog.rs | 28 ++++++++++++---------------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 1f9ecbfc3..e44097c5a 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -164,6 +164,28 @@ def test_python_table_provider(ctx: SessionContext): assert schema.table_names() == {"table4"} +def test_register_raw_table_without_capsule(ctx: SessionContext, database, monkeypatch): + schema = ctx.catalog().schema("public") + raw_table = schema.table("csv").table + + def fail(*args, **kwargs): + raise AssertionError("RawTable capsule path should not be invoked") + + monkeypatch.setattr(type(raw_table), "__datafusion_table_provider__", fail) + + schema.register_table("csv_copy", raw_table) + + # Restore the original implementation to avoid interfering with later assertions + monkeypatch.undo() + + batches = ctx.sql("select count(*) from csv_copy").collect() + + assert len(batches) == 1 + assert batches[0].column(0) == pa.array([4]) + + schema.deregister_table("csv_copy") + + def test_in_end_to_end_python_providers(ctx: SessionContext): """Test registering all python providers and running a query against them.""" diff --git a/src/catalog.rs b/src/catalog.rs index f80f9701c..f2dcb3859 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -197,7 +197,9 @@ impl PySchema { } fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> { - let provider = if table_provider.hasattr("__datafusion_table_provider__")? { + let provider = if let Ok(py_table) = table_provider.extract::() { + py_table.table + } else if table_provider.hasattr("__datafusion_table_provider__")? { let capsule = table_provider .getattr("__datafusion_table_provider__")? .call0()?; @@ -208,14 +210,9 @@ impl PySchema { let provider: ForeignTableProvider = provider.into(); Arc::new(provider) as Arc } else { - match table_provider.extract::() { - Ok(py_table) => py_table.table, - Err(_) => { - let py = table_provider.py(); - let provider = Dataset::new(&table_provider, py)?; - Arc::new(provider) as Arc - } - } + let py = table_provider.py(); + let provider = Dataset::new(&table_provider, py)?; + Arc::new(provider) as Arc }; let _ = self @@ -322,6 +319,10 @@ impl RustWrappedPySchemaProvider { return Ok(None); } + if let Ok(inner_table) = py_table.extract::() { + return Ok(Some(inner_table.table)); + } + if py_table.hasattr("__datafusion_table_provider__")? { let capsule = py_table.getattr("__datafusion_table_provider__")?.call0()?; let capsule = capsule.downcast::().map_err(py_datafusion_err)?; @@ -338,13 +339,8 @@ impl RustWrappedPySchemaProvider { } } - match py_table.extract::() { - Ok(py_table) => Ok(Some(py_table.table)), - Err(_) => { - let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?; - Ok(Some(Arc::new(ds) as Arc)) - } - } + let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?; + Ok(Some(Arc::new(ds) as Arc)) } }) } From c71659ce5c3b64e145a8ab0545a787eeb4c2e7ae Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 18:25:51 +0800 Subject: [PATCH 06/12] fix: simplify provider initialization in PyTable implementation --- src/catalog.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/catalog.rs b/src/catalog.rs index f2dcb3859..6b8e8b1a7 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -267,10 +267,7 @@ impl PyTable { let runtime = get_tokio_runtime().0.handle().clone(); let provider = Arc::clone(&self.table); - let provider_ptr = Arc::into_raw(provider); - let provider: Arc = - unsafe { Arc::from_raw(provider_ptr as *const (dyn TableProvider + Send)) }; - + let provider: Arc = provider; let provider = FFI_TableProvider::new(provider, false, Some(runtime)); PyCapsule::new(py, provider, Some(name.clone())) From 23bd49d21dbfdbba4cf7ae70aa27dc527cfad778 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 18:59:40 +0800 Subject: [PATCH 07/12] fix: refactor foreign table provider handling for improved clarity and safety --- src/catalog.rs | 22 ++++++++++------------ src/context.rs | 22 ++++++++++------------ src/udtf.rs | 11 ++++------- src/utils.rs | 8 ++++++++ 4 files changed, 32 insertions(+), 31 deletions(-) diff --git a/src/catalog.rs b/src/catalog.rs index 6b8e8b1a7..3574bdd17 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -17,7 +17,9 @@ use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; -use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future}; +use crate::utils::{ + foreign_table_provider_from_capsule, get_tokio_runtime, validate_pycapsule, wait_for_future, +}; use async_trait::async_trait; use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider}; use datafusion::common::DataFusionError; @@ -27,7 +29,7 @@ use datafusion::{ datasource::{TableProvider, TableType}, }; use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider}; -use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; +use datafusion_ffi::table_provider::FFI_TableProvider; use pyo3::exceptions::PyKeyError; use pyo3::prelude::*; use pyo3::types::PyCapsule; @@ -204,11 +206,9 @@ impl PySchema { .getattr("__datafusion_table_provider__")? .call0()?; let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; - - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); - Arc::new(provider) as Arc + let provider = foreign_table_provider_from_capsule(capsule)?; + let provider: Arc = Arc::new(provider); + provider } else { let py = table_provider.py(); let provider = Dataset::new(&table_provider, py)?; @@ -323,12 +323,10 @@ impl RustWrappedPySchemaProvider { if py_table.hasattr("__datafusion_table_provider__")? { let capsule = py_table.getattr("__datafusion_table_provider__")?.call0()?; let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; - - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); + let provider = foreign_table_provider_from_capsule(capsule)?; + let provider: Arc = Arc::new(provider); - Ok(Some(Arc::new(provider) as Arc)) + Ok(Some(provider)) } else { if let Ok(inner_table) = py_table.getattr("table") { if let Ok(inner_table) = inner_table.extract::() { diff --git a/src/context.rs b/src/context.rs index 9801db01e..ac56e346f 100644 --- a/src/context.rs +++ b/src/context.rs @@ -45,7 +45,10 @@ use crate::udaf::PyAggregateUDF; use crate::udf::PyScalarUDF; use crate::udtf::PyTableFunction; use crate::udwf::PyWindowUDF; -use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future}; +use crate::utils::{ + foreign_table_provider_from_capsule, get_global_ctx, get_tokio_runtime, validate_pycapsule, + wait_for_future, +}; use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; @@ -71,7 +74,6 @@ use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider}; -use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; use pyo3::IntoPyObjectExt; use tokio::task::JoinHandle; @@ -654,12 +656,10 @@ impl PySessionContext { if provider.hasattr("__datafusion_table_provider__")? { let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; - - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); + let provider = foreign_table_provider_from_capsule(capsule)?; + let provider: Arc = Arc::new(provider); - let _ = self.ctx.register_table(name, Arc::new(provider))?; + let _ = self.ctx.register_table(name, provider)?; Ok(()) } else { @@ -1113,12 +1113,10 @@ impl PySessionContext { if table.hasattr("__datafusion_table_provider__")? { let capsule = table.getattr("__datafusion_table_provider__")?.call0()?; let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; - - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); + let provider = foreign_table_provider_from_capsule(capsule)?; + let provider: Arc = Arc::new(provider); - let df = self.ctx.read_table(Arc::new(provider))?; + let df = self.ctx.read_table(provider)?; Ok(PyDataFrame::new(df)) } else { Err(crate::errors::PyDataFusionError::Common( diff --git a/src/udtf.rs b/src/udtf.rs index db16d6c05..608d9d475 100644 --- a/src/udtf.rs +++ b/src/udtf.rs @@ -21,11 +21,10 @@ use std::sync::Arc; use crate::dataframe::PyTableProvider; use crate::errors::{py_datafusion_err, to_datafusion_err}; use crate::expr::PyExpr; -use crate::utils::validate_pycapsule; +use crate::utils::{foreign_table_provider_from_capsule, validate_pycapsule}; use datafusion::catalog::{TableFunctionImpl, TableProvider}; use datafusion::error::Result as DataFusionResult; use datafusion::logical_expr::Expr; -use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use datafusion_ffi::udtf::{FFI_TableFunction, ForeignTableFunction}; use pyo3::exceptions::PyNotImplementedError; use pyo3::types::{PyCapsule, PyTuple}; @@ -102,12 +101,10 @@ fn call_python_table_function( if provider.hasattr("__datafusion_table_provider__")? { let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; + let provider = foreign_table_provider_from_capsule(capsule)?; + let provider: Arc = Arc::new(provider); - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); - - Ok(Arc::new(provider) as Arc) + Ok(provider) } else { Err(PyNotImplementedError::new_err( "__datafusion_table_provider__ does not exist on Table Provider object.", diff --git a/src/utils.rs b/src/utils.rs index 3b30de5de..efb84ef63 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -23,6 +23,7 @@ use crate::{ use datafusion::{ common::ScalarValue, execution::context::SessionContext, logical_expr::Volatility, }; +use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::prelude::*; use pyo3::{exceptions::PyValueError, types::PyCapsule}; use std::{future::Future, sync::OnceLock, time::Duration}; @@ -116,6 +117,13 @@ pub(crate) fn validate_pycapsule(capsule: &Bound, name: &str) -> PyRe Ok(()) } +pub(crate) fn foreign_table_provider_from_capsule( + capsule: &Bound, +) -> PyResult { + validate_pycapsule(capsule, "datafusion_table_provider")?; + Ok(unsafe { capsule.reference::() }.into()) +} + pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult { // convert Python object to PyScalarValue to ScalarValue From 0e5a77b29e2b217ddbc2e698f2f6d34abcd8c8a9 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 20:39:39 +0800 Subject: [PATCH 08/12] fix: replace foreign_table_provider_from_capsule with try_table_provider_from_object for improved clarity and error handling --- src/catalog.rs | 17 +++-------------- src/context.rs | 16 +++------------- src/udtf.rs | 9 ++------- src/utils.rs | 26 +++++++++++++++++++++++--- 4 files changed, 31 insertions(+), 37 deletions(-) diff --git a/src/catalog.rs b/src/catalog.rs index 3574bdd17..86f8dae33 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -18,7 +18,7 @@ use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; use crate::utils::{ - foreign_table_provider_from_capsule, get_tokio_runtime, validate_pycapsule, wait_for_future, + get_tokio_runtime, try_table_provider_from_object, validate_pycapsule, wait_for_future, }; use async_trait::async_trait; use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider}; @@ -201,13 +201,7 @@ impl PySchema { fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> { let provider = if let Ok(py_table) = table_provider.extract::() { py_table.table - } else if table_provider.hasattr("__datafusion_table_provider__")? { - let capsule = table_provider - .getattr("__datafusion_table_provider__")? - .call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - let provider = foreign_table_provider_from_capsule(capsule)?; - let provider: Arc = Arc::new(provider); + } else if let Some(provider) = try_table_provider_from_object(&table_provider)? { provider } else { let py = table_provider.py(); @@ -320,12 +314,7 @@ impl RustWrappedPySchemaProvider { return Ok(Some(inner_table.table)); } - if py_table.hasattr("__datafusion_table_provider__")? { - let capsule = py_table.getattr("__datafusion_table_provider__")?.call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - let provider = foreign_table_provider_from_capsule(capsule)?; - let provider: Arc = Arc::new(provider); - + if let Some(provider) = try_table_provider_from_object(&py_table)? { Ok(Some(provider)) } else { if let Ok(inner_table) = py_table.getattr("table") { diff --git a/src/context.rs b/src/context.rs index ac56e346f..039066398 100644 --- a/src/context.rs +++ b/src/context.rs @@ -46,7 +46,7 @@ use crate::udf::PyScalarUDF; use crate::udtf::PyTableFunction; use crate::udwf::PyWindowUDF; use crate::utils::{ - foreign_table_provider_from_capsule, get_global_ctx, get_tokio_runtime, validate_pycapsule, + get_global_ctx, get_tokio_runtime, try_table_provider_from_object, validate_pycapsule, wait_for_future, }; use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; @@ -653,12 +653,7 @@ impl PySessionContext { name: &str, provider: Bound<'_, PyAny>, ) -> PyDataFusionResult<()> { - if provider.hasattr("__datafusion_table_provider__")? { - let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - let provider = foreign_table_provider_from_capsule(capsule)?; - let provider: Arc = Arc::new(provider); - + if let Some(provider) = try_table_provider_from_object(&provider)? { let _ = self.ctx.register_table(name, provider)?; Ok(()) @@ -1110,12 +1105,7 @@ impl PySessionContext { return Ok(PyDataFrame::new(df)); } - if table.hasattr("__datafusion_table_provider__")? { - let capsule = table.getattr("__datafusion_table_provider__")?.call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - let provider = foreign_table_provider_from_capsule(capsule)?; - let provider: Arc = Arc::new(provider); - + if let Some(provider) = try_table_provider_from_object(&table)? { let df = self.ctx.read_table(provider)?; Ok(PyDataFrame::new(df)) } else { diff --git a/src/udtf.rs b/src/udtf.rs index 608d9d475..f07185e0a 100644 --- a/src/udtf.rs +++ b/src/udtf.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use crate::dataframe::PyTableProvider; use crate::errors::{py_datafusion_err, to_datafusion_err}; use crate::expr::PyExpr; -use crate::utils::{foreign_table_provider_from_capsule, validate_pycapsule}; +use crate::utils::{try_table_provider_from_object, validate_pycapsule}; use datafusion::catalog::{TableFunctionImpl, TableProvider}; use datafusion::error::Result as DataFusionResult; use datafusion::logical_expr::Expr; @@ -98,12 +98,7 @@ fn call_python_table_function( let provider_obj = func.call1(py, py_args)?; let provider = provider_obj.bind(py); - if provider.hasattr("__datafusion_table_provider__")? { - let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - let provider = foreign_table_provider_from_capsule(capsule)?; - let provider: Arc = Arc::new(provider); - + if let Some(provider) = try_table_provider_from_object(provider)? { Ok(provider) } else { Err(PyNotImplementedError::new_err( diff --git a/src/utils.rs b/src/utils.rs index efb84ef63..20b30cb7c 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -17,16 +17,21 @@ use crate::{ common::data_type::PyScalarValue, - errors::{PyDataFusionError, PyDataFusionResult}, + errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult}, TokioRuntime, }; use datafusion::{ - common::ScalarValue, execution::context::SessionContext, logical_expr::Volatility, + common::ScalarValue, datasource::TableProvider, execution::context::SessionContext, + logical_expr::Volatility, }; use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::prelude::*; use pyo3::{exceptions::PyValueError, types::PyCapsule}; -use std::{future::Future, sync::OnceLock, time::Duration}; +use std::{ + future::Future, + sync::{Arc, OnceLock}, + time::Duration, +}; use tokio::{runtime::Runtime, time::sleep}; /// Utility to get the Tokio Runtime from Python #[inline] @@ -124,6 +129,21 @@ pub(crate) fn foreign_table_provider_from_capsule( Ok(unsafe { capsule.reference::() }.into()) } +pub(crate) fn try_table_provider_from_object( + provider: &Bound<'_, PyAny>, +) -> PyResult>> { + if !provider.hasattr("__datafusion_table_provider__")? { + return Ok(None); + } + + let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + let provider = foreign_table_provider_from_capsule(capsule)?; + let provider: Arc = Arc::new(provider); + + Ok(Some(provider)) +} + pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult { // convert Python object to PyScalarValue to ScalarValue From a4b94f9f733f887b40751b73d03485f33fccf063 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 21:16:56 +0800 Subject: [PATCH 09/12] fix: introduce TableLike type alias for improved readability in read_table method --- python/datafusion/context.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index f88f41f8d..cf1998cba 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -82,6 +82,10 @@ class TableProviderExportable(Protocol): def __datafusion_table_provider__(self) -> object: ... # noqa: D105 +# Type alias for objects accepted by read_table +TableLike = Table | TableProviderExportable + + class CatalogProviderExportable(Protocol): """Type hint for object that has __datafusion_catalog_provider__ PyCapsule. @@ -1164,7 +1168,7 @@ def read_avro( ) def read_table( - self, table: Table | TableProviderExportable + self, table: TableLike ) -> DataFrame: """Creates a :py:class:`~datafusion.dataframe.DataFrame` from a table. From 0ef9c4b79c864774685658829cb30ae4b62c1a44 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 21:50:10 +0800 Subject: [PATCH 10/12] fix: update TableLike type alias to use Union for compatibility with Python 3.9 --- python/datafusion/context.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index cf1998cba..c3fff27ee 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -20,7 +20,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol, Union try: from warnings import deprecated # Python 3.13+ @@ -83,7 +83,10 @@ def __datafusion_table_provider__(self) -> object: ... # noqa: D105 # Type alias for objects accepted by read_table -TableLike = Table | TableProviderExportable +# Use typing.Union here (instead of PEP 604 `|`) because this alias is +# evaluated at import time and must work on Python 3.9 where PEP 604 +# syntax is not supported for runtime expressions. +TableLike = Union[Table, TableProviderExportable] class CatalogProviderExportable(Protocol): From c164ecb81dbb2279a3c39116116a6f3af49346d3 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 21:51:27 +0800 Subject: [PATCH 11/12] fix Ruff errors --- python/datafusion/catalog.py | 2 +- python/datafusion/context.py | 4 +--- python/tests/test_catalog.py | 3 ++- python/tests/test_context.py | 4 +--- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 1995de73c..15ec7ca80 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -149,7 +149,7 @@ def __repr__(self) -> str: """Print a string representation of the table.""" return self.table.__repr__() - def __datafusion_table_provider__(self) -> object: # noqa: D105 + def __datafusion_table_provider__(self) -> object: """Expose the internal DataFusion table provider PyCapsule. This forwards the call to the underlying Rust-backed RawTable so the diff --git a/python/datafusion/context.py b/python/datafusion/context.py index c3fff27ee..6cd0c7ce8 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -1170,9 +1170,7 @@ def read_avro( self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) ) - def read_table( - self, table: TableLike - ) -> DataFrame: + def read_table(self, table: TableLike) -> DataFrame: """Creates a :py:class:`~datafusion.dataframe.DataFrame` from a table. Args: diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index e44097c5a..2c3f7012f 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -169,7 +169,8 @@ def test_register_raw_table_without_capsule(ctx: SessionContext, database, monke raw_table = schema.table("csv").table def fail(*args, **kwargs): - raise AssertionError("RawTable capsule path should not be invoked") + msg = "RawTable capsule path should not be invoked" + raise AssertionError(msg) monkeypatch.setattr(type(raw_table), "__datafusion_table_provider__", fail) diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 6b1d96b69..a14ecd795 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -127,9 +127,7 @@ def test_read_table_accepts_table_provider(ctx): expected = pa.Table.from_batches([batch]) - provider_result = pa.Table.from_batches( - ctx.read_table(provider).collect() - ) + provider_result = pa.Table.from_batches(ctx.read_table(provider).collect()) assert provider_result.equals(expected) table_result = pa.Table.from_batches(ctx.read_table(table).collect()) From fe50c295057c6810b21757c486b708e7289b876e Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 21:45:47 +0800 Subject: [PATCH 12/12] fix: standardize usage of TABLE_PROVIDER_CAPSULE_NAME to table_provider_capsule_name function --- examples/datafusion-ffi-example/Cargo.toml | 7 ++++++- .../datafusion-ffi-example/src/table_provider.rs | 5 ++--- src/catalog.rs | 7 +++---- src/dataframe.rs | 7 +++---- src/utils.rs | 16 +++++++++++++++- 5 files changed, 29 insertions(+), 13 deletions(-) diff --git a/examples/datafusion-ffi-example/Cargo.toml b/examples/datafusion-ffi-example/Cargo.toml index 647f6c51e..fc6e68c45 100644 --- a/examples/datafusion-ffi-example/Cargo.toml +++ b/examples/datafusion-ffi-example/Cargo.toml @@ -23,7 +23,12 @@ edition = "2021" [dependencies] datafusion = { version = "49.0.2" } datafusion-ffi = { version = "49.0.2" } -pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"] } +datafusion-python = { path = "../../" } +pyo3 = { version = "0.25", features = [ + "extension-module", + "abi3", + "abi3-py39", +] } arrow = { version = "55.0.0" } arrow-array = { version = "55.0.0" } arrow-schema = { version = "55.0.0" } diff --git a/examples/datafusion-ffi-example/src/table_provider.rs b/examples/datafusion-ffi-example/src/table_provider.rs index e884585b5..b3c3b47a6 100644 --- a/examples/datafusion-ffi-example/src/table_provider.rs +++ b/examples/datafusion-ffi-example/src/table_provider.rs @@ -20,6 +20,7 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion::catalog::MemTable; use datafusion::error::{DataFusionError, Result as DataFusionResult}; use datafusion_ffi::table_provider::FFI_TableProvider; +use datafusion_python::utils::table_provider_capsule_name; use pyo3::exceptions::PyRuntimeError; use pyo3::types::PyCapsule; use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; @@ -91,13 +92,11 @@ impl MyTableProvider { &self, py: Python<'py>, ) -> PyResult> { - let name = cr"datafusion_table_provider".into(); - let provider = self .create_table() .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; let provider = FFI_TableProvider::new(Arc::new(provider), false, None); - PyCapsule::new(py, provider, Some(name)) + PyCapsule::new(py, provider, Some(table_provider_capsule_name().to_owned())) } } diff --git a/src/catalog.rs b/src/catalog.rs index 86f8dae33..6764083ee 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -18,7 +18,8 @@ use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; use crate::utils::{ - get_tokio_runtime, try_table_provider_from_object, validate_pycapsule, wait_for_future, + get_tokio_runtime, table_provider_capsule_name, try_table_provider_from_object, + validate_pycapsule, wait_for_future, }; use async_trait::async_trait; use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider}; @@ -36,7 +37,6 @@ use pyo3::types::PyCapsule; use pyo3::IntoPyObjectExt; use std::any::Any; use std::collections::HashSet; -use std::ffi::CString; use std::sync::Arc; #[pyclass(name = "RawCatalog", module = "datafusion.catalog", subclass)] @@ -257,14 +257,13 @@ impl PyTable { &self, py: Python<'py>, ) -> PyResult> { - let name = CString::new("datafusion_table_provider").unwrap(); let runtime = get_tokio_runtime().0.handle().clone(); let provider = Arc::clone(&self.table); let provider: Arc = provider; let provider = FFI_TableProvider::new(provider, false, Some(runtime)); - PyCapsule::new(py, provider, Some(name.clone())) + PyCapsule::new(py, provider, Some(table_provider_capsule_name().to_owned())) } fn __repr__(&self) -> PyResult { diff --git a/src/dataframe.rs b/src/dataframe.rs index 5882acf76..0ba34a4fd 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -51,7 +51,8 @@ use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; use crate::sql::logical::PyLogicalPlan; use crate::utils::{ - get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future, + get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, table_provider_capsule_name, + validate_pycapsule, wait_for_future, }; use crate::{ errors::PyDataFusionResult, @@ -83,12 +84,10 @@ impl PyTableProvider { &self, py: Python<'py>, ) -> PyResult> { - let name = CString::new("datafusion_table_provider").unwrap(); - let runtime = get_tokio_runtime().0.handle().clone(); let provider = FFI_TableProvider::new(Arc::clone(&self.provider), false, Some(runtime)); - PyCapsule::new(py, provider, Some(name.clone())) + PyCapsule::new(py, provider, Some(table_provider_capsule_name().to_owned())) } } diff --git a/src/utils.rs b/src/utils.rs index 20b30cb7c..03f189a4c 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -27,12 +27,26 @@ use datafusion::{ use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::prelude::*; use pyo3::{exceptions::PyValueError, types::PyCapsule}; +use std::ffi::CString; use std::{ + ffi::CStr, future::Future, sync::{Arc, OnceLock}, time::Duration, }; use tokio::{runtime::Runtime, time::sleep}; + +pub const TABLE_PROVIDER_CAPSULE_NAME_STR: &str = "datafusion_table_provider"; +/// Return a static CStr for the PyCapsule name. +/// +/// We create this lazily from a `CString` to avoid unsafe const +/// initialization from a byte literal and to satisfy compiler lints. +pub fn table_provider_capsule_name() -> &'static CStr { + static NAME: OnceLock = OnceLock::new(); + NAME.get_or_init(|| CString::new(TABLE_PROVIDER_CAPSULE_NAME_STR).unwrap()) + .as_c_str() +} + /// Utility to get the Tokio Runtime from Python #[inline] pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime { @@ -125,7 +139,7 @@ pub(crate) fn validate_pycapsule(capsule: &Bound, name: &str) -> PyRe pub(crate) fn foreign_table_provider_from_capsule( capsule: &Bound, ) -> PyResult { - validate_pycapsule(capsule, "datafusion_table_provider")?; + validate_pycapsule(capsule, TABLE_PROVIDER_CAPSULE_NAME_STR)?; Ok(unsafe { capsule.reference::() }.into()) }