|
17 | 17 |
|
18 | 18 | //! Recursive visitors for ast Nodes. See [`Visitor`] for more details. |
19 | 19 |
|
20 | | -use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor}; |
| 20 | +use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor, Value}; |
21 | 21 | use core::ops::ControlFlow; |
22 | 22 |
|
23 | 23 | /// A type that can be visited by a [`Visitor`]. See [`Visitor`] for |
@@ -233,6 +233,16 @@ pub trait Visitor { |
233 | 233 | fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> { |
234 | 234 | ControlFlow::Continue(()) |
235 | 235 | } |
| 236 | + |
| 237 | + /// Invoked for any Value that appear in the AST before visiting children |
| 238 | + fn pre_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> { |
| 239 | + ControlFlow::Continue(()) |
| 240 | + } |
| 241 | + |
| 242 | + /// Invoked for any Value that appear in the AST after visiting children |
| 243 | + fn post_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> { |
| 244 | + ControlFlow::Continue(()) |
| 245 | + } |
236 | 246 | } |
237 | 247 |
|
238 | 248 | /// A visitor that can be used to mutate an AST tree. |
@@ -337,6 +347,16 @@ pub trait VisitorMut { |
337 | 347 | fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> { |
338 | 348 | ControlFlow::Continue(()) |
339 | 349 | } |
| 350 | + |
| 351 | + /// Invoked for any value that appear in the AST before visiting children |
| 352 | + fn pre_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> { |
| 353 | + ControlFlow::Continue(()) |
| 354 | + } |
| 355 | + |
| 356 | + /// Invoked for any statements that appear in the AST after visiting children |
| 357 | + fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> { |
| 358 | + ControlFlow::Continue(()) |
| 359 | + } |
340 | 360 | } |
341 | 361 |
|
342 | 362 | struct RelationVisitor<F>(F); |
@@ -647,6 +667,7 @@ where |
647 | 667 | #[cfg(test)] |
648 | 668 | mod tests { |
649 | 669 | use super::*; |
| 670 | + use crate::ast::Statement; |
650 | 671 | use crate::dialect::GenericDialect; |
651 | 672 | use crate::parser::Parser; |
652 | 673 | use crate::tokenizer::Tokenizer; |
@@ -720,17 +741,16 @@ mod tests { |
720 | 741 | } |
721 | 742 | } |
722 | 743 |
|
723 | | - fn do_visit(sql: &str) -> Vec<String> { |
| 744 | + fn do_visit<V: Visitor>(sql: &str, visitor: &mut V) -> Statement { |
724 | 745 | let dialect = GenericDialect {}; |
725 | 746 | let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap(); |
726 | 747 | let s = Parser::new(&dialect) |
727 | 748 | .with_tokens(tokens) |
728 | 749 | .parse_statement() |
729 | 750 | .unwrap(); |
730 | 751 |
|
731 | | - let mut visitor = TestVisitor::default(); |
732 | | - s.visit(&mut visitor); |
733 | | - visitor.visited |
| 752 | + s.visit(visitor); |
| 753 | + s |
734 | 754 | } |
735 | 755 |
|
736 | 756 | #[test] |
@@ -889,34 +909,76 @@ mod tests { |
889 | 909 | ), |
890 | 910 | ]; |
891 | 911 | for (sql, expected) in tests { |
892 | | - let actual = do_visit(sql); |
893 | | - let actual: Vec<_> = actual.iter().map(|x| x.as_str()).collect(); |
| 912 | + let mut visitor = TestVisitor::default(); |
| 913 | + let _ = do_visit(sql, &mut visitor); |
| 914 | + let actual: Vec<_> = visitor.visited.iter().map(|x| x.as_str()).collect(); |
894 | 915 | assert_eq!(actual, expected) |
895 | 916 | } |
896 | 917 | } |
| 918 | +} |
897 | 919 |
|
898 | | - struct QuickVisitor; // [`TestVisitor`] is too slow to iterate over thousands of nodes |
| 920 | +#[cfg(test)] |
| 921 | +mod visit_mut_tests { |
| 922 | + use core::ops::ControlFlow; |
| 923 | + use crate::dialect::GenericDialect; |
| 924 | + use crate::tokenizer::Tokenizer; |
| 925 | + use crate::parser::Parser; |
| 926 | + use crate::ast::{VisitorMut, VisitMut, Value, Statement}; |
899 | 927 |
|
900 | | - impl Visitor for QuickVisitor { |
901 | | - type Break = (); |
| 928 | + #[derive(Default)] |
| 929 | + struct MutatorVisitor { |
| 930 | + index: u64 |
902 | 931 | } |
903 | 932 |
|
904 | | - #[test] |
905 | | - fn overflow() { |
906 | | - let cond = (0..1000) |
907 | | - .map(|n| format!("X = {}", n)) |
908 | | - .collect::<Vec<_>>() |
909 | | - .join(" OR "); |
910 | | - let sql = format!("SELECT x where {0}", cond); |
| 933 | + impl VisitorMut for MutatorVisitor { |
| 934 | + type Break = (); |
| 935 | + |
| 936 | + fn pre_visit_value(&mut self, value: &mut Value) -> ControlFlow<Self::Break> { |
| 937 | + self.index += 1; |
| 938 | + *value = Value::SingleQuotedString(format!("REDACTED_{}", self.index)); |
| 939 | + ControlFlow::Continue(()) |
| 940 | + } |
911 | 941 |
|
| 942 | + fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> { |
| 943 | + ControlFlow::Continue(()) |
| 944 | + } |
| 945 | + } |
| 946 | + |
| 947 | + fn do_visit_mut<V: VisitorMut>(sql: &str, visitor: &mut V) -> Statement { |
912 | 948 | let dialect = GenericDialect {}; |
913 | | - let tokens = Tokenizer::new(&dialect, sql.as_str()).tokenize().unwrap(); |
914 | | - let s = Parser::new(&dialect) |
| 949 | + let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap(); |
| 950 | + let mut s = Parser::new(&dialect) |
915 | 951 | .with_tokens(tokens) |
916 | 952 | .parse_statement() |
917 | 953 | .unwrap(); |
918 | 954 |
|
919 | | - let mut visitor = QuickVisitor {}; |
920 | | - s.visit(&mut visitor); |
| 955 | + s.visit(visitor); |
| 956 | + s |
| 957 | + } |
| 958 | + |
| 959 | + #[test] |
| 960 | + fn test_value_redact() { |
| 961 | + let tests = vec![ |
| 962 | + ( |
| 963 | + concat!( |
| 964 | + "SELECT * FROM monthly_sales ", |
| 965 | + "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ", |
| 966 | + "ORDER BY EMPID" |
| 967 | + ), |
| 968 | + concat!( |
| 969 | + "SELECT * FROM monthly_sales ", |
| 970 | + "PIVOT(SUM(a.amount) FOR a.MONTH IN ('REDACTED_1', 'REDACTED_2', 'REDACTED_3', 'REDACTED_4')) AS p (c, d) ", |
| 971 | + "ORDER BY EMPID" |
| 972 | + ), |
| 973 | + |
| 974 | + ), |
| 975 | + |
| 976 | + ]; |
| 977 | + |
| 978 | + for (sql, expected) in tests { |
| 979 | + let mut visitor = MutatorVisitor::default(); |
| 980 | + let mutated = do_visit_mut(sql, &mut visitor); |
| 981 | + assert_eq!(mutated.to_string(), expected) |
| 982 | + } |
921 | 983 | } |
922 | 984 | } |
0 commit comments