@@ -45,7 +45,10 @@ use crate::udaf::PyAggregateUDF;
45
45
use crate :: udf:: PyScalarUDF ;
46
46
use crate :: udtf:: PyTableFunction ;
47
47
use crate :: udwf:: PyWindowUDF ;
48
- use crate :: utils:: { get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future} ;
48
+ use crate :: utils:: {
49
+ foreign_table_provider_from_capsule, get_global_ctx, get_tokio_runtime, validate_pycapsule,
50
+ wait_for_future,
51
+ } ;
49
52
use datafusion:: arrow:: datatypes:: { DataType , Schema , SchemaRef } ;
50
53
use datafusion:: arrow:: pyarrow:: PyArrowType ;
51
54
use datafusion:: arrow:: record_batch:: RecordBatch ;
@@ -71,7 +74,6 @@ use datafusion::prelude::{
71
74
AvroReadOptions , CsvReadOptions , DataFrame , NdJsonReadOptions , ParquetReadOptions ,
72
75
} ;
73
76
use datafusion_ffi:: catalog_provider:: { FFI_CatalogProvider , ForeignCatalogProvider } ;
74
- use datafusion_ffi:: table_provider:: { FFI_TableProvider , ForeignTableProvider } ;
75
77
use pyo3:: types:: { PyCapsule , PyDict , PyList , PyTuple , PyType } ;
76
78
use pyo3:: IntoPyObjectExt ;
77
79
use tokio:: task:: JoinHandle ;
@@ -654,12 +656,10 @@ impl PySessionContext {
654
656
if provider. hasattr ( "__datafusion_table_provider__" ) ? {
655
657
let capsule = provider. getattr ( "__datafusion_table_provider__" ) ?. call0 ( ) ?;
656
658
let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
657
- validate_pycapsule ( capsule, "datafusion_table_provider" ) ?;
658
-
659
- let provider = unsafe { capsule. reference :: < FFI_TableProvider > ( ) } ;
660
- let provider: ForeignTableProvider = provider. into ( ) ;
659
+ let provider = foreign_table_provider_from_capsule ( capsule) ?;
660
+ let provider: Arc < dyn TableProvider > = Arc :: new ( provider) ;
661
661
662
- let _ = self . ctx . register_table ( name, Arc :: new ( provider) ) ?;
662
+ let _ = self . ctx . register_table ( name, provider) ?;
663
663
664
664
Ok ( ( ) )
665
665
} else {
@@ -1113,12 +1113,10 @@ impl PySessionContext {
1113
1113
if table. hasattr ( "__datafusion_table_provider__" ) ? {
1114
1114
let capsule = table. getattr ( "__datafusion_table_provider__" ) ?. call0 ( ) ?;
1115
1115
let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
1116
- validate_pycapsule ( capsule, "datafusion_table_provider" ) ?;
1117
-
1118
- let provider = unsafe { capsule. reference :: < FFI_TableProvider > ( ) } ;
1119
- let provider: ForeignTableProvider = provider. into ( ) ;
1116
+ let provider = foreign_table_provider_from_capsule ( capsule) ?;
1117
+ let provider: Arc < dyn TableProvider > = Arc :: new ( provider) ;
1120
1118
1121
- let df = self . ctx . read_table ( Arc :: new ( provider) ) ?;
1119
+ let df = self . ctx . read_table ( provider) ?;
1122
1120
Ok ( PyDataFrame :: new ( df) )
1123
1121
} else {
1124
1122
Err ( crate :: errors:: PyDataFusionError :: Common (
0 commit comments