Skip to content

Commit d563352

Browse files
committed
fix: refactor foreign table provider handling for improved clarity and safety
1 parent 68cb22f commit d563352

File tree

4 files changed

+32
-31
lines changed

4 files changed

+32
-31
lines changed

src/catalog.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
use crate::dataset::Dataset;
1919
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
20-
use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
20+
use crate::utils::{
21+
foreign_table_provider_from_capsule, get_tokio_runtime, validate_pycapsule, wait_for_future,
22+
};
2123
use async_trait::async_trait;
2224
use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
2325
use datafusion::common::DataFusionError;
@@ -27,7 +29,7 @@ use datafusion::{
2729
datasource::{TableProvider, TableType},
2830
};
2931
use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider};
30-
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
32+
use datafusion_ffi::table_provider::FFI_TableProvider;
3133
use pyo3::exceptions::PyKeyError;
3234
use pyo3::prelude::*;
3335
use pyo3::types::PyCapsule;
@@ -204,11 +206,9 @@ impl PySchema {
204206
.getattr("__datafusion_table_provider__")?
205207
.call0()?;
206208
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
207-
validate_pycapsule(capsule, "datafusion_table_provider")?;
208-
209-
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
210-
let provider: ForeignTableProvider = provider.into();
211-
Arc::new(provider) as Arc<dyn TableProvider>
209+
let provider = foreign_table_provider_from_capsule(capsule)?;
210+
let provider: Arc<dyn TableProvider> = Arc::new(provider);
211+
provider
212212
} else {
213213
let py = table_provider.py();
214214
let provider = Dataset::new(&table_provider, py)?;
@@ -323,12 +323,10 @@ impl RustWrappedPySchemaProvider {
323323
if py_table.hasattr("__datafusion_table_provider__")? {
324324
let capsule = py_table.getattr("__datafusion_table_provider__")?.call0()?;
325325
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
326-
validate_pycapsule(capsule, "datafusion_table_provider")?;
327-
328-
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
329-
let provider: ForeignTableProvider = provider.into();
326+
let provider = foreign_table_provider_from_capsule(capsule)?;
327+
let provider: Arc<dyn TableProvider> = Arc::new(provider);
330328

331-
Ok(Some(Arc::new(provider) as Arc<dyn TableProvider>))
329+
Ok(Some(provider))
332330
} else {
333331
if let Ok(inner_table) = py_table.getattr("table") {
334332
if let Ok(inner_table) = inner_table.extract::<PyTable>() {

src/context.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ use crate::udaf::PyAggregateUDF;
4545
use crate::udf::PyScalarUDF;
4646
use crate::udtf::PyTableFunction;
4747
use crate::udwf::PyWindowUDF;
48-
use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
48+
use crate::utils::{
49+
foreign_table_provider_from_capsule, get_global_ctx, get_tokio_runtime, validate_pycapsule,
50+
wait_for_future,
51+
};
4952
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
5053
use datafusion::arrow::pyarrow::PyArrowType;
5154
use datafusion::arrow::record_batch::RecordBatch;
@@ -71,7 +74,6 @@ use datafusion::prelude::{
7174
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
7275
};
7376
use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider};
74-
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
7577
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
7678
use pyo3::IntoPyObjectExt;
7779
use tokio::task::JoinHandle;
@@ -654,12 +656,10 @@ impl PySessionContext {
654656
if provider.hasattr("__datafusion_table_provider__")? {
655657
let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
656658
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
657-
validate_pycapsule(capsule, "datafusion_table_provider")?;
658-
659-
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
660-
let provider: ForeignTableProvider = provider.into();
659+
let provider = foreign_table_provider_from_capsule(capsule)?;
660+
let provider: Arc<dyn TableProvider> = Arc::new(provider);
661661

662-
let _ = self.ctx.register_table(name, Arc::new(provider))?;
662+
let _ = self.ctx.register_table(name, provider)?;
663663

664664
Ok(())
665665
} else {
@@ -1113,12 +1113,10 @@ impl PySessionContext {
11131113
if table.hasattr("__datafusion_table_provider__")? {
11141114
let capsule = table.getattr("__datafusion_table_provider__")?.call0()?;
11151115
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
1116-
validate_pycapsule(capsule, "datafusion_table_provider")?;
1117-
1118-
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
1119-
let provider: ForeignTableProvider = provider.into();
1116+
let provider = foreign_table_provider_from_capsule(capsule)?;
1117+
let provider: Arc<dyn TableProvider> = Arc::new(provider);
11201118

1121-
let df = self.ctx.read_table(Arc::new(provider))?;
1119+
let df = self.ctx.read_table(provider)?;
11221120
Ok(PyDataFrame::new(df))
11231121
} else {
11241122
Err(crate::errors::PyDataFusionError::Common(

src/udtf.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,10 @@ use std::sync::Arc;
2121
use crate::dataframe::PyTableProvider;
2222
use crate::errors::{py_datafusion_err, to_datafusion_err};
2323
use crate::expr::PyExpr;
24-
use crate::utils::validate_pycapsule;
24+
use crate::utils::{foreign_table_provider_from_capsule, validate_pycapsule};
2525
use datafusion::catalog::{TableFunctionImpl, TableProvider};
2626
use datafusion::error::Result as DataFusionResult;
2727
use datafusion::logical_expr::Expr;
28-
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
2928
use datafusion_ffi::udtf::{FFI_TableFunction, ForeignTableFunction};
3029
use pyo3::exceptions::PyNotImplementedError;
3130
use pyo3::types::{PyCapsule, PyTuple};
@@ -102,12 +101,10 @@ fn call_python_table_function(
102101
if provider.hasattr("__datafusion_table_provider__")? {
103102
let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
104103
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
105-
validate_pycapsule(capsule, "datafusion_table_provider")?;
104+
let provider = foreign_table_provider_from_capsule(capsule)?;
105+
let provider: Arc<dyn TableProvider> = Arc::new(provider);
106106

107-
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
108-
let provider: ForeignTableProvider = provider.into();
109-
110-
Ok(Arc::new(provider) as Arc<dyn TableProvider>)
107+
Ok(provider)
111108
} else {
112109
Err(PyNotImplementedError::new_err(
113110
"__datafusion_table_provider__ does not exist on Table Provider object.",

src/utils.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use crate::{
2323
use datafusion::{
2424
common::ScalarValue, execution::context::SessionContext, logical_expr::Volatility,
2525
};
26+
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
2627
use pyo3::prelude::*;
2728
use pyo3::{exceptions::PyValueError, types::PyCapsule};
2829
use std::{future::Future, sync::OnceLock, time::Duration};
@@ -116,6 +117,13 @@ pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyRe
116117
Ok(())
117118
}
118119

120+
pub(crate) fn foreign_table_provider_from_capsule(
121+
capsule: &Bound<PyCapsule>,
122+
) -> PyResult<ForeignTableProvider> {
123+
validate_pycapsule(capsule, "datafusion_table_provider")?;
124+
Ok(unsafe { capsule.reference::<FFI_TableProvider>() }.into())
125+
}
126+
119127
pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult<ScalarValue> {
120128
// convert Python object to PyScalarValue to ScalarValue
121129

0 commit comments

Comments
 (0)