Skip to content

Commit 4f638fe

Browse files
committed
feat: SQL transformations related to GROUP BY
Includes a version bump to the latest `sqltk` version which vendors `sqlparser` as `sqltk-parser` and includes an extra parameter in `Transform::transform` to allow contextual transformations (instead of solely by type). The following refactorings were done as part of this PR: 1. Break sole `Transform` impl into modular rules using the `TransformationRule` trait. 2. Removed `EqlFunctionTracker` - this existed to support `Expr` transformations in `ORDER BY` clauses. With contextual transformation it was no longer required. 3. Removed `TypeCell` - it did not carry its weight and made debugging really hard. The `TypeRegistry` now associates nodes to Arc<Type> using an intermediary `TypeVar`. An extra method was added to the post type checking step in the `EqlMapper` to resolve all `Value` nodes which have no concrete type after the inference step to `NativeValue`.
1 parent 06f90cd commit 4f638fe

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+2153
-1825
lines changed

Cargo.lock

Lines changed: 24 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ strip = "none"
3535
debug = true
3636

3737
[workspace.dependencies]
38-
sqlparser = { version = "^0.52", features = ["bigdecimal", "serde"] }
38+
sqltk = { version = "0.5.0" }
39+
sqltk-parser = { version = "0.52.0" }
3940
thiserror = "2.0.9"
4041
tokio = { version = "1.42.0", features = ["full"] }
4142
tracing = "0.1"

packages/cipherstash-proxy/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ rustls-pki-types = "1.10.0"
4040
serde = "1.0"
4141
serde_json = "1.0"
4242
socket2 = "0.5.7"
43-
sqlparser = { workspace = true }
43+
sqltk = { workspace = true }
44+
sqltk-parser = { workspace = true }
4445
thiserror = { workspace = true }
4546
tokio = { workspace = true }
4647
tokio-postgres = { version = "0.7", features = [

packages/cipherstash-proxy/src/encrypt/schema/manager.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::error::Error;
44
use crate::{connect, log::SCHEMA};
55
use arc_swap::ArcSwap;
66
use eql_mapper::{Column, Schema, Table};
7-
use sqlparser::ast::Ident;
7+
use sqltk_parser::ast::Ident;
88
use std::sync::Arc;
99
use std::time::Duration;
1010
use tokio::{task::JoinHandle, time};

packages/cipherstash-proxy/src/eql/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use cipherstash_client::{
33
zerokms::{encrypted_record, EncryptedRecord},
44
};
55
use serde::{Deserialize, Serialize};
6-
use sqlparser::ast::Ident;
6+
use sqltk_parser::ast::Ident;
77

88
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
99
pub struct Plaintext {

packages/cipherstash-proxy/src/error.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,8 @@ impl From<serde_json::Error> for Error {
320320
}
321321
}
322322

323-
impl From<sqlparser::parser::ParserError> for Error {
324-
fn from(e: sqlparser::parser::ParserError) -> Self {
323+
impl From<sqltk_parser::parser::ParserError> for Error {
324+
fn from(e: sqltk_parser::parser::ParserError) -> Self {
325325
Error::Mapping(MappingError::InvalidSqlStatement(e.to_string()))
326326
}
327327
}

packages/cipherstash-proxy/src/postgresql/data/from_sql.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use cipherstash_config::ColumnType;
1111
use postgres_types::FromSql;
1212
use postgres_types::Type;
1313
use rust_decimal::Decimal;
14-
use sqlparser::ast::Value;
14+
use sqltk_parser::ast::Value;
1515
use std::str::FromStr;
1616
use tracing::debug;
1717

packages/cipherstash-proxy/src/postgresql/frontend.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@ use crate::prometheus::{
2626
use crate::Encrypted;
2727
use bytes::BytesMut;
2828
use cipherstash_client::encryption::Plaintext;
29-
use eql_mapper::{self, EqlMapperError, EqlValue, NodeKey, TableColumn, TypedStatement};
29+
use eql_mapper::{self, EqlMapperError, EqlValue, TableColumn, TypedStatement};
3030
use metrics::{counter, histogram};
3131
use pg_escape::quote_literal;
3232
use serde::Serialize;
33-
use sqlparser::ast::{self, Expr, Value};
34-
use sqlparser::dialect::PostgreSqlDialect;
35-
use sqlparser::parser::Parser;
33+
use sqltk::AsNodeKey;
34+
use sqltk_parser::ast::{self, Value};
35+
use sqltk_parser::dialect::PostgreSqlDialect;
36+
use sqltk_parser::parser::Parser;
3637
use std::collections::HashMap;
3738
use std::time::Instant;
3839
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@@ -285,7 +286,7 @@ where
285286

286287
match self.to_encryptable_statement(&typed_statement, vec![])? {
287288
Some(statement) => {
288-
if statement.has_literals() || typed_statement.has_nodes_to_wrap() {
289+
if statement.has_literals() {
289290
let encrypted_literals = self
290291
.encrypt_literals(&typed_statement, &statement.literal_columns)
291292
.await?;
@@ -421,16 +422,15 @@ where
421422
.literals
422423
.iter()
423424
.zip(encrypted_expressions.into_iter())
424-
.filter_map(|((_, original_node), en)| en.map(|en| (NodeKey::new(*original_node), en)))
425+
.filter_map(|((_, original_node), en)| en.map(|en| (original_node.as_node_key(), en)))
425426
.collect::<HashMap<_, _>>();
426427

427428
debug!(target: MAPPER,
428429
client_id = self.context.client_id,
429-
nodes_to_wrap = typed_statement.nodes_to_wrap.len(),
430430
literals = encrypted_nodes.len(),
431431
);
432432

433-
if !typed_statement.has_nodes_to_wrap() && encrypted_nodes.is_empty() {
433+
if encrypted_nodes.is_empty() {
434434
return Ok(None);
435435
}
436436

@@ -500,7 +500,7 @@ where
500500

501501
match self.to_encryptable_statement(&typed_statement, param_types)? {
502502
Some(statement) => {
503-
if statement.has_literals() || typed_statement.has_nodes_to_wrap() {
503+
if statement.has_literals() {
504504
let encrypted_literals = self
505505
.encrypt_literals(&typed_statement, &statement.literal_columns)
506506
.await?;
@@ -620,7 +620,6 @@ where
620620
if (param_columns.is_empty() || no_encrypted_param_columns)
621621
&& (projection_columns.is_empty() || no_encrypted_projection_columns)
622622
&& literal_columns.is_empty()
623-
&& !typed_statement.has_nodes_to_wrap()
624623
{
625624
return Ok(None);
626625
}
@@ -781,7 +780,7 @@ where
781780
typed_statement: &eql_mapper::TypedStatement<'_>,
782781
) -> Result<Vec<Option<Column>>, Error> {
783782
let mut projection_columns = vec![];
784-
if let Some(eql_mapper::Projection::WithColumns(columns)) = &typed_statement.projection {
783+
if let eql_mapper::Projection::WithColumns(columns) = &typed_statement.projection {
785784
for col in columns {
786785
let eql_mapper::ProjectionColumn { ty, .. } = col;
787786
let configured_column = match ty {
@@ -819,7 +818,7 @@ where
819818

820819
for param in typed_statement.params.iter() {
821820
let configured_column = match param {
822-
eql_mapper::Value::Eql(EqlValue(TableColumn { table, column })) => {
821+
(_, eql_mapper::Value::Eql(EqlValue(TableColumn { table, column }))) => {
823822
let identifier = Identifier::from((table, column));
824823

825824
debug!(
@@ -967,9 +966,9 @@ fn literals_to_plaintext(
967966
Ok(plaintexts)
968967
}
969968

970-
fn to_json_literal_expr<T>(literal: &T) -> Result<Expr, Error>
969+
fn to_json_literal_expr<T>(literal: &T) -> Result<Value, Error>
971970
where
972971
T: ?Sized + Serialize,
973972
{
974-
Ok(serde_json::to_string(literal).map(|json| Expr::Value(Value::SingleQuotedString(json)))?)
973+
Ok(serde_json::to_string(literal).map(Value::SingleQuotedString)?)
975974
}

packages/eql-mapper/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ authors = [
1212

1313
[dependencies]
1414
derive_more = { version = "^1.0", features = ["display", "constructor"] }
15+
impl-trait-for-tuples = "0.2.3"
1516
itertools = "^0.13"
16-
sqlparser = { workspace = true }
17-
sqltk = { git = "https://github.com/cipherstash/sqltk/", rev = "214f9b90e4f07d4414292813ffd6e45dec075fbb" }
17+
sqltk = { workspace = true }
18+
sqltk-parser = { workspace = true }
1819
thiserror = { workspace = true }
1920
tracing = { workspace = true }
2021
tracing-subscriber = { workspace = true }
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
use std::{
2+
collections::HashMap,
3+
fmt::{Debug, Display},
4+
};
5+
6+
use sqltk::NodeKey;
7+
use sqltk_parser::ast::{
8+
Delete, Expr, Function, Insert, Query, Select, SelectItem, SetExpr, Statement, Value, Values,
9+
};
10+
11+
use crate::{EqlValue, Param, Type};
12+
13+
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
14+
pub struct Fmt<T>(pub(crate) T);
15+
16+
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
17+
pub struct FmtAst<T>(pub(crate) T);
18+
19+
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
20+
pub struct FmtAstVec<T>(pub(crate) T);
21+
22+
impl Display for Fmt<NodeKey<'_>> {
23+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24+
if let Some(node) = self.0.get_as::<Statement>() {
25+
return Display::fmt(&FmtAst(node), f);
26+
}
27+
if let Some(node) = self.0.get_as::<Query>() {
28+
return Display::fmt(&FmtAst(node), f);
29+
}
30+
if let Some(node) = self.0.get_as::<Insert>() {
31+
return Display::fmt(&FmtAst(node), f);
32+
}
33+
if let Some(node) = self.0.get_as::<Delete>() {
34+
return Display::fmt(&FmtAst(node), f);
35+
}
36+
if let Some(node) = self.0.get_as::<Expr>() {
37+
return Display::fmt(&FmtAst(node), f);
38+
}
39+
if let Some(node) = self.0.get_as::<SetExpr>() {
40+
return Display::fmt(&FmtAst(node), f);
41+
}
42+
if let Some(node) = self.0.get_as::<Select>() {
43+
return Display::fmt(&FmtAst(node), f);
44+
}
45+
if let Some(node) = self.0.get_as::<SelectItem>() {
46+
return Display::fmt(&FmtAst(node), f);
47+
}
48+
if let Some(node) = self.0.get_as::<Vec<SelectItem>>() {
49+
return Display::fmt(&FmtAstVec(node), f);
50+
}
51+
if let Some(node) = self.0.get_as::<Function>() {
52+
return Display::fmt(&FmtAst(node), f);
53+
}
54+
if let Some(node) = self.0.get_as::<Values>() {
55+
return Display::fmt(&FmtAst(node), f);
56+
}
57+
if let Some(node) = self.0.get_as::<Value>() {
58+
return Display::fmt(&FmtAst(node), f);
59+
}
60+
61+
f.write_str("!! CANNOT RENDER SQL NODE !!!")?;
62+
63+
Ok(())
64+
}
65+
}
66+
67+
impl Display for Fmt<&HashMap<NodeKey<'_>, Type>> {
68+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69+
let mut out: Vec<String> = Vec::new();
70+
out.push("{ ".into());
71+
for (k, v) in self.0.iter() {
72+
out.push(format!("{}: {}", Fmt(*k), v));
73+
}
74+
out.push(" }".into());
75+
f.write_str(&out.join(", "))
76+
}
77+
}
78+
79+
impl<T: Display> Display for FmtAstVec<&Vec<T>> {
80+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81+
f.write_str("![")?;
82+
let children = self
83+
.0
84+
.iter()
85+
.map(|n| n.to_string())
86+
.collect::<Vec<_>>()
87+
.join(", ");
88+
f.write_str(&children)?;
89+
f.write_str("]!")
90+
}
91+
}
92+
93+
impl<T: Display> Display for FmtAst<&T> {
94+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95+
T::fmt(self.0, f)
96+
}
97+
}
98+
99+
impl Display for Fmt<&Vec<(Param, crate::Value)>> {
100+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101+
let formatted = self
102+
.0
103+
.iter()
104+
.map(|(p, v)| format!("{}: {}", p, v))
105+
.collect::<Vec<_>>()
106+
.join(", ");
107+
f.write_str(&formatted)
108+
}
109+
}
110+
111+
impl Display for Fmt<&Vec<(EqlValue, &sqltk_parser::ast::Value)>> {
112+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113+
let formatted = self
114+
.0
115+
.iter()
116+
.map(|(e, n)| format!("{}: {}", n, e))
117+
.collect::<Vec<_>>()
118+
.join(", ");
119+
f.write_str(&formatted)
120+
}
121+
}
122+
123+
impl<T: Display> Display for Fmt<Option<T>> {
124+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125+
match &self.0 {
126+
Some(t) => t.fmt(f),
127+
None => Display::fmt("<not initialised>", f),
128+
}
129+
}
130+
}

0 commit comments

Comments
 (0)