Skip to content

Commit b322274

Browse files
committed
Multistatement
1 parent 892c813 commit b322274

File tree

3 files changed

+86
-36
lines changed

3 files changed

+86
-36
lines changed

Cargo.lock

Lines changed: 15 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion-postgres/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ readme = "../README.md"
1919
pgwire = { workspace = true }
2020
datafusion = { workspace = true }
2121
futures = "0.3"
22+
sqlparser = "0.55"
2223
async-trait = "0.1"
2324
chrono = { version = "0.4", features = ["std"] }

datafusion-postgres/src/handlers.rs

Lines changed: 70 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// src/handlers.rs
12
use std::collections::HashMap;
23
use std::sync::Arc;
34

@@ -15,8 +16,14 @@ use pgwire::api::stmt::StoredStatement;
1516
use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type};
1617
use pgwire::error::{PgWireError, PgWireResult};
1718

19+
// --- ADD THESE IMPORTS FOR MULTI-STATEMENT PARSING ---
20+
use sqlparser::dialect::GenericDialect;
21+
use sqlparser::parser::Parser as SqlParser;
22+
// ------------------------------------------------------
23+
1824
use crate::datatypes::{self, into_pg_type};
1925

26+
/// A factory that creates our handlers for the PGWire server.
2027
pub struct HandlerFactory(pub Arc<DfSessionService>);
2128

2229
impl NoopStartupHandler for DfSessionService {}
@@ -49,9 +56,10 @@ impl PgWireServerHandlers for HandlerFactory {
4956
}
5057
}
5158

59+
/// Our primary session service, storing a DataFusion `SessionContext`.
5260
pub struct DfSessionService {
53-
session_context: Arc<SessionContext>,
54-
parser: Arc<Parser>,
61+
pub session_context: Arc<SessionContext>,
62+
pub parser: Arc<Parser>,
5563
}
5664

5765
impl DfSessionService {
@@ -67,27 +75,7 @@ impl DfSessionService {
6775
}
6876
}
6977

70-
#[async_trait]
71-
impl SimpleQueryHandler for DfSessionService {
72-
async fn do_query<'a, C>(
73-
&self,
74-
_client: &mut C,
75-
query: &'a str,
76-
) -> PgWireResult<Vec<Response<'a>>>
77-
where
78-
C: ClientInfo + Unpin + Send + Sync,
79-
{
80-
let ctx = &self.session_context;
81-
let df = ctx
82-
.sql(query)
83-
.await
84-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
85-
86-
let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
87-
Ok(vec![Response::Query(resp)])
88-
}
89-
}
90-
78+
/// A simple parser that builds a logical plan from SQL text, using DataFusion.
9179
pub struct Parser {
9280
session_context: Arc<SessionContext>,
9381
}
@@ -104,18 +92,68 @@ impl QueryParser for Parser {
10492
.create_logical_plan(sql)
10593
.await
10694
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
107-
let optimised = state
95+
let optimized = state
10896
.optimize(&logical_plan)
10997
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
11098

111-
Ok(optimised)
99+
Ok(optimized)
112100
}
113101
}
114102

103+
// ----------------------------------------------------------------
104+
// SimpleQueryHandler Implementation (multi-statement support)
105+
// ----------------------------------------------------------------
106+
#[async_trait]
107+
impl SimpleQueryHandler for DfSessionService {
108+
async fn do_query<'a, C>(
109+
&self,
110+
_client: &mut C,
111+
query: &'a str,
112+
) -> PgWireResult<Vec<Response<'a>>>
113+
where
114+
C: ClientInfo + Unpin + Send + Sync,
115+
{
116+
// 1) Parse the incoming query string into multiple statements using sqlparser.
117+
let dialect = GenericDialect {};
118+
let stmts = match SqlParser::parse_sql(&dialect, query) {
119+
Ok(s) => s,
120+
Err(e) => {
121+
return Err(PgWireError::ApiError(Box::new(e)));
122+
}
123+
};
124+
125+
// 2) For each parsed statement, execute with DataFusion and collect results.
126+
let mut responses = Vec::with_capacity(stmts.len());
127+
for statement in stmts {
128+
// Convert the AST statement back to SQL text
129+
// (some statements might be empty if there's a trailing semicolon)
130+
let stmt_string = statement.to_string().trim().to_owned();
131+
if stmt_string.is_empty() {
132+
continue;
133+
}
134+
135+
// Execute the statement in DataFusion
136+
let df = self
137+
.session_context
138+
.sql(&stmt_string)
139+
.await
140+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
141+
142+
// 3) Encode the DataFrame into a QueryResponse for the client
143+
let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
144+
responses.push(Response::Query(resp));
145+
}
146+
147+
Ok(responses)
148+
}
149+
}
150+
151+
// ----------------------------------------------------------------
152+
// ExtendedQueryHandler Implementation (same as original)
153+
// ----------------------------------------------------------------
115154
#[async_trait]
116155
impl ExtendedQueryHandler for DfSessionService {
117156
type Statement = LogicalPlan;
118-
119157
type QueryParser = Parser;
120158

121159
fn query_parser(&self) -> Arc<Self::QueryParser> {
@@ -201,11 +239,11 @@ impl ExtendedQueryHandler for DfSessionService {
201239
}
202240
}
203241

242+
/// Helper to convert DataFusion’s parameter map into an ordered list.
204243
fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
205-
// Datafusion stores the parameters as a map. In our case, the keys will be
206-
// `$1`, `$2` etc. The values will be the parameter types.
207-
208-
let mut types = types.iter().collect::<Vec<_>>();
209-
types.sort_by(|a, b| a.0.cmp(b.0));
210-
types.into_iter().map(|pt| pt.1.as_ref()).collect()
244+
// DataFusion stores parameters as a map keyed by "$1", "$2", etc.
245+
// We sort them in ascending order by key to match the expected param order.
246+
let mut types_vec = types.iter().collect::<Vec<_>>();
247+
types_vec.sort_by(|a, b| a.0.cmp(b.0));
248+
types_vec.into_iter().map(|pt| pt.1.as_ref()).collect()
211249
}

0 commit comments

Comments
 (0)