|
18 | 18 | use crate::dataset::Dataset;
|
19 | 19 | use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
|
20 | 20 | use crate::utils::{
|
21 |
| - foreign_table_provider_from_capsule, get_tokio_runtime, validate_pycapsule, wait_for_future, |
| 21 | + get_tokio_runtime, try_table_provider_from_object, validate_pycapsule, wait_for_future, |
22 | 22 | };
|
23 | 23 | use async_trait::async_trait;
|
24 | 24 | use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
|
@@ -201,13 +201,7 @@ impl PySchema {
|
201 | 201 | fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
|
202 | 202 | let provider = if let Ok(py_table) = table_provider.extract::<PyTable>() {
|
203 | 203 | py_table.table
|
204 |
| - } else if table_provider.hasattr("__datafusion_table_provider__")? { |
205 |
| - let capsule = table_provider |
206 |
| - .getattr("__datafusion_table_provider__")? |
207 |
| - .call0()?; |
208 |
| - let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?; |
209 |
| - let provider = foreign_table_provider_from_capsule(capsule)?; |
210 |
| - let provider: Arc<dyn TableProvider> = Arc::new(provider); |
| 204 | + } else if let Some(provider) = try_table_provider_from_object(&table_provider)? { |
211 | 205 | provider
|
212 | 206 | } else {
|
213 | 207 | let py = table_provider.py();
|
@@ -320,12 +314,7 @@ impl RustWrappedPySchemaProvider {
|
320 | 314 | return Ok(Some(inner_table.table));
|
321 | 315 | }
|
322 | 316 |
|
323 |
| - if py_table.hasattr("__datafusion_table_provider__")? { |
324 |
| - let capsule = py_table.getattr("__datafusion_table_provider__")?.call0()?; |
325 |
| - let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?; |
326 |
| - let provider = foreign_table_provider_from_capsule(capsule)?; |
327 |
| - let provider: Arc<dyn TableProvider> = Arc::new(provider); |
328 |
| - |
| 317 | + if let Some(provider) = try_table_provider_from_object(&py_table)? { |
329 | 318 | Ok(Some(provider))
|
330 | 319 | } else {
|
331 | 320 | if let Ok(inner_table) = py_table.getattr("table") {
|
|
0 commit comments