Skip to content

Commit 15d7d79

Browse files
committed
feat: enhance read_table to accept custom table providers and update documentation
1 parent bf22c1d commit 15d7d79

File tree

5 files changed

+89
-11
lines changed

5 files changed

+89
-11
lines changed

docs/source/user-guide/io/table_provider.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,13 @@ to the ``SessionContext``.
5656
ctx.register_table_provider("my_table", provider)
5757
5858
ctx.table("my_table").show()
59+
60+
If you already have a provider instance you can also use
61+
``SessionContext.read_table`` to obtain a :class:`~datafusion.DataFrame`
62+
directly without registering it first:
63+
64+
.. code-block:: python
65+
66+
provider = MyTableProvider()
67+
df = ctx.read_table(provider)
68+
df.show()

python/datafusion/context.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,14 +1163,23 @@ def read_avro(
11631163
self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension)
11641164
)
11651165

1166-
def read_table(self, table: Table) -> DataFrame:
1166+
def read_table(
1167+
self, table: Table | TableProviderExportable
1168+
) -> DataFrame:
11671169
"""Creates a :py:class:`~datafusion.dataframe.DataFrame` from a table.
11681170
1169-
For a :py:class:`~datafusion.catalog.Table` such as a
1170-
:py:class:`~datafusion.catalog.ListingTable`, create a
1171-
:py:class:`~datafusion.dataframe.DataFrame`.
1171+
Args:
1172+
table: Either a :py:class:`~datafusion.catalog.Table` (such as a
1173+
:py:class:`~datafusion.catalog.ListingTable`) or an object that
1174+
implements ``__datafusion_table_provider__`` and returns a
1175+
PyCapsule describing a custom table provider.
1176+
1177+
Returns:
1178+
A :py:class:`~datafusion.dataframe.DataFrame` backed by the
1179+
provided table provider.
11721180
"""
1173-
return DataFrame(self.ctx.read_table(table.table))
1181+
provider = table.table if isinstance(table, Table) else table
1182+
return DataFrame(self.ctx.read_table(provider))
11741183

11751184
def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
11761185
"""Execute the ``plan`` and return the results."""

python/tests/test_context.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import datetime as dt
1818
import gzip
1919
import pathlib
20+
from uuid import uuid4
2021

2122
import pyarrow as pa
2223
import pyarrow.dataset as ds
@@ -113,6 +114,28 @@ def test_register_record_batches(ctx):
113114
assert result[0].column(1) == pa.array([-3, -3, -3])
114115

115116

117+
def test_read_table_accepts_table_provider(ctx):
118+
batch = pa.RecordBatch.from_arrays(
119+
[pa.array([1, 2]), pa.array(["x", "y"])],
120+
names=["value", "label"],
121+
)
122+
123+
ctx.register_record_batches("capsule_provider", [[batch]])
124+
125+
table = ctx.catalog().schema().table("capsule_provider")
126+
provider = table.table
127+
128+
expected = pa.Table.from_batches([batch])
129+
130+
provider_result = pa.Table.from_batches(
131+
ctx.read_table(provider).collect()
132+
)
133+
assert provider_result.equals(expected)
134+
135+
table_result = pa.Table.from_batches(ctx.read_table(table).collect())
136+
assert table_result.equals(expected)
137+
138+
116139
def test_create_dataframe_registers_unique_table_name(ctx):
117140
# create a RecordBatch and register it as memtable
118141
batch = pa.RecordBatch.from_arrays(
@@ -484,8 +507,6 @@ def test_table_exist(ctx):
484507

485508

486509
def test_table_not_found(ctx):
487-
from uuid import uuid4
488-
489510
with pytest.raises(KeyError):
490511
ctx.table(f"not-found-{uuid4()}")
491512

src/catalog.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use crate::dataset::Dataset;
1919
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
20-
use crate::utils::{validate_pycapsule, wait_for_future};
20+
use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
2121
use async_trait::async_trait;
2222
use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
2323
use datafusion::common::DataFusionError;
@@ -34,6 +34,7 @@ use pyo3::types::PyCapsule;
3434
use pyo3::IntoPyObjectExt;
3535
use std::any::Any;
3636
use std::collections::HashSet;
37+
use std::ffi::CString;
3738
use std::sync::Arc;
3839

3940
#[pyclass(name = "RawCatalog", module = "datafusion.catalog", subclass)]
@@ -261,6 +262,23 @@ impl PyTable {
261262
}
262263
}
263264

265+
fn __datafusion_table_provider__<'py>(
266+
&self,
267+
py: Python<'py>,
268+
) -> PyResult<Bound<'py, PyCapsule>> {
269+
let name = CString::new("datafusion_table_provider").unwrap();
270+
let runtime = get_tokio_runtime().0.handle().clone();
271+
272+
let provider = Arc::clone(&self.table);
273+
let provider_ptr = Arc::into_raw(provider);
274+
let provider: Arc<dyn TableProvider + Send> =
275+
unsafe { Arc::from_raw(provider_ptr as *const (dyn TableProvider + Send)) };
276+
277+
let provider = FFI_TableProvider::new(provider, false, Some(runtime));
278+
279+
PyCapsule::new(py, provider, Some(name.clone()))
280+
}
281+
264282
fn __repr__(&self) -> PyResult<String> {
265283
let kind = self.kind();
266284
Ok(format!("Table(kind={kind})"))

src/context.rs

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,9 +1102,29 @@ impl PySessionContext {
11021102
Ok(PyDataFrame::new(df))
11031103
}
11041104

1105-
pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult<PyDataFrame> {
1106-
let df = self.ctx.read_table(table.table())?;
1107-
Ok(PyDataFrame::new(df))
1105+
pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult<PyDataFrame> {
1106+
if table.hasattr("__datafusion_table_provider__")? {
1107+
let capsule = table.getattr("__datafusion_table_provider__")?.call0()?;
1108+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
1109+
validate_pycapsule(capsule, "datafusion_table_provider")?;
1110+
1111+
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
1112+
let provider: ForeignTableProvider = provider.into();
1113+
1114+
let df = self.ctx.read_table(Arc::new(provider))?;
1115+
Ok(PyDataFrame::new(df))
1116+
} else {
1117+
match table.extract::<PyTable>() {
1118+
Ok(py_table) => {
1119+
let df = self.ctx.read_table(py_table.table())?;
1120+
Ok(PyDataFrame::new(df))
1121+
}
1122+
Err(_) => Err(crate::errors::PyDataFusionError::Common(
1123+
"Object must be a datafusion.Table or expose __datafusion_table_provider__()."
1124+
.to_string(),
1125+
)),
1126+
}
1127+
}
11081128
}
11091129

11101130
fn __repr__(&self) -> PyResult<String> {

0 commit comments

Comments
 (0)