Skip to content

Commit 5705317

Browse files
authored
Merge pull request #317 from cipherstash/refactor-frontend-parser
♻️ refactor: extract SQL parsing logic into dedicated Parser module
2 parents f1b66a2 + efb30a9 commit 5705317

File tree

3 files changed

+39
-43
lines changed

3 files changed

+39
-43
lines changed

packages/cipherstash-proxy/src/postgresql/frontend.rs

Lines changed: 4 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use super::messages::execute::Execute;
66
use super::messages::parse::Parse;
77
use super::messages::query::Query;
88
use super::messages::FrontendCode as Code;
9+
use super::parser::SqlParser;
910
use super::protocol::{self};
1011
use crate::connect::Sender;
1112
use crate::eql::Identifier;
@@ -22,7 +23,7 @@ use crate::prometheus::{
2223
CLIENTS_BYTES_RECEIVED_TOTAL, ENCRYPTED_VALUES_TOTAL, ENCRYPTION_DURATION_SECONDS,
2324
ENCRYPTION_ERROR_TOTAL, ENCRYPTION_REQUESTS_TOTAL, SERVER_BYTES_SENT_TOTAL,
2425
STATEMENTS_ENCRYPTED_TOTAL, STATEMENTS_PASSTHROUGH_MAPPING_DISABLED_TOTAL,
25-
STATEMENTS_PASSTHROUGH_TOTAL, STATEMENTS_TOTAL, STATEMENTS_UNMAPPABLE_TOTAL,
26+
STATEMENTS_PASSTHROUGH_TOTAL, STATEMENTS_UNMAPPABLE_TOTAL,
2627
};
2728
use crate::proxy::Proxy;
2829
use crate::EqlEncrypted;
@@ -34,17 +35,13 @@ use pg_escape::quote_literal;
3435
use postgres_types::Type;
3536
use serde::Serialize;
3637
use sqltk::parser::ast::{self, Value};
37-
use sqltk::parser::dialect::PostgreSqlDialect;
38-
use sqltk::parser::parser::Parser;
3938
use sqltk::NodeKey;
4039
use std::collections::HashMap;
4140
use std::sync::Arc;
4241
use std::time::Instant;
4342
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
4443
use tracing::{debug, error, info, warn};
4544

46-
const DIALECT: PostgreSqlDialect = PostgreSqlDialect {};
47-
4845
/// The PostgreSQL proxy frontend that handles client-to-server message processing.
4946
///
5047
/// The Frontend intercepts messages from PostgreSQL clients, analyzes SQL statements for
@@ -384,7 +381,7 @@ where
384381
let mut query = Query::try_from(bytes)?;
385382

386383
// Simple Query may contain many statements
387-
let parsed_statements = self.parse_statements(&query.statement)?;
384+
let parsed_statements = SqlParser::parse_statements(&query.statement)?;
388385
let mut transformed_statements = vec![];
389386

390387
debug!(target: MAPPER,
@@ -659,7 +656,7 @@ where
659656
parse = ?message
660657
);
661658

662-
let statement = self.parse_statement(&message.statement)?;
659+
let statement = SqlParser::parse_statement(&message.statement)?;
663660

664661
if let Some(mapping_disabled) = self.context.maybe_set_unsafe_disable_mapping(&statement) {
665662
warn!(
@@ -748,42 +745,6 @@ where
748745
}
749746
}
750747

751-
///
752-
/// Parse a SQL statement string into an SqlParser AST
753-
///
754-
fn parse_statement(&mut self, statement: &str) -> Result<ast::Statement, Error> {
755-
let statement = Parser::new(&DIALECT)
756-
.try_with_sql(statement)?
757-
.parse_statement()?;
758-
759-
debug!(target: MAPPER,
760-
client_id = self.context.client_id,
761-
statement = %statement
762-
);
763-
764-
counter!(STATEMENTS_TOTAL).increment(1);
765-
766-
Ok(statement)
767-
}
768-
769-
///
770-
/// Parse a SQL String potentially containing multiple statements into parsed SqlParser AST
771-
///
772-
fn parse_statements(&mut self, statement: &str) -> Result<Vec<ast::Statement>, Error> {
773-
let statement = Parser::new(&DIALECT)
774-
.try_with_sql(statement)?
775-
.parse_statements()?;
776-
777-
debug!(target: MAPPER,
778-
client_id = self.context.client_id,
779-
statement = ?statement
780-
);
781-
782-
counter!(STATEMENTS_TOTAL).increment(statement.len() as u64);
783-
784-
Ok(statement)
785-
}
786-
787748
///
788749
/// Check the Statement AST for DDL
789750
/// Sets a schema changed flag in the Context

packages/cipherstash-proxy/src/postgresql/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ mod frontend;
77
mod handler;
88
mod message_buffer;
99
mod messages;
10+
mod parser;
1011
mod protocol;
1112
mod startup;
1213

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
use crate::error::Error;
2+
use crate::prometheus::STATEMENTS_TOTAL;
3+
use metrics::counter;
4+
use sqltk::parser::ast;
5+
use sqltk::parser::dialect::PostgreSqlDialect;
6+
use sqltk::parser::parser::Parser;
7+
8+
const DIALECT: PostgreSqlDialect = PostgreSqlDialect {};
9+
10+
pub struct SqlParser;
11+
12+
impl SqlParser {
13+
/// Parse a SQL statement string into an SqlParser AST
14+
pub fn parse_statement(statement: &str) -> Result<ast::Statement, Error> {
15+
let statement = Parser::new(&DIALECT)
16+
.try_with_sql(statement)?
17+
.parse_statement()?;
18+
19+
counter!(STATEMENTS_TOTAL).increment(1);
20+
21+
Ok(statement)
22+
}
23+
24+
/// Parse a SQL String potentially containing multiple statements into parsed SqlParser AST
25+
pub fn parse_statements(statement: &str) -> Result<Vec<ast::Statement>, Error> {
26+
let statement = Parser::new(&DIALECT)
27+
.try_with_sql(statement)?
28+
.parse_statements()?;
29+
30+
counter!(STATEMENTS_TOTAL).increment(statement.len() as u64);
31+
32+
Ok(statement)
33+
}
34+
}

0 commit comments

Comments
 (0)