|
| 1 | +use gosyn::ast::FuncDecl; |
| 2 | + |
| 3 | +use crate::prelude::*; |
| 4 | + |
| 5 | +pub struct FuncVisitor { |
| 6 | + param_name: String, |
| 7 | + func_decl: gosyn::ast::FuncDecl, |
| 8 | +} |
| 9 | + |
| 10 | +impl FuncVisitor { |
| 11 | + pub fn verify_source_code(source: &str, benchmarks: &[String]) -> anyhow::Result<Vec<String>> { |
| 12 | + let file = gosyn::parse_source(source)?; |
| 13 | + Self::verify_file(&file, benchmarks) |
| 14 | + } |
| 15 | + |
| 16 | + pub fn verify_file( |
| 17 | + file: &gosyn::ast::File, |
| 18 | + benchmarks: &[String], |
| 19 | + ) -> anyhow::Result<Vec<String>> { |
| 20 | + let mut valid_benchmarks = Vec::new(); |
| 21 | + for decl in &file.decl { |
| 22 | + let gosyn::ast::Declaration::Function(func_decl) = decl else { |
| 23 | + continue; |
| 24 | + }; |
| 25 | + let func_name = &func_decl.name.name; |
| 26 | + if !benchmarks.contains(func_name) { |
| 27 | + continue; |
| 28 | + } |
| 29 | + |
| 30 | + if FuncVisitor::new(func_decl.clone())?.is_valid().is_err() { |
| 31 | + continue; |
| 32 | + }; |
| 33 | + |
| 34 | + valid_benchmarks.push(func_name.clone()); |
| 35 | + } |
| 36 | + |
| 37 | + Ok(valid_benchmarks) |
| 38 | + } |
| 39 | + |
| 40 | + pub fn new(func: FuncDecl) -> anyhow::Result<Self> { |
| 41 | + let param_name = Self::find_testing_b_param_name(&func)?; |
| 42 | + |
| 43 | + Ok(Self { |
| 44 | + param_name, |
| 45 | + func_decl: func, |
| 46 | + }) |
| 47 | + } |
| 48 | + |
| 49 | + fn find_testing_b_param_name(func_decl: &gosyn::ast::FuncDecl) -> Result<String> { |
| 50 | + let params = &func_decl.typ.params; |
| 51 | + for param in ¶ms.list { |
| 52 | + let type_expr = ¶m.typ; |
| 53 | + |
| 54 | + let gosyn::ast::Expression::TypePointer(pointer_type) = type_expr else { |
| 55 | + continue; |
| 56 | + }; |
| 57 | + let gosyn::ast::Expression::Selector(selector) = pointer_type.typ.as_ref() else { |
| 58 | + continue; |
| 59 | + }; |
| 60 | + |
| 61 | + let gosyn::ast::Expression::Ident(pkg) = selector.x.as_ref() else { |
| 62 | + continue; |
| 63 | + }; |
| 64 | + |
| 65 | + // We need a testing.B parameter |
| 66 | + if pkg.name != "testing" || selector.sel.name != "B" { |
| 67 | + continue; |
| 68 | + } |
| 69 | + |
| 70 | + if let Some(first_name) = param.name.first() { |
| 71 | + return Ok(first_name.name.clone()); |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + bail!("Benchmark function does not have *testing.B parameter"); |
| 76 | + } |
| 77 | + |
| 78 | + pub fn is_valid(&self) -> anyhow::Result<()> { |
| 79 | + let Some(body) = &self.func_decl.body else { |
| 80 | + return Ok(()); |
| 81 | + }; |
| 82 | + |
| 83 | + self.is_valid_block(body) |
| 84 | + } |
| 85 | + |
| 86 | + fn is_valid_block(&self, block: &gosyn::ast::BlockStmt) -> anyhow::Result<()> { |
| 87 | + for stmt in &block.list { |
| 88 | + self.is_valid_stmt(stmt)?; |
| 89 | + } |
| 90 | + Ok(()) |
| 91 | + } |
| 92 | + |
| 93 | + fn is_valid_stmt(&self, stmt: &gosyn::ast::Statement) -> anyhow::Result<()> { |
| 94 | + match stmt { |
| 95 | + gosyn::ast::Statement::Expr(expr_stmt) => self.is_valid_expr(&expr_stmt.expr), |
| 96 | + gosyn::ast::Statement::Assign(assign_stmt) => { |
| 97 | + for expr in &assign_stmt.right { |
| 98 | + self.is_valid_expr(expr)?; |
| 99 | + } |
| 100 | + Ok(()) |
| 101 | + } |
| 102 | + gosyn::ast::Statement::If(if_stmt) => { |
| 103 | + self.is_valid_expr(&if_stmt.cond)?; |
| 104 | + self.is_valid_block(&if_stmt.body)?; |
| 105 | + if let Some(else_stmt) = &if_stmt.else_ { |
| 106 | + self.is_valid_stmt(else_stmt)?; |
| 107 | + } |
| 108 | + Ok(()) |
| 109 | + } |
| 110 | + gosyn::ast::Statement::For(for_stmt) => { |
| 111 | + if let Some(condition) = &for_stmt.cond { |
| 112 | + self.is_valid_stmt(condition)?; |
| 113 | + } |
| 114 | + if let Some(init) = &for_stmt.init { |
| 115 | + self.is_valid_stmt(init)?; |
| 116 | + } |
| 117 | + if let Some(post) = &for_stmt.post { |
| 118 | + self.is_valid_stmt(post)?; |
| 119 | + } |
| 120 | + self.is_valid_block(&for_stmt.body) |
| 121 | + } |
| 122 | + gosyn::ast::Statement::Block(block_stmt) => self.is_valid_block(block_stmt), |
| 123 | + _ => Ok(()), |
| 124 | + } |
| 125 | + } |
| 126 | + |
| 127 | + fn is_valid_expr(&self, expr: &gosyn::ast::Expression) -> anyhow::Result<()> { |
| 128 | + match expr { |
| 129 | + gosyn::ast::Expression::Call(call_expr) => { |
| 130 | + if let gosyn::ast::Expression::Selector(_) = call_expr.func.as_ref() { |
| 131 | + for arg in &call_expr.args { |
| 132 | + if self.uses_testing_ident(arg) { |
| 133 | + bail!( |
| 134 | + "testing.B parameter '{}' passed as argument to method call", |
| 135 | + self.param_name |
| 136 | + ); |
| 137 | + } |
| 138 | + } |
| 139 | + } else { |
| 140 | + for arg in &call_expr.args { |
| 141 | + if self.uses_testing_ident(arg) { |
| 142 | + bail!( |
| 143 | + "testing.B parameter '{}' passed as argument to function call", |
| 144 | + self.param_name |
| 145 | + ); |
| 146 | + } |
| 147 | + } |
| 148 | + } |
| 149 | + |
| 150 | + for arg in &call_expr.args { |
| 151 | + self.is_valid_expr(arg)?; |
| 152 | + } |
| 153 | + } |
| 154 | + gosyn::ast::Expression::Operation(operation) => { |
| 155 | + self.is_valid_expr(&operation.x)?; |
| 156 | + if let Some(y) = &operation.y { |
| 157 | + self.is_valid_expr(y)?; |
| 158 | + } |
| 159 | + } |
| 160 | + _ => {} |
| 161 | + } |
| 162 | + Ok(()) |
| 163 | + } |
| 164 | + |
| 165 | + fn uses_testing_ident(&self, expr: &gosyn::ast::Expression) -> bool { |
| 166 | + if let gosyn::ast::Expression::Ident(ident) = expr { |
| 167 | + ident.name == self.param_name |
| 168 | + } else { |
| 169 | + false |
| 170 | + } |
| 171 | + } |
| 172 | +} |
| 173 | + |
| 174 | +#[cfg(test)] |
| 175 | +mod tests { |
| 176 | + use super::*; |
| 177 | + |
| 178 | + #[test] |
| 179 | + fn test_valid_benchmark() { |
| 180 | + let source = include_str!("../../testdata/verifier/valid_benchmark.go"); |
| 181 | + let valid_benches = |
| 182 | + FuncVisitor::verify_source_code(source, &["BenchmarkValid".into()]).unwrap(); |
| 183 | + |
| 184 | + assert_eq!(valid_benches.len(), 1); |
| 185 | + assert!(valid_benches.contains(&"BenchmarkValid".to_string())); |
| 186 | + } |
| 187 | + |
| 188 | + #[test] |
| 189 | + fn test_invalid_benchmark_function_call() { |
| 190 | + let source = include_str!("../../testdata/verifier/invalid_benchmark_function_call.go"); |
| 191 | + let valid_benches = |
| 192 | + FuncVisitor::verify_source_code(source, &["BenchmarkInvalid".into()]).unwrap(); |
| 193 | + |
| 194 | + assert!(valid_benches.is_empty()); |
| 195 | + } |
| 196 | + |
| 197 | + #[test] |
| 198 | + fn test_valid_benchmark_method_calls() { |
| 199 | + let source = include_str!("../../testdata/verifier/valid_benchmark_methods.go"); |
| 200 | + let valid_benches = |
| 201 | + FuncVisitor::verify_source_code(source, &["BenchmarkValidMethods".into()]).unwrap(); |
| 202 | + |
| 203 | + assert_eq!(valid_benches.len(), 1); |
| 204 | + assert!(valid_benches.contains(&"BenchmarkValidMethods".to_string())); |
| 205 | + } |
| 206 | + |
| 207 | + #[test] |
| 208 | + fn test_multiple_benchmarks_mixed_validity() { |
| 209 | + let source = include_str!("../../testdata/verifier/mixed_validity_benchmarks.go"); |
| 210 | + let valid_benches = FuncVisitor::verify_source_code( |
| 211 | + source, |
| 212 | + &["BenchmarkValid".into(), "BenchmarkInvalid".into()], |
| 213 | + ) |
| 214 | + .unwrap(); |
| 215 | + |
| 216 | + assert_eq!(valid_benches.len(), 1); |
| 217 | + assert!(valid_benches.contains(&"BenchmarkValid".to_string())); |
| 218 | + } |
| 219 | +} |
0 commit comments