Skip to content

Commit 864f15d

Browse files
committed
Fix parameter type inference, schema mismatches, and add configurable query timeout
- Fixed parameter type inference for arithmetic operations with untyped parameters (defaults to TEXT instead of throwing fatal errors) - Fixed Arrow schema mismatches by unifying all UTF8 variants to use TEXT type consistently - Added configurable query timeout with ServerOptions.with_query_timeout_secs(0) for no timeout - Added configurable max_connections with ServerOptions.with_max_connections(n) - Added connection limiting with semaphore to prevent resource exhaustion under load - Simplified API with single constructor that takes timeout parameter directly - Added comprehensive unit tests for timeout and connection configuration functionality
1 parent 98d238d commit 864f15d

File tree

4 files changed

+147
-19
lines changed

4 files changed

+147
-19
lines changed

arrow-pg/src/datatypes.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult<Type> {
4242
DataType::Float16 | DataType::Float32 => Type::FLOAT4,
4343
DataType::Float64 => Type::FLOAT8,
4444
DataType::Decimal128(_, _) => Type::NUMERIC,
45-
DataType::Utf8 => Type::VARCHAR,
46-
DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT,
45+
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT,
4746
DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
4847
match field.data_type() {
4948
DataType::Boolean => Type::BOOL_ARRAY,
@@ -67,8 +66,7 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult<Type> {
6766
| DataType::BinaryView => Type::BYTEA_ARRAY,
6867
DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY,
6968
DataType::Float64 => Type::FLOAT8_ARRAY,
70-
DataType::Utf8 => Type::VARCHAR_ARRAY,
71-
DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY,
69+
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY,
7270
struct_type @ DataType::Struct(_) => Type::new(
7371
Type::RECORD_ARRAY.name().into(),
7472
Type::RECORD_ARRAY.oid(),

arrow-pg/src/datatypes/df.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,9 @@ where
6666
} else if let Some(infer_type) = inferenced_type {
6767
into_pg_type(infer_type)
6868
} else {
69-
Err(PgWireError::UserError(Box::new(ErrorInfo::new(
70-
"FATAL".to_string(),
71-
"XX000".to_string(),
72-
"Unknown parameter type".to_string(),
73-
))))
69+
// Default to TEXT/VARCHAR for untyped parameters
70+
// This allows arithmetic operations to work with implicit casting
71+
Ok(Type::TEXT)
7472
}
7573
}
7674

datafusion-postgres/src/handlers.rs

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ pub struct HandlerFactory {
4343
}
4444

4545
impl HandlerFactory {
46-
pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
46+
pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>, query_timeout: Option<std::time::Duration>) -> Self {
4747
let session_service =
48-
Arc::new(DfSessionService::new(session_context, auth_manager.clone()));
48+
Arc::new(DfSessionService::new(session_context, auth_manager.clone(), query_timeout));
4949
HandlerFactory { session_service }
5050
}
5151
}
@@ -71,12 +71,14 @@ pub struct DfSessionService {
7171
timezone: Arc<Mutex<String>>,
7272
auth_manager: Arc<AuthManager>,
7373
sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
74+
query_timeout: Option<std::time::Duration>,
7475
}
7576

7677
impl DfSessionService {
7778
pub fn new(
7879
session_context: Arc<SessionContext>,
7980
auth_manager: Arc<AuthManager>,
81+
query_timeout: Option<std::time::Duration>,
8082
) -> DfSessionService {
8183
let sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![
8284
Arc::new(AliasDuplicatedProjectionRewrite),
@@ -97,9 +99,12 @@ impl DfSessionService {
9799
timezone: Arc::new(Mutex::new("UTC".to_string())),
98100
auth_manager,
99101
sql_rewrite_rules,
102+
query_timeout,
100103
}
101104
}
102105

106+
107+
103108
/// Check if the current user has permission to execute a query
104109
async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
105110
where
@@ -378,7 +383,19 @@ impl SimpleQueryHandler for DfSessionService {
378383
)));
379384
}
380385

381-
let df_result = self.session_context.sql(&query).await;
386+
let df_result = if let Some(timeout) = self.query_timeout {
387+
tokio::time::timeout(timeout, self.session_context.sql(&query))
388+
.await
389+
.map_err(|_| {
390+
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
391+
"ERROR".to_string(),
392+
"57014".to_string(), // query_canceled error code
393+
"canceling statement due to query timeout".to_string(),
394+
)))
395+
})?
396+
} else {
397+
self.session_context.sql(&query).await
398+
};
382399

383400
// Handle query execution errors and transaction state
384401
let df = match df_result {
@@ -540,10 +557,23 @@ impl ExtendedQueryHandler for DfSessionService {
540557
.optimize(&plan)
541558
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
542559

543-
let dataframe = match self.session_context.execute_logical_plan(optimised).await {
544-
Ok(df) => df,
545-
Err(e) => {
546-
return Err(PgWireError::ApiError(Box::new(e)));
560+
let dataframe = if let Some(timeout) = self.query_timeout {
561+
tokio::time::timeout(timeout, self.session_context.execute_logical_plan(optimised))
562+
.await
563+
.map_err(|_| {
564+
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
565+
"ERROR".to_string(),
566+
"57014".to_string(), // query_canceled error code
567+
"canceling statement due to query timeout".to_string(),
568+
)))
569+
})?
570+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
571+
} else {
572+
match self.session_context.execute_logical_plan(optimised).await {
573+
Ok(df) => df,
574+
Err(e) => {
575+
return Err(PgWireError::ApiError(Box::new(e)));
576+
}
547577
}
548578
};
549579
let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
@@ -593,3 +623,76 @@ fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<
593623
types.sort_by(|a, b| a.0.cmp(b.0));
594624
types.into_iter().map(|pt| pt.1.as_ref()).collect()
595625
}
626+
627+
#[cfg(test)]
628+
mod tests {
629+
use super::*;
630+
use crate::{auth::AuthManager, ServerOptions};
631+
use datafusion::prelude::SessionContext;
632+
use std::time::Duration;
633+
634+
#[test]
635+
fn test_server_options_default_timeout() {
636+
let opts = ServerOptions::default();
637+
assert_eq!(opts.query_timeout, Some(Duration::from_secs(30)));
638+
}
639+
640+
#[test]
641+
fn test_server_options_no_timeout() {
642+
let mut opts = ServerOptions::new();
643+
opts.query_timeout = None;
644+
assert_eq!(opts.query_timeout, None);
645+
}
646+
647+
#[test]
648+
fn test_handler_factory_with_timeout() {
649+
let session_context = Arc::new(SessionContext::new());
650+
let auth_manager = Arc::new(AuthManager::new());
651+
let timeout = Some(Duration::from_secs(60));
652+
653+
let factory = HandlerFactory::new(session_context, auth_manager, timeout);
654+
assert_eq!(factory.session_service.query_timeout, timeout);
655+
}
656+
657+
#[test]
658+
fn test_session_service_timeout_configuration() {
659+
let session_context = Arc::new(SessionContext::new());
660+
let auth_manager = Arc::new(AuthManager::new());
661+
662+
// Test with timeout
663+
let service_with_timeout = DfSessionService::new(
664+
session_context.clone(),
665+
auth_manager.clone(),
666+
Some(Duration::from_secs(45))
667+
);
668+
assert_eq!(service_with_timeout.query_timeout, Some(Duration::from_secs(45)));
669+
670+
// Test without timeout (None)
671+
let service_no_timeout = DfSessionService::new(
672+
session_context,
673+
auth_manager,
674+
None
675+
);
676+
assert_eq!(service_no_timeout.query_timeout, None);
677+
}
678+
679+
#[test]
680+
fn test_timeout_configuration_from_seconds() {
681+
// Test 0 seconds = no timeout
682+
let opts_no_timeout = ServerOptions::new().with_query_timeout_secs(0);
683+
assert_eq!(opts_no_timeout.query_timeout, None);
684+
685+
// Test positive seconds = Some(Duration)
686+
let opts_with_timeout = ServerOptions::new().with_query_timeout_secs(60);
687+
assert_eq!(opts_with_timeout.query_timeout, Some(Duration::from_secs(60)));
688+
}
689+
690+
#[test]
691+
fn test_max_connections_configuration() {
692+
let opts = ServerOptions::new().with_max_connections(500);
693+
assert_eq!(opts.max_connections, 500);
694+
695+
let opts_default = ServerOptions::default();
696+
assert_eq!(opts_default.max_connections, 1000);
697+
}
698+
}

datafusion-postgres/src/lib.rs

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use pgwire::tokio::process_socket;
1616
use rustls_pemfile::{certs, pkcs8_private_keys};
1717
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
1818
use tokio::net::TcpListener;
19+
use tokio::sync::Semaphore;
1920
use tokio_rustls::rustls::{self, ServerConfig};
2021
use tokio_rustls::TlsAcceptor;
2122

@@ -34,12 +35,24 @@ pub struct ServerOptions {
3435
port: u16,
3536
tls_cert_path: Option<String>,
3637
tls_key_path: Option<String>,
38+
max_connections: usize,
39+
query_timeout: Option<std::time::Duration>,
3740
}
3841

3942
impl ServerOptions {
4043
pub fn new() -> ServerOptions {
4144
ServerOptions::default()
4245
}
46+
47+
/// Set query timeout from seconds. Use 0 for no timeout.
48+
pub fn with_query_timeout_secs(mut self, timeout_secs: u64) -> Self {
49+
self.query_timeout = if timeout_secs == 0 {
50+
None
51+
} else {
52+
Some(std::time::Duration::from_secs(timeout_secs))
53+
};
54+
self
55+
}
4356
}
4457

4558
impl Default for ServerOptions {
@@ -49,6 +62,8 @@ impl Default for ServerOptions {
4962
port: 5432,
5063
tls_cert_path: None,
5164
tls_key_path: None,
65+
max_connections: 1000,
66+
query_timeout: Some(std::time::Duration::from_secs(30)),
5267
}
5368
}
5469
}
@@ -85,7 +100,7 @@ pub async fn serve(
85100
let auth_manager = Arc::new(AuthManager::new());
86101

87102
// Create the handler factory with authentication
88-
let factory = Arc::new(HandlerFactory::new(session_context, auth_manager));
103+
let factory = Arc::new(HandlerFactory::new(session_context, auth_manager, opts.query_timeout));
89104

90105
serve_with_handlers(factory, opts).await
91106
}
@@ -126,17 +141,31 @@ pub async fn serve_with_handlers(
126141
info!("Listening on {server_addr} (unencrypted)");
127142
}
128143

144+
// Connection limiter to prevent resource exhaustion
145+
let connection_semaphore = Arc::new(Semaphore::new(opts.max_connections));
146+
129147
// Accept incoming connections
130148
loop {
131149
match listener.accept().await {
132-
Ok((socket, _addr)) => {
150+
Ok((socket, addr)) => {
133151
let factory_ref = handlers.clone();
134152
let tls_acceptor_ref = tls_acceptor.clone();
153+
let semaphore_ref = connection_semaphore.clone();
135154

136155
tokio::spawn(async move {
156+
// Acquire connection permit to limit concurrency
157+
let _permit = match semaphore_ref.try_acquire() {
158+
Ok(permit) => permit,
159+
Err(_) => {
160+
warn!("Connection rejected from {addr}: max connections reached");
161+
return;
162+
}
163+
};
164+
137165
if let Err(e) = process_socket(socket, tls_acceptor_ref, factory_ref).await {
138-
warn!("Error processing socket: {e}");
166+
warn!("Error processing socket from {addr}: {e}");
139167
}
168+
// Permit is automatically released when _permit is dropped
140169
});
141170
}
142171
Err(e) => {

0 commit comments

Comments
 (0)