Skip to content

Commit 5a46c38

Browse files
committed
feat: TransformationRule dry-run support
The Proxy must be able to reliably detect whether a statement requires transformation because if a statement does not require transformation then the potentially expensive AST rebuilding step can be skipped. The result of type-checking is insufficient in general to tell whether a statement requires transformation unless the `TransformationRule` logic is duplicated in the Proxy - which we don't want of course. This commit extends the `TransformationRule` trait with a `would_edit` method which answers the question "would this rule change the AST if it was applied?". Additionally, a new `TranformationRule` impl `DryRunnable` wraps another rule in such a way that it can "pretend" to be performing a `Transform` (as far as `sqltk` is concerned) when really its doing a dry-run after which it will tell us if it *would* change the AST. This is all wrapped up by `TypedStatement::transform` which now returns `Result<Cow<Statement>, _>`. When the statement does not need to be modified it returns `Cow::Borrowed(unmodified_statement)` and when the statement *has* been modified it returns a `Cow::Owned(modified_statement)` The Proxy `frontend` can now distinguish between the two cases via the `Cow`, and manage its passthrough counter correctly without having to be coupled to implementation details of the EQL Mapper.
1 parent dc01dea commit 5a46c38

13 files changed

+494
-175
lines changed

mise.local.example.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ CS_LOG__AUTHENTICATION_LEVEL = "info"
4242
CS_LOG__CONTEXT_LEVEL = "info"
4343
CS_LOG__KEYSET_LEVEL = "info"
4444
CS_LOG__PROTOCOL_LEVEL = "info"
45-
CS_LOG__MAPPER_LEVEL = "info"
45+
CS_LOG__MAPPER_LEVEL = "debug"
4646
CS_LOG__SCHEMA_LEVEL = "info"
4747
CS_LOG__CONFIG_LEVEL = "info"

packages/cipherstash-proxy/src/config/tandem.rs

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -426,36 +426,40 @@ mod tests {
426426

427427
#[test]
428428
fn prometheus_config() {
429-
let config = TandemConfig::build("tests/config/cipherstash-proxy-test.toml").unwrap();
430-
assert!(!config.prometheus_enabled());
431-
432-
temp_env::with_vars([("CS_PROMETHEUS__ENABLED", Some("true"))], || {
433-
let config = TandemConfig::build("tests/config/cipherstash-proxy-test.toml").unwrap();
434-
assert!(config.prometheus_enabled());
435-
assert!(config.prometheus.enabled);
436-
assert_eq!(config.prometheus.port, 9930);
437-
});
438-
439-
temp_env::with_vars([("CS_PROMETHEUS__PORT", Some("7777"))], || {
429+
temp_env::with_vars_unset(["CS_PROMETHEUS__ENABLED"], || {
440430
let config = TandemConfig::build("tests/config/cipherstash-proxy-test.toml").unwrap();
441431
assert!(!config.prometheus_enabled());
442-
assert!(!config.prometheus.enabled);
443-
assert_eq!(config.prometheus.port, 7777);
444-
});
445432

446-
temp_env::with_vars(
447-
[
448-
("CS_PROMETHEUS__ENABLED", Some("true")),
449-
("CS_PROMETHEUS__PORT", Some("7777")),
450-
],
451-
|| {
433+
temp_env::with_vars([("CS_PROMETHEUS__ENABLED", Some("true"))], || {
452434
let config =
453435
TandemConfig::build("tests/config/cipherstash-proxy-test.toml").unwrap();
454436
assert!(config.prometheus_enabled());
455437
assert!(config.prometheus.enabled);
438+
assert_eq!(config.prometheus.port, 9930);
439+
});
440+
441+
temp_env::with_vars([("CS_PROMETHEUS__PORT", Some("7777"))], || {
442+
let config =
443+
TandemConfig::build("tests/config/cipherstash-proxy-test.toml").unwrap();
444+
assert!(!config.prometheus_enabled());
445+
assert!(!config.prometheus.enabled);
456446
assert_eq!(config.prometheus.port, 7777);
457-
},
458-
);
447+
});
448+
449+
temp_env::with_vars(
450+
[
451+
("CS_PROMETHEUS__ENABLED", Some("true")),
452+
("CS_PROMETHEUS__PORT", Some("7777")),
453+
],
454+
|| {
455+
let config =
456+
TandemConfig::build("tests/config/cipherstash-proxy-test.toml").unwrap();
457+
assert!(config.prometheus_enabled());
458+
assert!(config.prometheus.enabled);
459+
assert_eq!(config.prometheus.port, 7777);
460+
},
461+
);
462+
});
459463
}
460464

461465
#[test]

packages/cipherstash-proxy/src/postgresql/frontend.rs

Lines changed: 61 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use sqltk::AsNodeKey;
3434
use sqltk_parser::ast::{self, Value};
3535
use sqltk_parser::dialect::PostgreSqlDialect;
3636
use sqltk_parser::parser::Parser;
37+
use std::borrow::Cow;
3738
use std::collections::HashMap;
3839
use std::time::Instant;
3940
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@@ -255,7 +256,7 @@ where
255256

256257
// Simple Query may contain many statements
257258
let parsed_statements = self.parse_statements(&query.statement)?;
258-
let mut transformed_statements = vec![];
259+
let mut transformed_statements: Vec<Cow<'_, ast::Statement>> = vec![];
259260

260261
debug!(target: MAPPER,
261262
client_id = self.context.client_id,
@@ -265,15 +266,15 @@ where
265266
let mut portal = Portal::passthrough();
266267
let mut encrypted = false;
267268

268-
for statement in parsed_statements {
269-
self.check_for_schema_change(&statement);
269+
for statement in parsed_statements.iter() {
270+
self.check_for_schema_change(statement);
270271

271-
if !eql_mapper::requires_type_check(&statement) {
272+
if !eql_mapper::requires_type_check(statement) {
272273
counter!(STATEMENTS_PASSTHROUGH_TOTAL).increment(1);
273274
continue;
274275
}
275276

276-
let typed_statement = match self.type_check(&statement) {
277+
let typed_statement = match self.type_check(statement) {
277278
Ok(ts) => ts,
278279
Err(err) => {
279280
if self.encrypt.config.mapping_errors_enabled() {
@@ -284,25 +285,25 @@ where
284285
}
285286
};
286287

287-
match self.to_encryptable_statement(&typed_statement, vec![])? {
288-
Some(statement) => {
289-
let encrypted_literals = self
290-
.encrypt_literals(&typed_statement, &statement.literal_columns)
291-
.await?;
292-
293-
if let Some(transformed_statement) = self
294-
.transform_statement(&typed_statement, &encrypted_literals)
295-
.await?
296-
{
297-
debug!(target: MAPPER,
298-
client_id = self.context.client_id,
299-
transformed_statement = ?transformed_statement,
300-
transformed_statement_text = %transformed_statement,
301-
);
288+
let statement = self.to_encryptable_statement(&typed_statement, vec![])?;
289+
let encrypted_literals = self
290+
.encrypt_literals(&typed_statement, &statement.literal_columns)
291+
.await?;
302292

303-
transformed_statements.push(transformed_statement);
304-
encrypted = true;
305-
}
293+
let transformed_statement = self
294+
.transform_statement(&typed_statement, &encrypted_literals)
295+
.await?;
296+
297+
debug!(target: MAPPER,
298+
client_id = self.context.client_id,
299+
transformed_statement = ?transformed_statement,
300+
transformed_statement_text = %transformed_statement,
301+
transform_was_noop = matches!(transformed_statement, Cow::Borrowed(_)),
302+
);
303+
304+
match &transformed_statement {
305+
Cow::Owned(_) => {
306+
encrypted = true;
306307
debug!(target: MAPPER,
307308
client_id = self.context.client_id,
308309
msg = "Encrypted Statement"
@@ -312,15 +313,16 @@ where
312313
// Set Encrypted portal
313314
portal = Portal::encrypted_text();
314315
}
315-
None => {
316+
Cow::Borrowed(_) => {
316317
debug!(target: MAPPER,
317318
client_id = self.context.client_id,
318319
msg = "Passthrough Statement"
319320
);
320321
counter!(STATEMENTS_PASSTHROUGH_TOTAL).increment(1);
321-
transformed_statements.push(statement);
322322
}
323-
};
323+
}
324+
325+
transformed_statements.push(transformed_statement);
324326
}
325327

326328
self.context.add_portal(Name::unnamed(), portal);
@@ -400,11 +402,11 @@ where
400402
/// - rewrites any encrypted literal values
401403
/// - wraps any nodes in appropriate EQL function
402404
///
403-
async fn transform_statement(
405+
async fn transform_statement<'ast>(
404406
&mut self,
405-
typed_statement: &TypedStatement<'_>,
407+
typed_statement: &TypedStatement<'ast>,
406408
encrypted_literals: &Vec<Option<Encrypted>>,
407-
) -> Result<Option<ast::Statement>, Error> {
409+
) -> Result<Cow<'ast, ast::Statement>, Error> {
408410
// Convert literals to ast Expr
409411
let mut encrypted_expressions = vec![];
410412
for encrypted in encrypted_literals {
@@ -429,11 +431,9 @@ where
429431
literals = encrypted_nodes.len(),
430432
);
431433

432-
let transformed_statement = typed_statement
434+
Ok(typed_statement
433435
.transform(encrypted_nodes)
434-
.map_err(|e| MappingError::StatementCouldNotBeTransformed(e.to_string()))?;
435-
436-
Ok(Some(transformed_statement))
436+
.map_err(|e| MappingError::StatementCouldNotBeTransformed(e.to_string()))?)
437437
}
438438

439439
///
@@ -493,39 +493,37 @@ where
493493
// These override the underlying column type
494494
let param_types = message.param_types.clone();
495495

496-
match self.to_encryptable_statement(&typed_statement, param_types)? {
497-
Some(statement) => {
498-
let encrypted_literals = self
499-
.encrypt_literals(&typed_statement, &statement.literal_columns)
500-
.await?;
496+
let statement = self.to_encryptable_statement(&typed_statement, param_types)?;
497+
let encrypted_literals = self
498+
.encrypt_literals(&typed_statement, &statement.literal_columns)
499+
.await?;
501500

502-
if let Some(transformed_statement) = self
503-
.transform_statement(&typed_statement, &encrypted_literals)
504-
.await?
505-
{
506-
debug!(target: MAPPER,
507-
client_id = self.context.client_id,
508-
transformed_statement = ?transformed_statement,
509-
transformed_statement_text = %transformed_statement,
510-
);
501+
let transformed_statement = self
502+
.transform_statement(&typed_statement, &encrypted_literals)
503+
.await?;
511504

512-
message.rewrite_statement(transformed_statement.to_string());
513-
}
505+
debug!(target: MAPPER,
506+
client_id = self.context.client_id,
507+
transformed_statement = ?transformed_statement,
508+
transformed_statement_text = %transformed_statement,
509+
transform_was_noop = matches!(transformed_statement, Cow::Borrowed(_)),
510+
);
514511

515-
counter!(STATEMENTS_ENCRYPTED_TOTAL).increment(1);
512+
if let Cow::Owned(transformed_statement) = transformed_statement {
513+
message.rewrite_statement(transformed_statement.to_string());
514+
counter!(STATEMENTS_ENCRYPTED_TOTAL).increment(1);
516515

517-
message.rewrite_param_types(&statement.param_columns);
518-
self.context
519-
.add_statement(message.name.to_owned(), statement);
520-
}
521-
_ => {
522-
debug!(target: MAPPER,
523-
client_id = self.context.client_id,
524-
msg = "Passthrough Parse"
525-
);
526-
counter!(STATEMENTS_PASSTHROUGH_TOTAL).increment(1);
527-
}
516+
message.rewrite_param_types(&statement.param_columns);
517+
self.context
518+
.add_statement(message.name.to_owned(), statement);
519+
} else {
520+
debug!(target: MAPPER,
521+
client_id = self.context.client_id,
522+
msg = "Passthrough Parse"
523+
);
524+
counter!(STATEMENTS_PASSTHROUGH_TOTAL).increment(1);
528525
}
526+
529527
let bytes = BytesMut::try_from(message)?;
530528

531529
debug!(target: MAPPER,
@@ -598,7 +596,7 @@ where
598596
&self,
599597
typed_statement: &TypedStatement<'_>,
600598
param_types: Vec<i32>,
601-
) -> Result<Option<Statement>, Error> {
599+
) -> Result<Statement, Error> {
602600
let param_columns = self.get_param_columns(typed_statement)?;
603601
let projection_columns = self.get_projection_columns(typed_statement)?;
604602
let literal_columns = self.get_literal_columns(typed_statement)?;
@@ -618,7 +616,7 @@ where
618616
param_types,
619617
);
620618

621-
Ok(Some(statement))
619+
Ok(statement)
622620
}
623621

624622
///

packages/eql-mapper/src/encrypted_statement.rs

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,25 @@ use std::{collections::HashMap, sync::Arc};
33
use sqltk::{NodeKey, NodePath, Transform, Visitable};
44

55
use crate::{
6-
EqlMapperError, FailOnPlaceholderChange, GroupByEqlCol, PreserveEffectiveAliases,
7-
ReplacePlaintextEqlLiterals, TransformationRule, Type, UseEquivalentSqlFuncForEqlTypes,
8-
WrapEqlColsInOrderByWithOreFn, WrapGroupedEqlColInAggregateFn,
6+
DryRunnable, EqlMapperError, FailOnPlaceholderChange, GroupByEqlCol, Mode,
7+
PreserveEffectiveAliases, ReplacePlaintextEqlLiterals, TransformationRule, Type,
8+
UseEquivalentSqlFuncForEqlTypes, WrapEqlColsInOrderByWithOreFn, WrapGroupedEqlColInAggregateFn,
99
};
1010

1111
#[derive(Debug)]
1212
pub(crate) struct EncryptedStatement<'ast> {
13-
transformation_rules: (
14-
WrapGroupedEqlColInAggregateFn<'ast>,
15-
GroupByEqlCol<'ast>,
16-
WrapEqlColsInOrderByWithOreFn<'ast>,
17-
PreserveEffectiveAliases,
18-
ReplacePlaintextEqlLiterals<'ast>,
19-
UseEquivalentSqlFuncForEqlTypes<'ast>,
20-
FailOnPlaceholderChange,
21-
),
13+
transformation_rules: DryRunnable<
14+
'ast,
15+
(
16+
WrapGroupedEqlColInAggregateFn<'ast>,
17+
GroupByEqlCol<'ast>,
18+
WrapEqlColsInOrderByWithOreFn<'ast>,
19+
PreserveEffectiveAliases,
20+
ReplacePlaintextEqlLiterals<'ast>,
21+
UseEquivalentSqlFuncForEqlTypes<'ast>,
22+
FailOnPlaceholderChange<'ast>,
23+
),
24+
>,
2225
}
2326

2427
impl<'ast> EncryptedStatement<'ast> {
@@ -27,17 +30,25 @@ impl<'ast> EncryptedStatement<'ast> {
2730
node_types: Arc<HashMap<NodeKey<'ast>, Type>>,
2831
) -> Self {
2932
Self {
30-
transformation_rules: (
33+
transformation_rules: DryRunnable::new((
3134
WrapGroupedEqlColInAggregateFn::new(Arc::clone(&node_types)),
3235
GroupByEqlCol::new(Arc::clone(&node_types)),
3336
WrapEqlColsInOrderByWithOreFn::new(Arc::clone(&node_types)),
3437
PreserveEffectiveAliases,
3538
ReplacePlaintextEqlLiterals::new(encrypted_literals),
3639
UseEquivalentSqlFuncForEqlTypes::new(Arc::clone(&node_types)),
37-
FailOnPlaceholderChange,
38-
),
40+
FailOnPlaceholderChange::new(),
41+
)),
3942
}
4043
}
44+
45+
pub(crate) fn set_mode(&mut self, new_mode: Mode) {
46+
self.transformation_rules.set_mode(new_mode);
47+
}
48+
49+
pub(crate) fn did_edit(&self) -> bool {
50+
self.transformation_rules.did_edit()
51+
}
4152
}
4253

4354
/// Applies all of the transormation rules from the `EncryptedStatement`.

0 commit comments

Comments
 (0)