1616import logging
1717import os
1818
19- from typing import (Any , List , Dict )
19+ from typing import (Any , List , Dict , Optional )
2020
2121from fuzz_introspector import (analysis , html_helpers , utils )
2222
23- from fuzz_introspector .datatypes import (project_profile , fuzzer_profile )
23+ from fuzz_introspector .datatypes import (project_profile , fuzzer_profile ,
24+ function_profile )
2425
2526from fuzz_introspector .frontends import oss_fuzz
2627
28+ from tree_sitter import Language , Parser , Query
29+ import tree_sitter_cpp
30+
2731logger = logging .getLogger (name = __name__ )
2832
33+ QUERY = """
34+ (declaration type: (_) @dt declarator: (pointer_declarator declarator: (identifier) @dn)) @dp
35+
36+ (declaration type: (_) @dt declarator: (array_declarator declarator: (identifier) @dn)) @da
37+
38+ (declaration type: (_) @dt declarator: (identifier) @dn) @d
39+
40+ (assignment_expression left: (identifier) @an
41+ right: (call_expression function: (identifier) @ai)) @ae
42+
43+ (call_expression function: (identifier) @cn arguments: (argument_list) @ca)
44+ """
45+
46+ PRIMITIVE_TYPES = [
47+ 'void' , 'auto' , '_Bool' , 'bool' , 'byte' , 'char' , 'char16_t' , 'char32_t' ,
48+ 'char8_t' , 'complex128' , 'complex64' , 'double' , 'f32' , 'f64' , 'float' ,
49+ 'float32' , 'float64' , 'i8' , 'i16' , 'i32' , 'i64' , 'i128' , 'int' , 'int8' ,
50+ 'int16' , 'int32' , 'int64' , 'isize' , 'long' , 'double' , 'nullptr_t' , 'rune' ,
51+ 'short' , 'str' , 'string' , 'u8' , 'u16' , 'u32' , 'u64' , 'u128' , 'uint' ,
52+ 'uint8' , 'uint16' , 'uint32' , 'uint64' , 'usize' , 'uintptr' ,
53+ 'unsafe.Pointer' , 'wchar_t' , 'size_t'
54+ ]
55+
2956
3057class FrontendAnalyser (analysis .AnalysisInterface ):
3158 """Analysis utility for a second frontend run and test file analysis."""
59+ # TODO arthur extend to other language
60+ LANGUAGE : dict [str , Language ] = {
61+ 'c-cpp' : Language (tree_sitter_cpp .language ()),
62+ }
3263
3364 name : str = 'FrontendAnalyser'
3465
@@ -40,6 +71,15 @@ def __init__(self) -> None:
4071 if os .path .isdir ('/src/' ):
4172 self .directory .add ('/src/' )
4273
74+ def _check_primitive (self , type_str : Optional [str ]) -> bool :
75+ """Check if the type str is primitive."""
76+ if not type_str :
77+ return True
78+
79+ type_str = type_str .replace ('*' , '' ).replace ('[]' , '' )
80+
81+ return type_str in PRIMITIVE_TYPES
82+
4383 @classmethod
4484 def get_name (cls ):
4585 """Return the analyser identifying name for processing.
@@ -119,8 +159,14 @@ def standalone_analysis(self,
119159 out_dir : str ) -> None :
120160 """Standalone analysis."""
121161 super ().standalone_analysis (proj_profile , profiles , out_dir )
122- functions = proj_profile .get_all_functions ()
123162
163+ # Extract all functions
164+ functions : list [function_profile .FunctionProfile ] = []
165+ for profile in profiles :
166+ functions .extend (profile .all_class_functions .values ())
167+ func_names = [f .function_name .split ('::' )[- 1 ] for f in functions ]
168+
169+ # Get test files from json
124170 test_files = set ()
125171 if os .path .isfile (os .path .join (out_dir , 'all_tests.json' )):
126172 with open (os .path .join (out_dir , 'all_tests.json' ), 'r' ) as f :
@@ -130,7 +176,7 @@ def standalone_analysis(self,
130176 if not self .directory :
131177 paths = [
132178 os .path .abspath (func .function_source_file )
133- for func in functions . values ()
179+ for func in functions
134180 ]
135181 common_path = os .path .commonpath (paths )
136182 if os .path .isfile (common_path ):
@@ -140,43 +186,147 @@ def standalone_analysis(self,
140186 if not self .language :
141187 self .language = proj_profile .language
142188
189+ # Ensure all test/example files has been added
143190 test_files .update (
144191 analysis .extract_tests_from_directories (self .directory ,
145192 self .language , out_dir ,
146193 False ))
147194
148- # Get all functions within test files
149- test_functions : dict [str , list [dict [str , object ]]] = {}
150- seen_functions : dict [str , set [tuple [str , str ]]] = {}
151- for function in functions .values ():
152- test_source = function .function_source_file
195+ tree_sitter_lang = self .LANGUAGE .get (self .language )
196+ if not tree_sitter_lang :
197+ logger .warning ('Language not support: %s' , self .language )
198+ return None
153199
154- # Skip unrelated functions
155- if test_source not in test_files :
200+ # Extract calls from each test/example file
201+ test_functions : dict [str , list [dict [str , object ]]] = {}
202+ parser = Parser (tree_sitter_lang )
203+ query = Query (tree_sitter_lang , QUERY )
204+ for test_file in test_files :
205+ func_call_list = []
206+ handled = []
207+
208+ # Tree sitter parsing of the test filees
209+ node = None
210+ if os .path .isfile (test_file ):
211+ with open (test_file , 'rb' ) as file :
212+ src = file .read () # type: bytes
213+ node = parser .parse (src ).root_node
214+
215+ if not node :
156216 continue
157217
158- if test_source not in test_functions :
159- test_functions [test_source ] = []
160- seen_functions [test_source ] = set ()
218+ # Extract function calls data from test files
219+ data = query .captures (node )
220+
221+ # Extract variable declarations (normal, pointers, arrays)
222+ declarations = {}
223+ type_nodes = data .get ('dt' , [])
224+ name_nodes = data .get ('dn' , [])
225+ kinds = {(n .start_point [0 ], n .start_point [1 ]): kind
226+ for kind in ('dp' , 'da' , 'dp' )
227+ for n in data .get (kind , [])}
228+
229+ # Process variable declarations
230+ for name_node , type_node in zip (name_nodes , type_nodes ):
231+ if not name_node .text or not type_node .text :
232+ continue
161233
162- for reached_name in function .functions_reached :
163- reached = functions .get (reached_name )
234+ name = name_node .text .decode (encoding = 'utf-8' ,
235+ errors = 'ignore' ).strip ()
236+ base = type_node .text .decode (encoding = 'utf-8' ,
237+ errors = 'ignore' ).strip ()
238+
239+ pos = (name_node .start_point [0 ], name_node .start_point [1 ])
240+ kind = kinds .get (pos , 'dp' )
241+
242+ if kind == 'dp' :
243+ full_type = f'{ base } *'
244+ elif kind == 'da' :
245+ full_type = f'{ base } []'
246+ else :
247+ full_type = base
248+
249+ declarations [name ] = {
250+ 'type' : full_type ,
251+ 'decl_line' : pos [0 ] + 1 ,
252+ 'init_func' : None ,
253+ 'init_start' : - 1 ,
254+ 'init_end' : - 1 ,
255+ }
256+
257+ # Extract and process variable initialisation and assignment
258+ assign_names = data .get ('an' , [])
259+ assign_inits = data .get ('ai' , [])
260+ for name_node , stmt_node in zip (assign_names , assign_inits ):
261+ if not name_node .text or not stmt_node .text :
262+ continue
164263
165- # Skip other test functions or external functions
166- if not reached or reached .function_source_file in test_files :
264+ name = name_node .text .decode (encoding = 'utf-8' ,
265+ errors = 'ignore' ).strip ()
266+ stmt = stmt_node .text .decode (encoding = 'utf-8' ,
267+ errors = 'ignore' ).strip ()
268+
269+ pos = (stmt_node .start_point [0 ], stmt_node .end_point [0 ])
270+ if name in declarations :
271+ declarations [name ]['init_func' ] = stmt
272+ declarations [name ]['init_start' ] = pos [0 ] + 1
273+ declarations [name ]['init_end' ] = pos [1 ] + 1
274+
275+ # Capture function called and args by this test files
276+ call_names = data .get ('cn' , [])
277+ call_args = data .get ('ca' , [])
278+ for name_node , args_node in zip (call_names , call_args ):
279+ if not name_node .text :
167280 continue
168281
169- key = (reached .function_name , reached .function_source_file )
282+ name = name_node .text .decode (encoding = 'utf-8' ,
283+ errors = 'ignore' ).strip ()
170284
171- # Skip duplicated, reached funcitons
172- if key in seen_functions [ test_source ] :
285+ # Skip non-project functions
286+ if name not in func_names :
173287 continue
174288
175- seen_functions [test_source ].add (key )
176- test_functions [test_source ].append (reached .to_dict ())
289+ # Extract declaration and intialisation for params
290+ # of this function call
291+ params = set ()
292+ for child in args_node .children :
293+ stack = [child ]
294+ while stack :
295+ curr = stack .pop ()
296+
297+ if curr .type == 'identifier' and curr .text :
298+ params .add (
299+ curr .text .decode (encoding = 'utf-8' ,
300+ errors = 'ignore' ).strip ())
301+ break
302+ if curr .child_count > 0 :
303+ stack .extend (curr .children )
304+
305+ # Filter declaration for this function call and store full
306+ # details including declaration initialisation of parameters
307+ # used for this function call
308+ filtered = [
309+ decl for param , decl in declarations .items ()
310+ if param in params
311+ and not self ._check_primitive (str (decl .get ('type' , '' )))
312+ ]
313+ key = (name , name_node .start_point [0 ], name_node .end_point [0 ])
314+ if key in handled :
315+ continue
177316
178- # Remove useless test files
179- test_functions = {k : v for k , v in test_functions .items () if v }
317+ handled .append (key )
318+ func_call_list .append ({
319+ 'function_name' : name ,
320+ 'params' : filtered ,
321+ 'call_start' : name_node .start_point [0 ] + 1 ,
322+ 'call_end' : name_node .end_point [0 ] + 1 ,
323+ })
324+
325+ func_call_list = [
326+ call for call in func_call_list if call ['params' ]
327+ ]
328+ if func_call_list :
329+ test_functions [test_file ] = func_call_list
180330
181331 # Store test files
182332 with open (os .path .join (out_dir , 'all_tests.json' ), 'w' ) as f :
0 commit comments