diff --git a/src/graph_sitter/core/interfaces/callable.py b/src/graph_sitter/core/interfaces/callable.py index 72baddab0..019d421bd 100644 --- a/src/graph_sitter/core/interfaces/callable.py +++ b/src/graph_sitter/core/interfaces/callable.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, Literal, Self, TypeVar, overload +from typing import TYPE_CHECKING, Generic, Self, TypeVar from graph_sitter.core.autocommit import reader from graph_sitter.core.detached_symbols.function_call import FunctionCall @@ -125,97 +125,3 @@ def get_parameter_by_type(self, type: "Symbol") -> TParameter | None: if self._parameters is None: return None return next((x for x in self._parameters if x.type == type), None) - - @overload - def call_graph_successors( - self, - *, - include_classes: Literal[False], - include_external: Literal[False], - ) -> list["Function"]: ... - - @overload - def call_graph_successors( - self, - *, - include_classes: Literal[False], - include_external: Literal[True] = ..., - ) -> list["Function | ExternalModule"]: ... - - @overload - def call_graph_successors( - self, - *, - include_classes: Literal[True] = ..., - include_external: Literal[False], - ) -> list["Function | Class"]: ... - - @overload - def call_graph_successors( - self, - *, - include_classes: Literal[True] = ..., - include_external: Literal[True] = ..., - ) -> list["Function | Class | ExternalModule"]: ... - - @reader - def call_graph_successors( - self, - *, - include_classes: bool = True, - include_external: bool = True, - ) -> list[FunctionCallDefinition]: - """Returns all function call definitions that are reachable from this callable. - - Analyzes the callable's implementation to find all function calls and their corresponding definitions. For classes, if a constructor exists, - returns the call graph successors of the constructor; otherwise returns an empty list. - - Args: - include_classes (bool): If True, includes class definitions in the results. Defaults to True. - include_external (bool): If True, includes external module definitions in the results. Defaults to True. - - Returns: - list[FunctionCallDefinition]: A list of FunctionCallDefinition objects, each containing a function call and its - possible callable definitions (Functions, Classes, or ExternalModules based on include flags). Returns empty list - for non-block symbols or classes without constructors. - """ - from graph_sitter.core.class_definition import Class - from graph_sitter.core.external_module import ExternalModule - from graph_sitter.core.function import Function - from graph_sitter.core.interfaces.has_block import HasBlock - - call_graph_successors: list[FunctionCallDefinition] = [] - - # Check if Callable has function_calls: - if isinstance(self, HasBlock): - # Special handling for classes. - # Classes with no constructors are not included in the code paths. Else, the code path of the constructor is included. - if isinstance(self, Class): - if self.constructor: - return self.constructor.call_graph_successors(include_classes=include_classes, include_external=include_external) - else: - return [] - - for call in self.function_calls: - call_graph_successor = FunctionCallDefinition(call, []) - # Extract function definition - for call_func in call.function_definitions: - # Case - Function with definition - if isinstance(call_func, Function): - call_graph_successor.callables.append(call_func) - - # =====[ Extract `__init__` from classes ]===== - elif isinstance(call_func, Class): - if include_classes: - call_graph_successor.callables.append(call_func) - # Case - external module (leaf node) - elif isinstance(call_func, ExternalModule) and include_external: - call_graph_successor.callables.append(call_func) - - if len(call_graph_successor.callables) > 0: - call_graph_successors.append(call_graph_successor) - else: - # Non-block symbols will not have any function calls - return [] - - return call_graph_successors diff --git a/src/graph_sitter/skills/graph_viz/graph_viz_call_graph.py b/src/graph_sitter/skills/graph_viz/graph_viz_call_graph.py index cb197b4d5..88f6e1594 100644 --- a/src/graph_sitter/skills/graph_viz/graph_viz_call_graph.py +++ b/src/graph_sitter/skills/graph_viz/graph_viz_call_graph.py @@ -2,10 +2,12 @@ import networkx as nx +from graph_sitter.core.class_definition import Class from graph_sitter.core.codebase import CodebaseType +from graph_sitter.core.detached_symbols.function_call import FunctionCall from graph_sitter.core.external_module import ExternalModule from graph_sitter.core.function import Function -from graph_sitter.core.interfaces.callable import Callable, FunctionCallDefinition +from graph_sitter.core.interfaces.callable import Callable from graph_sitter.enums import ProgrammingLanguage from graph_sitter.skills.core.skill import Skill from graph_sitter.skills.core.skill_test import SkillTestCase, SkillTestCasePyFile @@ -69,7 +71,7 @@ def skill_func(codebase: CodebaseType): # ===== [ Maximum Recursive Depth ] ===== MAX_DEPTH = 5 - def create_downstream_call_trace(parent: FunctionCallDefinition | Function | None = None, depth: int = 0): + def create_downstream_call_trace(parent: FunctionCall | Function | None = None, depth: int = 0): """Creates call graph for parent This function recurses through the call graph of a function and creates a visualization @@ -82,20 +84,14 @@ def create_downstream_call_trace(parent: FunctionCallDefinition | Function | Non # if the maximum recursive depth has been exceeded return if MAX_DEPTH <= depth: return - # if parent is of type Function - if isinstance(parent, Function): - # set both src_call, src_func to parent - src_call, src_func = parent, parent + if isinstance(parent, FunctionCall): + src_call, src_func = parent, parent.function_definition else: - # get the first callable of parent - src_func = parent.callables[0] - src_call = parent.call + src_call, src_func = parent, parent # Iterate over all call paths of the symbol - for func_call_def in src_func.call_graph_successors(): - # the call of a function - call = func_call_def.call + for call in src_func.function_calls: # the symbol being called - func = func_call_def.callables[0] + func = call.function_definition # ignore direct recursive calls if func.name == src_func.name: @@ -108,7 +104,7 @@ def create_downstream_call_trace(parent: FunctionCallDefinition | Function | Non G.add_edge(src_call, call) # recursive call to function call - create_downstream_call_trace(func_call_def, depth + 1) + create_downstream_call_trace(call, depth + 1) elif GRAPH_EXERNAL_MODULE_CALLS: # add `call` to the graph and an edge from `src_call` to `call` G.add_node(call) @@ -187,12 +183,12 @@ def function_to_trace(): class CallGraphFilter(Skill, ABC): """This skill shows a visualization of the call graph from a given function or symbol. It iterates through the usages of the starting function and its subsequent calls, - creating a directed graph of function calls. The skill filters out test files and - includes only methods with specific names (post, get, patch, delete). - By default, the call graph uses red for the starting node, yellow for class methods, + creating a directed graph of function calls. The skill filters out test files and class declarations + and includes only methods with specific names (post, get, patch, delete). + The call graph uses red for the starting node, yellow for class methods, and can be customized based on user requests. The graph is limited to a specified depth - to manage complexity. In its current form, - it ignores recursive calls and external modules but can be modified trivially to include them + to manage complexity. In its current form, it ignores recursive calls and external modules + but can be modified trivially to include them """ @staticmethod @@ -211,30 +207,30 @@ def skill_func(codebase: CodebaseType): # ===== [ Maximum Recursive Depth ] ===== MAX_DEPTH = 5 + SKIP_CLASS_DECLARATIONS = True + cls = codebase.get_class("MyClass") # Define a recursive function to traverse function calls - def create_filtered_downstream_call_trace(parent_func: FunctionCallDefinition | Function, current_depth, max_depth): + def create_filtered_downstream_call_trace(parent: FunctionCall | Function, current_depth, max_depth): if current_depth > max_depth: return # if parent is of type Function - if isinstance(parent_func, Function): + if isinstance(parent, Function): # set both src_call, src_func to parent - src_call, src_func = parent_func, parent_func + src_call, src_func = parent, parent else: # get the first callable of parent - src_func = parent_func.callables[0] - src_call = parent_func.call + src_call, src_func = parent, parent.function_definition # Iterate over all call paths of the symbol - for func_call_def in src_func.call_graph_successors(): - # the call of a function - call = func_call_def.call + for call in src_func.function_calls: # the symbol being called - func = func_call_def.callables[0] + func = call.function_definition - # Skip the successor if the file name starts with 'test' + if SKIP_CLASS_DECLARATIONS and isinstance(func, Class): + continue # if the function being called is not from an external module and is not defined in a test file if not isinstance(func, ExternalModule) and not func.file.filepath.startswith("test"): @@ -247,7 +243,7 @@ def create_filtered_downstream_call_trace(parent_func: FunctionCallDefinition | G.add_edge(src_call, call, symbol=cls) # Add edge from current to successor # Recursively add successors of the current symbol - create_filtered_downstream_call_trace(func_call_def, current_depth + 1, max_depth) + create_filtered_downstream_call_trace(call, current_depth + 1, max_depth) # Start the recursive traversal create_filtered_downstream_call_trace(func_to_trace, 1, MAX_DEPTH) @@ -301,25 +297,22 @@ def skill_func(codebase: CodebaseType): MAX_DEPTH = 5 # Define a recursive function to traverse usages - def create_downstream_call_trace(parent_func: FunctionCallDefinition | Function, end: Callable, current_depth, max_depth): + def create_downstream_call_trace(parent: FunctionCall | Function, end: Callable, current_depth, max_depth): if current_depth > max_depth: return # if parent is of type Function - if isinstance(parent_func, Function): + if isinstance(parent, Function): # set both src_call, src_func to parent - src_call, src_func = parent_func, parent_func + src_call, src_func = parent, parent else: # get the first callable of parent - src_func = parent_func.callables[0] - src_call = parent_func.call + src_call, src_func = parent, parent.function_definition # Iterate over all call paths of the symbol - for func_call_def in src_func.call_graph_successors(): - # the call of a function - call = func_call_def.call + for call in src_func.function_calls: # the symbol being called - func = func_call_def.callables[0] + func = call.function_definition # ignore direct recursive calls if func.name == src_func.name: @@ -335,7 +328,7 @@ def create_downstream_call_trace(parent_func: FunctionCallDefinition | Function, G.add_edge(call, end) return # recursive call to function call - create_downstream_call_trace(func_call_def, end, current_depth + 1, max_depth) + create_downstream_call_trace(call, end, current_depth + 1, max_depth) # Get the start and end function start = codebase.get_function("start_func") diff --git a/tests/unit/python/function/test_function_call_graph_successors.py b/tests/unit/python/function/test_function_call_graph_successors.py deleted file mode 100644 index cf4ac8815..000000000 --- a/tests/unit/python/function/test_function_call_graph_successors.py +++ /dev/null @@ -1,115 +0,0 @@ -from graph_sitter.codebase.factory.get_session import get_codebase_session -from graph_sitter.core.external_module import ExternalModule - - -def test_function_call_graph_successors(tmpdir) -> None: - # language=python - content = """ -def f(tmpdir): - pass - -def g(tmpdir): - return f() -""" - with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: - file = codebase.get_file("test.py") - f = file.get_function("f") - g = file.get_function("g") - call_graph_successors = g.call_graph_successors() - function_call = call_graph_successors[0].call - function_called = call_graph_successors[0].callables[0] - assert len(call_graph_successors) == 1 - assert function_called == f - assert function_call.source == "f()" - assert list(function_call.line_range) == [5] - - -def test_function_multiple_call_graph_successors(tmpdir) -> None: - # language=python - content = """ -def f1(tmpdir): - pass - -def f2(tmpdir): - pass - -def f3(tmpdir): - pass - -def g(tmpdir): - return f1() + f2() + f3() -""" - with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: - file = codebase.get_file("test.py") - g = file.get_function("g") - call_graph_successors = g.call_graph_successors() - function_calls = [call_graph_successor.call for call_graph_successor in call_graph_successors] - functions_called = [callable for call_graph_successor in call_graph_successors for callable in call_graph_successor.callables] - assert len(call_graph_successors) == 3 - assert set(functions_called) == {file.get_function("f1"), file.get_function("f2"), file.get_function("f3")} - assert function_calls[0].source == "f1()" - assert function_calls[1].source == "f2()" - assert function_calls[2].source == "f3()" - assert list(function_calls[0].line_range) == [11] - - -def test_function_class_call_graph_successors(tmpdir) -> None: - # language=python - content = """ -class A(): - pass - -class B(): - def __init__(self): - pass - -def foo(): - a = A() - b = B() -""" - with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: - file = codebase.get_file("test.py") - foo = file.get_function("foo") - - call_graph_successors = foo.call_graph_successors() - function_calls = [call_graph_successor.call for call_graph_successor in call_graph_successors] - functions_called = [callable for call_graph_successor in call_graph_successors for callable in call_graph_successor.callables] - assert set(functions_called) == {file.get_class("A"), file.get_class("B").get_method("__init__")} - assert len(call_graph_successors) == 2 - assert function_calls[0].source == "A()" - assert function_calls[1].source == "B()" - assert list(function_calls[0].line_range) == [9] - assert list(function_calls[1].line_range) == [10] - - call_graph_successors2 = foo.call_graph_successors(include_classes=False) - function_calls2 = [call_graph_successor.call for call_graph_successor in call_graph_successors2] - functions_called2 = [callable for call_graph_successor in call_graph_successors2 for callable in call_graph_successor.callables] - assert len(call_graph_successors2) == 1 - assert functions_called2[0] == file.get_class("B").get_method("__init__") - assert function_calls2[0].source == "B()" - assert list(function_calls2[0].line_range) == [10] - - -def test_function_ext_call_graph_successors(tmpdir) -> None: - # language=python - content = """ -from a import b - -def foo(): - thing = b() -""" - with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: - file = codebase.get_file("test.py") - foo = file.get_function("foo") - - call_graph_successors = foo.call_graph_successors() - function_calls = [call_graph_successor.call for call_graph_successor in call_graph_successors] - functions_called = [callable for call_graph_successor in call_graph_successors for callable in call_graph_successor.callables] - assert len(call_graph_successors) == 1 - assert isinstance(functions_called[0], ExternalModule) - assert functions_called[0].name == "b" - assert function_calls[0].source == "b()" - assert list(function_calls[0].line_range) == [4] - - call_graph_successors2 = foo.call_graph_successors(include_external=False) - assert len(call_graph_successors2) == 0 diff --git a/tests/unit/skills/snapshots/test_skills/test_all_example_skills/call_graph_filter-PYTHON-case-0/call_graph_filter_unnamed.json b/tests/unit/skills/snapshots/test_skills/test_all_example_skills/call_graph_filter-PYTHON-case-0/call_graph_filter_unnamed.json index d6131b35b..622bdb147 100644 --- a/tests/unit/skills/snapshots/test_skills/test_all_example_skills/call_graph_filter-PYTHON-case-0/call_graph_filter_unnamed.json +++ b/tests/unit/skills/snapshots/test_skills/test_all_example_skills/call_graph_filter-PYTHON-case-0/call_graph_filter_unnamed.json @@ -24,25 +24,6 @@ "symbol_name": "PyFunction", "id": 17 }, - { - "name": "MyClass", - "text": null, - "code": null, - "color": null, - "shape": null, - "start_point": [ - 4, - 15 - ], - "emoji": null, - "end_point": [ - 4, - 24 - ], - "file_path": "path/to/file1.py", - "symbol_name": "FunctionCall", - "id": "range= filepath='path/to/file1.py'" - }, { "name": "MyClass.get", "text": null, @@ -159,26 +140,6 @@ } ], "links": [ - { - "name": "MyClass", - "text": null, - "code": null, - "color": null, - "shape": null, - "start_point": [ - 1, - 0 - ], - "emoji": null, - "end_point": [ - 21, - 51 - ], - "file_path": "path/to/file.py", - "symbol_name": "PyClass", - "source": 17, - "target": "range= filepath='path/to/file1.py'" - }, { "name": "MyClass", "text": null, diff --git a/tests/unit/typescript/function/test_function_call_graph_successors.py b/tests/unit/typescript/function/test_function_call_graph_successors.py deleted file mode 100644 index b2c49f955..000000000 --- a/tests/unit/typescript/function/test_function_call_graph_successors.py +++ /dev/null @@ -1,122 +0,0 @@ -import pytest - -from graph_sitter.codebase.factory.get_session import get_codebase_session -from graph_sitter.core.external_module import ExternalModule -from graph_sitter.enums import ProgrammingLanguage - - -def test_function_call_graph_successors(tmpdir) -> None: - # language=typescript - content = """ -function f() { - return; -} - -function g() { - return f(); -} -""" - with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.TYPESCRIPT, files={"test.ts": content}) as codebase: - file = codebase.get_file("test.ts") - f = file.get_function("f") - g = file.get_function("g") - call_graph_successors = g.call_graph_successors() - function_calls = [call_graph_successor.call for call_graph_successor in call_graph_successors] - functions_called = [callable for call_graph_successor in call_graph_successors for callable in call_graph_successor.callables] - assert len(function_calls) == 1 - assert list(function_calls[0].line_range) == [6] - assert functions_called[0] == f - - -def test_function_multiple_call_graph_successors(tmpdir) -> None: - # language=typescript - content = """ -function f1() { - return; -} - -function f2() { - return; -} - -function f3() { - return; -} - -function g() { - return f1(), f2(), f3(); -} -""" - with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.TYPESCRIPT, files={"test.ts": content}) as codebase: - file = codebase.get_file("test.ts") - g = file.get_function("g") - call_graph_successors = g.call_graph_successors() - function_calls = [call_graph_successor.call for call_graph_successor in call_graph_successors] - functions_called = [callable for call_graph_successor in call_graph_successors for callable in call_graph_successor.callables] - assert len(function_calls) == 3 - assert list(function_calls[0].line_range) == [14] - assert list(function_calls[1].line_range) == [14] - assert list(function_calls[2].line_range) == [14] - assert set(functions_called) == {file.get_function("f1"), file.get_function("f2"), file.get_function("f3")} - - -@pytest.mark.skip("TODO: Classes not in function calls for some reason. TODO @edward") -def test_function_class_call_graph_successors(tmpdir) -> None: - # language=typescript - content = """ -class A { - constructor() {} -} - -class B { - constructor() {} -} - -function foo() { - const a = new A(); - const b = new B(); -} -""" - with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.TYPESCRIPT, files={"test.ts": content}) as codebase: - file = codebase.get_file("test.ts") - foo = file.get_function("foo") - - call_graph_successors = foo.call_graph_successors() - function_calls = [call_graph_successor.call for call_graph_successor in call_graph_successors] - functions_called = [callable for call_graph_successor in call_graph_successors for callable in call_graph_successor.callables] - assert len(function_calls) == 2 - assert list(function_calls[0].line_range) == [10] - assert list(function_calls[1].line_range) == [11] - assert set(functions_called) == {file.get_class("A").constructor, file.get_class("B").constructor} - - call_graph_successors2 = foo.call_graph_successors(include_classes=False) - function_calls2 = [call_graph_successor.call for call_graph_successor in call_graph_successors2] - functions_called2 = [callable for call_graph_successor in call_graph_successors2 for callable in call_graph_successor.callables] - assert len(function_calls2) == 1 - assert list(function_calls2[0].line_range) == [14] - assert set(functions_called2) == {file.get_class("B").constructor} - - -def test_function_ext_call_graph_successors(tmpdir) -> None: - # language=typescript - content = """ -import { b } from 'a'; - -function foo() { - const thing = b(); -} -""" - with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.TYPESCRIPT, files={"test.ts": content}) as codebase: - file = codebase.get_file("test.ts") - foo = file.get_function("foo") - - call_graph_successors = foo.call_graph_successors() - function_calls = [call_graph_successor.call for call_graph_successor in call_graph_successors] - functions_called = [callable for call_graph_successor in call_graph_successors for callable in call_graph_successor.callables] - assert len(functions_called) == 1 - assert isinstance(functions_called[0], ExternalModule) - assert functions_called[0].name == "b" - assert list(function_calls[0].line_range) == [4] - - call_graph_successors2 = foo.call_graph_successors(include_external=False) - assert len(call_graph_successors2) == 0