Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ strip = "none"
debug = true

[workspace.dependencies]
sqlparser = { version = "^0.52", features = ["bigdecimal", "serde"] }
sqltk = { version = "0.5.0" }
sqltk-parser = { version = "0.52.0" }
thiserror = "2.0.9"
tokio = { version = "1.44", features = ["full"] }
tracing = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion mise.local.example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ CS_LOG__AUTHENTICATION_LEVEL = "info"
CS_LOG__CONTEXT_LEVEL = "info"
CS_LOG__KEYSET_LEVEL = "info"
CS_LOG__PROTOCOL_LEVEL = "info"
CS_LOG__MAPPER_LEVEL = "info"
CS_LOG__MAPPER_LEVEL = "debug"
CS_LOG__SCHEMA_LEVEL = "info"
CS_LOG__CONFIG_LEVEL = "info"
10 changes: 10 additions & 0 deletions mise.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ run = "docker compose rm --stop --force proxy proxy-tls"
alias = ['t', 'ci']
description = "Runs all tests (hygiene, unit, integration)"
run = """
mise run rust:version
mise run test:check
mise run test:format
mise run test:clippy
Expand Down Expand Up @@ -369,6 +370,15 @@ mise --env tls run proxy:down
description = "Runs cargo nextest, skipping integration tests"
run = 'cargo nextest run --no-fail-fast --nocapture -E "not package(cipherstash-proxy-integration)" {{arg(name="test",default="")}}'

[tasks."rust:version"]
description = "Outputs rust toolchain version info"
run = """
echo "rustc --version = " $(rustc --version)
echo "cargo --version = " $(cargo --version)
echo "cargo fmt --version = " $(cargo fmt --version)
echo "cargo clippy --version = " $(cargo clippy --version)
"""

[tasks."test:format"]
description = "Runs cargo fmt"
run = 'cargo fmt --all -- --check'
Expand Down
3 changes: 2 additions & 1 deletion packages/cipherstash-proxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ rustls-pki-types = "1.10.0"
serde = "1.0"
serde_json = "1.0"
socket2 = "0.5.7"
sqlparser = { workspace = true }
sqltk = { workspace = true }
sqltk-parser = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tokio-postgres = { version = "0.7", features = [
Expand Down
2 changes: 1 addition & 1 deletion packages/cipherstash-proxy/src/encrypt/schema/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::error::Error;
use crate::{connect, log::SCHEMA};
use arc_swap::ArcSwap;
use eql_mapper::{Column, Schema, Table};
use sqlparser::ast::Ident;
use sqltk_parser::ast::Ident;
use std::sync::Arc;
use std::time::Duration;
use tokio::{task::JoinHandle, time};
Expand Down
2 changes: 1 addition & 1 deletion packages/cipherstash-proxy/src/eql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use cipherstash_client::{
zerokms::{encrypted_record, EncryptedRecord},
};
use serde::{Deserialize, Serialize};
use sqlparser::ast::Ident;
use sqltk_parser::ast::Ident;

#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct Plaintext {
Expand Down
8 changes: 6 additions & 2 deletions packages/cipherstash-proxy/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{postgresql::Column, Identifier};
use bytes::BytesMut;
use cipherstash_client::encryption;
use eql_mapper::EqlMapperError;
use metrics_exporter_prometheus::BuildError;
use std::{io, time::Duration};
use thiserror::Error;
Expand Down Expand Up @@ -92,6 +93,9 @@ pub enum MappingError {

#[error("Statement encountered an internal error. This may be a bug in the statement mapping module of CipherStash Proxy. Please visit {}#mapping-internal-error for more information.", ERROR_DOC_BASE_URL)]
Internal(String),

#[error(transparent)]
EqlMapper(#[from] EqlMapperError),
}

#[derive(Error, Debug)]
Expand Down Expand Up @@ -320,8 +324,8 @@ impl From<serde_json::Error> for Error {
}
}

impl From<sqlparser::parser::ParserError> for Error {
fn from(e: sqlparser::parser::ParserError) -> Self {
impl From<sqltk_parser::parser::ParserError> for Error {
fn from(e: sqltk_parser::parser::ParserError) -> Self {
Error::Mapping(MappingError::InvalidSqlStatement(e.to_string()))
}
}
Expand Down
2 changes: 1 addition & 1 deletion packages/cipherstash-proxy/src/postgresql/data/from_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use cipherstash_config::ColumnType;
use postgres_types::FromSql;
use postgres_types::Type;
use rust_decimal::Decimal;
use sqlparser::ast::Value;
use sqltk_parser::ast::Value;
use std::str::FromStr;
use tracing::debug;

Expand Down
46 changes: 24 additions & 22 deletions packages/cipherstash-proxy/src/postgresql/frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ use crate::prometheus::{
use crate::Encrypted;
use bytes::BytesMut;
use cipherstash_client::encryption::Plaintext;
use eql_mapper::{self, EqlMapperError, EqlValue, NodeKey, TableColumn, TypedStatement};
use eql_mapper::{self, EqlMapperError, EqlValue, TableColumn, TypeCheckedStatement};
use metrics::{counter, histogram};
use pg_escape::quote_literal;
use serde::Serialize;
use sqlparser::ast::{self, Expr, Value};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
use sqltk::NodeKey;
use sqltk_parser::ast::{self, Value};
use sqltk_parser::dialect::PostgreSqlDialect;
use sqltk_parser::parser::Parser;
use std::collections::HashMap;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
Expand Down Expand Up @@ -285,7 +286,7 @@ where

match self.to_encryptable_statement(&typed_statement, vec![])? {
Some(statement) => {
if statement.has_literals() || typed_statement.has_nodes_to_wrap() {
if typed_statement.requires_transform() {
let encrypted_literals = self
.encrypt_literals(&typed_statement, &statement.literal_columns)
.await?;
Expand Down Expand Up @@ -356,7 +357,7 @@ where
///
async fn encrypt_literals(
&mut self,
typed_statement: &TypedStatement<'_>,
typed_statement: &TypeCheckedStatement<'_>,
literal_columns: &Vec<Option<Column>>,
) -> Result<Vec<Option<Encrypted>>, Error> {
let literal_values = typed_statement.literal_values();
Expand Down Expand Up @@ -402,14 +403,14 @@ where
///
async fn transform_statement(
&mut self,
typed_statement: &TypedStatement<'_>,
typed_statement: &TypeCheckedStatement<'_>,
encrypted_literals: &Vec<Option<Encrypted>>,
) -> Result<Option<ast::Statement>, Error> {
// Convert literals to ast Expr
let mut encrypted_expressions = vec![];
for encrypted in encrypted_literals {
let e = match encrypted {
Some(en) => Some(to_json_literal_expr(&en)?),
Some(en) => Some(to_json_literal_value(&en)?),
None => None,
};
encrypted_expressions.push(e);
Expand All @@ -426,11 +427,10 @@ where

debug!(target: MAPPER,
client_id = self.context.client_id,
nodes_to_wrap = typed_statement.nodes_to_wrap.len(),
literals = encrypted_nodes.len(),
);

if !typed_statement.has_nodes_to_wrap() && encrypted_nodes.is_empty() {
if !typed_statement.requires_transform() {
return Ok(None);
}

Expand Down Expand Up @@ -500,7 +500,7 @@ where

match self.to_encryptable_statement(&typed_statement, param_types)? {
Some(statement) => {
if statement.has_literals() || typed_statement.has_nodes_to_wrap() {
if typed_statement.requires_transform() {
let encrypted_literals = self
.encrypt_literals(&typed_statement, &statement.literal_columns)
.await?;
Expand Down Expand Up @@ -607,7 +607,7 @@ where
///
fn to_encryptable_statement(
&self,
typed_statement: &TypedStatement<'_>,
typed_statement: &TypeCheckedStatement<'_>,
param_types: Vec<i32>,
) -> Result<Option<Statement>, Error> {
let param_columns = self.get_param_columns(typed_statement)?;
Expand All @@ -619,8 +619,7 @@ where

if (param_columns.is_empty() || no_encrypted_param_columns)
&& (projection_columns.is_empty() || no_encrypted_projection_columns)
&& literal_columns.is_empty()
&& !typed_statement.has_nodes_to_wrap()
&& !typed_statement.requires_transform()
{
return Ok(None);
}
Expand Down Expand Up @@ -735,7 +734,10 @@ where
Ok(encrypted)
}

fn type_check<'a>(&self, statement: &'a ast::Statement) -> Result<TypedStatement<'a>, Error> {
fn type_check<'a>(
&self,
statement: &'a ast::Statement,
) -> Result<TypeCheckedStatement<'a>, Error> {
match eql_mapper::type_check(self.context.get_table_resolver(), statement) {
Ok(typed_statement) => {
debug!(target: MAPPER,
Expand Down Expand Up @@ -778,10 +780,10 @@ where
///
fn get_projection_columns(
&self,
typed_statement: &eql_mapper::TypedStatement<'_>,
typed_statement: &eql_mapper::TypeCheckedStatement<'_>,
) -> Result<Vec<Option<Column>>, Error> {
let mut projection_columns = vec![];
if let Some(eql_mapper::Projection::WithColumns(columns)) = &typed_statement.projection {
if let eql_mapper::Projection::WithColumns(columns) = &typed_statement.projection {
for col in columns {
let eql_mapper::ProjectionColumn { ty, .. } = col;
let configured_column = match ty {
Expand Down Expand Up @@ -813,13 +815,13 @@ where
///
fn get_param_columns(
&self,
typed_statement: &eql_mapper::TypedStatement<'_>,
typed_statement: &eql_mapper::TypeCheckedStatement<'_>,
) -> Result<Vec<Option<Column>>, Error> {
let mut param_columns = vec![];

for param in typed_statement.params.iter() {
let configured_column = match param {
eql_mapper::Value::Eql(EqlValue(TableColumn { table, column })) => {
(_, eql_mapper::Value::Eql(EqlValue(TableColumn { table, column }))) => {
let identifier = Identifier::from((table, column));

debug!(
Expand All @@ -841,7 +843,7 @@ where

fn get_literal_columns(
&self,
typed_statement: &eql_mapper::TypedStatement<'_>,
typed_statement: &eql_mapper::TypeCheckedStatement<'_>,
) -> Result<Vec<Option<Column>>, Error> {
let mut literal_columns = vec![];

Expand Down Expand Up @@ -967,9 +969,9 @@ fn literals_to_plaintext(
Ok(plaintexts)
}

fn to_json_literal_expr<T>(literal: &T) -> Result<Expr, Error>
fn to_json_literal_value<T>(literal: &T) -> Result<Value, Error>
where
T: ?Sized + Serialize,
{
Ok(serde_json::to_string(literal).map(|json| Expr::Value(Value::SingleQuotedString(json)))?)
Ok(serde_json::to_string(literal).map(Value::SingleQuotedString)?)
}
12 changes: 12 additions & 0 deletions packages/eql-mapper-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "eql-mapper-macros"
version.workspace = true
edition.workspace = true

[lib]
proc-macro = true

[dependencies]
syn = { version = "2.0", features = ["full"] }
quote = "1.0"
proc-macro2 = "1.0"
Loading