Skip to content

Commit bb9408b

Browse files
committed
Add frontend logic for rust projects
Signed-off-by: Arthur Chan <[email protected]>
1 parent fca0a8c commit bb9408b

File tree

5 files changed

+376
-0
lines changed

5 files changed

+376
-0
lines changed

frontends/rust/process.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright 2024 Fuzz Introspector Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import subprocess
16+
import json
17+
import yaml
18+
import os
19+
import sys
20+
import re
21+
from pathlib import Path
22+
from typing import Dict, List, Tuple
23+
24+
def get_rs_files(project_dir: str) -> List[str]:
25+
"""Recursively find all Rust files in the project directory."""
26+
rs_files = []
27+
for root, _, files in os.walk(project_dir):
28+
for file in files:
29+
if file.endswith(".rs"):
30+
rs_files.append(os.path.join(root, file))
31+
return rs_files
32+
33+
def extract_function_line_info_from_file(file_path: str) -> Dict[Tuple[str, str], Tuple[int, int]]:
34+
"""Extract function names with their start and end line numbers from a Rust file."""
35+
info = {}
36+
pattern = re.compile(r"fn\s+(\w+)\s*\(")
37+
38+
with open(file_path, "r") as f:
39+
lines = f.readlines()
40+
41+
curr_func = None
42+
start = 0
43+
44+
for i, line in enumerate(lines, 1):
45+
match = pattern.search(line)
46+
if match:
47+
if curr_func:
48+
info[(curr_func, file_path)] = (start, i - 1)
49+
50+
curr_func = match.group(1)
51+
start = i
52+
53+
if curr_func:
54+
info[(curr_func, file_path)] = (start, len(lines))
55+
56+
return info
57+
58+
def analyze_project_functions(project_dir: str) -> Dict[Tuple[str, str], Tuple[int, int]]:
59+
"""Analyze all functions in the Rust project and map their line numbers."""
60+
all_functions = {}
61+
rs_files = get_rs_files(project_dir)
62+
63+
for rs_file in rs_files:
64+
functions = extract_function_line_info_from_file(rs_file)
65+
all_functions.update(functions)
66+
67+
return all_functions
68+
69+
def run_rust_analysis(target_directory: str) -> List[Dict]:
70+
"""Run the Rust analysis tool and retrieve JSON results."""
71+
try:
72+
result = subprocess.run(
73+
["cargo", "run", "--", target_directory],
74+
stdout=subprocess.PIPE,
75+
stderr=subprocess.PIPE,
76+
cwd="rust_function_analyser",
77+
text=True,
78+
check=True
79+
)
80+
return json.loads(result.stdout)
81+
except subprocess.CalledProcessError:
82+
return []
83+
except ValueError:
84+
return []
85+
86+
def add_line_data(rust_results: List[Dict]):
87+
"""Add line data to functions from rust analysis result."""
88+
line_info = analyze_project_functions(target_dir)
89+
90+
for func in rust_results:
91+
func_key = (func["name"], func["file"])
92+
if func_key in line_info:
93+
func["start_line"], func["end_line"] = line_info[func_key]
94+
else:
95+
func["start_line"], func["end_line"] = 0, 0
96+
func["called_functions"] = [f.replace(" ", "") for f in func["called_functions"]]
97+
98+
def create_yaml_output(data: List[Dict], output_file="data.yaml"):
99+
"""Generate a YAML file with the analysis results."""
100+
yaml_data = {
101+
"Fuzzer filename": "",
102+
"All functions": {
103+
"Function list name": "All functions",
104+
"Elements": []
105+
}
106+
}
107+
108+
for func in data:
109+
yaml_data["All functions"]["Elements"].append({
110+
"functionName": func["name"],
111+
"functionSourceFile": func["file"],
112+
"linkageType": "",
113+
"functionLinenumber": func["start_line"],
114+
"functionLinenumberEnd": func["end_line"],
115+
"functionDepth": func["depth"],
116+
"returnType": func["return_type"],
117+
"argCount": func["arg_count"],
118+
"argTypes": func["arg_types"],
119+
"constantsTouched": [],
120+
"argNames": [],
121+
"BBCount": 0,
122+
"ICount": 0,
123+
"EdgeCount": 0,
124+
"CyclomaticComplexity": func["complexity"],
125+
"functionsReached": func["called_functions"],
126+
"functionUses": 0,
127+
"BranchProfiles": [],
128+
"Callsites": []
129+
})
130+
131+
with open(output_file, "w") as file:
132+
yaml.dump(yaml_data, file, default_flow_style=False)
133+
134+
print(f"YAML output saved to {output_file}")
135+
136+
if __name__ == "__main__":
137+
if len(sys.argv) != 2:
138+
print("Usage: python3 script.py <target_directory>")
139+
sys.exit(1)
140+
141+
target_dir = sys.argv[1]
142+
if not Path(target_dir).is_dir():
143+
print(f"Error: {target_dir} is not a valid directory")
144+
sys.exit(1)
145+
146+
# Run the rust analysis frontend code
147+
rust_analysis_results = run_rust_analysis(target_dir)
148+
149+
# Manually extract the line info for each function.
150+
# This is needed because the rust analysis syn AST approach
151+
# cannot retrieve line number info on stable rust and non-nightly build
152+
add_line_data(rust_analysis_results)
153+
154+
create_yaml_output(rust_analysis_results)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/target
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[package]
2+
name = "rust_function_analyser"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[dependencies]
7+
syn = { version = "2.0", features = ["full", "visit"] }
8+
quote = "1.0"
9+
walkdir = "2.4"
10+
serde = { version = "1.0", features = ["derive"] }
11+
serde_json = "1.0"
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/* Copyright 2024 Fuzz Introspector Authors
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
use syn::{ItemFn, Stmt, FnArg, ReturnType};
17+
use std::collections::{HashMap, HashSet};
18+
use serde::{Serialize, Deserialize};
19+
use std::fs;
20+
21+
#[derive(Serialize, Deserialize, Debug, Clone)]
22+
pub struct FunctionInfo {
23+
pub name: String,
24+
pub file: String,
25+
pub return_type: String,
26+
pub arg_count: usize,
27+
pub arg_types: Vec<String>,
28+
pub complexity: usize,
29+
pub called_functions: Vec<String>,
30+
pub depth: usize,
31+
}
32+
33+
pub struct FunctionAnalyser {
34+
pub functions: Vec<FunctionInfo>,
35+
pub call_stack: HashMap<String, HashSet<String>>,
36+
}
37+
38+
impl FunctionAnalyser {
39+
pub fn new() -> Self {
40+
Self {
41+
functions: Vec::new(),
42+
call_stack: HashMap::new(),
43+
}
44+
}
45+
46+
pub fn visit_function(&mut self, node: &ItemFn, file_name: &str) {
47+
let function_name = node.sig.ident.to_string();
48+
let return_type = match &node.sig.output {
49+
ReturnType::Default => "void".to_string(),
50+
ReturnType::Type(_, ty) => format!("{}", quote::ToTokens::to_token_stream(&**ty)),
51+
}
52+
.replace(' ', "");
53+
54+
let arg_types = node
55+
.sig
56+
.inputs
57+
.iter()
58+
.filter_map(|arg| {
59+
if let FnArg::Typed(pat) = arg {
60+
Some(format!("{}", quote::ToTokens::to_token_stream(&*pat.ty)).replace(' ', ""))
61+
} else {
62+
None
63+
}
64+
})
65+
.collect::<Vec<_>>();
66+
67+
let complexity = calculate_cyclomatic_complexity(node);
68+
69+
self.functions.push(FunctionInfo {
70+
name: function_name.clone(),
71+
file: file_name.to_string(),
72+
return_type,
73+
arg_count: arg_types.len(),
74+
arg_types,
75+
complexity,
76+
called_functions: vec![],
77+
depth: 0,
78+
});
79+
80+
self.call_stack
81+
.entry(function_name.clone())
82+
.or_default();
83+
}
84+
85+
pub fn calculate_depths(&mut self) {
86+
let mut depth_map: HashMap<String, usize> = HashMap::new();
87+
88+
for function in &self.functions {
89+
let depth = self.calculate_function_depth(&function.name);
90+
depth_map.insert(function.name.clone(), depth);
91+
}
92+
93+
for function in self.functions.iter_mut() {
94+
if let Some(&depth) = depth_map.get(&function.name) {
95+
function.depth = depth;
96+
}
97+
}
98+
}
99+
100+
fn calculate_function_depth(
101+
&self,
102+
function_name: &str,
103+
) -> usize {
104+
let mut max_depth = 0;
105+
let mut stack = vec![(function_name, 0)];
106+
107+
while let Some((func, depth)) = stack.pop() {
108+
if depth > max_depth {
109+
max_depth = depth;
110+
}
111+
112+
if let Some(called) = self.call_stack.get(func) {
113+
for callee in called {
114+
if callee != func {
115+
stack.push((callee, depth + 1));
116+
}
117+
}
118+
}
119+
}
120+
121+
max_depth
122+
}
123+
}
124+
125+
fn calculate_cyclomatic_complexity(node: &ItemFn) -> usize {
126+
let mut complexity = 1;
127+
128+
for stmt in &node.block.stmts {
129+
if matches!(stmt, Stmt::Expr(..)) {
130+
complexity += 1;
131+
}
132+
}
133+
134+
complexity
135+
}
136+
137+
pub fn analyse_directory(dir: &str, exclude_dirs: &[&str]) -> std::io::Result<String> {
138+
let mut analyser = FunctionAnalyser::new();
139+
140+
for entry in fs::read_dir(dir)? {
141+
let entry = entry?;
142+
let path = entry.path();
143+
144+
if path.is_dir() && exclude_dirs.iter().any(|d| path.ends_with(d)) {
145+
continue;
146+
} else if path.is_dir() {
147+
let sub_result = analyse_directory(path.to_str().unwrap(), exclude_dirs)?;
148+
let parsed_functions: Vec<FunctionInfo> = serde_json::from_str(&sub_result).unwrap();
149+
analyser.functions.extend(parsed_functions);
150+
} else if path.extension().and_then(|s| s.to_str()) == Some("rs") {
151+
let file_content = fs::read_to_string(&path)?;
152+
let syntax = syn::parse_file(&file_content)
153+
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
154+
155+
for item in syntax.items {
156+
if let syn::Item::Fn(func) = item {
157+
analyser.visit_function(&func, path.to_str().unwrap());
158+
}
159+
}
160+
}
161+
}
162+
163+
analyser.calculate_depths();
164+
Ok(serde_json::to_string(&analyser.functions).unwrap())
165+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/* Copyright 2024 Fuzz Introspector Authors
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
mod analyse;
17+
18+
use std::io;
19+
20+
fn main() -> io::Result<()> {
21+
let exclude_dirs = vec![
22+
"target",
23+
"node_modules",
24+
"aflplusplus",
25+
"tests",
26+
"examples",
27+
"benches",
28+
"honggfuzz",
29+
"inspector",
30+
"libfuzzer",
31+
];
32+
33+
let args: Vec<String> = std::env::args().collect();
34+
if args.len() != 2 {
35+
eprintln!("Usage: cargo run -- <source_directory>");
36+
std::process::exit(1);
37+
}
38+
let target_directory = &args[1];
39+
40+
// Collect all results into a single string and print to stdout
41+
let result = analyse::analyse_directory(target_directory, &exclude_dirs)?;
42+
println!("{}", result);
43+
44+
Ok(())
45+
}

0 commit comments

Comments
 (0)