diff --git a/datafusion-postgres-cli/src/main.rs b/datafusion-postgres-cli/src/main.rs index eb76ac8..44e33fc 100644 --- a/datafusion-postgres-cli/src/main.rs +++ b/datafusion-postgres-cli/src/main.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use datafusion::execution::options::{ ArrowReadOptions, AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, }; -use datafusion::prelude::SessionContext; +use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_postgres::{serve, ServerOptions}; // Assuming the crate name is `datafusion_postgres` use structopt::StructOpt; @@ -179,7 +179,8 @@ async fn main() -> Result<(), Box> { let mut opts = Opt::from_args(); opts.include_directory_files()?; - let session_context = SessionContext::new(); + let session_config = SessionConfig::new().with_information_schema(true); + let session_context = SessionContext::new_with_config(session_config); setup_session_context(&session_context, &opts).await?; diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 8a3f8de..a77f92a 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -19,7 +19,6 @@ use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; use tokio::sync::Mutex; use crate::datatypes; -use crate::information_schema::{columns_df, schemata_df, tables_df}; use pgwire::error::{PgWireError, PgWireResult}; pub struct HandlerFactory(pub Arc); @@ -91,31 +90,6 @@ impl DfSessionService { Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream))) } - // Mock pg_namespace response - async fn mock_pg_namespace<'a>(&self) -> PgWireResult> { - let fields = Arc::new(vec![FieldInfo::new( - "nspname".to_string(), - None, - None, - Type::VARCHAR, - FieldFormat::Text, - )]); - - let fields_ref = fields.clone(); - let rows = self - .session_context - .catalog_names() - .into_iter() - .map(move |name| { - let mut encoder = pgwire::api::results::DataRowEncoder::new(fields_ref.clone()); - encoder.encode_field(&Some(&name))?; // Return catalog_name as a schema - encoder.finish() - }); - - let row_stream = futures::stream::iter(rows); - Ok(QueryResponse::new(fields.clone(), Box::pin(row_stream))) - } - async fn try_respond_set_statements<'a>( &self, query_lower: &str, @@ -189,39 +163,6 @@ impl DfSessionService { Ok(None) } } - - async fn try_respond_information_schema<'a>( - &self, - query_lower: &str, - ) -> PgWireResult>> { - if query_lower.contains("information_schema.schemata") { - let df = schemata_df(&self.session_context) - .await - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; - return Ok(Some(Response::Query(resp))); - } else if query_lower.contains("information_schema.tables") { - let df = tables_df(&self.session_context) - .await - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; - return Ok(Some(Response::Query(resp))); - } else if query_lower.contains("information_schema.columns") { - let df = columns_df(&self.session_context) - .await - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; - return Ok(Some(Response::Query(resp))); - } - - // Handle pg_catalog.pg_namespace for pgcli compatibility - if query_lower.contains("pg_catalog.pg_namespace") { - let resp = self.mock_pg_namespace().await?; - return Ok(Some(Response::Query(resp))); - } - - Ok(None) - } } #[async_trait] @@ -241,10 +182,6 @@ impl SimpleQueryHandler for DfSessionService { return Ok(vec![resp]); } - if let Some(resp) = self.try_respond_information_schema(&query_lower).await? { - return Ok(vec![resp]); - } - let df = self .session_context .sql(query) @@ -361,11 +298,8 @@ impl ExtendedQueryHandler for DfSessionService { return Ok(resp); } - if let Some(resp) = self.try_respond_information_schema(&query).await? { - return Ok(resp); - } - let (_, plan) = &portal.statement.statement; + let param_types = plan .get_parameter_types() .map_err(|e| PgWireError::ApiError(Box::new(e)))?; diff --git a/datafusion-postgres/src/information_schema.rs b/datafusion-postgres/src/information_schema.rs deleted file mode 100644 index 641a85c..0000000 --- a/datafusion-postgres/src/information_schema.rs +++ /dev/null @@ -1,134 +0,0 @@ -use std::sync::Arc; - -use datafusion::arrow::array::{BooleanArray, StringArray, UInt32Array}; -use datafusion::arrow::datatypes::{DataType, Field, Schema}; -use datafusion::arrow::record_batch::RecordBatch; -use datafusion::error::DataFusionError; -use datafusion::prelude::{DataFrame, SessionContext}; - -/// Creates a DataFrame for the `information_schema.schemata` view. -pub async fn schemata_df(ctx: &SessionContext) -> Result { - let catalog = ctx.catalog(ctx.catalog_names()[0].as_str()).unwrap(); // Use default catalog - let schema_names: Vec = catalog.schema_names(); - - let schema = Arc::new(Schema::new(vec![ - Field::new("catalog_name", DataType::Utf8, false), - Field::new("schema_name", DataType::Utf8, false), - Field::new("schema_owner", DataType::Utf8, true), // Nullable, not implemented - Field::new("default_character_set_catalog", DataType::Utf8, true), - Field::new("default_character_set_schema", DataType::Utf8, true), - Field::new("default_character_set_name", DataType::Utf8, true), - ])); - - let catalog_name = ctx.catalog_names()[0].clone(); // Use the first catalog name - let record_batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(StringArray::from(vec![catalog_name; schema_names.len()])), - Arc::new(StringArray::from(schema_names.clone())), // Clone to avoid move - Arc::new(StringArray::from(vec![None::; schema_names.len()])), - Arc::new(StringArray::from(vec![None::; schema_names.len()])), - Arc::new(StringArray::from(vec![None::; schema_names.len()])), - Arc::new(StringArray::from(vec![None::; schema_names.len()])), - ], - )?; - - ctx.read_batch(record_batch) // Use read_batch instead of read_table -} - -/// Creates a DataFrame for the `information_schema.tables` view. -pub async fn tables_df(ctx: &SessionContext) -> Result { - let catalog = ctx.catalog(ctx.catalog_names()[0].as_str()).unwrap(); // Use default catalog - let mut catalog_names = Vec::new(); - let mut schema_names = Vec::new(); - let mut table_names = Vec::new(); - let mut table_types = Vec::new(); - - for schema_name in catalog.schema_names() { - let schema = catalog.schema(&schema_name).unwrap(); - for table_name in schema.table_names() { - catalog_names.push(ctx.catalog_names()[0].clone()); // Use the first catalog name - schema_names.push(schema_name.clone()); - table_names.push(table_name.clone()); - table_types.push("BASE TABLE".to_string()); // DataFusion only has base tables - } - } - - let schema = Arc::new(Schema::new(vec![ - Field::new("table_catalog", DataType::Utf8, false), - Field::new("table_schema", DataType::Utf8, false), - Field::new("table_name", DataType::Utf8, false), - Field::new("table_type", DataType::Utf8, false), - ])); - - let record_batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(StringArray::from(catalog_names)), - Arc::new(StringArray::from(schema_names)), - Arc::new(StringArray::from(table_names)), - Arc::new(StringArray::from(table_types)), - ], - )?; - - ctx.read_batch(record_batch) // Use read_batch instead of read_table -} - -/// Creates a DataFrame for the `information_schema.columns` view. -pub async fn columns_df(ctx: &SessionContext) -> Result { - let catalog = ctx.catalog(ctx.catalog_names()[0].as_str()).unwrap(); // Use default catalog - let mut catalog_names = Vec::new(); - let mut schema_names = Vec::new(); - let mut table_names = Vec::new(); - let mut column_names = Vec::new(); - let mut ordinal_positions = Vec::new(); - let mut data_types = Vec::new(); - let mut is_nullables = Vec::new(); - - for schema_name in catalog.schema_names() { - let schema = catalog.schema(&schema_name).unwrap(); - for table_name in schema.table_names() { - let table = schema - .table(&table_name) - .await - .unwrap_or_else(|_| panic!("Table {} not found", table_name)) - .unwrap(); // Unwrap the Option after handling the Result - let schema_ref = table.schema(); // Store SchemaRef in a variable - let fields = schema_ref.fields(); // Borrow fields from the stored SchemaRef - for (idx, field) in fields.iter().enumerate() { - catalog_names.push(ctx.catalog_names()[0].clone()); // Use the first catalog name - schema_names.push(schema_name.clone()); - table_names.push(table_name.clone()); - column_names.push(field.name().clone()); - ordinal_positions.push((idx + 1) as u32); // 1-based index - data_types.push(field.data_type().to_string()); - is_nullables.push(field.is_nullable()); - } - } - } - - let schema = Arc::new(Schema::new(vec![ - Field::new("table_catalog", DataType::Utf8, false), - Field::new("table_schema", DataType::Utf8, false), - Field::new("table_name", DataType::Utf8, false), - Field::new("column_name", DataType::Utf8, false), - Field::new("ordinal_position", DataType::UInt32, false), - Field::new("data_type", DataType::Utf8, false), - Field::new("is_nullable", DataType::Boolean, false), - ])); - - let record_batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(StringArray::from(catalog_names)), - Arc::new(StringArray::from(schema_names)), - Arc::new(StringArray::from(table_names)), - Arc::new(StringArray::from(column_names)), - Arc::new(UInt32Array::from(ordinal_positions)), - Arc::new(StringArray::from(data_types)), - Arc::new(BooleanArray::from(is_nullables)), - ], - )?; - - ctx.read_batch(record_batch) // Use read_batch instead of read_table -} diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index d7e5a17..2094d17 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -1,7 +1,6 @@ mod datatypes; mod encoder; mod handlers; -mod information_schema; use std::sync::Arc;