@@ -19,7 +19,7 @@ use crate::dataset::Dataset;
1919use crate :: errors:: { py_datafusion_err, to_datafusion_err, PyDataFusionError , PyDataFusionResult } ;
2020use crate :: utils:: { validate_pycapsule, wait_for_future} ;
2121use async_trait:: async_trait;
22- use datafusion:: catalog:: MemorySchemaProvider ;
22+ use datafusion:: catalog:: { MemoryCatalogProvider , MemorySchemaProvider } ;
2323use datafusion:: common:: DataFusionError ;
2424use datafusion:: {
2525 arrow:: pyarrow:: ToPyArrow ,
@@ -37,16 +37,19 @@ use std::collections::HashSet;
3737use std:: sync:: Arc ;
3838
3939#[ pyclass( name = "RawCatalog" , module = "datafusion.catalog" , subclass) ]
40+ #[ derive( Clone ) ]
4041pub struct PyCatalog {
4142 pub catalog : Arc < dyn CatalogProvider > ,
4243}
4344
4445#[ pyclass( name = "RawSchema" , module = "datafusion.catalog" , subclass) ]
46+ #[ derive( Clone ) ]
4547pub struct PySchema {
4648 pub schema : Arc < dyn SchemaProvider > ,
4749}
4850
4951#[ pyclass( name = "RawTable" , module = "datafusion.catalog" , subclass) ]
52+ #[ derive( Clone ) ]
5053pub struct PyTable {
5154 pub table : Arc < dyn TableProvider > ,
5255}
@@ -82,6 +85,13 @@ impl PyCatalog {
8285 catalog_provider. into ( )
8386 }
8487
88+ #[ staticmethod]
89+ fn memory_catalog ( ) -> Self {
90+ let catalog_provider =
91+ Arc :: new ( MemoryCatalogProvider :: default ( ) ) as Arc < dyn CatalogProvider > ;
92+ catalog_provider. into ( )
93+ }
94+
8595 fn schema_names ( & self ) -> HashSet < String > {
8696 self . catalog . schema_names ( ) . into_iter ( ) . collect ( )
8797 }
@@ -106,16 +116,6 @@ impl PyCatalog {
106116 } )
107117 }
108118
109- fn new_in_memory_schema ( & mut self , name : & str ) -> PyResult < ( ) > {
110- let schema = Arc :: new ( MemorySchemaProvider :: new ( ) ) as Arc < dyn SchemaProvider > ;
111- let _ = self
112- . catalog
113- . register_schema ( name, schema)
114- . map_err ( py_datafusion_err) ?;
115-
116- Ok ( ( ) )
117- }
118-
119119 fn register_schema ( & self , name : & str , schema_provider : Bound < ' _ , PyAny > ) -> PyResult < ( ) > {
120120 let provider = if schema_provider. hasattr ( "__datafusion_schema_provider__" ) ? {
121121 let capsule = schema_provider
@@ -128,8 +128,11 @@ impl PyCatalog {
128128 let provider: ForeignSchemaProvider = provider. into ( ) ;
129129 Arc :: new ( provider) as Arc < dyn SchemaProvider >
130130 } else {
131- let provider = RustWrappedPySchemaProvider :: new ( schema_provider. into ( ) ) ;
132- Arc :: new ( provider) as Arc < dyn SchemaProvider >
131+ match schema_provider. extract :: < PySchema > ( ) {
132+ Ok ( py_schema) => py_schema. schema ,
133+ Err ( _) => Arc :: new ( RustWrappedPySchemaProvider :: new ( schema_provider. into ( ) ) )
134+ as Arc < dyn SchemaProvider > ,
135+ }
133136 } ;
134137
135138 let _ = self
@@ -165,6 +168,12 @@ impl PySchema {
165168 schema_provider. into ( )
166169 }
167170
171+ #[ staticmethod]
172+ fn memory_schema ( ) -> Self {
173+ let schema_provider = Arc :: new ( MemorySchemaProvider :: default ( ) ) as Arc < dyn SchemaProvider > ;
174+ schema_provider. into ( )
175+ }
176+
168177 #[ getter]
169178 fn table_names ( & self ) -> HashSet < String > {
170179 self . schema . table_names ( ) . into_iter ( ) . collect ( )
0 commit comments