Skip to content

Commit 03842b6

Browse files
committed
fix: replace foreign_table_provider_from_capsule with try_table_provider_from_object for improved clarity and error handling
1 parent d563352 commit 03842b6

File tree

4 files changed

+31
-37
lines changed

4 files changed

+31
-37
lines changed

src/catalog.rs

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use crate::dataset::Dataset;
1919
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
2020
use crate::utils::{
21-
foreign_table_provider_from_capsule, get_tokio_runtime, validate_pycapsule, wait_for_future,
21+
get_tokio_runtime, try_table_provider_from_object, validate_pycapsule, wait_for_future,
2222
};
2323
use async_trait::async_trait;
2424
use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
@@ -201,13 +201,7 @@ impl PySchema {
201201
fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
202202
let provider = if let Ok(py_table) = table_provider.extract::<PyTable>() {
203203
py_table.table
204-
} else if table_provider.hasattr("__datafusion_table_provider__")? {
205-
let capsule = table_provider
206-
.getattr("__datafusion_table_provider__")?
207-
.call0()?;
208-
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
209-
let provider = foreign_table_provider_from_capsule(capsule)?;
210-
let provider: Arc<dyn TableProvider> = Arc::new(provider);
204+
} else if let Some(provider) = try_table_provider_from_object(&table_provider)? {
211205
provider
212206
} else {
213207
let py = table_provider.py();
@@ -320,12 +314,7 @@ impl RustWrappedPySchemaProvider {
320314
return Ok(Some(inner_table.table));
321315
}
322316

323-
if py_table.hasattr("__datafusion_table_provider__")? {
324-
let capsule = py_table.getattr("__datafusion_table_provider__")?.call0()?;
325-
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
326-
let provider = foreign_table_provider_from_capsule(capsule)?;
327-
let provider: Arc<dyn TableProvider> = Arc::new(provider);
328-
317+
if let Some(provider) = try_table_provider_from_object(&py_table)? {
329318
Ok(Some(provider))
330319
} else {
331320
if let Ok(inner_table) = py_table.getattr("table") {

src/context.rs

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ use crate::udf::PyScalarUDF;
4646
use crate::udtf::PyTableFunction;
4747
use crate::udwf::PyWindowUDF;
4848
use crate::utils::{
49-
foreign_table_provider_from_capsule, get_global_ctx, get_tokio_runtime, validate_pycapsule,
49+
get_global_ctx, get_tokio_runtime, try_table_provider_from_object, validate_pycapsule,
5050
wait_for_future,
5151
};
5252
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
@@ -653,12 +653,7 @@ impl PySessionContext {
653653
name: &str,
654654
provider: Bound<'_, PyAny>,
655655
) -> PyDataFusionResult<()> {
656-
if provider.hasattr("__datafusion_table_provider__")? {
657-
let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
658-
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
659-
let provider = foreign_table_provider_from_capsule(capsule)?;
660-
let provider: Arc<dyn TableProvider> = Arc::new(provider);
661-
656+
if let Some(provider) = try_table_provider_from_object(&provider)? {
662657
let _ = self.ctx.register_table(name, provider)?;
663658

664659
Ok(())
@@ -1110,12 +1105,7 @@ impl PySessionContext {
11101105
return Ok(PyDataFrame::new(df));
11111106
}
11121107

1113-
if table.hasattr("__datafusion_table_provider__")? {
1114-
let capsule = table.getattr("__datafusion_table_provider__")?.call0()?;
1115-
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
1116-
let provider = foreign_table_provider_from_capsule(capsule)?;
1117-
let provider: Arc<dyn TableProvider> = Arc::new(provider);
1118-
1108+
if let Some(provider) = try_table_provider_from_object(&table)? {
11191109
let df = self.ctx.read_table(provider)?;
11201110
Ok(PyDataFrame::new(df))
11211111
} else {

src/udtf.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use std::sync::Arc;
2121
use crate::dataframe::PyTableProvider;
2222
use crate::errors::{py_datafusion_err, to_datafusion_err};
2323
use crate::expr::PyExpr;
24-
use crate::utils::{foreign_table_provider_from_capsule, validate_pycapsule};
24+
use crate::utils::{try_table_provider_from_object, validate_pycapsule};
2525
use datafusion::catalog::{TableFunctionImpl, TableProvider};
2626
use datafusion::error::Result as DataFusionResult;
2727
use datafusion::logical_expr::Expr;
@@ -98,12 +98,7 @@ fn call_python_table_function(
9898
let provider_obj = func.call1(py, py_args)?;
9999
let provider = provider_obj.bind(py);
100100

101-
if provider.hasattr("__datafusion_table_provider__")? {
102-
let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
103-
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
104-
let provider = foreign_table_provider_from_capsule(capsule)?;
105-
let provider: Arc<dyn TableProvider> = Arc::new(provider);
106-
101+
if let Some(provider) = try_table_provider_from_object(provider)? {
107102
Ok(provider)
108103
} else {
109104
Err(PyNotImplementedError::new_err(

src/utils.rs

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,21 @@
1717

1818
use crate::{
1919
common::data_type::PyScalarValue,
20-
errors::{PyDataFusionError, PyDataFusionResult},
20+
errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult},
2121
TokioRuntime,
2222
};
2323
use datafusion::{
24-
common::ScalarValue, execution::context::SessionContext, logical_expr::Volatility,
24+
common::ScalarValue, datasource::TableProvider, execution::context::SessionContext,
25+
logical_expr::Volatility,
2526
};
2627
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
2728
use pyo3::prelude::*;
2829
use pyo3::{exceptions::PyValueError, types::PyCapsule};
29-
use std::{future::Future, sync::OnceLock, time::Duration};
30+
use std::{
31+
future::Future,
32+
sync::{Arc, OnceLock},
33+
time::Duration,
34+
};
3035
use tokio::{runtime::Runtime, time::sleep};
3136
/// Utility to get the Tokio Runtime from Python
3237
#[inline]
@@ -124,6 +129,21 @@ pub(crate) fn foreign_table_provider_from_capsule(
124129
Ok(unsafe { capsule.reference::<FFI_TableProvider>() }.into())
125130
}
126131

132+
pub(crate) fn try_table_provider_from_object(
133+
provider: &Bound<'_, PyAny>,
134+
) -> PyResult<Option<Arc<dyn TableProvider>>> {
135+
if !provider.hasattr("__datafusion_table_provider__")? {
136+
return Ok(None);
137+
}
138+
139+
let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
140+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
141+
let provider = foreign_table_provider_from_capsule(capsule)?;
142+
let provider: Arc<dyn TableProvider> = Arc::new(provider);
143+
144+
Ok(Some(provider))
145+
}
146+
127147
pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult<ScalarValue> {
128148
// convert Python object to PyScalarValue to ScalarValue
129149

0 commit comments

Comments
 (0)