Skip to content

Commit 4eaabdb

Browse files
committed
Refactor session context updates and add datestyle support
1 parent 1c167cc commit 4eaabdb

File tree

1 file changed

+122
-87
lines changed

1 file changed

+122
-87
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 122 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// src/handlers.rs
2+
23
use std::collections::HashMap;
34
use std::sync::Arc;
45

@@ -21,6 +22,7 @@ use pgwire::error::{PgWireError, PgWireResult};
2122
use sqlparser::ast::{Expr, Ident, ObjectName, Statement};
2223
use sqlparser::dialect::GenericDialect;
2324
use sqlparser::parser::Parser as SqlParser;
25+
use tokio::sync::RwLock;
2426

2527
use crate::datatypes::{self, into_pg_type};
2628

@@ -56,37 +58,41 @@ impl PgWireServerHandlers for HandlerFactory {
5658
}
5759
}
5860

61+
5962
pub struct DfSessionService {
60-
pub session_context: Arc<tokio::sync::RwLock<SessionContext>>,
63+
pub session_context: Arc<RwLock<SessionContext>>,
6164
pub parser: Arc<Parser>,
62-
custom_session_vars: Arc<tokio::sync::RwLock<HashMap<String, String>>>,
65+
custom_session_vars: Arc<RwLock<HashMap<String, String>>>,
6366
}
6467

6568
impl DfSessionService {
6669
pub fn new(session_context: SessionContext) -> DfSessionService {
67-
let session_context = Arc::new(tokio::sync::RwLock::new(session_context));
70+
let session_context = Arc::new(RwLock::new(session_context));
6871
let parser = Arc::new(Parser {
6972
session_context: session_context.clone(),
7073
});
7174
DfSessionService {
7275
session_context,
7376
parser,
74-
custom_session_vars: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
77+
custom_session_vars: Arc::new(RwLock::new(HashMap::new())),
7578
}
7679
}
7780

7881
async fn handle_set(&self, variable: &ObjectName, value: &[Expr]) -> PgWireResult<()> {
79-
let var_name = variable.0.get(0)
82+
let var_name = variable
83+
.0
84+
.get(0)
8085
.map(|ident| ident.to_string().to_lowercase())
8186
.unwrap_or_default();
87+
8288
let value_str = match value.get(0) {
8389
Some(Expr::Value(v)) => match &v.value {
8490
sqlparser::ast::Value::SingleQuotedString(s)
8591
| sqlparser::ast::Value::DoubleQuotedString(s) => s.clone(),
8692
sqlparser::ast::Value::Number(n, _) => n.to_string(),
8793
_ => v.to_string(),
8894
},
89-
Some(expr) => expr.to_string(),
95+
Some(expr) => expr.to_string(),
9096
None => {
9197
return Err(PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
9298
"ERROR".to_string(),
@@ -98,86 +104,110 @@ impl DfSessionService {
98104

99105
match var_name.as_str() {
100106
"timezone" => {
101-
let config = {
102-
let ctx = self.session_context.read().await;
103-
ctx.state().config().options().clone()
104-
};
105-
let mut new_config = config;
106-
new_config.execution.time_zone = Some(value_str);
107-
let new_context = SessionContext::new_with_config(new_config.into());
108-
{
109-
let ctx = self.session_context.read().await;
110-
for catalog_name in ctx.catalog_names() {
111-
if let Some(catalog) = ctx.catalog(&catalog_name) {
112-
for schema_name in catalog.schema_names() {
113-
if let Some(schema) = catalog.schema(&schema_name) {
114-
for table_name in schema.table_names() {
115-
if let Ok(Some(table)) = schema.table(&table_name).await {
116-
new_context.register_table(&table_name, table)
117-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
118-
}
107+
let mut sc_guard = self.session_context.write().await;
108+
109+
let mut config = sc_guard.state().config().options().clone();
110+
config.execution.time_zone = Some(value_str);
111+
112+
let new_context = SessionContext::new_with_config(config.into());
113+
114+
let old_catalog_names = sc_guard.catalog_names();
115+
for catalog_name in old_catalog_names {
116+
if let Some(catalog) = sc_guard.catalog(&catalog_name) {
117+
for schema_name in catalog.schema_names() {
118+
if let Some(schema) = catalog.schema(&schema_name) {
119+
for table_name in schema.table_names() {
120+
if let Ok(Some(table)) = schema.table(&table_name).await {
121+
new_context
122+
.register_table(&table_name, table)
123+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
119124
}
120125
}
121126
}
122127
}
123128
}
124129
}
125-
{
126-
let mut ctx = self.session_context.write().await;
127-
*ctx = new_context;
128-
}
130+
131+
*sc_guard = new_context;
129132
Ok(())
130133
}
131-
"client_encoding" | "search_path" | "application_name" => {
134+
"client_encoding" | "search_path" | "application_name" | "datestyle" => {
132135
let mut vars = self.custom_session_vars.write().await;
133136
vars.insert(var_name, value_str);
134137
Ok(())
135138
}
136139
_ => Err(PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
137140
"ERROR".to_string(),
138-
"42704".to_string(), // Undefined object
141+
"42704".to_string(),
139142
format!("Unrecognized configuration parameter '{}'", var_name),
140143
)))),
141144
}
142145
}
143146

144147
async fn handle_show<'a>(&self, variable: &[Ident]) -> PgWireResult<QueryResponse<'a>> {
145-
let var_name = variable.get(0)
148+
let var_name = variable
149+
.get(0)
146150
.map(|ident| ident.to_string().to_lowercase())
147151
.unwrap_or_default();
148-
let config = {
149-
let ctx = self.session_context.read().await;
150-
ctx.state().config().options().clone()
151-
};
152+
153+
let sc_guard = self.session_context.read().await;
154+
let config = sc_guard.state().config().options().clone();
155+
drop(sc_guard);
152156
let value = match var_name.as_str() {
153-
"timezone" => config.execution.time_zone.clone().unwrap_or_else(|| "UTC".to_string()),
154-
"client_encoding" => self.custom_session_vars
155-
.read().await
157+
"timezone" => config
158+
.execution
159+
.time_zone
160+
.clone()
161+
.unwrap_or_else(|| "UTC".to_string()),
162+
163+
"client_encoding" => self
164+
.custom_session_vars
165+
.read()
166+
.await
156167
.get(&var_name)
157168
.cloned()
158169
.unwrap_or_else(|| "UTF8".to_string()),
159-
"search_path" => self.custom_session_vars
160-
.read().await
170+
171+
"search_path" => self
172+
.custom_session_vars
173+
.read()
174+
.await
161175
.get(&var_name)
162176
.cloned()
163177
.unwrap_or_else(|| "public".to_string()),
164-
"application_name" => self.custom_session_vars
165-
.read().await
178+
179+
"application_name" => self
180+
.custom_session_vars
181+
.read()
182+
.await
166183
.get(&var_name)
167184
.cloned()
168185
.unwrap_or_else(|| "".to_string()),
186+
187+
"datestyle" => self
188+
.custom_session_vars
189+
.read()
190+
.await
191+
.get(&var_name)
192+
.cloned()
193+
.unwrap_or_else(|| "ISO, MDY".to_string()),
194+
169195
"all" => {
170196
let mut names = Vec::new();
171197
let mut values = Vec::new();
198+
172199
if let Some(tz) = &config.execution.time_zone {
173200
names.push("timezone".to_string());
174201
values.push(tz.clone());
175202
}
203+
176204
let custom_vars = self.custom_session_vars.read().await;
177205
for (name, value) in custom_vars.iter() {
178206
names.push(name.clone());
179207
values.push(value.clone());
180208
}
209+
210+
// Provide defaults if not set
181211
if !custom_vars.contains_key("client_encoding") {
182212
names.push("client_encoding".to_string());
183213
values.push("UTF8".to_string());
@@ -190,10 +220,16 @@ impl DfSessionService {
190220
names.push("application_name".to_string());
191221
values.push("".to_string());
192222
}
223+
if !custom_vars.contains_key("datestyle") {
224+
names.push("datestyle".to_string());
225+
values.push("ISO, MDY".to_string());
226+
}
227+
193228
let schema = Arc::new(Schema::new(vec![
194229
Field::new("name", DataType::Utf8, false),
195230
Field::new("setting", DataType::Utf8, false),
196231
]));
232+
197233
let batch = RecordBatch::try_new(
198234
schema.clone(),
199235
vec![
@@ -202,11 +238,13 @@ impl DfSessionService {
202238
],
203239
)
204240
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
205-
let df = {
206-
let ctx = self.session_context.read().await;
207-
ctx.read_batch(batch)
208-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
209-
};
241+
242+
let sc_guard = self.session_context.read().await;
243+
let df = sc_guard
244+
.read_batch(batch)
245+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
246+
drop(sc_guard);
247+
210248
return datatypes::encode_dataframe(df, &Format::UnifiedText).await;
211249
}
212250
_ => {
@@ -218,42 +256,40 @@ impl DfSessionService {
218256
}
219257
};
220258

221-
let schema = Arc::new(Schema::new(vec![
222-
Field::new(&var_name, DataType::Utf8, false),
223-
]));
224-
let batch = RecordBatch::try_new(
225-
schema.clone(),
226-
vec![Arc::new(StringArray::from(vec![value]))],
227-
)
228-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
229-
let df = {
230-
let ctx = self.session_context.read().await;
231-
ctx.read_batch(batch)
232-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
233-
};
259+
let schema = Arc::new(Schema::new(vec![Field::new(&var_name, DataType::Utf8, false)]));
260+
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(StringArray::from(vec![value]))])
261+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
262+
263+
let sc_guard = self.session_context.read().await;
264+
let df = sc_guard
265+
.read_batch(batch)
266+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
267+
drop(sc_guard);
268+
234269
datatypes::encode_dataframe(df, &Format::UnifiedText).await
235270
}
236271
}
237272

238273
pub struct Parser {
239-
session_context: Arc<tokio::sync::RwLock<SessionContext>>,
274+
session_context: Arc<RwLock<SessionContext>>,
240275
}
241276

242277
#[async_trait]
243278
impl QueryParser for Parser {
244279
type Statement = LogicalPlan;
245280

246281
async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult<Self::Statement> {
247-
let ctx = self.session_context.read().await;
248-
let logical_plan = ctx
249-
.state()
282+
let sc_guard = self.session_context.read().await;
283+
let state = sc_guard.state();
284+
285+
let logical_plan = state
250286
.create_logical_plan(sql)
251287
.await
252288
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
253-
let optimized = ctx
254-
.state()
289+
let optimized = state
255290
.optimize(&logical_plan)
256291
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
292+
257293
Ok(optimized)
258294
}
259295
}
@@ -295,12 +331,13 @@ impl SimpleQueryHandler for DfSessionService {
295331
responses.push(Response::Query(resp));
296332
}
297333
_ => {
298-
let df = {
299-
let ctx = self.session_context.read().await;
300-
ctx.sql(&stmt_string)
301-
.await
302-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
303-
};
334+
let sc_guard = self.session_context.read().await;
335+
let df = sc_guard
336+
.sql(&stmt_string)
337+
.await
338+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
339+
drop(sc_guard);
340+
304341
let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
305342
responses.push(Response::Query(resp));
306343
}
@@ -361,8 +398,8 @@ impl ExtendedQueryHandler for DfSessionService {
361398
let plan = &target.statement.statement;
362399
let format = &target.result_column_format;
363400
let schema = plan.schema();
364-
let fields =
365-
datatypes::df_schema_to_pg_fields(schema.as_ref(), format)?;
401+
let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), format)?;
402+
366403
Ok(DescribePortalResponse::new(fields))
367404
}
368405

@@ -388,9 +425,7 @@ impl ExtendedQueryHandler for DfSessionService {
388425
sqlparser::ast::OneOrManyWithParens::Many(ref names) => names.first().unwrap(),
389426
};
390427
self.handle_set(var, &value).await?;
391-
return Ok(Response::Execution(
392-
pgwire::api::results::Tag::new("SET").into(),
393-
));
428+
return Ok(Response::Execution(pgwire::api::results::Tag::new("SET").into()));
394429
}
395430
} else if stmt_upper.starts_with("SHOW ") {
396431
let dialect = GenericDialect {};
@@ -406,20 +441,20 @@ impl ExtendedQueryHandler for DfSessionService {
406441
let param_types = plan
407442
.get_parameter_types()
408443
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
409-
let param_values = datatypes::deserialize_parameters(
410-
portal,
411-
&ordered_param_types(&param_types),
412-
)?;
444+
let param_values =
445+
datatypes::deserialize_parameters(portal, &ordered_param_types(&param_types))?;
413446
let plan = plan
414447
.clone()
415448
.replace_params_with_values(&param_values)
416449
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
417-
let dataframe = {
418-
let ctx = self.session_context.read().await;
419-
ctx.execute_logical_plan(plan)
420-
.await
421-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
422-
};
450+
451+
let sc_guard = self.session_context.read().await;
452+
let dataframe = sc_guard
453+
.execute_logical_plan(plan)
454+
.await
455+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
456+
drop(sc_guard);
457+
423458
let resp = datatypes::encode_dataframe(dataframe, &portal.result_column_format).await?;
424459
Ok(Response::Query(resp))
425460
}

0 commit comments

Comments
 (0)