Skip to content

Commit 68bbf32

Browse files
committed
mypy: Fix mypy type check
Signed-off-by: Arthur Chan <[email protected]>
1 parent aa06050 commit 68bbf32

File tree

6 files changed

+50
-23
lines changed

6 files changed

+50
-23
lines changed

src/fuzz_introspector/frontends/datatypes.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,16 @@ def has_libfuzzer_harness(self) -> bool:
9494
"""Dummy function for source code files."""
9595
return False
9696

97+
# TODO To be removed after combning treesitter for C and C++
98+
def get_c_function_node(self, target_function_name):
99+
"""Dummy function for retrieving tree-sitter node of a function"""
100+
return None
101+
102+
# TODO To be removed after combning treesitter for C and C++
103+
def get_linenumber(self, bytepos) -> int:
104+
"""Dummy function to get line number from byte range"""
105+
return -1
106+
97107

98108
class Project(Generic[T]):
99109
"""Wrapper for doing analysis of a collection of source files."""
@@ -182,7 +192,7 @@ def dump_module_logic(self,
182192

183193
def extract_calltree(self,
184194
source_file: str = '',
185-
source_code: Optional[T] = None,
195+
source_code: Optional[SourceCodeFile] = None,
186196
function: Optional[str] = None,
187197
visited_functions: Optional[set[str]] = None,
188198
depth: int = 0,
@@ -195,7 +205,7 @@ def extract_calltree(self,
195205
def get_reachable_functions(
196206
self,
197207
source_file: str = '',
198-
source_code: Optional[T] = None,
208+
source_code: Optional[SourceCodeFile] = None,
199209
function: Optional[str] = None,
200210
visited_functions: Optional[set[str]] = None) -> set[str]:
201211
"""Get a list of reachable functions for a provided function name."""

src/fuzz_introspector/frontends/frontend_c.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def generate_report(self,
204204

205205
def get_source_code_with_target(self, target_func_name):
206206
for source_code in self.source_code_files:
207-
tfunc = source_code.get_function_node(target_func_name)
207+
tfunc = source_code.get_c_function_node(target_func_name)
208208
if not tfunc:
209209
continue
210210
return source_code
@@ -215,7 +215,7 @@ def get_source_codes_with_harnesses(self) -> list['CSourceCodeFile']:
215215

216216
def extract_calltree(self,
217217
source_file: str = '',
218-
source_code: Optional['CSourceCodeFile'] = None,
218+
source_code: Optional[SourceCodeFile] = None,
219219
function: Optional[str] = None,
220220
visited_functions: Optional[set[str]] = None,
221221
depth: int = 0,
@@ -244,7 +244,7 @@ def extract_calltree(self,
244244
line_to_print += '\n'
245245
if not source_code:
246246
return line_to_print
247-
func = source_code.get_function_node(function)
247+
func = source_code.get_c_function_node(function)
248248
callsites = func.callsites()
249249
if function in visited_functions:
250250
return line_to_print
@@ -263,7 +263,7 @@ def extract_calltree(self,
263263
def get_reachable_functions(
264264
self,
265265
source_file: str = '',
266-
source_code: Optional['CSourceCodeFile'] = None,
266+
source_code: Optional[SourceCodeFile] = None,
267267
function: Optional[str] = None,
268268
visited_functions: Optional[set[str]] = None) -> set[str]:
269269
"""Gets the reachable frunctions from a given function."""
@@ -284,7 +284,7 @@ def get_reachable_functions(
284284
if not source_code:
285285
return visited_functions
286286

287-
func = source_code.get_function_node(function)
287+
func = source_code.get_c_function_node(function)
288288
if not func:
289289
return visited_functions
290290

@@ -323,7 +323,7 @@ def get_function(self, target_function_name):
323323
"""Gets the first instance of a given function."""
324324

325325
for source_code in self.source_code_files:
326-
func = source_code.get_function_node(target_function_name)
326+
func = source_code.get_c_function_node(target_function_name)
327327
if func is not None:
328328
return func
329329
return None
@@ -940,7 +940,7 @@ def get_defined_function_names(self):
940940
func_names.append(func.name())
941941
return func_names
942942

943-
def get_function_node(self, target_function_name):
943+
def get_c_function_node(self, target_function_name):
944944
"""Gets the tree-sitter node corresponding to a function."""
945945

946946
# Find the first instance of the function name

src/fuzz_introspector/frontends/frontend_cpp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,8 @@ def get_function_from_name(self, function_name):
712712

713713
def extract_calltree(self,
714714
source_file: str = '',
715-
source_code: Optional[CppSourceCodeFile] = None,
715+
source_code: Optional[
716+
datatypes.SourceCodeFile] = None,
716717
function: Optional[str] = None,
717718
visited_functions: Optional[set[str]] = None,
718719
depth: int = 0,
@@ -738,7 +739,7 @@ def extract_calltree(self,
738739

739740
func_node = None
740741
if function:
741-
if source_code:
742+
if source_code and isinstance(source_code, CppSourceCodeFile):
742743
logger.debug('Using source code var to extract node')
743744
func_node = source_code.get_function_node(function)
744745
else:
@@ -793,7 +794,7 @@ def extract_calltree(self,
793794
def get_reachable_functions(
794795
self,
795796
source_file: str = '',
796-
source_code: Optional[CppSourceCodeFile] = None,
797+
source_code: Optional[datatypes.SourceCodeFile] = None,
797798
function: Optional[str] = None,
798799
visited_functions: Optional[set[str]] = None) -> set[str]:
799800
"""Gets the reachable frunctions from a given function."""

src/fuzz_introspector/frontends/frontend_go.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def generate_report(self,
247247

248248
def extract_calltree(self,
249249
source_file: str = '',
250-
source_code: Optional[GoSourceCodeFile] = None,
250+
source_code: Optional[SourceCodeFile] = None,
251251
function: Optional[str] = None,
252252
visited_functions: Optional[set[str]] = None,
253253
depth: int = 0,
@@ -271,7 +271,7 @@ def extract_calltree(self,
271271
line_to_print += ' '
272272
line_to_print += source_file
273273

274-
if not source_code:
274+
if not source_code or not isinstance(source_code, GoSourceCodeFile):
275275
source_code = self.find_source_with_func_def(function)
276276

277277
line_to_print += ' '
@@ -301,7 +301,7 @@ def extract_calltree(self,
301301
def get_reachable_functions(
302302
self,
303303
source_file: str = '',
304-
source_code: Optional[GoSourceCodeFile] = None,
304+
source_code: Optional[SourceCodeFile] = None,
305305
function: Optional[str] = None,
306306
visited_functions: Optional[set[str]] = None) -> set[str]:
307307
"""Get a list of reachable functions for a provided function name."""
@@ -318,7 +318,7 @@ def get_reachable_functions(
318318
if not source_code and function:
319319
source_code = self.find_source_with_func_def(function)
320320

321-
if not source_code:
321+
if not source_code or not isinstance(source_code, GoSourceCodeFile):
322322
visited_functions.add(function)
323323
return visited_functions
324324

src/fuzz_introspector/frontends/frontend_jvm.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,7 +1211,7 @@ def _recursive_method_depth(method: JavaMethod) -> int:
12111211

12121212
def extract_calltree(self,
12131213
source_file: str = '',
1214-
source_code: Optional[JvmSourceCodeFile] = None,
1214+
source_code: Optional[SourceCodeFile] = None,
12151215
function: Optional[str] = None,
12161216
visited_functions: Optional[set[str]] = None,
12171217
depth: int = 0,
@@ -1228,7 +1228,10 @@ def extract_calltree(self,
12281228
source_code = self.find_source_with_method(function)
12291229

12301230
if not function and source_code:
1231-
function = source_code.get_entry_method_name(True)
1231+
if not isinstance(source_code, JvmSourceCodeFile):
1232+
function = source_code.get_entry_function_name()
1233+
else:
1234+
function = source_code.get_entry_method_name(True)
12321235

12331236
if not function:
12341237
return ''
@@ -1241,7 +1244,7 @@ def extract_calltree(self,
12411244
line_to_print += str(line_number)
12421245
line_to_print += '\n'
12431246

1244-
if not source_code:
1247+
if not source_code or not isinstance(source_code, JvmSourceCodeFile):
12451248
return line_to_print
12461249

12471250
function_node = source_code.get_method_node(function)
@@ -1270,7 +1273,7 @@ def get_source_codes_with_harnesses(self) -> list[JvmSourceCodeFile]:
12701273
def get_reachable_functions(
12711274
self,
12721275
source_file: str = '',
1273-
source_code: Optional[JvmSourceCodeFile] = None,
1276+
source_code: Optional[SourceCodeFile] = None,
12741277
function: Optional[str] = None,
12751278
visited_functions: Optional[set[str]] = None) -> set[str]:
12761279
"""Get a list of reachable functions for a provided function name."""
@@ -1281,9 +1284,15 @@ def get_reachable_functions(
12811284
source_code = self.find_source_with_method(function)
12821285

12831286
if not function and source_code:
1284-
function = source_code.get_entry_method_name(True)
1287+
if not isinstance(source_code, JvmSourceCodeFile):
1288+
function = source_code.get_entry_function_name()
1289+
else:
1290+
function = source_code.get_entry_method_name(True)
12851291

12861292
if source_code and function:
1293+
if not isinstance(source_code, JvmSourceCodeFile):
1294+
return visited_functions
1295+
12871296
function_node = source_code.get_method_node(function)
12881297
if not function_node:
12891298
visited_functions.add(function)

src/fuzz_introspector/frontends/frontend_rust.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,8 @@ def _recursive_function_depth(function: RustFunction) -> int:
754754

755755
def extract_calltree(self,
756756
source_file: str = '',
757-
source_code: Optional[RustSourceCodeFile] = None,
757+
source_code: Optional[
758+
datatypes.SourceCodeFile] = None,
758759
function: Optional[str] = None,
759760
visited_functions: Optional[set[str]] = None,
760761
depth: int = 0,
@@ -775,6 +776,9 @@ def extract_calltree(self,
775776
source_code = self._find_source_with_function(function)
776777

777778
if not function and source_code:
779+
if not isinstance(source_code, RustSourceCodeFile):
780+
return ''
781+
778782
func_node = source_code.get_entry_function()
779783
if func_node:
780784
function = func_node.name
@@ -827,7 +831,7 @@ def extract_calltree(self,
827831
def get_reachable_functions(
828832
self,
829833
source_file: str = '',
830-
source_code: Optional[RustSourceCodeFile] = None,
834+
source_code: Optional[datatypes.SourceCodeFile] = None,
831835
function: Optional[str] = None,
832836
visited_functions: Optional[set[str]] = None) -> set[str]:
833837
"""Get a list of reachable functions for a provided function name."""
@@ -841,6 +845,9 @@ def get_reachable_functions(
841845
source_code = self._find_source_with_function(function)
842846

843847
if not function and source_code:
848+
if not isinstance(source_code, RustSourceCodeFile):
849+
return visited_functions
850+
844851
func_node = source_code.get_entry_function()
845852
if func_node:
846853
function = func_node.name

0 commit comments

Comments
 (0)