From 864f15d93b1ce30de615e38bdb4568066b894960 Mon Sep 17 00:00:00 2001 From: Oluwapeluwa Ibrahim Date: Fri, 5 Sep 2025 02:24:10 +0100 Subject: [PATCH 1/3] Fix parameter type inference, schema mismatches, and add configurable query timeout - Fixed parameter type inference for arithmetic operations with untyped parameters (defaults to TEXT instead of throwing fatal errors) - Fixed Arrow schema mismatches by unifying all UTF8 variants to use TEXT type consistently - Added configurable query timeout with ServerOptions.with_query_timeout_secs(0) for no timeout - Added configurable max_connections with ServerOptions.with_max_connections(n) - Added connection limiting with semaphore to prevent resource exhaustion under load - Simplified API with single constructor that takes timeout parameter directly - Added comprehensive unit tests for timeout and connection configuration functionality --- arrow-pg/src/datatypes.rs | 6 +- arrow-pg/src/datatypes/df.rs | 8 +- datafusion-postgres/src/handlers.rs | 117 ++++++++++++++++++++++++++-- datafusion-postgres/src/lib.rs | 35 ++++++++- 4 files changed, 147 insertions(+), 19 deletions(-) diff --git a/arrow-pg/src/datatypes.rs b/arrow-pg/src/datatypes.rs index 7af25f1..c3c6276 100644 --- a/arrow-pg/src/datatypes.rs +++ b/arrow-pg/src/datatypes.rs @@ -42,8 +42,7 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult { DataType::Float16 | DataType::Float32 => Type::FLOAT4, DataType::Float64 => Type::FLOAT8, DataType::Decimal128(_, _) => Type::NUMERIC, - DataType::Utf8 => Type::VARCHAR, - DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT, DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => { match field.data_type() { DataType::Boolean => Type::BOOL_ARRAY, @@ -67,8 +66,7 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult { | DataType::BinaryView => Type::BYTEA_ARRAY, DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY, DataType::Float64 => Type::FLOAT8_ARRAY, - DataType::Utf8 => Type::VARCHAR_ARRAY, - DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY, struct_type @ DataType::Struct(_) => Type::new( Type::RECORD_ARRAY.name().into(), Type::RECORD_ARRAY.oid(), diff --git a/arrow-pg/src/datatypes/df.rs b/arrow-pg/src/datatypes/df.rs index af98b99..741d31c 100644 --- a/arrow-pg/src/datatypes/df.rs +++ b/arrow-pg/src/datatypes/df.rs @@ -66,11 +66,9 @@ where } else if let Some(infer_type) = inferenced_type { into_pg_type(infer_type) } else { - Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "FATAL".to_string(), - "XX000".to_string(), - "Unknown parameter type".to_string(), - )))) + // Default to TEXT/VARCHAR for untyped parameters + // This allows arithmetic operations to work with implicit casting + Ok(Type::TEXT) } } diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 7d2ccb3..02dc9f6 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -43,9 +43,9 @@ pub struct HandlerFactory { } impl HandlerFactory { - pub fn new(session_context: Arc, auth_manager: Arc) -> Self { + pub fn new(session_context: Arc, auth_manager: Arc, query_timeout: Option) -> Self { let session_service = - Arc::new(DfSessionService::new(session_context, auth_manager.clone())); + Arc::new(DfSessionService::new(session_context, auth_manager.clone(), query_timeout)); HandlerFactory { session_service } } } @@ -71,12 +71,14 @@ pub struct DfSessionService { timezone: Arc>, auth_manager: Arc, sql_rewrite_rules: Vec>, + query_timeout: Option, } impl DfSessionService { pub fn new( session_context: Arc, auth_manager: Arc, + query_timeout: Option, ) -> DfSessionService { let sql_rewrite_rules: Vec> = vec![ Arc::new(AliasDuplicatedProjectionRewrite), @@ -97,9 +99,12 @@ impl DfSessionService { timezone: Arc::new(Mutex::new("UTC".to_string())), auth_manager, sql_rewrite_rules, + query_timeout, } } + + /// Check if the current user has permission to execute a query async fn check_query_permission(&self, client: &C, query: &str) -> PgWireResult<()> where @@ -378,7 +383,19 @@ impl SimpleQueryHandler for DfSessionService { ))); } - let df_result = self.session_context.sql(&query).await; + let df_result = if let Some(timeout) = self.query_timeout { + tokio::time::timeout(timeout, self.session_context.sql(&query)) + .await + .map_err(|_| { + PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "57014".to_string(), // query_canceled error code + "canceling statement due to query timeout".to_string(), + ))) + })? + } else { + self.session_context.sql(&query).await + }; // Handle query execution errors and transaction state let df = match df_result { @@ -540,10 +557,23 @@ impl ExtendedQueryHandler for DfSessionService { .optimize(&plan) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let dataframe = match self.session_context.execute_logical_plan(optimised).await { - Ok(df) => df, - Err(e) => { - return Err(PgWireError::ApiError(Box::new(e))); + let dataframe = if let Some(timeout) = self.query_timeout { + tokio::time::timeout(timeout, self.session_context.execute_logical_plan(optimised)) + .await + .map_err(|_| { + PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "57014".to_string(), // query_canceled error code + "canceling statement due to query timeout".to_string(), + ))) + })? + .map_err(|e| PgWireError::ApiError(Box::new(e)))? + } else { + match self.session_context.execute_logical_plan(optimised).await { + Ok(df) => df, + Err(e) => { + return Err(PgWireError::ApiError(Box::new(e))); + } } }; let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?; @@ -593,3 +623,76 @@ fn ordered_param_types(types: &HashMap>) -> Vec, tls_key_path: Option, + max_connections: usize, + query_timeout: Option, } impl ServerOptions { pub fn new() -> ServerOptions { ServerOptions::default() } + + /// Set query timeout from seconds. Use 0 for no timeout. + pub fn with_query_timeout_secs(mut self, timeout_secs: u64) -> Self { + self.query_timeout = if timeout_secs == 0 { + None + } else { + Some(std::time::Duration::from_secs(timeout_secs)) + }; + self + } } impl Default for ServerOptions { @@ -49,6 +62,8 @@ impl Default for ServerOptions { port: 5432, tls_cert_path: None, tls_key_path: None, + max_connections: 1000, + query_timeout: Some(std::time::Duration::from_secs(30)), } } } @@ -85,7 +100,7 @@ pub async fn serve( let auth_manager = Arc::new(AuthManager::new()); // Create the handler factory with authentication - let factory = Arc::new(HandlerFactory::new(session_context, auth_manager)); + let factory = Arc::new(HandlerFactory::new(session_context, auth_manager, opts.query_timeout)); serve_with_handlers(factory, opts).await } @@ -126,17 +141,31 @@ pub async fn serve_with_handlers( info!("Listening on {server_addr} (unencrypted)"); } + // Connection limiter to prevent resource exhaustion + let connection_semaphore = Arc::new(Semaphore::new(opts.max_connections)); + // Accept incoming connections loop { match listener.accept().await { - Ok((socket, _addr)) => { + Ok((socket, addr)) => { let factory_ref = handlers.clone(); let tls_acceptor_ref = tls_acceptor.clone(); + let semaphore_ref = connection_semaphore.clone(); tokio::spawn(async move { + // Acquire connection permit to limit concurrency + let _permit = match semaphore_ref.try_acquire() { + Ok(permit) => permit, + Err(_) => { + warn!("Connection rejected from {addr}: max connections reached"); + return; + } + }; + if let Err(e) = process_socket(socket, tls_acceptor_ref, factory_ref).await { - warn!("Error processing socket: {e}"); + warn!("Error processing socket from {addr}: {e}"); } + // Permit is automatically released when _permit is dropped }); } Err(e) => { From 8b95574e14f3bc922c572e506ced1a265c6f8397 Mon Sep 17 00:00:00 2001 From: Oluwapeluwa Ibrahim Date: Fri, 5 Sep 2025 02:36:55 +0100 Subject: [PATCH 2/3] Apply clippy fixes and format code - Fixed compiler warnings and clippy suggestions - Updated test files to use new DfSessionService constructor signature - Applied cargo fmt formatting across the codebase - All tests now pass successfully --- arrow-pg/src/encoder.rs | 2 +- datafusion-postgres/src/handlers.rs | 68 ++++++++++++++----------- datafusion-postgres/src/lib.rs | 6 ++- datafusion-postgres/src/sql.rs | 4 +- datafusion-postgres/tests/common/mod.rs | 2 +- datafusion-postgres/tests/dbeaver.rs | 2 +- 6 files changed, 49 insertions(+), 35 deletions(-) diff --git a/arrow-pg/src/encoder.rs b/arrow-pg/src/encoder.rs index 5490e1f..8ac10da 100644 --- a/arrow-pg/src/encoder.rs +++ b/arrow-pg/src/encoder.rs @@ -574,7 +574,7 @@ mod tests { { let mut bytes = BytesMut::new(); let _sql_text = value.to_sql_text(data_type, &mut bytes); - let string = String::from_utf8((&bytes).to_vec()); + let string = String::from_utf8(bytes.to_vec()); self.encoded_value = string.unwrap(); Ok(()) } diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 02dc9f6..219252b 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -43,9 +43,16 @@ pub struct HandlerFactory { } impl HandlerFactory { - pub fn new(session_context: Arc, auth_manager: Arc, query_timeout: Option) -> Self { - let session_service = - Arc::new(DfSessionService::new(session_context, auth_manager.clone(), query_timeout)); + pub fn new( + session_context: Arc, + auth_manager: Arc, + query_timeout: Option, + ) -> Self { + let session_service = Arc::new(DfSessionService::new( + session_context, + auth_manager.clone(), + query_timeout, + )); HandlerFactory { session_service } } } @@ -103,8 +110,6 @@ impl DfSessionService { } } - - /// Check if the current user has permission to execute a query async fn check_query_permission(&self, client: &C, query: &str) -> PgWireResult<()> where @@ -558,16 +563,19 @@ impl ExtendedQueryHandler for DfSessionService { .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let dataframe = if let Some(timeout) = self.query_timeout { - tokio::time::timeout(timeout, self.session_context.execute_logical_plan(optimised)) - .await - .map_err(|_| { - PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new( - "ERROR".to_string(), - "57014".to_string(), // query_canceled error code - "canceling statement due to query timeout".to_string(), - ))) - })? - .map_err(|e| PgWireError::ApiError(Box::new(e)))? + tokio::time::timeout( + timeout, + self.session_context.execute_logical_plan(optimised), + ) + .await + .map_err(|_| { + PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "57014".to_string(), // query_canceled error code + "canceling statement due to query timeout".to_string(), + ))) + })? + .map_err(|e| PgWireError::ApiError(Box::new(e)))? } else { match self.session_context.execute_logical_plan(optimised).await { Ok(df) => df, @@ -649,7 +657,7 @@ mod tests { let session_context = Arc::new(SessionContext::new()); let auth_manager = Arc::new(AuthManager::new()); let timeout = Some(Duration::from_secs(60)); - + let factory = HandlerFactory::new(session_context, auth_manager, timeout); assert_eq!(factory.session_service.query_timeout, timeout); } @@ -658,21 +666,20 @@ mod tests { fn test_session_service_timeout_configuration() { let session_context = Arc::new(SessionContext::new()); let auth_manager = Arc::new(AuthManager::new()); - + // Test with timeout let service_with_timeout = DfSessionService::new( - session_context.clone(), - auth_manager.clone(), + session_context.clone(), + auth_manager.clone(), + Some(Duration::from_secs(45)), + ); + assert_eq!( + service_with_timeout.query_timeout, Some(Duration::from_secs(45)) ); - assert_eq!(service_with_timeout.query_timeout, Some(Duration::from_secs(45))); - + // Test without timeout (None) - let service_no_timeout = DfSessionService::new( - session_context, - auth_manager, - None - ); + let service_no_timeout = DfSessionService::new(session_context, auth_manager, None); assert_eq!(service_no_timeout.query_timeout, None); } @@ -681,17 +688,20 @@ mod tests { // Test 0 seconds = no timeout let opts_no_timeout = ServerOptions::new().with_query_timeout_secs(0); assert_eq!(opts_no_timeout.query_timeout, None); - + // Test positive seconds = Some(Duration) let opts_with_timeout = ServerOptions::new().with_query_timeout_secs(60); - assert_eq!(opts_with_timeout.query_timeout, Some(Duration::from_secs(60))); + assert_eq!( + opts_with_timeout.query_timeout, + Some(Duration::from_secs(60)) + ); } #[test] fn test_max_connections_configuration() { let opts = ServerOptions::new().with_max_connections(500); assert_eq!(opts.max_connections, 500); - + let opts_default = ServerOptions::default(); assert_eq!(opts_default.max_connections, 1000); } diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index 54742cc..8355aeb 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -100,7 +100,11 @@ pub async fn serve( let auth_manager = Arc::new(AuthManager::new()); // Create the handler factory with authentication - let factory = Arc::new(HandlerFactory::new(session_context, auth_manager, opts.query_timeout)); + let factory = Arc::new(HandlerFactory::new( + session_context, + auth_manager, + opts.query_timeout, + )); serve_with_handlers(factory, opts).await } diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index 65c8021..2ae841f 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -296,7 +296,7 @@ struct RemoveUnsupportedTypesVisitor<'a> { unsupported_types: &'a HashSet, } -impl<'a> VisitorMut for RemoveUnsupportedTypesVisitor<'a> { +impl VisitorMut for RemoveUnsupportedTypesVisitor<'_> { type Break = (); fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { @@ -444,7 +444,7 @@ struct PrependUnqualifiedTableNameVisitor<'a> { table_names: &'a HashSet, } -impl<'a> VisitorMut for PrependUnqualifiedTableNameVisitor<'a> { +impl VisitorMut for PrependUnqualifiedTableNameVisitor<'_> { type Break = (); fn pre_visit_table_factor( diff --git a/datafusion-postgres/tests/common/mod.rs b/datafusion-postgres/tests/common/mod.rs index 7c7df52..2a56fa8 100644 --- a/datafusion-postgres/tests/common/mod.rs +++ b/datafusion-postgres/tests/common/mod.rs @@ -14,7 +14,7 @@ pub fn setup_handlers() -> DfSessionService { let session_context = SessionContext::new(); setup_pg_catalog(&session_context, "datafusion").expect("Failed to setup sesession context"); - DfSessionService::new(Arc::new(session_context), Arc::new(AuthManager::new())) + DfSessionService::new(Arc::new(session_context), Arc::new(AuthManager::new()), Some(std::time::Duration::from_secs(30))) } #[derive(Debug, Default)] diff --git a/datafusion-postgres/tests/dbeaver.rs b/datafusion-postgres/tests/dbeaver.rs index e132b91..24e5ab8 100644 --- a/datafusion-postgres/tests/dbeaver.rs +++ b/datafusion-postgres/tests/dbeaver.rs @@ -34,6 +34,6 @@ pub async fn test_dbeaver_startup_sql() { for query in DBEAVER_QUERIES { SimpleQueryHandler::do_query(&service, &mut client, query) .await - .expect(&format!("failed to run sql: {query}")); + .unwrap_or_else(|_| panic!("failed to run sql: {query}")); } } From 829dbc0b9a2b3d51fe66f2aae0a6a3175344e776 Mon Sep 17 00:00:00 2001 From: Oluwapeluwa Ibrahim Date: Fri, 5 Sep 2025 11:27:19 +0100 Subject: [PATCH 3/3] Apply clippy fixes and format code --- datafusion-postgres/tests/common/mod.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion-postgres/tests/common/mod.rs b/datafusion-postgres/tests/common/mod.rs index 2a56fa8..5f53588 100644 --- a/datafusion-postgres/tests/common/mod.rs +++ b/datafusion-postgres/tests/common/mod.rs @@ -14,7 +14,11 @@ pub fn setup_handlers() -> DfSessionService { let session_context = SessionContext::new(); setup_pg_catalog(&session_context, "datafusion").expect("Failed to setup sesession context"); - DfSessionService::new(Arc::new(session_context), Arc::new(AuthManager::new()), Some(std::time::Duration::from_secs(30))) + DfSessionService::new( + Arc::new(session_context), + Arc::new(AuthManager::new()), + Some(std::time::Duration::from_secs(30)), + ) } #[derive(Debug, Default)]