Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/user-guide/io/table_provider.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,13 @@ to the ``SessionContext``.
ctx.register_table_provider("my_table", provider)

ctx.table("my_table").show()

If you already have a provider instance you can also use
``SessionContext.read_table`` to obtain a :class:`~datafusion.DataFrame`
directly without registering it first:

.. code-block:: python

provider = MyTableProvider()
df = ctx.read_table(provider)
df.show()
7 changes: 6 additions & 1 deletion examples/datafusion-ffi-example/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ edition = "2021"
[dependencies]
datafusion = { version = "49.0.2" }
datafusion-ffi = { version = "49.0.2" }
pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"] }
datafusion-python = { path = "../../" }
pyo3 = { version = "0.25", features = [
"extension-module",
"abi3",
"abi3-py39",
] }
arrow = { version = "55.0.0" }
arrow-array = { version = "55.0.0" }
arrow-schema = { version = "55.0.0" }
Expand Down
5 changes: 2 additions & 3 deletions examples/datafusion-ffi-example/src/table_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use arrow_schema::{DataType, Field, Schema};
use datafusion::catalog::MemTable;
use datafusion::error::{DataFusionError, Result as DataFusionResult};
use datafusion_ffi::table_provider::FFI_TableProvider;
use datafusion_python::utils::table_provider_capsule_name;
use pyo3::exceptions::PyRuntimeError;
use pyo3::types::PyCapsule;
use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
Expand Down Expand Up @@ -91,13 +92,11 @@ impl MyTableProvider {
&self,
py: Python<'py>,
) -> PyResult<Bound<'py, PyCapsule>> {
let name = cr"datafusion_table_provider".into();

let provider = self
.create_table()
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
let provider = FFI_TableProvider::new(Arc::new(provider), false, None);

PyCapsule::new(py, provider, Some(name))
PyCapsule::new(py, provider, Some(table_provider_capsule_name().to_owned()))
}
}
8 changes: 8 additions & 0 deletions python/datafusion/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ def __repr__(self) -> str:
"""Print a string representation of the table."""
return self.table.__repr__()

def __datafusion_table_provider__(self) -> object:
"""Expose the internal DataFusion table provider PyCapsule.

This forwards the call to the underlying Rust-backed RawTable so the
object can be used as a TableProviderExportable by the FFI layer.
"""
return self.table.__datafusion_table_provider__()

@staticmethod
def from_dataset(dataset: pa.dataset.Dataset) -> Table:
"""Turn a pyarrow Dataset into a Table."""
Expand Down
26 changes: 20 additions & 6 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Protocol
from typing import TYPE_CHECKING, Any, Protocol, Union

try:
from warnings import deprecated # Python 3.13+
Expand Down Expand Up @@ -82,6 +82,13 @@ class TableProviderExportable(Protocol):
def __datafusion_table_provider__(self) -> object: ... # noqa: D105


# Type alias for objects accepted by read_table
# Use typing.Union here (instead of PEP 604 `|`) because this alias is
# evaluated at import time and must work on Python 3.9 where PEP 604
# syntax is not supported for runtime expressions.
TableLike = Union[Table, TableProviderExportable]


class CatalogProviderExportable(Protocol):
"""Type hint for object that has __datafusion_catalog_provider__ PyCapsule.

Expand Down Expand Up @@ -1163,14 +1170,21 @@ def read_avro(
self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension)
)

def read_table(self, table: Table) -> DataFrame:
def read_table(self, table: TableLike) -> DataFrame:
"""Creates a :py:class:`~datafusion.dataframe.DataFrame` from a table.

For a :py:class:`~datafusion.catalog.Table` such as a
:py:class:`~datafusion.catalog.ListingTable`, create a
:py:class:`~datafusion.dataframe.DataFrame`.
Args:
table: Either a :py:class:`~datafusion.catalog.Table` (such as a
:py:class:`~datafusion.catalog.ListingTable`) or an object that
implements ``__datafusion_table_provider__`` and returns a
PyCapsule describing a custom table provider.

Returns:
A :py:class:`~datafusion.dataframe.DataFrame` backed by the
provided table provider.
"""
return DataFrame(self.ctx.read_table(table.table))
provider = table.table if isinstance(table, Table) else table
return DataFrame(self.ctx.read_table(provider))

def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
"""Execute the ``plan`` and return the results."""
Expand Down
23 changes: 23 additions & 0 deletions python/tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,29 @@ def test_python_table_provider(ctx: SessionContext):
assert schema.table_names() == {"table4"}


def test_register_raw_table_without_capsule(ctx: SessionContext, database, monkeypatch):
schema = ctx.catalog().schema("public")
raw_table = schema.table("csv").table

def fail(*args, **kwargs):
msg = "RawTable capsule path should not be invoked"
raise AssertionError(msg)

monkeypatch.setattr(type(raw_table), "__datafusion_table_provider__", fail)

schema.register_table("csv_copy", raw_table)

# Restore the original implementation to avoid interfering with later assertions
monkeypatch.undo()

batches = ctx.sql("select count(*) from csv_copy").collect()

assert len(batches) == 1
assert batches[0].column(0) == pa.array([4])

schema.deregister_table("csv_copy")


def test_in_end_to_end_python_providers(ctx: SessionContext):
"""Test registering all python providers and running a query against them."""

Expand Down
23 changes: 21 additions & 2 deletions python/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import datetime as dt
import gzip
import pathlib
from uuid import uuid4

import pyarrow as pa
import pyarrow.dataset as ds
Expand Down Expand Up @@ -113,6 +114,26 @@ def test_register_record_batches(ctx):
assert result[0].column(1) == pa.array([-3, -3, -3])


def test_read_table_accepts_table_provider(ctx):
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2]), pa.array(["x", "y"])],
names=["value", "label"],
)

ctx.register_record_batches("capsule_provider", [[batch]])

table = ctx.catalog().schema().table("capsule_provider")
provider = table.table

expected = pa.Table.from_batches([batch])

provider_result = pa.Table.from_batches(ctx.read_table(provider).collect())
assert provider_result.equals(expected)

table_result = pa.Table.from_batches(ctx.read_table(table).collect())
assert table_result.equals(expected)


def test_create_dataframe_registers_unique_table_name(ctx):
# create a RecordBatch and register it as memtable
batch = pa.RecordBatch.from_arrays(
Expand Down Expand Up @@ -484,8 +505,6 @@ def test_table_exist(ctx):


def test_table_not_found(ctx):
from uuid import uuid4

with pytest.raises(KeyError):
ctx.table(f"not-found-{uuid4()}")

Expand Down
67 changes: 32 additions & 35 deletions src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

use crate::dataset::Dataset;
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
use crate::utils::{validate_pycapsule, wait_for_future};
use crate::utils::{
get_tokio_runtime, table_provider_capsule_name, try_table_provider_from_object,
validate_pycapsule, wait_for_future,
};
use async_trait::async_trait;
use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
use datafusion::common::DataFusionError;
Expand All @@ -27,7 +30,7 @@ use datafusion::{
datasource::{TableProvider, TableType},
};
use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider};
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
use datafusion_ffi::table_provider::FFI_TableProvider;
use pyo3::exceptions::PyKeyError;
use pyo3::prelude::*;
use pyo3::types::PyCapsule;
Expand Down Expand Up @@ -196,25 +199,14 @@ impl PySchema {
}

fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
let provider = if table_provider.hasattr("__datafusion_table_provider__")? {
let capsule = table_provider
.getattr("__datafusion_table_provider__")?
.call0()?;
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
validate_pycapsule(capsule, "datafusion_table_provider")?;

let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
let provider: ForeignTableProvider = provider.into();
Arc::new(provider) as Arc<dyn TableProvider>
let provider = if let Ok(py_table) = table_provider.extract::<PyTable>() {
py_table.table
} else if let Some(provider) = try_table_provider_from_object(&table_provider)? {
provider
} else {
match table_provider.extract::<PyTable>() {
Ok(py_table) => py_table.table,
Err(_) => {
let py = table_provider.py();
let provider = Dataset::new(&table_provider, py)?;
Arc::new(provider) as Arc<dyn TableProvider>
}
}
let py = table_provider.py();
let provider = Dataset::new(&table_provider, py)?;
Arc::new(provider) as Arc<dyn TableProvider>
};

let _ = self
Expand Down Expand Up @@ -261,6 +253,19 @@ impl PyTable {
}
}

fn __datafusion_table_provider__<'py>(
&self,
py: Python<'py>,
) -> PyResult<Bound<'py, PyCapsule>> {
let runtime = get_tokio_runtime().0.handle().clone();

let provider = Arc::clone(&self.table);
let provider: Arc<dyn TableProvider + Send> = provider;
let provider = FFI_TableProvider::new(provider, false, Some(runtime));

PyCapsule::new(py, provider, Some(table_provider_capsule_name().to_owned()))
}

fn __repr__(&self) -> PyResult<String> {
let kind = self.kind();
Ok(format!("Table(kind={kind})"))
Expand Down Expand Up @@ -304,29 +309,21 @@ impl RustWrappedPySchemaProvider {
return Ok(None);
}

if py_table.hasattr("__datafusion_table_provider__")? {
let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
validate_pycapsule(capsule, "datafusion_table_provider")?;

let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
let provider: ForeignTableProvider = provider.into();
if let Ok(inner_table) = py_table.extract::<PyTable>() {
return Ok(Some(inner_table.table));
}

Ok(Some(Arc::new(provider) as Arc<dyn TableProvider>))
if let Some(provider) = try_table_provider_from_object(&py_table)? {
Ok(Some(provider))
} else {
if let Ok(inner_table) = py_table.getattr("table") {
if let Ok(inner_table) = inner_table.extract::<PyTable>() {
return Ok(Some(inner_table.table));
}
}

match py_table.extract::<PyTable>() {
Ok(py_table) => Ok(Some(py_table.table)),
Err(_) => {
let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?;
Ok(Some(Arc::new(ds) as Arc<dyn TableProvider>))
}
}
let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?;
Ok(Some(Arc::new(ds) as Arc<dyn TableProvider>))
}
})
}
Expand Down
37 changes: 23 additions & 14 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ use crate::udaf::PyAggregateUDF;
use crate::udf::PyScalarUDF;
use crate::udtf::PyTableFunction;
use crate::udwf::PyWindowUDF;
use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
use crate::utils::{
get_global_ctx, get_tokio_runtime, try_table_provider_from_object, validate_pycapsule,
wait_for_future,
};
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
Expand All @@ -71,7 +74,6 @@ use datafusion::prelude::{
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
};
use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider};
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
use pyo3::IntoPyObjectExt;
use tokio::task::JoinHandle;
Expand Down Expand Up @@ -651,15 +653,8 @@ impl PySessionContext {
name: &str,
provider: Bound<'_, PyAny>,
) -> PyDataFusionResult<()> {
if provider.hasattr("__datafusion_table_provider__")? {
let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
validate_pycapsule(capsule, "datafusion_table_provider")?;

let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
let provider: ForeignTableProvider = provider.into();

let _ = self.ctx.register_table(name, Arc::new(provider))?;
if let Some(provider) = try_table_provider_from_object(&provider)? {
let _ = self.ctx.register_table(name, provider)?;

Ok(())
} else {
Expand Down Expand Up @@ -1102,9 +1097,23 @@ impl PySessionContext {
Ok(PyDataFrame::new(df))
}

pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult<PyDataFrame> {
let df = self.ctx.read_table(table.table())?;
Ok(PyDataFrame::new(df))
pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult<PyDataFrame> {
if let Ok(py_table) = table.extract::<PyTable>() {
// RawTable values returned from DataFusion (e.g. ctx.catalog().schema().table(...).table)
// should keep using this native path to avoid an unnecessary FFI round-trip.
let df = self.ctx.read_table(py_table.table())?;
return Ok(PyDataFrame::new(df));
}

if let Some(provider) = try_table_provider_from_object(&table)? {
let df = self.ctx.read_table(provider)?;
Ok(PyDataFrame::new(df))
} else {
Err(crate::errors::PyDataFusionError::Common(
"Object must be a datafusion.Table or expose __datafusion_table_provider__()."
.to_string(),
))
}
}

fn __repr__(&self) -> PyResult<String> {
Expand Down
7 changes: 3 additions & 4 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
use crate::sql::logical::PyLogicalPlan;
use crate::utils::{
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, table_provider_capsule_name,
validate_pycapsule, wait_for_future,
};
use crate::{
errors::PyDataFusionResult,
Expand Down Expand Up @@ -83,12 +84,10 @@ impl PyTableProvider {
&self,
py: Python<'py>,
) -> PyResult<Bound<'py, PyCapsule>> {
let name = CString::new("datafusion_table_provider").unwrap();

let runtime = get_tokio_runtime().0.handle().clone();
let provider = FFI_TableProvider::new(Arc::clone(&self.provider), false, Some(runtime));

PyCapsule::new(py, provider, Some(name.clone()))
PyCapsule::new(py, provider, Some(table_provider_capsule_name().to_owned()))
}
}

Expand Down
Loading
Loading