Skip to content

Commit f6c0e22

Browse files
committed
WIP
1 parent 06f90cd commit f6c0e22

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+2513
-1826
lines changed

Cargo.lock

Lines changed: 34 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ strip = "none"
3535
debug = true
3636

3737
[workspace.dependencies]
38-
sqlparser = { version = "^0.52", features = ["bigdecimal", "serde"] }
38+
sqltk = { version = "0.5.0" }
39+
sqltk-parser = { version = "0.52.0" }
3940
thiserror = "2.0.9"
4041
tokio = { version = "1.42.0", features = ["full"] }
4142
tracing = "0.1"

packages/cipherstash-proxy/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ rustls-pki-types = "1.10.0"
4040
serde = "1.0"
4141
serde_json = "1.0"
4242
socket2 = "0.5.7"
43-
sqlparser = { workspace = true }
43+
sqltk = { workspace = true }
44+
sqltk-parser = { workspace = true }
4445
thiserror = { workspace = true }
4546
tokio = { workspace = true }
4647
tokio-postgres = { version = "0.7", features = [

packages/cipherstash-proxy/src/encrypt/schema/manager.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::error::Error;
44
use crate::{connect, log::SCHEMA};
55
use arc_swap::ArcSwap;
66
use eql_mapper::{Column, Schema, Table};
7-
use sqlparser::ast::Ident;
7+
use sqltk_parser::ast::Ident;
88
use std::sync::Arc;
99
use std::time::Duration;
1010
use tokio::{task::JoinHandle, time};

packages/cipherstash-proxy/src/eql/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use cipherstash_client::{
33
zerokms::{encrypted_record, EncryptedRecord},
44
};
55
use serde::{Deserialize, Serialize};
6-
use sqlparser::ast::Ident;
6+
use sqltk_parser::ast::Ident;
77

88
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
99
pub struct Plaintext {

packages/cipherstash-proxy/src/error.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,8 @@ impl From<serde_json::Error> for Error {
320320
}
321321
}
322322

323-
impl From<sqlparser::parser::ParserError> for Error {
324-
fn from(e: sqlparser::parser::ParserError) -> Self {
323+
impl From<sqltk_parser::parser::ParserError> for Error {
324+
fn from(e: sqltk_parser::parser::ParserError) -> Self {
325325
Error::Mapping(MappingError::InvalidSqlStatement(e.to_string()))
326326
}
327327
}

packages/cipherstash-proxy/src/postgresql/data/from_sql.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use cipherstash_config::ColumnType;
1111
use postgres_types::FromSql;
1212
use postgres_types::Type;
1313
use rust_decimal::Decimal;
14-
use sqlparser::ast::Value;
14+
use sqltk_parser::ast::Value;
1515
use std::str::FromStr;
1616
use tracing::debug;
1717

packages/cipherstash-proxy/src/postgresql/frontend.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@ use crate::prometheus::{
2626
use crate::Encrypted;
2727
use bytes::BytesMut;
2828
use cipherstash_client::encryption::Plaintext;
29-
use eql_mapper::{self, EqlMapperError, EqlValue, NodeKey, TableColumn, TypedStatement};
29+
use eql_mapper::{self, EqlMapperError, EqlValue, TableColumn, TypedStatement};
3030
use metrics::{counter, histogram};
3131
use pg_escape::quote_literal;
3232
use serde::Serialize;
33-
use sqlparser::ast::{self, Expr, Value};
34-
use sqlparser::dialect::PostgreSqlDialect;
35-
use sqlparser::parser::Parser;
33+
use sqltk_parser::ast::{self, Value};
34+
use sqltk_parser::dialect::PostgreSqlDialect;
35+
use sqltk_parser::parser::Parser;
36+
use sqltk::AsNodeKey;
3637
use std::collections::HashMap;
3738
use std::time::Instant;
3839
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@@ -285,7 +286,7 @@ where
285286

286287
match self.to_encryptable_statement(&typed_statement, vec![])? {
287288
Some(statement) => {
288-
if statement.has_literals() || typed_statement.has_nodes_to_wrap() {
289+
if statement.has_literals() {
289290
let encrypted_literals = self
290291
.encrypt_literals(&typed_statement, &statement.literal_columns)
291292
.await?;
@@ -421,16 +422,15 @@ where
421422
.literals
422423
.iter()
423424
.zip(encrypted_expressions.into_iter())
424-
.filter_map(|((_, original_node), en)| en.map(|en| (NodeKey::new(*original_node), en)))
425+
.filter_map(|((_, original_node), en)| en.map(|en| (original_node.as_node_key(), en)))
425426
.collect::<HashMap<_, _>>();
426427

427428
debug!(target: MAPPER,
428429
client_id = self.context.client_id,
429-
nodes_to_wrap = typed_statement.nodes_to_wrap.len(),
430430
literals = encrypted_nodes.len(),
431431
);
432432

433-
if !typed_statement.has_nodes_to_wrap() && encrypted_nodes.is_empty() {
433+
if encrypted_nodes.is_empty() {
434434
return Ok(None);
435435
}
436436

@@ -500,7 +500,7 @@ where
500500

501501
match self.to_encryptable_statement(&typed_statement, param_types)? {
502502
Some(statement) => {
503-
if statement.has_literals() || typed_statement.has_nodes_to_wrap() {
503+
if statement.has_literals() {
504504
let encrypted_literals = self
505505
.encrypt_literals(&typed_statement, &statement.literal_columns)
506506
.await?;
@@ -620,7 +620,6 @@ where
620620
if (param_columns.is_empty() || no_encrypted_param_columns)
621621
&& (projection_columns.is_empty() || no_encrypted_projection_columns)
622622
&& literal_columns.is_empty()
623-
&& !typed_statement.has_nodes_to_wrap()
624623
{
625624
return Ok(None);
626625
}
@@ -781,7 +780,7 @@ where
781780
typed_statement: &eql_mapper::TypedStatement<'_>,
782781
) -> Result<Vec<Option<Column>>, Error> {
783782
let mut projection_columns = vec![];
784-
if let Some(eql_mapper::Projection::WithColumns(columns)) = &typed_statement.projection {
783+
if let eql_mapper::Projection::WithColumns(columns) = &typed_statement.projection {
785784
for col in columns {
786785
let eql_mapper::ProjectionColumn { ty, .. } = col;
787786
let configured_column = match ty {
@@ -819,7 +818,7 @@ where
819818

820819
for param in typed_statement.params.iter() {
821820
let configured_column = match param {
822-
eql_mapper::Value::Eql(EqlValue(TableColumn { table, column })) => {
821+
(_, eql_mapper::Value::Eql(EqlValue(TableColumn { table, column }))) => {
823822
let identifier = Identifier::from((table, column));
824823

825824
debug!(
@@ -967,9 +966,9 @@ fn literals_to_plaintext(
967966
Ok(plaintexts)
968967
}
969968

970-
fn to_json_literal_expr<T>(literal: &T) -> Result<Expr, Error>
969+
fn to_json_literal_expr<T>(literal: &T) -> Result<Value, Error>
971970
where
972971
T: ?Sized + Serialize,
973972
{
974-
Ok(serde_json::to_string(literal).map(|json| Expr::Value(Value::SingleQuotedString(json)))?)
973+
Ok(serde_json::to_string(literal).map(Value::SingleQuotedString)?)
975974
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[package]
2+
name = "eql-mapper-macros"
3+
version.workspace = true
4+
edition.workspace = true
5+
6+
[lib]
7+
proc-macro = true
8+
9+
[dependencies]
10+
syn = { version = "2.0", features = ["full"] }
11+
quote = "1.0"
12+
proc-macro2 = "1.0"
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
use proc_macro::TokenStream;
2+
use quote::{quote, ToTokens};
3+
use syn::{
4+
parse::Parse, parse_macro_input, parse_quote, Attribute, FnArg, Ident, ImplItem, ImplItemFn,
5+
ItemImpl, Pat, PatType, Signature, Type, TypePath, TypeReference,
6+
};
7+
8+
/// This macro generates consistently defined `#[tracing::instrument]` attributes for `InferType::infer_enter` &
9+
/// `InferType::infer_enter` implementations on `TypeInferencer`.
10+
///
11+
/// This attribute MUST be defined on the trait `impl` itself (not the trait method impls).
12+
#[proc_macro_attribute]
13+
pub fn trace_infer(_attr: TokenStream, item: TokenStream) -> TokenStream {
14+
let mut input = parse_macro_input!(item as ItemImpl);
15+
16+
for item in &mut input.items {
17+
if let ImplItem::Fn(ImplItemFn {
18+
attrs,
19+
sig:
20+
Signature {
21+
ident: method,
22+
inputs,
23+
..
24+
},
25+
..
26+
}) = item
27+
{
28+
let node_ident_and_type: Option<(&Ident, &Type)> =
29+
if let Some(FnArg::Typed(PatType {
30+
ty: node_ty, pat, ..
31+
})) = inputs.get(1)
32+
{
33+
if let Pat::Ident(pat_ident) = &**pat {
34+
Some((&pat_ident.ident, node_ty))
35+
} else {
36+
None
37+
}
38+
} else {
39+
None
40+
};
41+
42+
let vec_ident: Ident = parse_quote!(Vec);
43+
44+
match node_ident_and_type {
45+
Some((node_ident, node_ty)) => {
46+
let (formatter, node_ty_abbrev) = match node_ty {
47+
Type::Reference(TypeReference { elem, .. }) => match &**elem {
48+
Type::Path(TypePath { path, .. }) => {
49+
let last_segment = path.segments.last().unwrap();
50+
let last_segment_ident = &last_segment.ident;
51+
let last_segment_arguments = if last_segment.arguments.is_empty() {
52+
None
53+
} else {
54+
let args = &last_segment.arguments;
55+
Some(quote!(<#args>))
56+
};
57+
match last_segment_ident {
58+
ident if vec_ident == *ident => {
59+
(quote!(crate::FmtAstVec), quote!(#last_segment_ident #last_segment_arguments))
60+
}
61+
_ => (quote!(crate::FmtAst), quote!(#last_segment_ident #last_segment_arguments))
62+
}
63+
},
64+
_ => unreachable!("Infer::infer_enter/infer_exit has sig: infer_..(&mut self, delete: &'ast N) -> Result<(), TypeError>")
65+
},
66+
_ => unreachable!("Infer::infer_enter/infer_exit has sig: infer_..(&mut self, delete: &'ast N) -> Result<(), TypeError>")
67+
};
68+
69+
let node_ty_abbrev = node_ty_abbrev
70+
.to_token_stream()
71+
.to_string()
72+
.replace(" ", "");
73+
74+
let target = format!("eql-mapper::{}", method.to_string().to_uppercase());
75+
76+
let attr: TracingInstrumentAttr = syn::parse2(quote! {
77+
#[tracing::instrument(
78+
target = #target,
79+
level = "trace",
80+
skip(self, #node_ident),
81+
fields(
82+
ast_ty = #node_ty_abbrev,
83+
ast = %#formatter(#node_ident),
84+
),
85+
ret(Debug)
86+
)]
87+
})
88+
.unwrap();
89+
attrs.push(attr.attr);
90+
}
91+
None => {
92+
return quote!(compile_error!(
93+
"could not determine name of node argumemt in Infer impl"
94+
))
95+
.to_token_stream()
96+
.into();
97+
}
98+
}
99+
}
100+
}
101+
102+
input.to_token_stream().into()
103+
}
104+
105+
struct TracingInstrumentAttr {
106+
attr: Attribute,
107+
}
108+
109+
impl Parse for TracingInstrumentAttr {
110+
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
111+
Ok(Self {
112+
attr: Attribute::parse_outer(input)?.first().unwrap().clone(),
113+
})
114+
}
115+
}

0 commit comments

Comments
 (0)