diff --git a/.gitignore b/.gitignore index 54271a8..35a0847 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,8 @@ .envrc .vscode .aider* -/test_env \ No newline at end of file +/test_env + +# OS +.DS_Store +Thumbs.db \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index f6e55a8..02a83b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1510,6 +1510,7 @@ dependencies = [ "log", "pgwire", "postgres-types", + "regex", "rust_decimal", "rustls-pemfile", "rustls-pki-types", diff --git a/arrow-pg/src/datatypes/df.rs b/arrow-pg/src/datatypes/df.rs index af98b99..c85235e 100644 --- a/arrow-pg/src/datatypes/df.rs +++ b/arrow-pg/src/datatypes/df.rs @@ -104,7 +104,12 @@ where let value = portal.parameter::(i, &pg_type)?; deserialized_params.push(ScalarValue::Int64(value)); } - Type::TEXT | Type::VARCHAR => { + Type::TEXT => { + let value = portal.parameter::(i, &pg_type)?; + // Use Utf8View for TEXT type to match DataFusion's internal schema expectations + deserialized_params.push(ScalarValue::Utf8View(value)); + } + Type::VARCHAR => { let value = portal.parameter::(i, &pg_type)?; deserialized_params.push(ScalarValue::Utf8(value)); } @@ -236,7 +241,17 @@ where &DataType::Float64, ))); } - Type::TEXT_ARRAY | Type::VARCHAR_ARRAY => { + Type::TEXT_ARRAY => { + let value = portal.parameter::>>(i, &pg_type)?; + let scalar_values: Vec = value.map_or(Vec::new(), |v| { + v.into_iter().map(ScalarValue::Utf8View).collect() + }); + deserialized_params.push(ScalarValue::List(ScalarValue::new_list_nullable( + &scalar_values, + &DataType::Utf8View, + ))); + } + Type::VARCHAR_ARRAY => { let value = portal.parameter::>>(i, &pg_type)?; let scalar_values: Vec = value.map_or(Vec::new(), |v| { v.into_iter().map(ScalarValue::Utf8).collect() @@ -262,6 +277,18 @@ where // Store MAC addresses as strings for now deserialized_params.push(ScalarValue::Utf8(value)); } + Type::UNKNOWN => { + // For unknown types, try to deserialize as integer first, then fallback to text + // This handles cases like NULL arithmetic where DataFusion can't infer types + match portal.parameter::(i, &Type::INT4) { + Ok(value) => deserialized_params.push(ScalarValue::Int32(value)), + Err(_) => { + // Fallback to text if integer parsing fails + let value = portal.parameter::(i, &Type::TEXT)?; + deserialized_params.push(ScalarValue::Utf8View(value)); + } + } + } // TODO: add more advanced types (composite types, ranges, etc.) _ => { return Err(PgWireError::UserError(Box::new(ErrorInfo::new( diff --git a/datafusion-postgres-cli/src/main.rs b/datafusion-postgres-cli/src/main.rs index a05fd7d..fecefde 100644 --- a/datafusion-postgres-cli/src/main.rs +++ b/datafusion-postgres-cli/src/main.rs @@ -179,7 +179,7 @@ async fn setup_session_context( } // Register pg_catalog - setup_pg_catalog(session_context, "datafusion")?; + setup_pg_catalog(session_context, "datafusion").await?; Ok(()) } diff --git a/datafusion-postgres/Cargo.toml b/datafusion-postgres/Cargo.toml index c87f968..6160d63 100644 --- a/datafusion-postgres/Cargo.toml +++ b/datafusion-postgres/Cargo.toml @@ -22,6 +22,7 @@ futures.workspace = true getset = "0.1" log = "0.4" pgwire = { workspace = true, features = ["server-api-ring", "scram"] } +regex = "1" postgres-types.workspace = true rust_decimal.workspace = true tokio = { version = "1.47", features = ["sync", "net"] } diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 24d5beb..0b47486 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -42,9 +42,16 @@ pub struct HandlerFactory { } impl HandlerFactory { - pub fn new(session_context: Arc, auth_manager: Arc) -> Self { - let session_service = - Arc::new(DfSessionService::new(session_context, auth_manager.clone())); + pub fn new( + session_context: Arc, + auth_manager: Arc, + query_timeout: std::time::Duration, + ) -> Self { + let session_service = Arc::new(DfSessionService::new( + session_context, + auth_manager.clone(), + query_timeout, + )); HandlerFactory { session_service } } } @@ -70,12 +77,14 @@ pub struct DfSessionService { timezone: Arc>, transaction_state: Arc>, auth_manager: Arc, + query_timeout: std::time::Duration, } impl DfSessionService { pub fn new( session_context: Arc, auth_manager: Arc, + query_timeout: std::time::Duration, ) -> DfSessionService { let parser = Arc::new(Parser { session_context: session_context.clone(), @@ -86,6 +95,7 @@ impl DfSessionService { timezone: Arc::new(Mutex::new("UTC".to_string())), transaction_state: Arc::new(Mutex::new(TransactionState::None)), auth_manager, + query_timeout, } } @@ -305,8 +315,8 @@ impl DfSessionService { Ok(Some(Response::Query(resp))) } "show search_path" => { - let default_catalog = "datafusion"; - let resp = Self::mock_show_response("search_path", default_catalog)?; + let search_path = "\"$user\", public, pg_catalog"; + let resp = Self::mock_show_response("search_path", search_path)?; Ok(Some(Response::Query(resp))) } _ => Err(PgWireError::UserError(Box::new( @@ -321,6 +331,57 @@ impl DfSessionService { Ok(None) } } + + async fn try_respond_maintenance_statements<'a>( + &self, + query_lower: &str, + ) -> PgWireResult>> { + let query_trimmed = query_lower.trim().trim_end_matches(';'); + match query_trimmed { + // Commands that asyncpg commonly sends during cleanup/reset + "unlisten *" | "unlisten" => { + // UNLISTEN is for PostgreSQL LISTEN/NOTIFY feature + // Return success but do nothing + Ok(Some(Response::Execution(Tag::new("UNLISTEN")))) + } + "reset all" => { + // RESET ALL clears all session settings + // Return success but do nothing (we don't persist session settings anyway) + Ok(Some(Response::Execution(Tag::new("RESET")))) + } + "discard all" => { + // DISCARD ALL cleans up session state + // Return success but do nothing + Ok(Some(Response::Execution(Tag::new("DISCARD")))) + } + "deallocate all" => { + // DEALLOCATE ALL removes all prepared statements + // Return success but do nothing (we don't persist prepared statements) + Ok(Some(Response::Execution(Tag::new("DEALLOCATE")))) + } + _ if query_trimmed.starts_with("listen ") => { + // LISTEN is for PostgreSQL LISTEN/NOTIFY feature + // Return success but do nothing + Ok(Some(Response::Execution(Tag::new("LISTEN")))) + } + _ if query_trimmed.starts_with("unlisten ") => { + // UNLISTEN for specific channel + // Return success but do nothing + Ok(Some(Response::Execution(Tag::new("UNLISTEN")))) + } + _ if query_trimmed.starts_with("deallocate ") => { + // DEALLOCATE for specific prepared statement + // Return success but do nothing + Ok(Some(Response::Execution(Tag::new("DEALLOCATE")))) + } + _ if query_trimmed.starts_with("reset ") => { + // RESET for specific setting + // Return success but do nothing + Ok(Some(Response::Execution(Tag::new("RESET")))) + } + _ => Ok(None), + } + } } #[async_trait] @@ -360,6 +421,14 @@ impl SimpleQueryHandler for DfSessionService { return Ok(vec![resp]); } + // Handle PostgreSQL cleanup/maintenance commands that can be safely ignored + if let Some(resp) = self + .try_respond_maintenance_statements(&query_lower) + .await? + { + return Ok(vec![resp]); + } + // Check if we're in a failed transaction and block non-transaction commands { let state = self.transaction_state.lock().await; @@ -374,7 +443,23 @@ impl SimpleQueryHandler for DfSessionService { } } - let df_result = self.session_context.sql(query).await; + // Add query timeout for simple queries + let df_result = + match tokio::time::timeout(self.query_timeout, self.session_context.sql(query)).await { + Ok(result) => result, + Err(_) => { + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "57014".to_string(), // PostgreSQL query_canceled error code + format!( + "Query execution timeout after {} seconds", + self.query_timeout.as_secs() + ), + ), + ))); + } + }; // Handle query execution errors and transaction state let df = match df_result { @@ -509,19 +594,116 @@ impl ExtendedQueryHandler for DfSessionService { let (_, plan) = &portal.statement.statement; - let param_types = plan - .get_parameter_types() - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + // Enhanced parameter type inference with fallback for NULL + NULL scenarios + let param_types = match plan.get_parameter_types() { + Ok(types) => types, + Err(e) => { + // Check for specific planning errors related to NULL arithmetic operations + if matches!(e, datafusion::error::DataFusionError::Plan(_)) { + let error_msg = e.to_string(); + if error_msg + .contains("Cannot get result type for arithmetic operation Null + Null") + || error_msg.contains("Invalid arithmetic operation: Null + Null") + { + // Fallback: assume all parameters are integers for arithmetic operations + log::warn!("DataFusion type inference failed for arithmetic operation, using integer fallback"); + let param_count = portal.statement.parameter_types.len(); + std::collections::HashMap::from_iter((0..param_count).map(|i| { + ( + format!("${}", i + 1), + Some(datafusion::arrow::datatypes::DataType::Int32), + ) + })) + } else { + return Err(PgWireError::ApiError(Box::new(e))); + } + } else { + return Err(PgWireError::ApiError(Box::new(e))); + } + } + }; let param_values = df::deserialize_parameters(portal, &ordered_param_types(¶m_types))?; // Fixed: Use ¶m_types - let plan = plan - .clone() - .replace_params_with_values(¶m_values) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use ¶m_values - let dataframe = self - .session_context - .execute_logical_plan(plan) - .await - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + // Replace parameters with values, with automatic retry for type inference failures + let plan = match plan.clone().replace_params_with_values(¶m_values) { + Ok(plan) => plan, + Err(e) => { + // Check for specific planning errors related to NULL arithmetic operations + if matches!(e, datafusion::error::DataFusionError::Plan(_)) { + let error_msg = e.to_string(); + if error_msg + .contains("Cannot get result type for arithmetic operation Null + Null") + || error_msg.contains("Invalid arithmetic operation: Null + Null") + { + log::info!( + "Retrying query with enhanced type casting for arithmetic operations" + ); + + // Attempt to reparse the query with explicit type casting + let original_query = &portal.statement.statement.0; + let enhanced_query = enhance_query_with_type_casting(original_query); + + // Try to create a new plan with the enhanced query + match self.session_context.sql(&enhanced_query).await { + Ok(new_plan_df) => { + // Get the logical plan from the new dataframe + let new_plan = new_plan_df.logical_plan().clone(); + + // Try parameter substitution again with the new plan + match new_plan.replace_params_with_values(¶m_values) { + Ok(final_plan) => final_plan, + Err(_) => { + // If it still fails, return helpful error message + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "42804".to_string(), + "Cannot infer parameter types for arithmetic operation. Please use explicit type casting like $1::integer + $2::integer".to_string(), + ), + ))); + } + } + } + Err(_) => { + // If enhanced query fails, return helpful error message + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "42804".to_string(), + "Cannot infer parameter types for arithmetic operation. Please use explicit type casting like $1::integer + $2::integer".to_string(), + ), + ))); + } + } + } else { + return Err(PgWireError::ApiError(Box::new(e))); + } + } else { + return Err(PgWireError::ApiError(Box::new(e))); + } + } + }; + // Add query timeout to prevent long-running queries from hanging connections + let dataframe = match tokio::time::timeout( + self.query_timeout, + self.session_context.execute_logical_plan(plan), + ) + .await + { + Ok(result) => result.map_err(|e| PgWireError::ApiError(Box::new(e)))?, + Err(_) => { + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "57014".to_string(), // PostgreSQL query_canceled error code + format!( + "Query execution timeout after {} seconds", + self.query_timeout.as_secs() + ), + ), + ))); + } + }; let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?; Ok(Response::Query(resp)) } @@ -555,6 +737,59 @@ impl QueryParser for Parser { } } +/// Enhance a SQL query by adding type casting to parameters in arithmetic operations +/// and translating PostgreSQL-specific types to DataFusion-compatible types +/// This helps DataFusion's type inference when it encounters ambiguous parameter types +fn enhance_query_with_type_casting(query: &str) -> String { + use regex::Regex; + + let mut enhanced = query.to_string(); + + // First, handle PostgreSQL-specific type translations + // Translate 'oid' type to 'integer' (oid is a 32-bit unsigned integer in PostgreSQL) + let oid_pattern = Regex::new(r"\boid\b").unwrap(); + enhanced = oid_pattern.replace_all(&enhanced, "integer").to_string(); + + // Translate other PostgreSQL types to DataFusion equivalents + let regtype_pattern = Regex::new(r"\bregtype\b").unwrap(); + enhanced = regtype_pattern + .replace_all(&enhanced, "integer") + .to_string(); + + let regclass_pattern = Regex::new(r"\bregclass\b").unwrap(); + enhanced = regclass_pattern + .replace_all(&enhanced, "integer") + .to_string(); + + let regproc_pattern = Regex::new(r"\bregproc\b").unwrap(); + enhanced = regproc_pattern + .replace_all(&enhanced, "integer") + .to_string(); + + // Handle pg_catalog schema references more explicitly + let pg_catalog_pattern = Regex::new(r"\bpg_catalog\.(\w+)").unwrap(); + enhanced = pg_catalog_pattern.replace_all(&enhanced, "$1").to_string(); + + // Pattern to match arithmetic operations with parameters: $1 + $2, $1 - $2, etc. + let arithmetic_pattern = Regex::new(r"\$(\d+)\s*([+\-*/])\s*\$(\d+)").unwrap(); + + // Replace untyped parameters in arithmetic operations with integer-cast parameters + enhanced = arithmetic_pattern + .replace_all(&enhanced, "$$$1::integer $2 $$$3::integer") + .to_string(); + + // Pattern to match single parameters followed by arithmetic operators (avoiding lookaround) + let single_param_pattern = Regex::new(r"\$(\d+)(\s*[+\-*/=<>])").unwrap(); + + // Add integer casting to remaining untyped parameters in arithmetic contexts + enhanced = single_param_pattern + .replace_all(&enhanced, "$$$1::integer$2") + .to_string(); + + log::debug!("Enhanced query: {} -> {}", query, enhanced); + enhanced +} + fn ordered_param_types(types: &HashMap>) -> Vec> { // Datafusion stores the parameters as a map. In our case, the keys will be // `$1`, `$2` etc. The values will be the parameter types. @@ -562,3 +797,55 @@ fn ordered_param_types(types: &HashMap>) -> Vec, tls_key_path: Option, + query_timeout_seconds: u64, } impl ServerOptions { @@ -48,6 +49,7 @@ impl Default for ServerOptions { port: 5432, tls_cert_path: None, tls_key_path: None, + query_timeout_seconds: 300, // 5 minutes default } } } @@ -83,8 +85,13 @@ pub async fn serve( // Create authentication manager let auth_manager = Arc::new(AuthManager::new()); - // Create the handler factory with authentication - let factory = Arc::new(HandlerFactory::new(session_context, auth_manager)); + // Create the handler factory with authentication and timeout + let query_timeout = std::time::Duration::from_secs(opts.query_timeout_seconds); + let factory = Arc::new(HandlerFactory::new( + session_context, + auth_manager, + query_timeout, + )); serve_with_handlers(factory, opts).await } diff --git a/datafusion-postgres/src/pg_catalog.rs b/datafusion-postgres/src/pg_catalog.rs index 4e0a841..8d638b4 100644 --- a/datafusion-postgres/src/pg_catalog.rs +++ b/datafusion-postgres/src/pg_catalog.rs @@ -936,6 +936,40 @@ impl PgCatalogSchemaProvider { None, ); + // current_setting(setting_name) function + data.add_function( + 2077, + "current_setting", + 11, + 10, + 12, + 1.0, + 0.0, + 0, + 0, + "f", + false, + true, + true, + false, + "s", + "s", + 1, + 0, + 25, // returns TEXT + "25", // takes TEXT parameter + None, + None, + None, + None, + None, + "current_setting", + None, + None, + None, + None, + ); + data } } @@ -2020,12 +2054,110 @@ pub fn create_format_type_udf() -> ScalarUDF { ) } +pub fn create_current_setting_udf() -> ScalarUDF { + // Define the function implementation for current_setting(setting_name) + let func = move |args: &[ColumnarValue]| { + let args = ColumnarValue::values_to_arrays(args)?; + let setting_names = &args[0]; + + // Handle different setting name requests with reasonable defaults + let mut builder = StringBuilder::new(); + + for i in 0..setting_names.len() { + let setting_name = setting_names + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("Expected string array".to_string()))? + .value(i); + + // Provide reasonable defaults for common PostgreSQL settings + let value = match setting_name.to_lowercase().as_str() { + "server_version" => "16.0", // Match modern PostgreSQL version + "server_version_num" => "160000", + "client_encoding" => "UTF8", + "timezone" => "UTC", + "datestyle" => "ISO, MDY", + "default_transaction_isolation" => "read committed", + "application_name" => "datafusion-postgres", + "session_authorization" => "postgres", + "is_superuser" => "on", + "integer_datetimes" => "on", + "search_path" => "\"$user\", public, pg_catalog", + "standard_conforming_strings" => "on", + "synchronous_commit" => "on", + "wal_level" => "replica", + "max_connections" => "100", + "shared_preload_libraries" => "", + "log_statement" => "none", + "log_min_messages" => "warning", + "default_text_search_config" => "pg_catalog.english", + _ => "", // Return empty string for unknown settings + }; + + builder.append_value(value); + } + + let array: ArrayRef = Arc::new(builder.finish()); + Ok(ColumnarValue::Array(array)) + }; + + // Wrap the implementation in a scalar function + create_udf( + "current_setting", + vec![DataType::Utf8], + DataType::Utf8, + Volatility::Stable, + Arc::new(func), + ) +} + +pub fn create_set_config_udf() -> ScalarUDF { + // Define the function implementation for set_config(setting_name, new_value, is_local) + let func = move |args: &[ColumnarValue]| { + let args = ColumnarValue::values_to_arrays(args)?; + let _setting_names = &args[0]; // Setting name + let new_values = &args[1]; // New value + let _is_local = &args[2]; // Whether the setting is local to transaction + + // For asyncpg compatibility, we just return the new value that was "set" + // In a real PostgreSQL server, this would actually modify the setting + let mut builder = StringBuilder::new(); + + for i in 0..new_values.len() { + let new_value = new_values + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("Expected string array".to_string()))? + .value(i); + + // Just echo back the value that was "set" + builder.append_value(new_value); + } + + let array: ArrayRef = Arc::new(builder.finish()); + Ok(ColumnarValue::Array(array)) + }; + + // Wrap the implementation in a scalar function + create_udf( + "set_config", + vec![DataType::Utf8, DataType::Utf8, DataType::Boolean], + DataType::Utf8, + Volatility::Volatile, + Arc::new(func), + ) +} + /// Install pg_catalog and postgres UDFs to current `SessionContext` -pub fn setup_pg_catalog( +pub async fn setup_pg_catalog( session_context: &SessionContext, catalog_name: &str, ) -> Result<(), Box> { - let pg_catalog = PgCatalogSchemaProvider::new(session_context.state().catalog_list().clone()); + let pg_catalog = Arc::new(PgCatalogSchemaProvider::new( + session_context.state().catalog_list().clone(), + )); + + // Register in pg_catalog schema session_context .catalog(catalog_name) .ok_or_else(|| { @@ -2033,7 +2165,24 @@ pub fn setup_pg_catalog( "Catalog not found when registering pg_catalog: {catalog_name}" )) })? - .register_schema("pg_catalog", Arc::new(pg_catalog))?; + .register_schema("pg_catalog", pg_catalog.clone())?; + + // Also create individual pg_catalog tables in the public schema for asyncpg compatibility + // asyncpg often queries these tables without schema qualifiers + let pg_catalog_for_public = Arc::new(PgCatalogSchemaProvider::new( + session_context.state().catalog_list().clone(), + )); + + // Register all pg_catalog tables that asyncpg might need directly in public schema + for table_name in PG_CATALOG_TABLES { + // Register table directly in the current catalog's public namespace + let table_path = table_name.to_string(); + if let Ok(Some(table)) = pg_catalog_for_public.table(table_name).await { + session_context + .register_table(&table_path, table) + .map_err(Box::new)?; + } + } session_context.register_udf(create_current_schema_udf()); session_context.register_udf(create_current_schemas_udf()); @@ -2042,6 +2191,8 @@ pub fn setup_pg_catalog( session_context.register_udf(create_has_table_privilege_2param_udf()); session_context.register_udf(create_pg_table_is_visible()); session_context.register_udf(create_format_type_udf()); + session_context.register_udf(create_current_setting_udf()); + session_context.register_udf(create_set_config_udf()); Ok(()) }