|
18 | 18 | use crate::dataset::Dataset; |
19 | 19 | use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; |
20 | 20 | use crate::utils::{ |
21 | | - foreign_table_provider_from_capsule, get_tokio_runtime, validate_pycapsule, wait_for_future, |
| 21 | + get_tokio_runtime, try_table_provider_from_object, validate_pycapsule, wait_for_future, |
22 | 22 | }; |
23 | 23 | use async_trait::async_trait; |
24 | 24 | use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider}; |
@@ -201,13 +201,7 @@ impl PySchema { |
201 | 201 | fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> { |
202 | 202 | let provider = if let Ok(py_table) = table_provider.extract::<PyTable>() { |
203 | 203 | py_table.table |
204 | | - } else if table_provider.hasattr("__datafusion_table_provider__")? { |
205 | | - let capsule = table_provider |
206 | | - .getattr("__datafusion_table_provider__")? |
207 | | - .call0()?; |
208 | | - let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?; |
209 | | - let provider = foreign_table_provider_from_capsule(capsule)?; |
210 | | - let provider: Arc<dyn TableProvider> = Arc::new(provider); |
| 204 | + } else if let Some(provider) = try_table_provider_from_object(&table_provider)? { |
211 | 205 | provider |
212 | 206 | } else { |
213 | 207 | let py = table_provider.py(); |
@@ -320,12 +314,7 @@ impl RustWrappedPySchemaProvider { |
320 | 314 | return Ok(Some(inner_table.table)); |
321 | 315 | } |
322 | 316 |
|
323 | | - if py_table.hasattr("__datafusion_table_provider__")? { |
324 | | - let capsule = py_table.getattr("__datafusion_table_provider__")?.call0()?; |
325 | | - let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?; |
326 | | - let provider = foreign_table_provider_from_capsule(capsule)?; |
327 | | - let provider: Arc<dyn TableProvider> = Arc::new(provider); |
328 | | - |
| 317 | + if let Some(provider) = try_table_provider_from_object(&py_table)? { |
329 | 318 | Ok(Some(provider)) |
330 | 319 | } else { |
331 | 320 | if let Ok(inner_table) = py_table.getattr("table") { |
|
0 commit comments