@@ -45,7 +45,10 @@ use crate::udaf::PyAggregateUDF;
4545use crate :: udf:: PyScalarUDF ;
4646use crate :: udtf:: PyTableFunction ;
4747use crate :: udwf:: PyWindowUDF ;
48- use crate :: utils:: { get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future} ;
48+ use crate :: utils:: {
49+ foreign_table_provider_from_capsule, get_global_ctx, get_tokio_runtime, validate_pycapsule,
50+ wait_for_future,
51+ } ;
4952use datafusion:: arrow:: datatypes:: { DataType , Schema , SchemaRef } ;
5053use datafusion:: arrow:: pyarrow:: PyArrowType ;
5154use datafusion:: arrow:: record_batch:: RecordBatch ;
@@ -71,7 +74,6 @@ use datafusion::prelude::{
7174 AvroReadOptions , CsvReadOptions , DataFrame , NdJsonReadOptions , ParquetReadOptions ,
7275} ;
7376use datafusion_ffi:: catalog_provider:: { FFI_CatalogProvider , ForeignCatalogProvider } ;
74- use datafusion_ffi:: table_provider:: { FFI_TableProvider , ForeignTableProvider } ;
7577use pyo3:: types:: { PyCapsule , PyDict , PyList , PyTuple , PyType } ;
7678use pyo3:: IntoPyObjectExt ;
7779use tokio:: task:: JoinHandle ;
@@ -654,12 +656,10 @@ impl PySessionContext {
654656 if provider. hasattr ( "__datafusion_table_provider__" ) ? {
655657 let capsule = provider. getattr ( "__datafusion_table_provider__" ) ?. call0 ( ) ?;
656658 let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
657- validate_pycapsule ( capsule, "datafusion_table_provider" ) ?;
658-
659- let provider = unsafe { capsule. reference :: < FFI_TableProvider > ( ) } ;
660- let provider: ForeignTableProvider = provider. into ( ) ;
659+ let provider = foreign_table_provider_from_capsule ( capsule) ?;
660+ let provider: Arc < dyn TableProvider > = Arc :: new ( provider) ;
661661
662- let _ = self . ctx . register_table ( name, Arc :: new ( provider) ) ?;
662+ let _ = self . ctx . register_table ( name, provider) ?;
663663
664664 Ok ( ( ) )
665665 } else {
@@ -1113,12 +1113,10 @@ impl PySessionContext {
11131113 if table. hasattr ( "__datafusion_table_provider__" ) ? {
11141114 let capsule = table. getattr ( "__datafusion_table_provider__" ) ?. call0 ( ) ?;
11151115 let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
1116- validate_pycapsule ( capsule, "datafusion_table_provider" ) ?;
1117-
1118- let provider = unsafe { capsule. reference :: < FFI_TableProvider > ( ) } ;
1119- let provider: ForeignTableProvider = provider. into ( ) ;
1116+ let provider = foreign_table_provider_from_capsule ( capsule) ?;
1117+ let provider: Arc < dyn TableProvider > = Arc :: new ( provider) ;
11201118
1121- let df = self . ctx . read_table ( Arc :: new ( provider) ) ?;
1119+ let df = self . ctx . read_table ( provider) ?;
11221120 Ok ( PyDataFrame :: new ( df) )
11231121 } else {
11241122 Err ( crate :: errors:: PyDataFusionError :: Common (
0 commit comments