22
33import networkx as nx
44
5+ from graph_sitter .core .class_definition import Class
56from graph_sitter .core .codebase import CodebaseType
7+ from graph_sitter .core .detached_symbols .function_call import FunctionCall
68from graph_sitter .core .external_module import ExternalModule
79from graph_sitter .core .function import Function
8- from graph_sitter .core .interfaces .callable import Callable , FunctionCallDefinition
10+ from graph_sitter .core .interfaces .callable import Callable
911from graph_sitter .enums import ProgrammingLanguage
1012from graph_sitter .skills .core .skill import Skill
1113from graph_sitter .skills .core .skill_test import SkillTestCase , SkillTestCasePyFile
@@ -69,7 +71,7 @@ def skill_func(codebase: CodebaseType):
6971 # ===== [ Maximum Recursive Depth ] =====
7072 MAX_DEPTH = 5
7173
72- def create_downstream_call_trace (parent : FunctionCallDefinition | Function | None = None , depth : int = 0 ):
74+ def create_downstream_call_trace (parent : FunctionCall | Function | None = None , depth : int = 0 ):
7375 """Creates call graph for parent
7476
7577 This function recurses through the call graph of a function and creates a visualization
@@ -82,20 +84,14 @@ def create_downstream_call_trace(parent: FunctionCallDefinition | Function | Non
8284 # if the maximum recursive depth has been exceeded return
8385 if MAX_DEPTH <= depth :
8486 return
85- # if parent is of type Function
86- if isinstance (parent , Function ):
87- # set both src_call, src_func to parent
88- src_call , src_func = parent , parent
87+ if isinstance (parent , FunctionCall ):
88+ src_call , src_func = parent , parent .function_definition
8989 else :
90- # get the first callable of parent
91- src_func = parent .callables [0 ]
92- src_call = parent .call
90+ src_call , src_func = parent , parent
9391 # Iterate over all call paths of the symbol
94- for func_call_def in src_func .call_graph_successors ():
95- # the call of a function
96- call = func_call_def .call
92+ for call in src_func .function_calls :
9793 # the symbol being called
98- func = func_call_def . callables [ 0 ]
94+ func = call . function_definition
9995
10096 # ignore direct recursive calls
10197 if func .name == src_func .name :
@@ -108,7 +104,7 @@ def create_downstream_call_trace(parent: FunctionCallDefinition | Function | Non
108104 G .add_edge (src_call , call )
109105
110106 # recursive call to function call
111- create_downstream_call_trace (func_call_def , depth + 1 )
107+ create_downstream_call_trace (call , depth + 1 )
112108 elif GRAPH_EXERNAL_MODULE_CALLS :
113109 # add `call` to the graph and an edge from `src_call` to `call`
114110 G .add_node (call )
@@ -187,12 +183,12 @@ def function_to_trace():
187183class CallGraphFilter (Skill , ABC ):
188184 """This skill shows a visualization of the call graph from a given function or symbol.
189185 It iterates through the usages of the starting function and its subsequent calls,
190- creating a directed graph of function calls. The skill filters out test files and
191- includes only methods with specific names (post, get, patch, delete).
192- By default, the call graph uses red for the starting node, yellow for class methods,
186+ creating a directed graph of function calls. The skill filters out test files and class declarations
187+ and includes only methods with specific names (post, get, patch, delete).
188+ The call graph uses red for the starting node, yellow for class methods,
193189 and can be customized based on user requests. The graph is limited to a specified depth
194- to manage complexity. In its current form,
195- it ignores recursive calls and external modules but can be modified trivially to include them
190+ to manage complexity. In its current form, it ignores recursive calls and external modules
191+ but can be modified trivially to include them
196192 """
197193
198194 @staticmethod
@@ -211,30 +207,30 @@ def skill_func(codebase: CodebaseType):
211207 # ===== [ Maximum Recursive Depth ] =====
212208 MAX_DEPTH = 5
213209
210+ SKIP_CLASS_DECLARATIONS = True
211+
214212 cls = codebase .get_class ("MyClass" )
215213
216214 # Define a recursive function to traverse function calls
217- def create_filtered_downstream_call_trace (parent_func : FunctionCallDefinition | Function , current_depth , max_depth ):
215+ def create_filtered_downstream_call_trace (parent : FunctionCall | Function , current_depth , max_depth ):
218216 if current_depth > max_depth :
219217 return
220218
221219 # if parent is of type Function
222- if isinstance (parent_func , Function ):
220+ if isinstance (parent , Function ):
223221 # set both src_call, src_func to parent
224- src_call , src_func = parent_func , parent_func
222+ src_call , src_func = parent , parent
225223 else :
226224 # get the first callable of parent
227- src_func = parent_func .callables [0 ]
228- src_call = parent_func .call
225+ src_call , src_func = parent , parent .function_definition
229226
230227 # Iterate over all call paths of the symbol
231- for func_call_def in src_func .call_graph_successors ():
232- # the call of a function
233- call = func_call_def .call
228+ for call in src_func .function_calls :
234229 # the symbol being called
235- func = func_call_def . callables [ 0 ]
230+ func = call . function_definition
236231
237- # Skip the successor if the file name starts with 'test'
232+ if SKIP_CLASS_DECLARATIONS and isinstance (func , Class ):
233+ continue
238234
239235 # if the function being called is not from an external module and is not defined in a test file
240236 if not isinstance (func , ExternalModule ) and not func .file .filepath .startswith ("test" ):
@@ -247,7 +243,7 @@ def create_filtered_downstream_call_trace(parent_func: FunctionCallDefinition |
247243 G .add_edge (src_call , call , symbol = cls ) # Add edge from current to successor
248244
249245 # Recursively add successors of the current symbol
250- create_filtered_downstream_call_trace (func_call_def , current_depth + 1 , max_depth )
246+ create_filtered_downstream_call_trace (call , current_depth + 1 , max_depth )
251247
252248 # Start the recursive traversal
253249 create_filtered_downstream_call_trace (func_to_trace , 1 , MAX_DEPTH )
@@ -301,25 +297,22 @@ def skill_func(codebase: CodebaseType):
301297 MAX_DEPTH = 5
302298
303299 # Define a recursive function to traverse usages
304- def create_downstream_call_trace (parent_func : FunctionCallDefinition | Function , end : Callable , current_depth , max_depth ):
300+ def create_downstream_call_trace (parent : FunctionCall | Function , end : Callable , current_depth , max_depth ):
305301 if current_depth > max_depth :
306302 return
307303
308304 # if parent is of type Function
309- if isinstance (parent_func , Function ):
305+ if isinstance (parent , Function ):
310306 # set both src_call, src_func to parent
311- src_call , src_func = parent_func , parent_func
307+ src_call , src_func = parent , parent
312308 else :
313309 # get the first callable of parent
314- src_func = parent_func .callables [0 ]
315- src_call = parent_func .call
310+ src_call , src_func = parent , parent .function_definition
316311
317312 # Iterate over all call paths of the symbol
318- for func_call_def in src_func .call_graph_successors ():
319- # the call of a function
320- call = func_call_def .call
313+ for call in src_func .function_calls :
321314 # the symbol being called
322- func = func_call_def . callables [ 0 ]
315+ func = call . function_definition
323316
324317 # ignore direct recursive calls
325318 if func .name == src_func .name :
@@ -335,7 +328,7 @@ def create_downstream_call_trace(parent_func: FunctionCallDefinition | Function,
335328 G .add_edge (call , end )
336329 return
337330 # recursive call to function call
338- create_downstream_call_trace (func_call_def , end , current_depth + 1 , max_depth )
331+ create_downstream_call_trace (call , end , current_depth + 1 , max_depth )
339332
340333 # Get the start and end function
341334 start = codebase .get_function ("start_func" )
0 commit comments