From c18f37c24a67a631130077d784dca057a42f0b39 Mon Sep 17 00:00:00 2001 From: Alex Chi Date: Fri, 22 Nov 2024 23:40:54 -0500 Subject: [PATCH] fix(bridge): handle information schema Signed-off-by: Alex Chi --- optd-datafusion-bridge/src/lib.rs | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/optd-datafusion-bridge/src/lib.rs b/optd-datafusion-bridge/src/lib.rs index c2ee6632..a328772b 100644 --- a/optd-datafusion-bridge/src/lib.rs +++ b/optd-datafusion-bridge/src/lib.rs @@ -14,6 +14,7 @@ use std::sync::{Arc, Mutex}; use async_trait::async_trait; use datafusion::arrow::datatypes::DataType; +use datafusion::catalog::information_schema::InformationSchemaProvider; use datafusion::catalog::CatalogList; use datafusion::error::Result; use datafusion::execution::context::{QueryPlanner, SessionState}; @@ -60,8 +61,24 @@ impl DatafusionCatalog { impl Catalog for DatafusionCatalog { fn get(&self, name: &str) -> optd_datafusion_repr::properties::schema::Schema { let catalog = self.catalog.catalog("datafusion").unwrap(); - let schema = catalog.schema("public").unwrap(); - let table = futures_lite::future::block_on(schema.table(name.as_ref())).unwrap(); + let (schema, table) = if let Some((schema, table)) = name.split_once('.') { + (schema, table) + } else { + ("public", name) + }; + let schema = if schema == "information_schema" { + // This is `INFORMATION_SCHEMA` but datafusion didn't expose this constant. + // Also, note that we didn't check `session_state.is_information_schema_enabled()` so this schema is always + // available. + Arc::new(InformationSchemaProvider::new(Arc::clone(&self.catalog))) + } else if let Some(schema) = catalog.schema(schema) { + schema + } else { + panic!("schema not found in datafusion catalog: {}", schema) + }; + let Some(table) = futures_lite::future::block_on(schema.table(table.as_ref())) else { + panic!("table not found in datafusion catalog: {}", name); + }; let schema = table.schema(); let fields = schema.fields(); let mut optd_fields = Vec::with_capacity(fields.len());