Skip to content

Commit 770198c

Browse files
committed
feat: verify benchmarks with AST parser
1 parent 3c40324 commit 770198c

File tree

5 files changed

+276
-0
lines changed

5 files changed

+276
-0
lines changed

go-runner/src/builder/verifier.rs

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
use gosyn::ast::FuncDecl;
2+
3+
use crate::prelude::*;
4+
5+
pub struct FuncVisitor {
6+
param_name: String,
7+
func_decl: gosyn::ast::FuncDecl,
8+
}
9+
10+
impl FuncVisitor {
11+
pub fn verify_source_code(source: &str, benchmarks: &[String]) -> anyhow::Result<Vec<String>> {
12+
let file = gosyn::parse_source(source)?;
13+
Self::verify_file(&file, benchmarks)
14+
}
15+
16+
pub fn verify_file(
17+
file: &gosyn::ast::File,
18+
benchmarks: &[String],
19+
) -> anyhow::Result<Vec<String>> {
20+
let mut valid_benchmarks = Vec::new();
21+
for decl in &file.decl {
22+
let gosyn::ast::Declaration::Function(func_decl) = decl else {
23+
continue;
24+
};
25+
let func_name = &func_decl.name.name;
26+
if !benchmarks.contains(func_name) {
27+
continue;
28+
}
29+
30+
if FuncVisitor::new(func_decl.clone())?.is_valid().is_err() {
31+
continue;
32+
};
33+
34+
valid_benchmarks.push(func_name.clone());
35+
}
36+
37+
Ok(valid_benchmarks)
38+
}
39+
40+
pub fn new(func: FuncDecl) -> anyhow::Result<Self> {
41+
let param_name = Self::find_testing_b_param_name(&func)?;
42+
43+
Ok(Self {
44+
param_name,
45+
func_decl: func,
46+
})
47+
}
48+
49+
fn find_testing_b_param_name(func_decl: &gosyn::ast::FuncDecl) -> Result<String> {
50+
let params = &func_decl.typ.params;
51+
for param in &params.list {
52+
let type_expr = &param.typ;
53+
54+
let gosyn::ast::Expression::TypePointer(pointer_type) = type_expr else {
55+
continue;
56+
};
57+
let gosyn::ast::Expression::Selector(selector) = pointer_type.typ.as_ref() else {
58+
continue;
59+
};
60+
61+
let gosyn::ast::Expression::Ident(pkg) = selector.x.as_ref() else {
62+
continue;
63+
};
64+
65+
// We need a testing.B parameter
66+
if pkg.name != "testing" || selector.sel.name != "B" {
67+
continue;
68+
}
69+
70+
if let Some(first_name) = param.name.first() {
71+
return Ok(first_name.name.clone());
72+
}
73+
}
74+
75+
bail!("Benchmark function does not have *testing.B parameter");
76+
}
77+
78+
pub fn is_valid(&self) -> anyhow::Result<()> {
79+
let Some(body) = &self.func_decl.body else {
80+
return Ok(());
81+
};
82+
83+
self.is_valid_block(body)
84+
}
85+
86+
fn is_valid_block(&self, block: &gosyn::ast::BlockStmt) -> anyhow::Result<()> {
87+
for stmt in &block.list {
88+
self.is_valid_stmt(stmt)?;
89+
}
90+
Ok(())
91+
}
92+
93+
fn is_valid_stmt(&self, stmt: &gosyn::ast::Statement) -> anyhow::Result<()> {
94+
match stmt {
95+
gosyn::ast::Statement::Expr(expr_stmt) => self.is_valid_expr(&expr_stmt.expr),
96+
gosyn::ast::Statement::Assign(assign_stmt) => {
97+
for expr in &assign_stmt.right {
98+
self.is_valid_expr(expr)?;
99+
}
100+
Ok(())
101+
}
102+
gosyn::ast::Statement::If(if_stmt) => {
103+
self.is_valid_expr(&if_stmt.cond)?;
104+
self.is_valid_block(&if_stmt.body)?;
105+
if let Some(else_stmt) = &if_stmt.else_ {
106+
self.is_valid_stmt(else_stmt)?;
107+
}
108+
Ok(())
109+
}
110+
gosyn::ast::Statement::For(for_stmt) => {
111+
if let Some(condition) = &for_stmt.cond {
112+
self.is_valid_stmt(condition)?;
113+
}
114+
if let Some(init) = &for_stmt.init {
115+
self.is_valid_stmt(init)?;
116+
}
117+
if let Some(post) = &for_stmt.post {
118+
self.is_valid_stmt(post)?;
119+
}
120+
self.is_valid_block(&for_stmt.body)
121+
}
122+
gosyn::ast::Statement::Block(block_stmt) => self.is_valid_block(block_stmt),
123+
_ => Ok(()),
124+
}
125+
}
126+
127+
fn is_valid_expr(&self, expr: &gosyn::ast::Expression) -> anyhow::Result<()> {
128+
match expr {
129+
gosyn::ast::Expression::Call(call_expr) => {
130+
if let gosyn::ast::Expression::Selector(_) = call_expr.func.as_ref() {
131+
for arg in &call_expr.args {
132+
if self.uses_testing_ident(arg) {
133+
bail!(
134+
"testing.B parameter '{}' passed as argument to method call",
135+
self.param_name
136+
);
137+
}
138+
}
139+
} else {
140+
for arg in &call_expr.args {
141+
if self.uses_testing_ident(arg) {
142+
bail!(
143+
"testing.B parameter '{}' passed as argument to function call",
144+
self.param_name
145+
);
146+
}
147+
}
148+
}
149+
150+
for arg in &call_expr.args {
151+
self.is_valid_expr(arg)?;
152+
}
153+
}
154+
gosyn::ast::Expression::Operation(operation) => {
155+
self.is_valid_expr(&operation.x)?;
156+
if let Some(y) = &operation.y {
157+
self.is_valid_expr(y)?;
158+
}
159+
}
160+
_ => {}
161+
}
162+
Ok(())
163+
}
164+
165+
fn uses_testing_ident(&self, expr: &gosyn::ast::Expression) -> bool {
166+
if let gosyn::ast::Expression::Ident(ident) = expr {
167+
ident.name == self.param_name
168+
} else {
169+
false
170+
}
171+
}
172+
}
173+
174+
#[cfg(test)]
175+
mod tests {
176+
use super::*;
177+
178+
#[test]
179+
fn test_valid_benchmark() {
180+
let source = include_str!("../../testdata/verifier/valid_benchmark.go");
181+
let valid_benches =
182+
FuncVisitor::verify_source_code(source, &["BenchmarkValid".into()]).unwrap();
183+
184+
assert_eq!(valid_benches.len(), 1);
185+
assert!(valid_benches.contains(&"BenchmarkValid".to_string()));
186+
}
187+
188+
#[test]
189+
fn test_invalid_benchmark_function_call() {
190+
let source = include_str!("../../testdata/verifier/invalid_benchmark_function_call.go");
191+
let valid_benches =
192+
FuncVisitor::verify_source_code(source, &["BenchmarkInvalid".into()]).unwrap();
193+
194+
assert!(valid_benches.is_empty());
195+
}
196+
197+
#[test]
198+
fn test_valid_benchmark_method_calls() {
199+
let source = include_str!("../../testdata/verifier/valid_benchmark_methods.go");
200+
let valid_benches =
201+
FuncVisitor::verify_source_code(source, &["BenchmarkValidMethods".into()]).unwrap();
202+
203+
assert_eq!(valid_benches.len(), 1);
204+
assert!(valid_benches.contains(&"BenchmarkValidMethods".to_string()));
205+
}
206+
207+
#[test]
208+
fn test_multiple_benchmarks_mixed_validity() {
209+
let source = include_str!("../../testdata/verifier/mixed_validity_benchmarks.go");
210+
let valid_benches = FuncVisitor::verify_source_code(
211+
source,
212+
&["BenchmarkValid".into(), "BenchmarkInvalid".into()],
213+
)
214+
.unwrap();
215+
216+
assert_eq!(valid_benches.len(), 1);
217+
assert!(valid_benches.contains(&"BenchmarkValid".to_string()));
218+
}
219+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package main
2+
3+
import "testing"
4+
5+
func helper(b *testing.B) {
6+
// This function receives testing.B as parameter
7+
}
8+
9+
func BenchmarkInvalid(b *testing.B) {
10+
helper(b) // This is invalid - passing testing.B to another function
11+
for i := 0; i < b.N; i++ {
12+
// Some work
13+
}
14+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package main
2+
3+
import "testing"
4+
5+
func helper(b *testing.B) {
6+
// Helper function
7+
}
8+
9+
func BenchmarkValid(b *testing.B) {
10+
for i := 0; i < b.N; i++ {
11+
// Some work
12+
}
13+
}
14+
15+
func BenchmarkInvalid(b *testing.B) {
16+
helper(b) // Invalid
17+
for i := 0; i < b.N; i++ {
18+
// Some work
19+
}
20+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package main
2+
3+
import "testing"
4+
5+
func BenchmarkValid(b *testing.B) {
6+
for i := 0; i < b.N; i++ {
7+
// Some work
8+
}
9+
b.ReportAllocs()
10+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package main
2+
3+
import "testing"
4+
5+
func BenchmarkValidMethods(b *testing.B) {
6+
b.ResetTimer()
7+
for i := 0; i < b.N; i++ {
8+
// Some work
9+
}
10+
b.StopTimer()
11+
b.ReportAllocs()
12+
b.StartTimer()
13+
}

0 commit comments

Comments
 (0)