diff --git a/Cargo.lock b/Cargo.lock index 28f9d33..572174a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -380,7 +380,7 @@ checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -480,7 +480,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.96", + "syn 2.0.100", "which", ] @@ -559,7 +559,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", "syn_derive", ] @@ -1242,7 +1242,7 @@ checksum = "4800e1ff7ecf8f310887e9b54c9c444b8e215ccbc7b21c2f244cfae373b1ece7" dependencies = [ "datafusion-expr", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -1353,11 +1353,14 @@ dependencies = [ name = "datafusion-postgres" version = "0.3.0" dependencies = [ + "arrow", "async-trait", "chrono", "datafusion", "futures", + "log", "pgwire", + "tokio", ] [[package]] @@ -1396,7 +1399,7 @@ checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -1418,7 +1421,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -1564,7 +1567,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -1867,7 +1870,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -1978,7 +1981,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -2332,9 +2335,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.19.0" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" [[package]] name = "ordered-float" @@ -2563,7 +2566,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" dependencies = [ "proc-macro2", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -2705,7 +2708,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -2966,7 +2969,7 @@ checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -3052,7 +3055,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -3096,7 +3099,7 @@ checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -3175,7 +3178,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -3197,9 +3200,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.96" +version = "2.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" +checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" dependencies = [ "proc-macro2", "quote", @@ -3215,7 +3218,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -3226,7 +3229,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -3283,7 +3286,7 @@ checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -3294,7 +3297,7 @@ checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -3368,7 +3371,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -3431,7 +3434,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -3470,7 +3473,7 @@ checksum = "f9534daa9fd3ed0bd911d462a37f172228077e7abf18c18a5f67199d959205f8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -3617,7 +3620,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", "wasm-bindgen-shared", ] @@ -3639,7 +3642,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3840,7 +3843,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", "synstructure", ] @@ -3862,7 +3865,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -3882,7 +3885,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", "synstructure", ] @@ -3903,7 +3906,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] @@ -3925,7 +3928,7 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.100", ] [[package]] diff --git a/datafusion-postgres-cli/src/main.rs b/datafusion-postgres-cli/src/main.rs index a42317c..acdee6d 100644 --- a/datafusion-postgres-cli/src/main.rs +++ b/datafusion-postgres-cli/src/main.rs @@ -4,7 +4,7 @@ use datafusion::execution::options::{ ArrowReadOptions, AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, }; use datafusion::prelude::SessionContext; -use datafusion_postgres::{DfSessionService, HandlerFactory}; +use datafusion_postgres::{DfSessionService, HandlerFactory}; // Assuming the crate name is `datafusion_postgres` use pgwire::tokio::process_socket; use structopt::StructOpt; use tokio::net::TcpListener; @@ -12,7 +12,7 @@ use tokio::net::TcpListener; #[derive(Debug, StructOpt)] #[structopt( name = "datafusion-postgres", - about = "A postgres interface for datatfusion. Serve any CSV/JSON/Arrow files as tables." + about = "A postgres interface for datafusion. Serve any CSV/JSON/Arrow files as tables." )] struct Opt { /// CSV files to register as table, using syntax `table_name:file_path` @@ -45,27 +45,30 @@ fn parse_table_def(table_def: &str) -> (&str, &str) { } #[tokio::main] -async fn main() { +async fn main() -> Result<(), Box> { let opts = Opt::from_args(); let session_context = SessionContext::new(); + // Register CSV tables for (table_name, table_path) in opts.csv_tables.iter().map(|s| parse_table_def(s.as_ref())) { session_context .register_csv(table_name, table_path, CsvReadOptions::default()) .await - .unwrap_or_else(|e| panic!("Failed to register table: {table_name}, {e}")); + .map_err(|e| format!("Failed to register CSV table '{}': {}", table_name, e))?; println!("Loaded {} as table {}", table_path, table_name); } + // Register JSON tables for (table_name, table_path) in opts.json_tables.iter().map(|s| parse_table_def(s.as_ref())) { session_context .register_json(table_name, table_path, NdJsonReadOptions::default()) .await - .unwrap_or_else(|e| panic!("Failed to register table: {table_name}, {e}")); + .map_err(|e| format!("Failed to register JSON table '{}': {}", table_name, e))?; println!("Loaded {} as table {}", table_path, table_name); } + // Register Arrow tables for (table_name, table_path) in opts .arrow_tables .iter() @@ -74,10 +77,11 @@ async fn main() { session_context .register_arrow(table_name, table_path, ArrowReadOptions::default()) .await - .unwrap_or_else(|e| panic!("Failed to register table: {table_name}, {e}")); + .map_err(|e| format!("Failed to register Arrow table '{}': {}", table_name, e))?; println!("Loaded {} as table {}", table_path, table_name); } + // Register Parquet tables for (table_name, table_path) in opts .parquet_tables .iter() @@ -86,29 +90,46 @@ async fn main() { session_context .register_parquet(table_name, table_path, ParquetReadOptions::default()) .await - .unwrap_or_else(|e| panic!("Failed to register table: {table_name}, {e}")); + .map_err(|e| format!("Failed to register Parquet table '{}': {}", table_name, e))?; println!("Loaded {} as table {}", table_path, table_name); } + // Register Avro tables for (table_name, table_path) in opts.avro_tables.iter().map(|s| parse_table_def(s.as_ref())) { session_context .register_avro(table_name, table_path, AvroReadOptions::default()) .await - .unwrap_or_else(|e| panic!("Failed to register table: {table_name}, {e}")); + .map_err(|e| format!("Failed to register Avro table '{}': {}", table_name, e))?; println!("Loaded {} as table {}", table_path, table_name); } + // Get the first catalog name from the session context + let catalog_name = session_context + .catalog_names() // Fixed: Removed .catalog_list() + .first() + .cloned(); + + // Create the handler factory with the session context and catalog name let factory = Arc::new(HandlerFactory(Arc::new(DfSessionService::new( session_context, + catalog_name, )))); + // Bind to the specified host and port let server_addr = format!("{}:{}", opts.host, opts.port); - let listener = TcpListener::bind(&server_addr).await.unwrap(); - println!("Listening to {}", server_addr); + let listener = TcpListener::bind(&server_addr).await?; + println!("Listening on {}", server_addr); + + // Accept incoming connections loop { - let incoming_socket = listener.accept().await.unwrap(); + let (socket, addr) = listener.accept().await?; let factory_ref = factory.clone(); + println!("Accepted connection from {}", addr); - tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await }); + tokio::spawn(async move { + if let Err(e) = process_socket(socket, None, factory_ref).await { + eprintln!("Error processing socket: {}", e); + } + }); } } diff --git a/datafusion-postgres/Cargo.toml b/datafusion-postgres/Cargo.toml index 1a4ae7d..1787d7d 100644 --- a/datafusion-postgres/Cargo.toml +++ b/datafusion-postgres/Cargo.toml @@ -18,6 +18,9 @@ readme = "../README.md" [dependencies] pgwire = { workspace = true } datafusion = { workspace = true } +tokio = { version = "1.0", features = ["sync"] } +arrow = "54.2.0" futures = "0.3" async-trait = "0.1" +log = "0.4" chrono = { version = "0.4", features = ["std"] } diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 6067e37..65f741a 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -1,21 +1,26 @@ use std::collections::HashMap; use std::sync::Arc; +use arrow::datatypes::DataType; use async_trait::async_trait; -use datafusion::arrow::datatypes::DataType; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::*; use pgwire::api::auth::noop::NoopStartupHandler; use pgwire::api::copy::NoopCopyHandler; use pgwire::api::portal::{Format, Portal}; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; -use pgwire::api::results::{DescribePortalResponse, DescribeStatementResponse, Response}; +use pgwire::api::results::{ + DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo, QueryResponse, + Response, Tag, +}; use pgwire::api::stmt::QueryParser; use pgwire::api::stmt::StoredStatement; use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; -use pgwire::error::{PgWireError, PgWireResult}; +use tokio::sync::Mutex; -use crate::datatypes::{self, into_pg_type}; +use crate::datatypes; +use crate::information_schema::{columns_df, schemata_df, tables_df}; +use pgwire::error::{PgWireError, PgWireResult}; pub struct HandlerFactory(pub Arc); @@ -52,19 +57,69 @@ impl PgWireServerHandlers for HandlerFactory { pub struct DfSessionService { session_context: Arc, parser: Arc, + timezone: Arc>, + catalog_name: String, } impl DfSessionService { - pub fn new(session_context: SessionContext) -> DfSessionService { + pub fn new(session_context: SessionContext, catalog_name: Option) -> DfSessionService { let session_context = Arc::new(session_context); let parser = Arc::new(Parser { session_context: session_context.clone(), }); + let catalog_name = catalog_name.unwrap_or_else(|| { + session_context + .catalog_names() + .first() + .cloned() + .unwrap_or_else(|| "datafusion".to_string()) + }); DfSessionService { session_context, parser, + timezone: Arc::new(Mutex::new("UTC".to_string())), + catalog_name, } } + + fn mock_show_response<'a>(name: &str, value: &str) -> PgWireResult> { + let fields = vec![FieldInfo::new( + name.to_string(), + None, + None, + Type::VARCHAR, + FieldFormat::Text, + )]; + + let row = { + let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone())); + encoder.encode_field(&Some(value))?; + encoder.finish() + }; + + let row_stream = futures::stream::once(async move { row }); + Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream))) + } + + // Mock pg_namespace response + async fn mock_pg_namespace<'a>(&self) -> PgWireResult> { + let fields = vec![FieldInfo::new( + "nspname".to_string(), + None, + None, + Type::VARCHAR, + FieldFormat::Text, + )]; + + let row = { + let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone())); + encoder.encode_field(&Some(&self.catalog_name))?; // Return catalog_name as a schema + encoder.finish() + }; + + let row_stream = futures::stream::once(async move { row }); + Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream))) + } } #[async_trait] @@ -77,45 +132,113 @@ impl SimpleQueryHandler for DfSessionService { where C: ClientInfo + Unpin + Send + Sync, { - let ctx = &self.session_context; - let df = ctx - .sql(query) - .await - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let query_lower = query.to_lowercase().trim().to_string(); + log::debug!("Received query: {}", query); // Log the query for debugging + + if query_lower.starts_with("set time zone") { + let parts: Vec<&str> = query_lower.split_whitespace().collect(); + if parts.len() >= 4 { + let tz = parts[3].trim_matches('"'); + let mut timezone = self.timezone.lock().await; + *timezone = tz.to_string(); + return Ok(vec![Response::Execution(Tag::new("SET"))]); + } + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "42601".to_string(), + "Invalid SET TIME ZONE syntax".to_string(), + ), + ))); + } - let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; - Ok(vec![Response::Query(resp)]) - } -} + if query_lower.starts_with("show ") { + match query_lower.as_str() { + "show time zone" => { + let timezone = self.timezone.lock().await.clone(); + let resp = Self::mock_show_response("TimeZone", &timezone)?; + return Ok(vec![Response::Query(resp)]); + } + "show server_version" => { + let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?; + return Ok(vec![Response::Query(resp)]); + } + "show transaction_isolation" => { + let resp = + Self::mock_show_response("transaction_isolation", "read uncommitted")?; + return Ok(vec![Response::Query(resp)]); + } + "show catalogs" => { + let catalogs = self.session_context.catalog_names(); + let value = catalogs.join(", "); + let resp = Self::mock_show_response("Catalogs", &value)?; + return Ok(vec![Response::Query(resp)]); + } + "show search_path" => { + let resp = Self::mock_show_response("search_path", &self.catalog_name)?; + return Ok(vec![Response::Query(resp)]); + } + _ => { + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "42704".to_string(), + format!("Unrecognized SHOW command: {}", query), + ), + ))); + } + } + } -pub struct Parser { - session_context: Arc, -} + if query_lower.contains("information_schema.schemata") { + let df = schemata_df(&self.session_context) + .await + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; + return Ok(vec![Response::Query(resp)]); + } else if query_lower.contains("information_schema.tables") { + let df = tables_df(&self.session_context) + .await + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; + return Ok(vec![Response::Query(resp)]); + } else if query_lower.contains("information_schema.columns") { + let df = columns_df(&self.session_context) + .await + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; + return Ok(vec![Response::Query(resp)]); + } -#[async_trait] -impl QueryParser for Parser { - type Statement = LogicalPlan; + // Handle pg_catalog.pg_namespace for pgcli compatibility + if query_lower.contains("pg_catalog.pg_namespace") { + let resp = self.mock_pg_namespace().await?; + return Ok(vec![Response::Query(resp)]); + } - async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult { - let context = &self.session_context; - let state = context.state(); + let ctx = &self.session_context; + let qualified_query = if !query_lower.contains(&format!("{}.", self.catalog_name)) + && !query_lower.contains("information_schema") + && !query_lower.contains("pg_catalog") + && query_lower.contains("from") + { + query.replace(" FROM ", &format!(" FROM {}.", self.catalog_name)) + } else { + query.to_string() + }; - let logical_plan = state - .create_logical_plan(sql) + let df = ctx + .sql(&qualified_query) .await .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let optimised = state - .optimize(&logical_plan) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - - Ok(optimised) + let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; + Ok(vec![Response::Query(resp)]) } } #[async_trait] impl ExtendedQueryHandler for DfSessionService { type Statement = LogicalPlan; - type QueryParser = Parser; fn query_parser(&self) -> Arc { @@ -131,7 +254,6 @@ impl ExtendedQueryHandler for DfSessionService { C: ClientInfo + Unpin + Send + Sync, { let plan = &target.statement; - let schema = plan.schema(); let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), &Format::UnifiedBinary)?; let params = plan @@ -140,8 +262,9 @@ impl ExtendedQueryHandler for DfSessionService { let mut param_types = Vec::with_capacity(params.len()); for param_type in ordered_param_types(¶ms).iter() { + // Fixed: Use ¶ms if let Some(datatype) = param_type { - let pgtype = into_pg_type(datatype)?; + let pgtype = datatypes::into_pg_type(datatype)?; param_types.push(pgtype); } else { param_types.push(Type::UNKNOWN); @@ -176,35 +299,123 @@ impl ExtendedQueryHandler for DfSessionService { where C: ClientInfo + Unpin + Send + Sync, { - let plan = &portal.statement.statement; + let query = portal + .statement + .statement + .to_string() + .to_lowercase() + .trim() + .to_string(); + log::debug!("Received extended query: {}", query); // Log for debugging + + if query.starts_with("show ") { + match query.as_str() { + "show time zone" => { + let timezone = self.timezone.lock().await.clone(); + let resp = Self::mock_show_response("TimeZone", &timezone)?; + return Ok(Response::Query(resp)); + } + "show server_version" => { + let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?; + return Ok(Response::Query(resp)); + } + "show transaction_isolation" => { + let resp = + Self::mock_show_response("transaction_isolation", "read uncommitted")?; + return Ok(Response::Query(resp)); + } + "show catalogs" => { + let catalogs = self.session_context.catalog_names(); + let value = catalogs.join(", "); + let resp = Self::mock_show_response("Catalogs", &value)?; + return Ok(Response::Query(resp)); + } + "show search_path" => { + let resp = Self::mock_show_response("search_path", &self.catalog_name)?; + return Ok(Response::Query(resp)); + } + _ => { + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "42704".to_string(), + format!("Unrecognized SHOW command: {}", query), + ), + ))); + } + } + } + + if query.contains("information_schema.schemata") { + let df = schemata_df(&self.session_context) + .await + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let resp = datatypes::encode_dataframe(df, &portal.result_column_format).await?; + return Ok(Response::Query(resp)); + } else if query.contains("information_schema.tables") { + let df = tables_df(&self.session_context) + .await + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let resp = datatypes::encode_dataframe(df, &portal.result_column_format).await?; + return Ok(Response::Query(resp)); + } else if query.contains("information_schema.columns") { + let df = columns_df(&self.session_context) + .await + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let resp = datatypes::encode_dataframe(df, &portal.result_column_format).await?; + return Ok(Response::Query(resp)); + } + + if query.contains("pg_catalog.pg_namespace") { + let resp = self.mock_pg_namespace().await?; + return Ok(Response::Query(resp)); + } + let plan = &portal.statement.statement; let param_types = plan .get_parameter_types() .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let param_values = - datatypes::deserialize_parameters(portal, &ordered_param_types(¶m_types))?; - + datatypes::deserialize_parameters(portal, &ordered_param_types(¶m_types))?; // Fixed: Use ¶m_types let plan = plan .clone() .replace_params_with_values(¶m_values) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use ¶m_values let dataframe = self .session_context .execute_logical_plan(plan) .await .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let resp = datatypes::encode_dataframe(dataframe, &portal.result_column_format).await?; Ok(Response::Query(resp)) } } +pub struct Parser { + session_context: Arc, +} + +#[async_trait] +impl QueryParser for Parser { + type Statement = LogicalPlan; + + async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult { + let context = &self.session_context; + let state = context.state(); + let logical_plan = state + .create_logical_plan(sql) + .await + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let optimised = state + .optimize(&logical_plan) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + Ok(optimised) + } +} + fn ordered_param_types(types: &HashMap>) -> Vec> { // Datafusion stores the parameters as a map. In our case, the keys will be // `$1`, `$2` etc. The values will be the parameter types. - let mut types = types.iter().collect::>(); types.sort_by(|a, b| a.0.cmp(b.0)); types.into_iter().map(|pt| pt.1.as_ref()).collect() diff --git a/datafusion-postgres/src/information_schema.rs b/datafusion-postgres/src/information_schema.rs new file mode 100644 index 0000000..641a85c --- /dev/null +++ b/datafusion-postgres/src/information_schema.rs @@ -0,0 +1,134 @@ +use std::sync::Arc; + +use datafusion::arrow::array::{BooleanArray, StringArray, UInt32Array}; +use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::error::DataFusionError; +use datafusion::prelude::{DataFrame, SessionContext}; + +/// Creates a DataFrame for the `information_schema.schemata` view. +pub async fn schemata_df(ctx: &SessionContext) -> Result { + let catalog = ctx.catalog(ctx.catalog_names()[0].as_str()).unwrap(); // Use default catalog + let schema_names: Vec = catalog.schema_names(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("catalog_name", DataType::Utf8, false), + Field::new("schema_name", DataType::Utf8, false), + Field::new("schema_owner", DataType::Utf8, true), // Nullable, not implemented + Field::new("default_character_set_catalog", DataType::Utf8, true), + Field::new("default_character_set_schema", DataType::Utf8, true), + Field::new("default_character_set_name", DataType::Utf8, true), + ])); + + let catalog_name = ctx.catalog_names()[0].clone(); // Use the first catalog name + let record_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec![catalog_name; schema_names.len()])), + Arc::new(StringArray::from(schema_names.clone())), // Clone to avoid move + Arc::new(StringArray::from(vec![None::; schema_names.len()])), + Arc::new(StringArray::from(vec![None::; schema_names.len()])), + Arc::new(StringArray::from(vec![None::; schema_names.len()])), + Arc::new(StringArray::from(vec![None::; schema_names.len()])), + ], + )?; + + ctx.read_batch(record_batch) // Use read_batch instead of read_table +} + +/// Creates a DataFrame for the `information_schema.tables` view. +pub async fn tables_df(ctx: &SessionContext) -> Result { + let catalog = ctx.catalog(ctx.catalog_names()[0].as_str()).unwrap(); // Use default catalog + let mut catalog_names = Vec::new(); + let mut schema_names = Vec::new(); + let mut table_names = Vec::new(); + let mut table_types = Vec::new(); + + for schema_name in catalog.schema_names() { + let schema = catalog.schema(&schema_name).unwrap(); + for table_name in schema.table_names() { + catalog_names.push(ctx.catalog_names()[0].clone()); // Use the first catalog name + schema_names.push(schema_name.clone()); + table_names.push(table_name.clone()); + table_types.push("BASE TABLE".to_string()); // DataFusion only has base tables + } + } + + let schema = Arc::new(Schema::new(vec![ + Field::new("table_catalog", DataType::Utf8, false), + Field::new("table_schema", DataType::Utf8, false), + Field::new("table_name", DataType::Utf8, false), + Field::new("table_type", DataType::Utf8, false), + ])); + + let record_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(catalog_names)), + Arc::new(StringArray::from(schema_names)), + Arc::new(StringArray::from(table_names)), + Arc::new(StringArray::from(table_types)), + ], + )?; + + ctx.read_batch(record_batch) // Use read_batch instead of read_table +} + +/// Creates a DataFrame for the `information_schema.columns` view. +pub async fn columns_df(ctx: &SessionContext) -> Result { + let catalog = ctx.catalog(ctx.catalog_names()[0].as_str()).unwrap(); // Use default catalog + let mut catalog_names = Vec::new(); + let mut schema_names = Vec::new(); + let mut table_names = Vec::new(); + let mut column_names = Vec::new(); + let mut ordinal_positions = Vec::new(); + let mut data_types = Vec::new(); + let mut is_nullables = Vec::new(); + + for schema_name in catalog.schema_names() { + let schema = catalog.schema(&schema_name).unwrap(); + for table_name in schema.table_names() { + let table = schema + .table(&table_name) + .await + .unwrap_or_else(|_| panic!("Table {} not found", table_name)) + .unwrap(); // Unwrap the Option after handling the Result + let schema_ref = table.schema(); // Store SchemaRef in a variable + let fields = schema_ref.fields(); // Borrow fields from the stored SchemaRef + for (idx, field) in fields.iter().enumerate() { + catalog_names.push(ctx.catalog_names()[0].clone()); // Use the first catalog name + schema_names.push(schema_name.clone()); + table_names.push(table_name.clone()); + column_names.push(field.name().clone()); + ordinal_positions.push((idx + 1) as u32); // 1-based index + data_types.push(field.data_type().to_string()); + is_nullables.push(field.is_nullable()); + } + } + } + + let schema = Arc::new(Schema::new(vec![ + Field::new("table_catalog", DataType::Utf8, false), + Field::new("table_schema", DataType::Utf8, false), + Field::new("table_name", DataType::Utf8, false), + Field::new("column_name", DataType::Utf8, false), + Field::new("ordinal_position", DataType::UInt32, false), + Field::new("data_type", DataType::Utf8, false), + Field::new("is_nullable", DataType::Boolean, false), + ])); + + let record_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(catalog_names)), + Arc::new(StringArray::from(schema_names)), + Arc::new(StringArray::from(table_names)), + Arc::new(StringArray::from(column_names)), + Arc::new(UInt32Array::from(ordinal_positions)), + Arc::new(StringArray::from(data_types)), + Arc::new(BooleanArray::from(is_nullables)), + ], + )?; + + ctx.read_batch(record_batch) // Use read_batch instead of read_table +} diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index c08044c..c560526 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -1,4 +1,5 @@ mod datatypes; mod handlers; +mod information_schema; pub use handlers::{DfSessionService, HandlerFactory, Parser};