Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/ast/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ use sqlparser_derive::{Visit, VisitMut};
/// Primitive SQL values such as number and string
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
#[cfg_attr(
feature = "visitor",
derive(Visit, VisitMut),
visit(with = "visit_value")
)]

pub enum Value {
/// Numeric literal
#[cfg(not(feature = "bigdecimal"))]
Expand Down
99 changes: 92 additions & 7 deletions src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.

use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor};
use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor, Value};
use core::ops::ControlFlow;

/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
Expand Down Expand Up @@ -233,6 +233,16 @@ pub trait Visitor {
fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any Value that appear in the AST before visiting children
fn pre_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any Value that appear in the AST after visiting children
fn post_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
}

/// A visitor that can be used to mutate an AST tree.
Expand Down Expand Up @@ -337,6 +347,16 @@ pub trait VisitorMut {
fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any value that appear in the AST before visiting children
fn pre_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any statements that appear in the AST after visiting children
fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
}

struct RelationVisitor<F>(F);
Expand Down Expand Up @@ -647,6 +667,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::Statement;
use crate::dialect::GenericDialect;
use crate::parser::Parser;
use crate::tokenizer::Tokenizer;
Expand Down Expand Up @@ -720,17 +741,16 @@ mod tests {
}
}

fn do_visit(sql: &str) -> Vec<String> {
fn do_visit<V: Visitor>(sql: &str, visitor: &mut V) -> Statement {
let dialect = GenericDialect {};
let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
let s = Parser::new(&dialect)
.with_tokens(tokens)
.parse_statement()
.unwrap();

let mut visitor = TestVisitor::default();
s.visit(&mut visitor);
visitor.visited
s.visit(visitor);
s
}

#[test]
Expand Down Expand Up @@ -889,8 +909,9 @@ mod tests {
),
];
for (sql, expected) in tests {
let actual = do_visit(sql);
let actual: Vec<_> = actual.iter().map(|x| x.as_str()).collect();
let mut visitor = TestVisitor::default();
let _ = do_visit(sql, &mut visitor);
let actual: Vec<_> = visitor.visited.iter().map(|x| x.as_str()).collect();
assert_eq!(actual, expected)
}
}
Expand Down Expand Up @@ -920,3 +941,67 @@ mod tests {
s.visit(&mut visitor);
}
}

#[cfg(test)]
mod visit_mut_tests {
use crate::ast::{Statement, Value, VisitMut, VisitorMut};
use crate::dialect::GenericDialect;
use crate::parser::Parser;
use crate::tokenizer::Tokenizer;
use core::ops::ControlFlow;

#[derive(Default)]
struct MutatorVisitor {
index: u64,
}

impl VisitorMut for MutatorVisitor {
type Break = ();

fn pre_visit_value(&mut self, value: &mut Value) -> ControlFlow<Self::Break> {
self.index += 1;
*value = Value::SingleQuotedString(format!("REDACTED_{}", self.index));
ControlFlow::Continue(())
}

fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
}

fn do_visit_mut<V: VisitorMut>(sql: &str, visitor: &mut V) -> Statement {
let dialect = GenericDialect {};
let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
let mut s = Parser::new(&dialect)
.with_tokens(tokens)
.parse_statement()
.unwrap();

s.visit(visitor);
s
}

#[test]
fn test_value_redact() {
let tests = vec![
(
concat!(
"SELECT * FROM monthly_sales ",
"PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
"ORDER BY EMPID"
),
concat!(
"SELECT * FROM monthly_sales ",
"PIVOT(SUM(a.amount) FOR a.MONTH IN ('REDACTED_1', 'REDACTED_2', 'REDACTED_3', 'REDACTED_4')) AS p (c, d) ",
"ORDER BY EMPID"
),
),
];

for (sql, expected) in tests {
let mut visitor = MutatorVisitor::default();
let mutated = do_visit_mut(sql, &mut visitor);
assert_eq!(mutated.to_string(), expected)
}
}
}