Skip to content

Commit 55b8cf4

Browse files
authored
frontend-analyser: Fix test xref extraction logic (#2206)
* frontend-analyser: Fix logic for tests xref info extraction Signed-off-by: Arthur Chan <[email protected]> * Fix logic Signed-off-by: Arthur Chan <[email protected]> * Fix formatting Signed-off-by: Arthur Chan <[email protected]> * Fix typo Signed-off-by: Arthur Chan <[email protected]> * Fix formatting and mypy error Signed-off-by: Arthur Chan <[email protected]> --------- Signed-off-by: Arthur Chan <[email protected]>
1 parent 394c9d3 commit 55b8cf4

File tree

1 file changed

+175
-25
lines changed

1 file changed

+175
-25
lines changed

src/fuzz_introspector/analyses/frontend_analyser.py

Lines changed: 175 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,50 @@
1616
import logging
1717
import os
1818

19-
from typing import (Any, List, Dict)
19+
from typing import (Any, List, Dict, Optional)
2020

2121
from fuzz_introspector import (analysis, html_helpers, utils)
2222

23-
from fuzz_introspector.datatypes import (project_profile, fuzzer_profile)
23+
from fuzz_introspector.datatypes import (project_profile, fuzzer_profile,
24+
function_profile)
2425

2526
from fuzz_introspector.frontends import oss_fuzz
2627

28+
from tree_sitter import Language, Parser, Query
29+
import tree_sitter_cpp
30+
2731
logger = logging.getLogger(name=__name__)
2832

33+
QUERY = """
34+
(declaration type: (_) @dt declarator: (pointer_declarator declarator: (identifier) @dn)) @dp
35+
36+
(declaration type: (_) @dt declarator: (array_declarator declarator: (identifier) @dn)) @da
37+
38+
(declaration type: (_) @dt declarator: (identifier) @dn) @d
39+
40+
(assignment_expression left: (identifier) @an
41+
right: (call_expression function: (identifier) @ai)) @ae
42+
43+
(call_expression function: (identifier) @cn arguments: (argument_list) @ca)
44+
"""
45+
46+
PRIMITIVE_TYPES = [
47+
'void', 'auto', '_Bool', 'bool', 'byte', 'char', 'char16_t', 'char32_t',
48+
'char8_t', 'complex128', 'complex64', 'double', 'f32', 'f64', 'float',
49+
'float32', 'float64', 'i8', 'i16', 'i32', 'i64', 'i128', 'int', 'int8',
50+
'int16', 'int32', 'int64', 'isize', 'long', 'double', 'nullptr_t', 'rune',
51+
'short', 'str', 'string', 'u8', 'u16', 'u32', 'u64', 'u128', 'uint',
52+
'uint8', 'uint16', 'uint32', 'uint64', 'usize', 'uintptr',
53+
'unsafe.Pointer', 'wchar_t', 'size_t'
54+
]
55+
2956

3057
class FrontendAnalyser(analysis.AnalysisInterface):
3158
"""Analysis utility for a second frontend run and test file analysis."""
59+
# TODO arthur extend to other language
60+
LANGUAGE: dict[str, Language] = {
61+
'c-cpp': Language(tree_sitter_cpp.language()),
62+
}
3263

3364
name: str = 'FrontendAnalyser'
3465

@@ -40,6 +71,15 @@ def __init__(self) -> None:
4071
if os.path.isdir('/src/'):
4172
self.directory.add('/src/')
4273

74+
def _check_primitive(self, type_str: Optional[str]) -> bool:
75+
"""Check if the type str is primitive."""
76+
if not type_str:
77+
return True
78+
79+
type_str = type_str.replace('*', '').replace('[]', '')
80+
81+
return type_str in PRIMITIVE_TYPES
82+
4383
@classmethod
4484
def get_name(cls):
4585
"""Return the analyser identifying name for processing.
@@ -119,8 +159,14 @@ def standalone_analysis(self,
119159
out_dir: str) -> None:
120160
"""Standalone analysis."""
121161
super().standalone_analysis(proj_profile, profiles, out_dir)
122-
functions = proj_profile.get_all_functions()
123162

163+
# Extract all functions
164+
functions: list[function_profile.FunctionProfile] = []
165+
for profile in profiles:
166+
functions.extend(profile.all_class_functions.values())
167+
func_names = [f.function_name.split('::')[-1] for f in functions]
168+
169+
# Get test files from json
124170
test_files = set()
125171
if os.path.isfile(os.path.join(out_dir, 'all_tests.json')):
126172
with open(os.path.join(out_dir, 'all_tests.json'), 'r') as f:
@@ -130,7 +176,7 @@ def standalone_analysis(self,
130176
if not self.directory:
131177
paths = [
132178
os.path.abspath(func.function_source_file)
133-
for func in functions.values()
179+
for func in functions
134180
]
135181
common_path = os.path.commonpath(paths)
136182
if os.path.isfile(common_path):
@@ -140,43 +186,147 @@ def standalone_analysis(self,
140186
if not self.language:
141187
self.language = proj_profile.language
142188

189+
# Ensure all test/example files has been added
143190
test_files.update(
144191
analysis.extract_tests_from_directories(self.directory,
145192
self.language, out_dir,
146193
False))
147194

148-
# Get all functions within test files
149-
test_functions: dict[str, list[dict[str, object]]] = {}
150-
seen_functions: dict[str, set[tuple[str, str]]] = {}
151-
for function in functions.values():
152-
test_source = function.function_source_file
195+
tree_sitter_lang = self.LANGUAGE.get(self.language)
196+
if not tree_sitter_lang:
197+
logger.warning('Language not support: %s', self.language)
198+
return None
153199

154-
# Skip unrelated functions
155-
if test_source not in test_files:
200+
# Extract calls from each test/example file
201+
test_functions: dict[str, list[dict[str, object]]] = {}
202+
parser = Parser(tree_sitter_lang)
203+
query = Query(tree_sitter_lang, QUERY)
204+
for test_file in test_files:
205+
func_call_list = []
206+
handled = []
207+
208+
# Tree sitter parsing of the test filees
209+
node = None
210+
if os.path.isfile(test_file):
211+
with open(test_file, 'rb') as file:
212+
src = file.read() # type: bytes
213+
node = parser.parse(src).root_node
214+
215+
if not node:
156216
continue
157217

158-
if test_source not in test_functions:
159-
test_functions[test_source] = []
160-
seen_functions[test_source] = set()
218+
# Extract function calls data from test files
219+
data = query.captures(node)
220+
221+
# Extract variable declarations (normal, pointers, arrays)
222+
declarations = {}
223+
type_nodes = data.get('dt', [])
224+
name_nodes = data.get('dn', [])
225+
kinds = {(n.start_point[0], n.start_point[1]): kind
226+
for kind in ('dp', 'da', 'dp')
227+
for n in data.get(kind, [])}
228+
229+
# Process variable declarations
230+
for name_node, type_node in zip(name_nodes, type_nodes):
231+
if not name_node.text or not type_node.text:
232+
continue
161233

162-
for reached_name in function.functions_reached:
163-
reached = functions.get(reached_name)
234+
name = name_node.text.decode(encoding='utf-8',
235+
errors='ignore').strip()
236+
base = type_node.text.decode(encoding='utf-8',
237+
errors='ignore').strip()
238+
239+
pos = (name_node.start_point[0], name_node.start_point[1])
240+
kind = kinds.get(pos, 'dp')
241+
242+
if kind == 'dp':
243+
full_type = f'{base}*'
244+
elif kind == 'da':
245+
full_type = f'{base}[]'
246+
else:
247+
full_type = base
248+
249+
declarations[name] = {
250+
'type': full_type,
251+
'decl_line': pos[0] + 1,
252+
'init_func': None,
253+
'init_start': -1,
254+
'init_end': -1,
255+
}
256+
257+
# Extract and process variable initialisation and assignment
258+
assign_names = data.get('an', [])
259+
assign_inits = data.get('ai', [])
260+
for name_node, stmt_node in zip(assign_names, assign_inits):
261+
if not name_node.text or not stmt_node.text:
262+
continue
164263

165-
# Skip other test functions or external functions
166-
if not reached or reached.function_source_file in test_files:
264+
name = name_node.text.decode(encoding='utf-8',
265+
errors='ignore').strip()
266+
stmt = stmt_node.text.decode(encoding='utf-8',
267+
errors='ignore').strip()
268+
269+
pos = (stmt_node.start_point[0], stmt_node.end_point[0])
270+
if name in declarations:
271+
declarations[name]['init_func'] = stmt
272+
declarations[name]['init_start'] = pos[0] + 1
273+
declarations[name]['init_end'] = pos[1] + 1
274+
275+
# Capture function called and args by this test files
276+
call_names = data.get('cn', [])
277+
call_args = data.get('ca', [])
278+
for name_node, args_node in zip(call_names, call_args):
279+
if not name_node.text:
167280
continue
168281

169-
key = (reached.function_name, reached.function_source_file)
282+
name = name_node.text.decode(encoding='utf-8',
283+
errors='ignore').strip()
170284

171-
# Skip duplicated, reached funcitons
172-
if key in seen_functions[test_source]:
285+
# Skip non-project functions
286+
if name not in func_names:
173287
continue
174288

175-
seen_functions[test_source].add(key)
176-
test_functions[test_source].append(reached.to_dict())
289+
# Extract declaration and intialisation for params
290+
# of this function call
291+
params = set()
292+
for child in args_node.children:
293+
stack = [child]
294+
while stack:
295+
curr = stack.pop()
296+
297+
if curr.type == 'identifier' and curr.text:
298+
params.add(
299+
curr.text.decode(encoding='utf-8',
300+
errors='ignore').strip())
301+
break
302+
if curr.child_count > 0:
303+
stack.extend(curr.children)
304+
305+
# Filter declaration for this function call and store full
306+
# details including declaration initialisation of parameters
307+
# used for this function call
308+
filtered = [
309+
decl for param, decl in declarations.items()
310+
if param in params
311+
and not self._check_primitive(str(decl.get('type', '')))
312+
]
313+
key = (name, name_node.start_point[0], name_node.end_point[0])
314+
if key in handled:
315+
continue
177316

178-
# Remove useless test files
179-
test_functions = {k: v for k, v in test_functions.items() if v}
317+
handled.append(key)
318+
func_call_list.append({
319+
'function_name': name,
320+
'params': filtered,
321+
'call_start': name_node.start_point[0] + 1,
322+
'call_end': name_node.end_point[0] + 1,
323+
})
324+
325+
func_call_list = [
326+
call for call in func_call_list if call['params']
327+
]
328+
if func_call_list:
329+
test_functions[test_file] = func_call_list
180330

181331
# Store test files
182332
with open(os.path.join(out_dir, 'all_tests.json'), 'w') as f:

0 commit comments

Comments
 (0)