Skip to content

Commit 9cd94de

Browse files
committed
Add support for creating in memory catalog and schema
1 parent 7a66a60 commit 9cd94de

File tree

5 files changed

+72
-35
lines changed

5 files changed

+72
-35
lines changed

python/datafusion/catalog.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ def schema_names(self) -> set[str]:
5858
"""Returns the list of schemas in this catalog."""
5959
return self.catalog.schema_names()
6060

61+
@staticmethod
62+
def memory_catalog() -> Catalog:
63+
"""Create an in-memory catalog provider."""
64+
catalog = df_internal.catalog.RawCatalog.memory_catalog()
65+
return Catalog(catalog)
66+
6167
def schema(self, name: str = "public") -> Schema:
6268
"""Returns the database with the given ``name`` from this catalog."""
6369
schema = self.catalog.schema(name)
@@ -73,13 +79,10 @@ def database(self, name: str = "public") -> Schema:
7379
"""Returns the database with the given ``name`` from this catalog."""
7480
return self.schema(name)
7581

76-
def new_in_memory_schema(self, name: str) -> Schema:
77-
"""Create a new schema in this catalog using an in-memory provider."""
78-
self.catalog.new_in_memory_schema(name)
79-
return self.schema(name)
80-
8182
def register_schema(self, name, schema) -> Schema | None:
8283
"""Register a schema with this catalog."""
84+
if isinstance(schema, Schema):
85+
return self.catalog.register_schema(name, schema._raw_schema)
8386
return self.catalog.register_schema(name, schema)
8487

8588
def deregister_schema(self, name: str, cascade: bool = True) -> Schema | None:
@@ -98,6 +101,12 @@ def __repr__(self) -> str:
98101
"""Print a string representation of the schema."""
99102
return self._raw_schema.__repr__()
100103

104+
@staticmethod
105+
def memory_schema() -> Schema:
106+
"""Create an in-memory schema provider."""
107+
schema = df_internal.catalog.RawSchema.memory_schema()
108+
return Schema(schema)
109+
101110
def names(self) -> set[str]:
102111
"""This is an alias for `table_names`."""
103112
return self.table_names()

python/datafusion/context.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -759,16 +759,14 @@ def catalog_names(self) -> set[str]:
759759
"""Returns the list of catalogs in this context."""
760760
return self.ctx.catalog_names()
761761

762-
def new_in_memory_catalog(self, name: str) -> Catalog:
763-
"""Create a new catalog in this context using an in-memory provider."""
764-
self.ctx.new_in_memory_catalog(name)
765-
return self.catalog(name)
766-
767762
def register_catalog_provider(
768-
self, name: str, provider: CatalogProviderExportable
763+
self, name: str, provider: CatalogProviderExportable | Catalog
769764
) -> None:
770765
"""Register a catalog provider."""
771-
self.ctx.register_catalog_provider(name, provider)
766+
if isinstance(provider, Catalog):
767+
self.ctx.register_catalog_provider(name, provider.catalog)
768+
else:
769+
self.ctx.register_catalog_provider(name, provider)
772770

773771
def register_table_provider(
774772
self, name: str, provider: TableProviderExportable

python/tests/test_catalog.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,24 @@ def test_python_catalog_provider(ctx: SessionContext):
106106
assert my_catalog.schema_names() == {"second_schema"}
107107

108108

109+
def test_in_memory_providers(ctx: SessionContext):
110+
catalog = dfn.catalog.Catalog.memory_catalog()
111+
ctx.register_catalog_provider("in_mem_catalog", catalog)
112+
113+
assert ctx.catalog_names() == {"datafusion", "in_mem_catalog"}
114+
115+
schema = dfn.catalog.Schema.memory_schema()
116+
catalog.register_schema("in_mem_schema", schema)
117+
118+
schema.register_table("my_table", create_dataset())
119+
120+
batches = ctx.sql("select * from in_mem_catalog.in_mem_schema.my_table").collect()
121+
122+
assert len(batches) == 1
123+
assert batches[0].column(0) == pa.array([1, 2, 3])
124+
assert batches[0].column(1) == pa.array([4, 5, 6])
125+
126+
109127
def test_python_schema_provider(ctx: SessionContext):
110128
catalog = ctx.catalog()
111129

src/catalog.rs

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::dataset::Dataset;
1919
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
2020
use crate::utils::{validate_pycapsule, wait_for_future};
2121
use async_trait::async_trait;
22-
use datafusion::catalog::MemorySchemaProvider;
22+
use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
2323
use datafusion::common::DataFusionError;
2424
use datafusion::{
2525
arrow::pyarrow::ToPyArrow,
@@ -37,16 +37,19 @@ use std::collections::HashSet;
3737
use std::sync::Arc;
3838

3939
#[pyclass(name = "RawCatalog", module = "datafusion.catalog", subclass)]
40+
#[derive(Clone)]
4041
pub struct PyCatalog {
4142
pub catalog: Arc<dyn CatalogProvider>,
4243
}
4344

4445
#[pyclass(name = "RawSchema", module = "datafusion.catalog", subclass)]
46+
#[derive(Clone)]
4547
pub struct PySchema {
4648
pub schema: Arc<dyn SchemaProvider>,
4749
}
4850

4951
#[pyclass(name = "RawTable", module = "datafusion.catalog", subclass)]
52+
#[derive(Clone)]
5053
pub struct PyTable {
5154
pub table: Arc<dyn TableProvider>,
5255
}
@@ -82,6 +85,13 @@ impl PyCatalog {
8285
catalog_provider.into()
8386
}
8487

88+
#[staticmethod]
89+
fn memory_catalog() -> Self {
90+
let catalog_provider =
91+
Arc::new(MemoryCatalogProvider::default()) as Arc<dyn CatalogProvider>;
92+
catalog_provider.into()
93+
}
94+
8595
fn schema_names(&self) -> HashSet<String> {
8696
self.catalog.schema_names().into_iter().collect()
8797
}
@@ -106,16 +116,6 @@ impl PyCatalog {
106116
})
107117
}
108118

109-
fn new_in_memory_schema(&mut self, name: &str) -> PyResult<()> {
110-
let schema = Arc::new(MemorySchemaProvider::new()) as Arc<dyn SchemaProvider>;
111-
let _ = self
112-
.catalog
113-
.register_schema(name, schema)
114-
.map_err(py_datafusion_err)?;
115-
116-
Ok(())
117-
}
118-
119119
fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> {
120120
let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? {
121121
let capsule = schema_provider
@@ -128,8 +128,11 @@ impl PyCatalog {
128128
let provider: ForeignSchemaProvider = provider.into();
129129
Arc::new(provider) as Arc<dyn SchemaProvider>
130130
} else {
131-
let provider = RustWrappedPySchemaProvider::new(schema_provider.into());
132-
Arc::new(provider) as Arc<dyn SchemaProvider>
131+
match schema_provider.extract::<PySchema>() {
132+
Ok(py_schema) => py_schema.schema,
133+
Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into()))
134+
as Arc<dyn SchemaProvider>,
135+
}
133136
};
134137

135138
let _ = self
@@ -165,6 +168,12 @@ impl PySchema {
165168
schema_provider.into()
166169
}
167170

171+
#[staticmethod]
172+
fn memory_schema() -> Self {
173+
let schema_provider = Arc::new(MemorySchemaProvider::default()) as Arc<dyn SchemaProvider>;
174+
schema_provider.into()
175+
}
176+
168177
#[getter]
169178
fn table_names(&self) -> HashSet<String> {
170179
self.schema.table_names().into_iter().collect()

src/context.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_f
4949
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
5050
use datafusion::arrow::pyarrow::PyArrowType;
5151
use datafusion::arrow::record_batch::RecordBatch;
52-
use datafusion::catalog::{CatalogProvider, MemoryCatalogProvider};
52+
use datafusion::catalog::CatalogProvider;
5353
use datafusion::common::TableReference;
5454
use datafusion::common::{exec_err, ScalarValue};
5555
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
@@ -612,13 +612,6 @@ impl PySessionContext {
612612
Ok(())
613613
}
614614

615-
pub fn new_in_memory_catalog(&mut self, name: &str) -> PyResult<()> {
616-
let catalog = Arc::new(MemoryCatalogProvider::new()) as Arc<dyn CatalogProvider>;
617-
let _ = self.ctx.register_catalog(name, catalog);
618-
619-
Ok(())
620-
}
621-
622615
pub fn register_catalog_provider(
623616
&mut self,
624617
name: &str,
@@ -635,8 +628,18 @@ impl PySessionContext {
635628
let provider: ForeignCatalogProvider = provider.into();
636629
Arc::new(provider) as Arc<dyn CatalogProvider>
637630
} else {
638-
let provider = RustWrappedPyCatalogProvider::new(provider.into());
639-
Arc::new(provider) as Arc<dyn CatalogProvider>
631+
println!("Provider has type {}", provider.get_type());
632+
match provider.extract::<PyCatalog>() {
633+
Ok(py_catalog) => {
634+
println!("registering an existing PyCatalog");
635+
py_catalog.catalog
636+
}
637+
Err(_) => {
638+
println!("registering a rust wrapped catalog provider");
639+
Arc::new(RustWrappedPyCatalogProvider::new(provider.into()))
640+
as Arc<dyn CatalogProvider>
641+
}
642+
}
640643
};
641644

642645
let _ = self.ctx.register_catalog(name, provider);

0 commit comments

Comments
 (0)