Skip to content

Commit 512442b

Browse files
committed
TableProvider refactor & PyDataFrame integration
* Removed unused helpers (`extract_table_provider`, `_wrap`) and dead code to simplify maintenance. * Consolidated and streamlined table-provider extraction and registration logic; improved error handling and replaced a hardcoded error message with `EXPECTED_PROVIDER_MSG`. * Marked `from_view` as deprecated; updated deprecation message formatting and adjusted the warning `stacklevel` so it points to caller code. * Removed the `Send` marker from TableProvider trait objects to increase type flexibility — review threading assumptions. * Added type hints to `register_schema` and `deregister_table` methods. * Adjusted tests and exceptions (e.g., changed one test to expect `RuntimeError`) and updated test coverage accordingly. * Introduced a refactored `TableProvider` class and enhanced Python integration by adding support for extracting `PyDataFrame` in `PySchema`. Notes: * Consumers should migrate away from `TableProvider::from_view` to the new TableProvider integration. * Audit any code relying on `Send` for trait objects passed across threads. * Update downstream tests and documentation to reflect the changed exception types and deprecation.
1 parent 3da3f93 commit 512442b

File tree

12 files changed

+160
-90
lines changed

12 files changed

+160
-90
lines changed

python/datafusion/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from . import functions, object_store, substrait, unparser
3434

3535
# The following imports are okay to remain as opaque to the user.
36-
from ._internal import Config, TableProvider
36+
from ._internal import Config, EXPECTED_PROVIDER_MSG
3737
from .catalog import Catalog, Database, Table
3838
from .col import col, column
3939
from .common import (
@@ -54,6 +54,7 @@
5454
from .io import read_avro, read_csv, read_json, read_parquet
5555
from .plan import ExecutionPlan, LogicalPlan
5656
from .record_batch import RecordBatch, RecordBatchStream
57+
from .table_provider import TableProvider
5758
from .user_defined import (
5859
Accumulator,
5960
AggregateUDF,
@@ -76,6 +77,7 @@
7677
"DFSchema",
7778
"DataFrame",
7879
"Database",
80+
"EXPECTED_PROVIDER_MSG",
7981
"ExecutionPlan",
8082
"Expr",
8183
"LogicalPlan",

python/datafusion/catalog.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ def database(self, name: str = "public") -> Schema:
8585
"""Returns the database with the given ``name`` from this catalog."""
8686
return self.schema(name)
8787

88-
def register_schema(self, name, schema) -> Schema | None:
88+
def register_schema(
89+
self,
90+
name: str,
91+
schema: Schema | SchemaProvider | SchemaProviderExportable,
92+
) -> Schema | None:
8993
"""Register a schema with this catalog."""
9094
if isinstance(schema, Schema):
9195
return self.catalog.register_schema(name, schema._raw_schema)
@@ -126,7 +130,7 @@ def table(self, name: str) -> Table:
126130
return Table(self._raw_schema.table(name))
127131

128132
def register_table(
129-
self, name, table: Table | TableProvider | TableProviderExportable
133+
self, name: str, table: Table | TableProvider | TableProviderExportable
130134
) -> None:
131135
"""Register a table or table provider in this schema.
132136
@@ -240,7 +244,7 @@ def register_table( # noqa: B027
240244
and treated as :class:`TableProvider` instances.
241245
"""
242246

243-
def deregister_table(self, name, cascade: bool) -> None: # noqa: B027
247+
def deregister_table(self, name: str, cascade: bool) -> None: # noqa: B027
244248
"""Remove a table from this schema.
245249
246250
This method is optional. If your schema provides a fixed list of tables, you do

python/datafusion/dataframe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@
5252
import polars as pl
5353
import pyarrow as pa
5454

55-
from datafusion._internal import TableProvider
5655
from datafusion._internal import expr as expr_internal
56+
from datafusion.table_provider import TableProvider
5757

5858
from enum import Enum
5959

@@ -316,7 +316,9 @@ def into_view(self) -> TableProvider:
316316
``TableProvider.from_dataframe`` calls this method under the hood,
317317
and the older ``TableProvider.from_view`` helper is deprecated.
318318
"""
319-
return self.df.into_view()
319+
from datafusion.table_provider import TableProvider as _TableProvider
320+
321+
return _TableProvider(self.df.into_view())
320322

321323
def __getitem__(self, key: str | list[str]) -> DataFrame:
322324
"""Return a new :py:class`DataFrame` with the specified column or columns.
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Wrapper helpers for :mod:`datafusion._internal.TableProvider`."""
18+
19+
from __future__ import annotations
20+
21+
import warnings
22+
from typing import Any
23+
24+
import datafusion._internal as df_internal
25+
from datafusion._internal import EXPECTED_PROVIDER_MSG
26+
27+
_InternalTableProvider = df_internal.TableProvider
28+
29+
30+
class TableProvider:
31+
"""High level wrapper around :mod:`datafusion._internal.TableProvider`."""
32+
33+
__slots__ = ("_table_provider",)
34+
35+
def __init__(self, table_provider: _InternalTableProvider) -> None:
36+
"""Wrap a low level :class:`~datafusion._internal.TableProvider`."""
37+
if isinstance(table_provider, TableProvider):
38+
table_provider = table_provider._table_provider
39+
40+
if not isinstance(table_provider, _InternalTableProvider):
41+
raise TypeError(EXPECTED_PROVIDER_MSG)
42+
43+
self._table_provider = table_provider
44+
45+
@classmethod
46+
def from_capsule(cls, capsule: Any) -> TableProvider:
47+
"""Create a :class:`TableProvider` from a PyCapsule."""
48+
provider = _InternalTableProvider.from_capsule(capsule)
49+
return cls(provider)
50+
51+
@classmethod
52+
def from_dataframe(cls, df: Any) -> TableProvider:
53+
"""Create a :class:`TableProvider` from a :class:`DataFrame`."""
54+
from datafusion.dataframe import DataFrame as DataFrameWrapper
55+
56+
if isinstance(df, DataFrameWrapper):
57+
df = df.df
58+
59+
provider = _InternalTableProvider.from_dataframe(df)
60+
return cls(provider)
61+
62+
@classmethod
63+
def from_view(cls, df: Any) -> TableProvider:
64+
"""Deprecated.
65+
66+
Use :meth:`DataFrame.into_view` or :meth:`TableProvider.from_dataframe`.
67+
"""
68+
from datafusion.dataframe import DataFrame as DataFrameWrapper
69+
70+
if isinstance(df, DataFrameWrapper):
71+
df = df.df
72+
73+
provider = _InternalTableProvider.from_view(df)
74+
warnings.warn(
75+
"TableProvider.from_view is deprecated; use DataFrame.into_view or "
76+
"TableProvider.from_dataframe instead.",
77+
DeprecationWarning,
78+
stacklevel=2,
79+
)
80+
return cls(provider)
81+
82+
# ------------------------------------------------------------------
83+
# passthrough helpers
84+
# ------------------------------------------------------------------
85+
def __getattr__(self, name: str) -> Any:
86+
"""Delegate attribute lookup to the wrapped provider."""
87+
return getattr(self._table_provider, name)
88+
89+
def __dir__(self) -> list[str]:
90+
"""Expose delegated attributes via :func:`dir`."""
91+
return dir(self._table_provider) + super().__dir__()
92+
93+
def __repr__(self) -> str: # pragma: no cover - simple delegation
94+
"""Return a representation of the wrapped provider."""
95+
return repr(self._table_provider)
96+
97+
def __datafusion_table_provider__(self) -> Any:
98+
"""Expose the wrapped provider for FFI integrations."""
99+
return self._table_provider.__datafusion_table_provider__()
100+
101+
102+
__all__ = ["TableProvider"]

python/tests/test_context.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import pytest
2424
from datafusion import (
2525
DataFrame,
26+
EXPECTED_PROVIDER_MSG,
2627
RuntimeEnvBuilder,
2728
SessionConfig,
2829
SessionContext,
@@ -350,7 +351,7 @@ def test_table_provider_from_capsule(ctx):
350351

351352

352353
def test_table_provider_from_dataframe(ctx):
353-
df = ctx.from_pydict({"a": [1, 2]}).df
354+
df = ctx.from_pydict({"a": [1, 2]})
354355
provider = TableProvider.from_dataframe(df)
355356
ctx.register_table("from_dataframe_tbl", provider)
356357
result = ctx.sql("SELECT * FROM from_dataframe_tbl").collect()
@@ -374,19 +375,16 @@ def __datafusion_table_provider__(self):
374375

375376

376377
def test_table_provider_from_capsule_invalid():
377-
with pytest.raises(Exception): # noqa: B017
378+
with pytest.raises(RuntimeError):
378379
TableProvider.from_capsule(object())
379380

380381

381382
def test_register_table_with_dataframe_errors(ctx):
382383
df = ctx.from_pydict({"a": [1]})
383-
with pytest.raises(Exception) as exc_info: # noqa: B017
384+
with pytest.raises(Exception) as exc_info:
384385
ctx.register_table("bad", df)
385386

386-
assert (
387-
str(exc_info.value)
388-
== 'Expected a Table or TableProvider. Convert DataFrames with "DataFrame.into_view()" or "TableProvider.from_dataframe()".'
389-
)
387+
assert str(exc_info.value) == EXPECTED_PROVIDER_MSG
390388

391389

392390
def test_register_dataset(ctx):

src/catalog.rs

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::dataframe::PyDataFrame;
1819
use crate::dataset::Dataset;
1920
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
2021
use crate::table::PyTableProvider;
2122
use crate::utils::{
22-
table_provider_from_pycapsule, table_provider_send_to_table_provider, table_provider_to_send,
23-
validate_pycapsule, wait_for_future,
23+
table_provider_from_pycapsule, validate_pycapsule, wait_for_future, EXPECTED_PROVIDER_MSG,
2424
};
2525
use async_trait::async_trait;
2626
use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
@@ -54,7 +54,7 @@ pub struct PySchema {
5454
#[pyclass(name = "RawTable", module = "datafusion.catalog", subclass)]
5555
#[derive(Clone)]
5656
pub struct PyTable {
57-
pub table: Arc<dyn TableProvider + Send>,
57+
pub table: Arc<dyn TableProvider>,
5858
}
5959

6060
impl From<Arc<dyn CatalogProvider>> for PyCatalog {
@@ -70,11 +70,11 @@ impl From<Arc<dyn SchemaProvider>> for PySchema {
7070
}
7171

7272
impl PyTable {
73-
pub fn new(table: Arc<dyn TableProvider + Send>) -> Self {
73+
pub fn new(table: Arc<dyn TableProvider>) -> Self {
7474
Self { table }
7575
}
7676

77-
pub fn table(&self) -> Arc<dyn TableProvider + Send> {
77+
pub fn table(&self) -> Arc<dyn TableProvider> {
7878
self.table.clone()
7979
}
8080
}
@@ -203,12 +203,14 @@ impl PySchema {
203203
py_table.table
204204
} else if let Ok(py_provider) = table_provider.extract::<PyTableProvider>() {
205205
py_provider.into_inner()
206+
} else if table_provider.extract::<PyDataFrame>().is_ok() {
207+
return Err(PyDataFusionError::Common(EXPECTED_PROVIDER_MSG.to_string()).into());
206208
} else if let Some(provider) = table_provider_from_pycapsule(&table_provider)? {
207209
provider
208210
} else {
209211
let py = table_provider.py();
210212
let provider = Dataset::new(&table_provider, py)?;
211-
Arc::new(provider) as Arc<dyn TableProvider + Send>
213+
Arc::new(provider) as Arc<dyn TableProvider>
212214
};
213215

214216
let _ = self
@@ -288,7 +290,7 @@ impl RustWrappedPySchemaProvider {
288290
}
289291
}
290292

291-
fn table_inner(&self, name: &str) -> PyResult<Option<Arc<dyn TableProvider + Send>>> {
293+
fn table_inner(&self, name: &str) -> PyResult<Option<Arc<dyn TableProvider>>> {
292294
Python::with_gil(|py| {
293295
let provider = self.schema_provider.bind(py);
294296
let py_table_method = provider.getattr("table")?;
@@ -315,7 +317,7 @@ impl RustWrappedPySchemaProvider {
315317
Ok(py_table) => Ok(Some(py_table.table)),
316318
Err(_) => {
317319
let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?;
318-
Ok(Some(Arc::new(ds) as Arc<dyn TableProvider + Send>))
320+
Ok(Some(Arc::new(ds) as Arc<dyn TableProvider>))
319321
}
320322
}
321323
}
@@ -351,22 +353,15 @@ impl SchemaProvider for RustWrappedPySchemaProvider {
351353
&self,
352354
name: &str,
353355
) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
354-
// Convert from our internal Send type to the trait expected type
355-
match self.table_inner(name).map_err(to_datafusion_err)? {
356-
Some(table) => Ok(Some(table_provider_send_to_table_provider(table))),
357-
None => Ok(None),
358-
}
356+
self.table_inner(name).map_err(to_datafusion_err)
359357
}
360358

361359
fn register_table(
362360
&self,
363361
name: String,
364362
table: Arc<dyn TableProvider>,
365363
) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
366-
// Convert from trait type to our internal Send type
367-
let send_table = table_provider_to_send(table);
368-
369-
let py_table = PyTable::new(send_table);
364+
let py_table = PyTable::new(table);
370365
Python::with_gil(|py| {
371366
let provider = self.schema_provider.bind(py);
372367
let _ = provider
@@ -395,10 +390,7 @@ impl SchemaProvider for RustWrappedPySchemaProvider {
395390
// If we can turn this table provider into a `Dataset`, return it.
396391
// Otherwise, return None.
397392
let dataset = match Dataset::new(&table, py) {
398-
Ok(dataset) => {
399-
let send_table = Arc::new(dataset) as Arc<dyn TableProvider + Send>;
400-
Some(table_provider_send_to_table_provider(send_table))
401-
}
393+
Ok(dataset) => Some(Arc::new(dataset) as Arc<dyn TableProvider>),
402394
Err(_) => None,
403395
};
404396

src/context.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ use crate::udtf::PyTableFunction;
4848
use crate::udwf::PyWindowUDF;
4949
use crate::utils::{
5050
get_global_ctx, get_tokio_runtime, table_provider_from_pycapsule, validate_pycapsule,
51-
wait_for_future,
51+
wait_for_future, EXPECTED_PROVIDER_MSG,
5252
};
5353
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
5454
use datafusion::arrow::pyarrow::PyArrowType;
@@ -618,7 +618,7 @@ impl PySessionContext {
618618
provider
619619
} else {
620620
return Err(crate::errors::PyDataFusionError::Common(
621-
"Expected a Table or TableProvider. Convert DataFrames with \"DataFrame.into_view()\" or \"TableProvider.from_dataframe()\".".to_string(),
621+
EXPECTED_PROVIDER_MSG.to_string(),
622622
));
623623
};
624624

@@ -852,7 +852,7 @@ impl PySessionContext {
852852
dataset: &Bound<'_, PyAny>,
853853
py: Python,
854854
) -> PyDataFusionResult<()> {
855-
let table: Arc<dyn TableProvider + Send> = Arc::new(Dataset::new(dataset, py)?);
855+
let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
856856

857857
self.ctx.register_table(name, table)?;
858858

src/dataframe.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ impl PyDataFrame {
268268
}
269269
}
270270

271-
pub(crate) fn to_view_provider(&self) -> Arc<dyn TableProvider + Send> {
271+
pub(crate) fn to_view_provider(&self) -> Arc<dyn TableProvider> {
272272
self.df.as_ref().clone().into_view()
273273
}
274274

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
8181
// Initialize logging
8282
pyo3_log::init();
8383

84+
m.add("EXPECTED_PROVIDER_MSG", crate::utils::EXPECTED_PROVIDER_MSG)?;
85+
8486
// Register the python classes
8587
m.add_class::<context::PyRuntimeEnvBuilder>()?;
8688
m.add_class::<context::PySessionConfig>()?;

0 commit comments

Comments
 (0)