Skip to content

Commit 213ac23

Browse files
committed
Refactor statement timeout to be truly per-session using ClientInfo metadata
- Remove statement_timeout field from DfSessionService (was application-scoped) - Store timeout in ClientInfo::metadata() using METADATA_STATEMENT_TIMEOUT key - Add helper functions get_statement_timeout() and set_statement_timeout() - Update SET/SHOW statement_timeout handlers to use client metadata - Update query execution logic to read timeout from client session - Add comprehensive MockClient for testing session-specific behavior - Now each PostgreSQL session has its own independent timeout setting This follows the PostgreSQL standard where statement_timeout is a session variable, not a server-wide configuration.
1 parent 1bca5e0 commit 213ac23

File tree

1 file changed

+122
-30
lines changed

1 file changed

+122
-30
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 122 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ use tokio::sync::Mutex;
3131
use arrow_pg::datatypes::df;
3232
use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
3333

34+
// Metadata keys for session-level settings
35+
const METADATA_STATEMENT_TIMEOUT: &str = "statement_timeout_ms";
36+
3437
/// Simple startup handler that does no authentication
3538
/// For production, use DfAuthSource with proper pgwire authentication handlers
3639
pub struct SimpleStartupHandler;
@@ -71,7 +74,6 @@ pub struct DfSessionService {
7174
timezone: Arc<Mutex<String>>,
7275
auth_manager: Arc<AuthManager>,
7376
sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
74-
statement_timeout: Arc<Mutex<Option<std::time::Duration>>>,
7577
}
7678

7779
impl DfSessionService {
@@ -98,7 +100,31 @@ impl DfSessionService {
98100
timezone: Arc::new(Mutex::new("UTC".to_string())),
99101
auth_manager,
100102
sql_rewrite_rules,
101-
statement_timeout: Arc::new(Mutex::new(None)),
103+
}
104+
}
105+
106+
/// Get statement timeout from client metadata
107+
fn get_statement_timeout<C>(client: &C) -> Option<std::time::Duration>
108+
where
109+
C: ClientInfo,
110+
{
111+
client
112+
.metadata()
113+
.get(METADATA_STATEMENT_TIMEOUT)
114+
.and_then(|s| s.parse::<u64>().ok())
115+
.map(std::time::Duration::from_millis)
116+
}
117+
118+
/// Set statement timeout in client metadata
119+
fn set_statement_timeout<C>(client: &mut C, timeout: Option<std::time::Duration>)
120+
where
121+
C: ClientInfo,
122+
{
123+
let metadata = client.metadata_mut();
124+
if let Some(duration) = timeout {
125+
metadata.insert(METADATA_STATEMENT_TIMEOUT.to_string(), duration.as_millis().to_string());
126+
} else {
127+
metadata.remove(METADATA_STATEMENT_TIMEOUT);
102128
}
103129
}
104130

@@ -196,10 +222,14 @@ impl DfSessionService {
196222
Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
197223
}
198224

199-
async fn try_respond_set_statements<'a>(
225+
async fn try_respond_set_statements<'a, C>(
200226
&self,
227+
client: &mut C,
201228
query_lower: &str,
202-
) -> PgWireResult<Option<Response<'a>>> {
229+
) -> PgWireResult<Option<Response<'a>>>
230+
where
231+
C: ClientInfo,
232+
{
203233
if query_lower.starts_with("set") {
204234
if query_lower.starts_with("set time zone") {
205235
let parts: Vec<&str> = query_lower.split_whitespace().collect();
@@ -221,10 +251,9 @@ impl DfSessionService {
221251
let parts: Vec<&str> = query_lower.split_whitespace().collect();
222252
if parts.len() >= 3 {
223253
let timeout_str = parts[2].trim_matches('"').trim_matches('\'');
224-
let mut statement_timeout = self.statement_timeout.lock().await;
225254

226-
if timeout_str == "0" || timeout_str.is_empty() {
227-
*statement_timeout = None;
255+
let timeout = if timeout_str == "0" || timeout_str.is_empty() {
256+
None
228257
} else {
229258
// Parse timeout value (supports ms, s, min formats)
230259
let timeout_ms = if timeout_str.ends_with("ms") {
@@ -245,14 +274,12 @@ impl DfSessionService {
245274
};
246275

247276
match timeout_ms {
248-
Ok(ms) if ms > 0 => {
249-
*statement_timeout = Some(std::time::Duration::from_millis(ms));
250-
}
251-
_ => {
252-
*statement_timeout = None;
253-
}
277+
Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
278+
_ => None,
254279
}
255-
}
280+
};
281+
282+
Self::set_statement_timeout(client, timeout);
256283
Ok(Some(Response::Execution(Tag::new("SET"))))
257284
} else {
258285
Err(PgWireError::UserError(Box::new(
@@ -322,10 +349,14 @@ impl DfSessionService {
322349
}
323350
}
324351

325-
async fn try_respond_show_statements<'a>(
352+
async fn try_respond_show_statements<'a, C>(
326353
&self,
354+
client: &C,
327355
query_lower: &str,
328-
) -> PgWireResult<Option<Response<'a>>> {
356+
) -> PgWireResult<Option<Response<'a>>>
357+
where
358+
C: ClientInfo,
359+
{
329360
if query_lower.starts_with("show ") {
330361
match query_lower.strip_suffix(";").unwrap_or(query_lower) {
331362
"show time zone" => {
@@ -354,7 +385,7 @@ impl DfSessionService {
354385
Ok(Some(Response::Query(resp)))
355386
}
356387
"show statement_timeout" => {
357-
let timeout = *self.statement_timeout.lock().await;
388+
let timeout = Self::get_statement_timeout(client);
358389
let timeout_str = match timeout {
359390
Some(duration) => format!("{}ms", duration.as_millis()),
360391
None => "0".to_string(),
@@ -408,7 +439,7 @@ impl SimpleQueryHandler for DfSessionService {
408439
self.check_query_permission(client, &query).await?;
409440
}
410441

411-
if let Some(resp) = self.try_respond_set_statements(&query_lower).await? {
442+
if let Some(resp) = self.try_respond_set_statements(client, &query_lower).await? {
412443
return Ok(vec![resp]);
413444
}
414445

@@ -419,7 +450,7 @@ impl SimpleQueryHandler for DfSessionService {
419450
return Ok(vec![resp]);
420451
}
421452

422-
if let Some(resp) = self.try_respond_show_statements(&query_lower).await? {
453+
if let Some(resp) = self.try_respond_show_statements(client, &query_lower).await? {
423454
return Ok(vec![resp]);
424455
}
425456

@@ -436,7 +467,7 @@ impl SimpleQueryHandler for DfSessionService {
436467
}
437468

438469
let df_result = {
439-
let timeout = *self.statement_timeout.lock().await;
470+
let timeout = Self::get_statement_timeout(client);
440471
if let Some(timeout_duration) = timeout {
441472
tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
442473
.await
@@ -568,7 +599,7 @@ impl ExtendedQueryHandler for DfSessionService {
568599
.await?;
569600
}
570601

571-
if let Some(resp) = self.try_respond_set_statements(&query).await? {
602+
if let Some(resp) = self.try_respond_set_statements(client, &query).await? {
572603
return Ok(resp);
573604
}
574605

@@ -579,7 +610,7 @@ impl ExtendedQueryHandler for DfSessionService {
579610
return Ok(resp);
580611
}
581612

582-
if let Some(resp) = self.try_respond_show_statements(&query).await? {
613+
if let Some(resp) = self.try_respond_show_statements(client, &query).await? {
583614
return Ok(resp);
584615
}
585616

@@ -613,7 +644,7 @@ impl ExtendedQueryHandler for DfSessionService {
613644
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
614645

615646
let dataframe = {
616-
let timeout = *self.statement_timeout.lock().await;
647+
let timeout = Self::get_statement_timeout(client);
617648
if let Some(timeout_duration) = timeout {
618649
tokio::time::timeout(
619650
timeout_duration,
@@ -690,28 +721,88 @@ mod tests {
690721
use super::*;
691722
use crate::auth::AuthManager;
692723
use datafusion::prelude::SessionContext;
724+
use std::collections::HashMap;
693725
use std::time::Duration;
694726

727+
struct MockClient {
728+
metadata: HashMap<String, String>,
729+
}
730+
731+
impl MockClient {
732+
fn new() -> Self {
733+
Self {
734+
metadata: HashMap::new(),
735+
}
736+
}
737+
}
738+
739+
impl ClientInfo for MockClient {
740+
fn socket_addr(&self) -> std::net::SocketAddr {
741+
"127.0.0.1:5432".parse().unwrap()
742+
}
743+
744+
fn is_secure(&self) -> bool {
745+
false
746+
}
747+
748+
fn protocol_version(&self) -> pgwire::messages::ProtocolVersion {
749+
pgwire::messages::ProtocolVersion::PROTOCOL3_0
750+
}
751+
752+
fn set_protocol_version(&mut self, _version: pgwire::messages::ProtocolVersion) {}
753+
754+
fn pid_and_secret_key(&self) -> (i32, pgwire::messages::startup::SecretKey) {
755+
(0, pgwire::messages::startup::SecretKey::I32(0))
756+
}
757+
758+
fn set_pid_and_secret_key(&mut self, _pid: i32, _secret_key: pgwire::messages::startup::SecretKey) {}
759+
760+
fn state(&self) -> pgwire::api::PgWireConnectionState {
761+
pgwire::api::PgWireConnectionState::ReadyForQuery
762+
}
763+
764+
fn set_state(&mut self, _new_state: pgwire::api::PgWireConnectionState) {}
765+
766+
fn transaction_status(&self) -> pgwire::messages::response::TransactionStatus {
767+
pgwire::messages::response::TransactionStatus::Idle
768+
}
769+
770+
fn set_transaction_status(&mut self, _new_status: pgwire::messages::response::TransactionStatus) {}
771+
772+
fn metadata(&self) -> &HashMap<String, String> {
773+
&self.metadata
774+
}
775+
776+
fn metadata_mut(&mut self) -> &mut HashMap<String, String> {
777+
&mut self.metadata
778+
}
779+
780+
fn client_certificates<'a>(&self) -> Option<&[rustls_pki_types::CertificateDer<'a>]> {
781+
None
782+
}
783+
}
784+
695785
#[tokio::test]
696786
async fn test_statement_timeout_set_and_show() {
697787
let session_context = Arc::new(SessionContext::new());
698788
let auth_manager = Arc::new(AuthManager::new());
699789
let service = DfSessionService::new(session_context, auth_manager);
790+
let mut client = MockClient::new();
700791

701792
// Test setting timeout to 5000ms
702793
let set_response = service
703-
.try_respond_set_statements("set statement_timeout '5000ms'")
794+
.try_respond_set_statements(&mut client, "set statement_timeout '5000ms'")
704795
.await
705796
.unwrap();
706797
assert!(set_response.is_some());
707798

708-
// Verify the timeout was set
709-
let timeout = *service.statement_timeout.lock().await;
799+
// Verify the timeout was set in client metadata
800+
let timeout = DfSessionService::get_statement_timeout(&client);
710801
assert_eq!(timeout, Some(Duration::from_millis(5000)));
711802

712803
// Test SHOW statement_timeout
713804
let show_response = service
714-
.try_respond_show_statements("show statement_timeout")
805+
.try_respond_show_statements(&client, "show statement_timeout")
715806
.await
716807
.unwrap();
717808
assert!(show_response.is_some());
@@ -722,20 +813,21 @@ mod tests {
722813
let session_context = Arc::new(SessionContext::new());
723814
let auth_manager = Arc::new(AuthManager::new());
724815
let service = DfSessionService::new(session_context, auth_manager);
816+
let mut client = MockClient::new();
725817

726818
// Set timeout first
727819
service
728-
.try_respond_set_statements("set statement_timeout '1000ms'")
820+
.try_respond_set_statements(&mut client, "set statement_timeout '1000ms'")
729821
.await
730822
.unwrap();
731823

732824
// Disable timeout with 0
733825
service
734-
.try_respond_set_statements("set statement_timeout '0'")
826+
.try_respond_set_statements(&mut client, "set statement_timeout '0'")
735827
.await
736828
.unwrap();
737829

738-
let timeout = *service.statement_timeout.lock().await;
830+
let timeout = DfSessionService::get_statement_timeout(&client);
739831
assert_eq!(timeout, None);
740832
}
741833
}

0 commit comments

Comments
 (0)