From 02684b42b40e03e75caa6b86457a05cd0cc7004d Mon Sep 17 00:00:00 2001 From: "codegen-sh[bot]" <131295404+codegen-sh[bot]@users.noreply.github.com> Date: Wed, 14 May 2025 14:37:51 +0000 Subject: [PATCH] Organize visualization files into structured subdirectories --- .../analyzers/visualization/README.md | 44 + .../visualization/call_graph/__init__.py | 6 + .../visualization/call_graph/call_trace.py | 83 + .../call_graph/graph_viz_call_graph.py | 358 ++++ .../call_graph/method_relationships.py | 107 ++ .../visualization/call_graph/viz_cal_graph.py | 121 ++ .../visualization/codebase_visualizer.py | 1687 +++++++++++++++-- .../dependency_graph/__init__.py | 6 + .../dependency_graph/blast_radius.py | 119 ++ .../dependency_graph/dependency_trace.py | 83 + .../dependency_graph/viz_dead_code.py | 154 ++ .../analyzers/visualization/docs/__init__.py | 6 + .../docs/codebase-visualization.mdx | 399 ++++ .../visualization/structure_graph/__init__.py | 6 + .../structure_graph/graph_viz_dir_tree.py | 111 ++ .../structure_graph/graph_viz_foreign_key.py | 178 ++ 16 files changed, 3270 insertions(+), 198 deletions(-) create mode 100644 codegen-on-oss/codegen_on_oss/analyzers/visualization/README.md create mode 100644 codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/__init__.py create mode 100644 codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/call_trace.py create mode 100644 codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/graph_viz_call_graph.py create mode 100644 codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/method_relationships.py create mode 100644 codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/viz_cal_graph.py create mode 100644 codegen-on-oss/codegen_on_oss/analyzers/visualization/dependency_graph/__init__.py create mode 100644 codegen-on-oss/codegen_on_oss/analyzers/visualization/dependency_graph/blast_radius.py create mode 100644 codegen-on-oss/codegen_on_oss/analyzers/visualization/dependency_graph/dependency_trace.py create mode 100644 codegen-on-oss/codegen_on_oss/analyzers/visualization/dependency_graph/viz_dead_code.py create mode 100644 codegen-on-oss/codegen_on_oss/analyzers/visualization/docs/__init__.py create mode 100644 codegen-on-oss/codegen_on_oss/analyzers/visualization/docs/codebase-visualization.mdx create mode 100644 codegen-on-oss/codegen_on_oss/analyzers/visualization/structure_graph/__init__.py create mode 100644 codegen-on-oss/codegen_on_oss/analyzers/visualization/structure_graph/graph_viz_dir_tree.py create mode 100644 codegen-on-oss/codegen_on_oss/analyzers/visualization/structure_graph/graph_viz_foreign_key.py diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/README.md b/codegen-on-oss/codegen_on_oss/analyzers/visualization/README.md new file mode 100644 index 000000000..2595849ea --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/README.md @@ -0,0 +1,44 @@ +# Codebase Visualization + +This directory contains tools and utilities for visualizing various aspects of a codebase. + +## Directory Structure + +- **call_graph/**: Visualizations related to function call relationships and method interactions + - `call_trace.py`: Traces function call paths through a codebase + - `graph_viz_call_graph.py`: Creates directed call graphs for functions + - `method_relationships.py`: Visualizes relationships between methods in a class + - `viz_cal_graph.py`: Generates call graphs with detailed metadata + +- **dependency_graph/**: Visualizations related to code dependencies and impact analysis + - `blast_radius.py`: Shows the "blast radius" of changes to a function + - `dependency_trace.py`: Traces symbol dependencies through a codebase + - `viz_dead_code.py`: Identifies and visualizes dead/unused code + +- **structure_graph/**: Visualizations related to code structure and organization + - `graph_viz_dir_tree.py`: Displays directory structure as a graph + - `graph_viz_foreign_key.py`: Visualizes database schema relationships + +- **docs/**: Documentation and examples for visualization tools + - `codebase-visualization.mdx`: Comprehensive guide to codebase visualization + +## Base Visualization Files + +- `analysis_visualizer.py`: Core visualization for analysis results +- `code_visualizer.py`: Visualization tools for code elements +- `codebase_visualizer.py`: Main visualization engine for codebases +- `visualizer.py`: Base visualization framework + +## Usage + +These visualization tools can be used to: + +1. Understand complex codebases +2. Plan refactoring efforts +3. Identify tightly coupled components +4. Analyze critical paths +5. Document system architecture +6. Find dead code +7. Visualize database schemas +8. Understand directory structures + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/__init__.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/__init__.py new file mode 100644 index 000000000..e9e9da182 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/__init__.py @@ -0,0 +1,6 @@ +""" +Call Graph Visualization Module + +This module provides tools for visualizing call graphs and function relationships in a codebase. +""" + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/call_trace.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/call_trace.py new file mode 100644 index 000000000..85448ac4f --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/call_trace.py @@ -0,0 +1,83 @@ +import codegen +import networkx as nx +from codegen import Codebase +from codegen.sdk.core.class_definition import Class +from codegen.sdk.core.import_resolution import Import +from codegen.sdk.core.symbol import Symbol + +G = nx.DiGraph() + +IGNORE_EXTERNAL_MODULE_CALLS = True +IGNORE_CLASS_CALLS = False +MAX_DEPTH = 10 + +COLOR_PALETTE = { + "StartFunction": "#9cdcfe", # Light blue for the starting function + "PyFunction": "#a277ff", # Purple for Python functions + "PyClass": "#ffca85", # Orange for Python classes + "ExternalModule": "#f694ff", # Pink for external module references +} + +# Dictionary to track visited nodes and prevent cycles +visited = {} + + +def create_dependencies_visualization(symbol: Symbol, depth: int = 0): + """Creates a visualization of symbol dependencies in the codebase + + Recursively traverses the dependency tree of a symbol (function, class, etc.) + and creates a directed graph representation. Dependencies can be either direct + symbol references or imports. + + Args: + symbol (Symbol): The starting symbol whose dependencies will be mapped + depth (int): Current depth in the recursive traversal + """ + if depth >= MAX_DEPTH: + return + + for dep in symbol.dependencies: + dep_symbol = None + + if isinstance(dep, Symbol): + dep_symbol = dep + elif isinstance(dep, Import): + dep_symbol = dep.resolved_symbol if dep.resolved_symbol else None + + if dep_symbol: + G.add_node(dep_symbol, color=COLOR_PALETTE.get(dep_symbol.__class__.__name__, "#f694ff")) + G.add_edge(symbol, dep_symbol) + + if not isinstance(dep_symbol, Class): + create_dependencies_visualization(dep_symbol, depth + 1) + + +@codegen.function("visualize-symbol-dependencies") +def run(codebase: Codebase): + """Generate a visualization of symbol dependencies in a codebase. + + This codemod: + 1. Creates a directed graph of symbol dependencies starting from a target function + 2. Tracks relationships between functions, classes, and imports + 3. Generates a visual representation of the dependency hierarchy + """ + global G + G = nx.DiGraph() + + target_func = codebase.get_function("get_query_runner") + G.add_node(target_func, color=COLOR_PALETTE.get("StartFunction")) + + create_dependencies_visualization(target_func) + + print(G) + print("Use codegen.sh to visualize the graph!") + + +if __name__ == "__main__": + print("Initializing codebase...") + codebase = Codebase.from_repo("codegen-oss/posthog", commit="b174f2221ea4ae50e715eb6a7e70e9a2b0760800", language="python") + print(f"Codebase with {len(codebase.files)} files and {len(codebase.functions)} functions.") + print("Creating graph...") + + run(codebase) + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/graph_viz_call_graph.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/graph_viz_call_graph.py new file mode 100644 index 000000000..9fd770841 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/graph_viz_call_graph.py @@ -0,0 +1,358 @@ +from abc import ABC + +import networkx as nx + +from codegen.sdk.core.class_definition import Class +from codegen.sdk.core.codebase import CodebaseType +from codegen.sdk.core.detached_symbols.function_call import FunctionCall +from codegen.sdk.core.external_module import ExternalModule +from codegen.sdk.core.function import Function +from codegen.sdk.core.interfaces.callable import Callable +from codegen.shared.enums.programming_language import ProgrammingLanguage +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill +from tests.shared.skills.skill_test import SkillTestCase, SkillTestCasePyFile + +CallGraphFromNodeTest = SkillTestCase( + [ + SkillTestCasePyFile( + input=""" +def function_to_trace(): + Y() + Z() + +def Y(): + A() + +def Z(): + B() + +def A(): + pass + +def B(): + C() + +def C(): + pass +""", + filepath="example.py", + ) + ], + graph=True, +) + + +@skill(eval_skill=False, prompt="Show me a visualization of the call graph from X", uid="81e8fbb7-a00a-4e74-b9c2-24f79d24d389") +class CallGraphFromNode(Skill, ABC): + """This skill creates a directed call graph for a given function. Starting from the specified function, it recursively iterates + through its function calls and the functions called by them, building a graph of the call paths to a maximum depth. The root of the directed graph + is the starting function, each node represents a function call, and edge from node A to node B indicates that function A calls function B. In its current form, + it ignores recursive calls and external modules but can be modified trivially to include them. Furthermore, this skill can easily be adapted to support + creating a call graph for a class method. In order to do this one simply needs to replace + + `function_to_trace = codebase.get_function("function_to_trace")` + + with + + `function_to_trace = codebase.get_class("class_of_method_to_trace").get_method("method_to_trace")` + """ + + @staticmethod + @skill_impl(test_cases=[CallGraphFromNodeTest], language=ProgrammingLanguage.PYTHON) + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def skill_func(codebase: CodebaseType): + # Create a directed graph + G = nx.DiGraph() + + # ===== [ Whether to Graph External Modules] ===== + GRAPH_EXERNAL_MODULE_CALLS = False + + # ===== [ Maximum Recursive Depth ] ===== + MAX_DEPTH = 5 + + def create_downstream_call_trace(parent: FunctionCall | Function | None = None, depth: int = 0): + """Creates call graph for parent + + This function recurses through the call graph of a function and creates a visualization + + Args: + parent (FunctionCallDefinition| Function): The function for which a call graph will be created. + depth (int): The current depth of the recursive stack. + + """ + # if the maximum recursive depth has been exceeded return + if MAX_DEPTH <= depth: + return + if isinstance(parent, FunctionCall): + src_call, src_func = parent, parent.function_definition + else: + src_call, src_func = parent, parent + # Iterate over all call paths of the symbol + for call in src_func.function_calls: + # the symbol being called + func = call.function_definition + + # ignore direct recursive calls + if func.name == src_func.name: + continue + + # if the function being called is not from an external module + if not isinstance(func, ExternalModule): + # add `call` to the graph and an edge from `src_call` to `call` + G.add_node(call) + G.add_edge(src_call, call) + + # recursive call to function call + create_downstream_call_trace(call, depth + 1) + elif GRAPH_EXERNAL_MODULE_CALLS: + # add `call` to the graph and an edge from `src_call` to `call` + G.add_node(call) + G.add_edge(src_call, call) + + # ===== [ Function To Be Traced] ===== + function_to_trace = codebase.get_function("function_to_trace") + + # Set starting node + G.add_node(function_to_trace, color="yellow") + + # Add all the children (and sub-children) to the graph + create_downstream_call_trace(function_to_trace) + + # Visualize the graph + codebase.visualize(G) + + +CallGraphFilterTest = SkillTestCase( + [ + SkillTestCasePyFile( + input=""" +class MyClass: + def get(self): + self.helper_method() + return "GET request" + + def post(self): + self.helper_method() + return "POST request" + + def patch(self): + return "PATCH request" + + def delete(self): + return "DELETE request" + + def helper_method(self): + pass + + def other_method(self): + self.helper_method() + return "This method should not be included" + +def external_function(): + instance = MyClass() + instance.get() + instance.post() + instance.other_method() +""", + filepath="path/to/file.py", + ), + SkillTestCasePyFile( + input=""" +from path.to.file import MyClass + +def function_to_trace(): + instance = MyClass() + assert instance.get() == "GET request" + assert instance.post() == "POST request" + assert instance.patch() == "PATCH request" + assert instance.delete() == "DELETE request" +""", + filepath="path/to/file1.py", + ), + ], + graph=True, +) + + +@skill( + eval_skill=False, + prompt="Show me a visualization of the call graph from MyClass and filter out test files and include only the methods that have the name post, get, patch, delete", + uid="fc1f3ea0-46e7-460a-88ad-5312d4ca1a12", +) +class CallGraphFilter(Skill, ABC): + """This skill shows a visualization of the call graph from a given function or symbol. + It iterates through the usages of the starting function and its subsequent calls, + creating a directed graph of function calls. The skill filters out test files and class declarations + and includes only methods with specific names (post, get, patch, delete). + The call graph uses red for the starting node, yellow for class methods, + and can be customized based on user requests. The graph is limited to a specified depth + to manage complexity. In its current form, it ignores recursive calls and external modules + but can be modified trivially to include them + """ + + @staticmethod + @skill_impl(test_cases=[CallGraphFilterTest], language=ProgrammingLanguage.PYTHON) + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def skill_func(codebase: CodebaseType): + # Create a directed graph + G = nx.DiGraph() + + # Get the symbol for my_class + func_to_trace = codebase.get_function("function_to_trace") + + # Add the main symbol as a node + G.add_node(func_to_trace, color="red") + + # ===== [ Maximum Recursive Depth ] ===== + MAX_DEPTH = 5 + + SKIP_CLASS_DECLARATIONS = True + + cls = codebase.get_class("MyClass") + + # Define a recursive function to traverse function calls + def create_filtered_downstream_call_trace(parent: FunctionCall | Function, current_depth, max_depth): + if current_depth > max_depth: + return + + # if parent is of type Function + if isinstance(parent, Function): + # set both src_call, src_func to parent + src_call, src_func = parent, parent + else: + # get the first callable of parent + src_call, src_func = parent, parent.function_definition + + # Iterate over all call paths of the symbol + for call in src_func.function_calls: + # the symbol being called + func = call.function_definition + + if SKIP_CLASS_DECLARATIONS and isinstance(func, Class): + continue + + # if the function being called is not from an external module and is not defined in a test file + if not isinstance(func, ExternalModule) and not func.file.filepath.startswith("test"): + # add `call` to the graph and an edge from `src_call` to `call` + metadata = {} + if isinstance(func, Function) and func.is_method and func.name in ["post", "get", "patch", "delete"]: + name = f"{func.parent_class.name}.{func.name}" + metadata = {"color": "yellow", "name": name} + G.add_node(call, **metadata) + G.add_edge(src_call, call, symbol=cls) # Add edge from current to successor + + # Recursively add successors of the current symbol + create_filtered_downstream_call_trace(call, current_depth + 1, max_depth) + + # Start the recursive traversal + create_filtered_downstream_call_trace(func_to_trace, 1, MAX_DEPTH) + + # Visualize the graph + codebase.visualize(G) + + +CallPathsBetweenNodesTest = SkillTestCase( + [ + SkillTestCasePyFile( + input=""" +def start_func(): + intermediate_func() +def intermediate_func(): + end_func() + +def end_func(): + pass +""", + filepath="example.py", + ) + ], + graph=True, +) + + +@skill(eval_skill=False, prompt="Show me a visualization of the call paths between start_class and end_class", uid="aa3f70c3-ac1c-4737-a8b8-7ba89e3c5671") +class CallPathsBetweenNodes(Skill, ABC): + """This skill generates and visualizes a call graph between two specified functions. + It starts from a given function and iteratively traverses through its function calls, + building a directed graph of the call paths. The skill then identifies all simple paths between the + start and end functions, creating a subgraph that includes only the nodes in these paths. + + By default, the call graph uses blue for the starting node and red for the ending node, but these + colors can be customized based on user preferences. The visualization provides a clear representation + of how functions are interconnected, helping developers understand the flow of execution and + dependencies between different parts of the codebase. + + In its current form, it ignores recursive calls and external modules but can be modified trivially to include them + """ + + @staticmethod + @skill_impl(test_cases=[CallPathsBetweenNodesTest], language=ProgrammingLanguage.PYTHON) + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def skill_func(codebase: CodebaseType): + # Create a directed graph + G = nx.DiGraph() + + # ===== [ Maximum Recursive Depth ] ===== + MAX_DEPTH = 5 + + # Define a recursive function to traverse usages + def create_downstream_call_trace(parent: FunctionCall | Function, end: Callable, current_depth, max_depth): + if current_depth > max_depth: + return + + # if parent is of type Function + if isinstance(parent, Function): + # set both src_call, src_func to parent + src_call, src_func = parent, parent + else: + # get the first callable of parent + src_call, src_func = parent, parent.function_definition + + # Iterate over all call paths of the symbol + for call in src_func.function_calls: + # the symbol being called + func = call.function_definition + + # ignore direct recursive calls + if func.name == src_func.name: + continue + + # if the function being called is not from an external module + if not isinstance(func, ExternalModule): + # add `call` to the graph and an edge from `src_call` to `call` + G.add_node(call) + G.add_edge(src_call, call) + + if func == end: + G.add_edge(call, end) + return + # recursive call to function call + create_downstream_call_trace(call, end, current_depth + 1, max_depth) + + # Get the start and end function + start = codebase.get_function("start_func") + end = codebase.get_function("end_func") + + # Set starting node as blue + G.add_node(start, color="blue") + # Set ending node as red + G.add_node(end, color="red") + + # Start the recursive traversal + create_downstream_call_trace(start, end, 1, MAX_DEPTH) + + # Find all the simple paths between start and end + all_paths = nx.all_simple_paths(G, source=start, target=end) + + # Collect all nodes that are part of these paths + nodes_in_paths = set() + for path in all_paths: + nodes_in_paths.update(path) + + # Create a new subgraph with only the nodes in the paths + G = G.subgraph(nodes_in_paths) + + # Visualize the graph + codebase.visualize(G) + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/method_relationships.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/method_relationships.py new file mode 100644 index 000000000..b45e1e3fd --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/method_relationships.py @@ -0,0 +1,107 @@ +import codegen +import networkx as nx +from codegen import Codebase +from codegen.sdk.core.class_definition import Class +from codegen.sdk.core.detached_symbols.function_call import FunctionCall +from codegen.sdk.core.external_module import ExternalModule +from codegen.sdk.core.function import Function + +G = nx.DiGraph() + +# Configuration Settings +IGNORE_EXTERNAL_MODULE_CALLS = False +IGNORE_CLASS_CALLS = True +MAX_DEPTH = 100 + +# Track visited nodes to prevent duplicate processing +visited = set() + +COLOR_PALETTE = { + "StartMethod": "#9cdcfe", # Light blue for root/entry point methods + "PyFunction": "#a277ff", # Purple for regular Python functions + "PyClass": "#ffca85", # Warm peach for class definitions + "ExternalModule": "#f694ff", # Pink for external module calls + "StartClass": "#FFE082", # Yellow for the starting class +} + + +def graph_class_methods(target_class: Class): + """Creates a graph visualization of all methods in a class and their call relationships""" + G.add_node(target_class, color=COLOR_PALETTE["StartClass"]) + + for method in target_class.methods: + method_name = f"{target_class.name}.{method.name}" + G.add_node(method, name=method_name, color=COLOR_PALETTE["StartMethod"]) + visited.add(method) + G.add_edge(target_class, method) + + for method in target_class.methods: + create_downstream_call_trace(method) + + +def generate_edge_meta(call: FunctionCall) -> dict: + """Generate metadata for graph edges representing function calls""" + return {"name": call.name, "file_path": call.filepath, "start_point": call.start_point, "end_point": call.end_point, "symbol_name": "FunctionCall"} + + +def create_downstream_call_trace(src_func: Function, depth: int = 0): + """Creates call graph for parent function by recursively traversing all function calls""" + if MAX_DEPTH <= depth or isinstance(src_func, ExternalModule): + return + + for call in src_func.function_calls: + if call.name == src_func.name: + continue + + func = call.function_definition + if not func: + continue + + if isinstance(func, ExternalModule) and IGNORE_EXTERNAL_MODULE_CALLS: + continue + if isinstance(func, Class) and IGNORE_CLASS_CALLS: + continue + + if isinstance(func, (Class, ExternalModule)): + func_name = func.name + elif isinstance(func, Function): + func_name = f"{func.parent_class.name}.{func.name}" if func.is_method else func.name + + if func not in visited: + G.add_node(func, name=func_name, color=COLOR_PALETTE.get(func.__class__.__name__, None)) + visited.add(func) + + G.add_edge(src_func, func, **generate_edge_meta(call)) + + if isinstance(func, Function): + create_downstream_call_trace(func, depth + 1) + + +@codegen.function("visualize-class-method-relationships") +def run(codebase: Codebase): + """Generate a visualization of method call relationships within a class. + + This codemod: + 1. Creates a directed graph with the target class as the root node + 2. Adds all class methods and their downstream function calls + 3. Generates a visual representation of the call hierarchy + """ + global G, visited + G = nx.DiGraph() + visited = set() + + target_class = codebase.get_class("_Client") + graph_class_methods(target_class) + + print(G) + print("Use codegen.sh to visualize the graph!") + + +if __name__ == "__main__": + print("Initializing codebase...") + codebase = Codebase.from_repo("codegen-oss/modal-client", commit="00bf226a1526f9d775d2d70fc7711406aaf42958", language="python") + print(f"Codebase with {len(codebase.files)} files and {len(codebase.functions)} functions.") + print("Creating graph...") + + run(codebase) + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/viz_cal_graph.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/viz_cal_graph.py new file mode 100644 index 000000000..095e5f92b --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/call_graph/viz_cal_graph.py @@ -0,0 +1,121 @@ +import codegen +import networkx as nx +from codegen import Codebase +from codegen.sdk.core.class_definition import Class +from codegen.sdk.core.detached_symbols.function_call import FunctionCall +from codegen.sdk.core.external_module import ExternalModule +from codegen.sdk.core.function import Function + +G = nx.DiGraph() + +IGNORE_EXTERNAL_MODULE_CALLS = True +IGNORE_CLASS_CALLS = False +MAX_DEPTH = 10 + +# Color scheme for different types of nodes in the visualization +# Each node type has a distinct color for better visual differentiation +COLOR_PALETTE = { + "StartFunction": "#9cdcfe", # Base purple - draws attention to the root node + "PyFunction": "#a277ff", # Mint green - complementary to purple + "PyClass": "#ffca85", # Warm peach - provides contrast + "ExternalModule": "#f694ff", # Light pink - analogous to base purple +} + + +def generate_edge_meta(call: FunctionCall) -> dict: + """Generate metadata for graph edges representing function calls + + Args: + call (FunctionCall): Object containing information about the function call + + Returns: + dict: Metadata including name, file path, and location information + """ + return {"name": call.name, "file_path": call.filepath, "start_point": call.start_point, "end_point": call.end_point, "symbol_name": "FunctionCall"} + + +def create_downstream_call_trace(src_func: Function, depth: int = 0): + """Creates call graph for parent function by recursively traversing all function calls + + This function builds a directed graph showing all downstream function calls, + up to MAX_DEPTH levels deep. Each node represents a function and edges + represent calls between functions. + + Args: + src_func (Function): The function for which a call graph will be created + depth (int): Current depth in the recursive traversal + """ + # Stop recursion if max depth reached + if MAX_DEPTH <= depth: + return + # Stop if the source is an external module + if isinstance(src_func, ExternalModule): + return + + # Examine each function call made by the source function + for call in src_func.function_calls: + # Skip recursive calls + if call.name == src_func.name: + continue + + # Get the function definition being called + func = call.function_definition + + # Skip if function definition not found + if not func: + continue + # Apply filtering based on configuration flags + if isinstance(func, ExternalModule) and IGNORE_EXTERNAL_MODULE_CALLS: + continue + if isinstance(func, Class) and IGNORE_CLASS_CALLS: + continue + + # Generate the display name for the function + # For methods, include the class name + if isinstance(func, (Class, ExternalModule)): + func_name = func.name + elif isinstance(func, Function): + func_name = f"{func.parent_class.name}.{func.name}" if func.is_method else func.name + + # Add node and edge to the graph with appropriate metadata + G.add_node(func, name=func_name, color=COLOR_PALETTE.get(func.__class__.__name__)) + G.add_edge(src_func, func, **generate_edge_meta(call)) + + # Recursively process called function if it's a regular function + if isinstance(func, Function): + create_downstream_call_trace(func, depth + 1) + + +@codegen.function("visualize-function-call-relationships") +def run(codebase: Codebase): + """Generate a visualization of function call relationships in a codebase. + + This codemod: + 1. Creates a directed graph of function calls starting from a target method + 2. Tracks relationships between functions, classes, and external modules + 3. Generates a visual representation of the call hierarchy + """ + global G + G = nx.DiGraph() + + target_class = codebase.get_class("SharingConfigurationViewSet") + target_method = target_class.get_method("patch") + + # Generate the call graph starting from the target method + create_downstream_call_trace(target_method) + + # Add the root node (target method) to the graph + G.add_node(target_method, name=f"{target_class.name}.{target_method.name}", color=COLOR_PALETTE.get("StartFunction")) + + print(G) + print("Use codegen.sh to visualize the graph!") + + +if __name__ == "__main__": + print("Initializing codebase...") + codebase = Codebase.from_repo("codegen-oss/posthog", commit="b174f2221ea4ae50e715eb6a7e70e9a2b0760800", language="python") + print(f"Codebase with {len(codebase.files)} files and {len(codebase.functions)} functions.") + print("Creating graph...") + + run(codebase) + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/codebase_visualizer.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/codebase_visualizer.py index 52f77eade..2cea2331b 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/visualization/codebase_visualizer.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/codebase_visualizer.py @@ -2,23 +2,60 @@ """ Codebase Visualizer Module -This module provides a unified interface to all visualization capabilities -for codebases. It integrates the specialized visualizers into a single, -easy-to-use API for generating various types of visualizations. +This module provides comprehensive visualization capabilities for codebases and PR analyses. +It integrates with codebase_analyzer.py and context_codebase.py to provide visual representations +of code structure, dependencies, and issues. It supports multiple visualization types to help +developers understand codebase architecture and identify potential problems. """ -import argparse +import json import logging import os import sys +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any -from .analysis_visualizer import AnalysisVisualizer -from .code_visualizer import CodeVisualizer -from .visualizer import ( - OutputFormat, - VisualizationConfig, - VisualizationType, -) +try: + import matplotlib.pyplot as plt + import networkx as nx + from matplotlib.colors import LinearSegmentedColormap +except ImportError: + print( + "Visualization dependencies not found. Please install them with: pip install networkx matplotlib" + ) + sys.exit(1) + +try: + from codegen.sdk.core.class_definition import Class + from codegen.sdk.core.codebase import Codebase + from codegen.sdk.core.detached_symbols.function_call import FunctionCall + from codegen.sdk.core.file import SourceFile + from codegen.sdk.core.function import Function + from codegen.sdk.core.import_resolution import Import + from codegen.sdk.core.symbol import Symbol + from codegen.sdk.enums import EdgeType, SymbolType + + from codegen_on_oss.codebase_analyzer import ( + AnalysisType, + CodebaseAnalyzer, + Issue, + IssueSeverity, + ) + + # Import custom modules + from codegen_on_oss.context_codebase import ( + GLOBAL_FILE_IGNORE_LIST, + CodebaseContext, + get_node_classes, + ) + from codegen_on_oss.current_code_codebase import get_selected_codebase +except ImportError: + print( + "Codegen SDK or custom modules not found. Please ensure all dependencies are installed." + ) + sys.exit(1) # Configure logging logging.basicConfig( @@ -29,22 +66,85 @@ logger = logging.getLogger(__name__) +class VisualizationType(str, Enum): + """Types of visualizations supported by this module.""" + + CALL_GRAPH = "call_graph" + DEPENDENCY_GRAPH = "dependency_graph" + BLAST_RADIUS = "blast_radius" + CLASS_METHODS = "class_methods" + MODULE_DEPENDENCIES = "module_dependencies" + DEAD_CODE = "dead_code" + CYCLOMATIC_COMPLEXITY = "cyclomatic_complexity" + ISSUES_HEATMAP = "issues_heatmap" + PR_COMPARISON = "pr_comparison" + + +class OutputFormat(str, Enum): + """Output formats for visualizations.""" + + JSON = "json" + PNG = "png" + SVG = "svg" + HTML = "html" + DOT = "dot" + + +@dataclass +class VisualizationConfig: + """Configuration for visualization generation.""" + + max_depth: int = 5 + ignore_external: bool = True + ignore_tests: bool = True + node_size_base: int = 300 + edge_width_base: float = 1.0 + filename_filter: list[str] | None = None + symbol_filter: list[str] | None = None + output_format: OutputFormat = OutputFormat.JSON + output_directory: str | None = None + layout_algorithm: str = "spring" + highlight_nodes: list[str] = field(default_factory=list) + highlight_color: str = "#ff5555" + color_palette: dict[str, str] = field( + default_factory=lambda: { + "Function": "#a277ff", # Purple + "Class": "#ffca85", # Orange + "File": "#80CBC4", # Teal + "Module": "#81D4FA", # Light Blue + "Variable": "#B39DDB", # Light Purple + "Root": "#ef5350", # Red + "Warning": "#FFCA28", # Amber + "Error": "#EF5350", # Red + "Dead": "#78909C", # Gray + "External": "#B0BEC5", # Light Gray + } + ) + + class CodebaseVisualizer: """ - Main visualizer class providing a unified interface to all visualization capabilities. + Visualizer for codebase structures and analytics. - This class acts as a facade to the specialized visualizers, simplifying - the generation of different types of visualizations for codebases. + This class provides methods to generate various visualizations of a codebase, + including call graphs, dependency graphs, complexity heatmaps, and more. + It integrates with CodebaseAnalyzer to visualize analysis results. """ - def __init__(self, analyzer=None, codebase=None, context=None, config=None): + def __init__( + self, + analyzer: CodebaseAnalyzer | None = None, + codebase: Codebase | None = None, + context: CodebaseContext | None = None, + config: VisualizationConfig | None = None, + ): """ Initialize the CodebaseVisualizer. Args: - analyzer: Optional analyzer with analysis results - codebase: Optional codebase to visualize - context: Optional context providing graph representation + analyzer: Optional CodebaseAnalyzer instance with analysis results + codebase: Optional Codebase instance to visualize + context: Optional CodebaseContext providing graph representation config: Visualization configuration options """ self.analyzer = analyzer @@ -52,196 +152,1377 @@ def __init__(self, analyzer=None, codebase=None, context=None, config=None): self.context = context or (analyzer.base_context if analyzer else None) self.config = config or VisualizationConfig() - # Initialize specialized visualizers - self.code_visualizer = CodeVisualizer( - analyzer=analyzer, - codebase=self.codebase, - context=self.context, - config=self.config, - ) - - self.analysis_visualizer = AnalysisVisualizer( - analyzer=analyzer, - codebase=self.codebase, - context=self.context, - config=self.config, - ) - # Create visualization directory if specified if self.config.output_directory: os.makedirs(self.config.output_directory, exist_ok=True) + # Initialize graph for visualization + self.graph = nx.DiGraph() + # Initialize codebase if needed if not self.codebase and not self.context: + logger.info( + "No codebase or context provided, initializing from current directory" + ) + self.codebase = get_selected_codebase() + self.context = CodebaseContext( + codebase=self.codebase, base_path=os.getcwd() + ) + elif self.codebase and not self.context: + logger.info("Creating context from provided codebase") + self.context = CodebaseContext( + codebase=self.codebase, + base_path=os.getcwd() + if not hasattr(self.codebase, "base_path") + else self.codebase.base_path, + ) + + def _initialize_graph(self): + """Initialize a fresh graph for visualization.""" + self.graph = nx.DiGraph() + + def _add_node(self, node: Any, **attrs): + """ + Add a node to the visualization graph with attributes. + + Args: + node: Node object to add + **attrs: Node attributes + """ + # Skip if node already exists + if self.graph.has_node(node): + return + + # Generate node ID (memory address for unique identification) + node_id = id(node) + + # Get node name + if "name" in attrs: + node_name = attrs["name"] + elif hasattr(node, "name"): + node_name = node.name + elif hasattr(node, "path"): + node_name = str(node.path).split("/")[-1] + else: + node_name = str(node) + + # Determine node type and color + node_type = node.__class__.__name__ + color = attrs.get("color", self.config.color_palette.get(node_type, "#BBBBBB")) + + # Add node with attributes + self.graph.add_node( + node_id, + original_node=node, + name=node_name, + type=node_type, + color=color, + **attrs, + ) + + return node_id + + def _add_edge(self, source: Any, target: Any, **attrs): + """ + Add an edge to the visualization graph with attributes. + + Args: + source: Source node + target: Target node + **attrs: Edge attributes + """ + # Get node IDs + source_id = id(source) + target_id = id(target) + + # Add edge with attributes + self.graph.add_edge(source_id, target_id, **attrs) + + def _generate_filename( + self, visualization_type: VisualizationType, entity_name: str + ): + """ + Generate a filename for the visualization. + + Args: + visualization_type: Type of visualization + entity_name: Name of the entity being visualized + + Returns: + Generated filename + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + sanitized_name = ( + entity_name.replace("/", "_").replace("\\", "_").replace(".", "_") + ) + return f"{visualization_type.value}_{sanitized_name}_{timestamp}.{self.config.output_format.value}" + + def _save_visualization( + self, visualization_type: VisualizationType, entity_name: str, data: Any + ): + """ + Save a visualization to file or return it. + + Args: + visualization_type: Type of visualization + entity_name: Name of the entity being visualized + data: Visualization data to save + + Returns: + Path to saved file or visualization data + """ + filename = self._generate_filename(visualization_type, entity_name) + + if self.config.output_directory: + filepath = os.path.join(self.config.output_directory, filename) + else: + filepath = filename + + if self.config.output_format == OutputFormat.JSON: + with open(filepath, "w") as f: + json.dump(data, f, indent=2) + elif self.config.output_format in [OutputFormat.PNG, OutputFormat.SVG]: + # Save matplotlib figure + plt.savefig( + filepath, format=self.config.output_format.value, bbox_inches="tight" + ) + plt.close() + elif self.config.output_format == OutputFormat.DOT: + # Save as DOT file for Graphviz try: - from codegen_on_oss.analyzers.context_codebase import CodebaseContext - from codegen_on_oss.current_code_codebase import get_selected_codebase + from networkx.drawing.nx_agraph import write_dot - logger.info( - "No codebase or context provided, initializing from current directory" + write_dot(self.graph, filepath) + except ImportError: + logger.exception( + "networkx.drawing.nx_agraph not available. Install pygraphviz for DOT format." ) - self.codebase = get_selected_codebase() - self.context = CodebaseContext( - codebase=self.codebase, base_path=os.getcwd() + return None + + logger.info(f"Visualization saved to {filepath}") + return filepath + + def _convert_graph_to_json(self): + """ + Convert the networkx graph to a JSON-serializable dictionary. + + Returns: + Dictionary representation of the graph + """ + nodes = [] + for node, attrs in self.graph.nodes(data=True): + # Create a serializable node + node_data = { + "id": node, + "name": attrs.get("name", ""), + "type": attrs.get("type", ""), + "color": attrs.get("color", "#BBBBBB"), + } + + # Add file path if available + if "file_path" in attrs: + node_data["file_path"] = attrs["file_path"] + + # Add other attributes + for key, value in attrs.items(): + if key not in ["name", "type", "color", "file_path", "original_node"]: + if ( + isinstance(value, str | int | float | bool | list | dict) + or value is None + ): + node_data[key] = value + + nodes.append(node_data) + + edges = [] + for source, target, attrs in self.graph.edges(data=True): + # Create a serializable edge + edge_data = { + "source": source, + "target": target, + } + + # Add other attributes + for key, value in attrs.items(): + if ( + isinstance(value, str | int | float | bool | list | dict) + or value is None + ): + edge_data[key] = value + + edges.append(edge_data) + + return { + "nodes": nodes, + "edges": edges, + "metadata": { + "visualization_type": self.current_visualization_type, + "entity_name": self.current_entity_name, + "timestamp": datetime.now().isoformat(), + "node_count": len(nodes), + "edge_count": len(edges), + }, + } + + def _plot_graph(self): + """ + Plot the graph using matplotlib. + + Returns: + Matplotlib figure + """ + plt.figure(figsize=(12, 10)) + + # Extract node positions using specified layout algorithm + if self.config.layout_algorithm == "spring": + pos = nx.spring_layout(self.graph, seed=42) + elif self.config.layout_algorithm == "kamada_kawai": + pos = nx.kamada_kawai_layout(self.graph) + elif self.config.layout_algorithm == "spectral": + pos = nx.spectral_layout(self.graph) + else: + # Default to spring layout + pos = nx.spring_layout(self.graph, seed=42) + + # Extract node colors + node_colors = [ + attrs.get("color", "#BBBBBB") for _, attrs in self.graph.nodes(data=True) + ] + + # Extract node sizes (can be based on some metric) + node_sizes = [self.config.node_size_base for _ in self.graph.nodes()] + + # Draw nodes + nx.draw_networkx_nodes( + self.graph, pos, node_color=node_colors, node_size=node_sizes, alpha=0.8 + ) + + # Draw edges + nx.draw_networkx_edges( + self.graph, + pos, + width=self.config.edge_width_base, + alpha=0.6, + arrows=True, + arrowsize=10, + ) + + # Draw labels + nx.draw_networkx_labels( + self.graph, + pos, + labels={ + node: attrs.get("name", "") + for node, attrs in self.graph.nodes(data=True) + }, + font_size=8, + font_weight="bold", + ) + + plt.title(f"{self.current_visualization_type} - {self.current_entity_name}") + plt.axis("off") + + return plt.gcf() + + def visualize_call_graph(self, function_name: str, max_depth: int | None = None): + """ + Generate a call graph visualization for a function. + + Args: + function_name: Name of the function to visualize + max_depth: Maximum depth of the call graph (overrides config) + + Returns: + Visualization data or path to saved file + """ + self.current_visualization_type = VisualizationType.CALL_GRAPH + self.current_entity_name = function_name + + # Set max depth + current_max_depth = ( + max_depth if max_depth is not None else self.config.max_depth + ) + + # Initialize graph + self._initialize_graph() + + # Find the function in the codebase + function = None + for func in self.codebase.functions: + if func.name == function_name: + function = func + break + + if not function: + logger.error(f"Function {function_name} not found in codebase") + return None + + # Add root node + self._add_node( + function, + name=function_name, + color=self.config.color_palette.get("Root"), + is_root=True, + ) + + # Recursively add call relationships + visited = {function} + + def add_calls(func, depth=0): + if depth >= current_max_depth: + return + + # Skip if no function calls attribute + if not hasattr(func, "function_calls"): + return + + for call in func.function_calls: + # Skip recursive calls + if call.name == func.name: + continue + + # Get the called function + called_func = call.function_definition + if not called_func: + continue + + # Skip external modules if configured + if ( + self.config.ignore_external + and hasattr(called_func, "is_external") + and called_func.is_external + ): + continue + + # Generate name for display + if ( + hasattr(called_func, "is_method") + and called_func.is_method + and hasattr(called_func, "parent_class") + ): + called_name = f"{called_func.parent_class.name}.{called_func.name}" + else: + called_name = called_func.name + + # Add node for called function + self._add_node( + called_func, + name=called_name, + color=self.config.color_palette.get("Function"), + file_path=called_func.file.path + if hasattr(called_func, "file") + and hasattr(called_func.file, "path") + else None, ) - # Update specialized visualizers - self.code_visualizer.codebase = self.codebase - self.code_visualizer.context = self.context - self.analysis_visualizer.codebase = self.codebase - self.analysis_visualizer.context = self.context - except ImportError: - logger.exception( - "Could not automatically initialize codebase. Please provide a codebase or context." + # Add edge for call relationship + self._add_edge( + function, + called_func, + type="call", + file_path=call.filepath if hasattr(call, "filepath") else None, + line=call.line if hasattr(call, "line") else None, ) - def visualize(self, visualization_type: VisualizationType, **kwargs): + # Recursively process called function + if isinstance(called_func, Function) and called_func not in visited: + visited.add(called_func) + add_calls(called_func, depth + 1) + + # Start from the root function + add_calls(function) + + # Generate visualization data + if self.config.output_format == OutputFormat.JSON: + data = self._convert_graph_to_json() + return self._save_visualization( + VisualizationType.CALL_GRAPH, function_name, data + ) + else: + fig = self._plot_graph() + return self._save_visualization( + VisualizationType.CALL_GRAPH, function_name, fig + ) + + def visualize_dependency_graph( + self, symbol_name: str, max_depth: int | None = None + ): """ - Generate a visualization of the specified type. + Generate a dependency graph visualization for a symbol. Args: - visualization_type: Type of visualization to generate - **kwargs: Additional arguments for the specific visualization + symbol_name: Name of the symbol to visualize + max_depth: Maximum depth of the dependency graph (overrides config) Returns: Visualization data or path to saved file """ - # Route to the appropriate specialized visualizer based on visualization type - if visualization_type in [ - VisualizationType.CALL_GRAPH, - VisualizationType.DEPENDENCY_GRAPH, - VisualizationType.BLAST_RADIUS, - VisualizationType.CLASS_METHODS, - VisualizationType.MODULE_DEPENDENCIES, - ]: - # Code structure visualizations - return self._visualize_code_structure(visualization_type, **kwargs) - elif visualization_type in [ - VisualizationType.DEAD_CODE, - VisualizationType.CYCLOMATIC_COMPLEXITY, - VisualizationType.ISSUES_HEATMAP, - VisualizationType.PR_COMPARISON, - ]: - # Analysis result visualizations - return self._visualize_analysis_results(visualization_type, **kwargs) + self.current_visualization_type = VisualizationType.DEPENDENCY_GRAPH + self.current_entity_name = symbol_name + + # Set max depth + current_max_depth = ( + max_depth if max_depth is not None else self.config.max_depth + ) + + # Initialize graph + self._initialize_graph() + + # Find the symbol in the codebase + symbol = None + for sym in self.codebase.symbols: + if hasattr(sym, "name") and sym.name == symbol_name: + symbol = sym + break + + if not symbol: + logger.error(f"Symbol {symbol_name} not found in codebase") + return None + + # Add root node + self._add_node( + symbol, + name=symbol_name, + color=self.config.color_palette.get("Root"), + is_root=True, + ) + + # Recursively add dependencies + visited = {symbol} + + def add_dependencies(sym, depth=0): + if depth >= current_max_depth: + return + + # Skip if no dependencies attribute + if not hasattr(sym, "dependencies"): + return + + for dep in sym.dependencies: + dep_symbol = None + + if isinstance(dep, Symbol): + dep_symbol = dep + elif isinstance(dep, Import) and hasattr(dep, "resolved_symbol"): + dep_symbol = dep.resolved_symbol + + if not dep_symbol: + continue + + # Skip external modules if configured + if ( + self.config.ignore_external + and hasattr(dep_symbol, "is_external") + and dep_symbol.is_external + ): + continue + + # Add node for dependency + self._add_node( + dep_symbol, + name=dep_symbol.name + if hasattr(dep_symbol, "name") + else str(dep_symbol), + color=self.config.color_palette.get( + dep_symbol.__class__.__name__, "#BBBBBB" + ), + file_path=dep_symbol.file.path + if hasattr(dep_symbol, "file") and hasattr(dep_symbol.file, "path") + else None, + ) + + # Add edge for dependency relationship + self._add_edge(sym, dep_symbol, type="depends_on") + + # Recursively process dependency + if dep_symbol not in visited: + visited.add(dep_symbol) + add_dependencies(dep_symbol, depth + 1) + + # Start from the root symbol + add_dependencies(symbol) + + # Generate visualization data + if self.config.output_format == OutputFormat.JSON: + data = self._convert_graph_to_json() + return self._save_visualization( + VisualizationType.DEPENDENCY_GRAPH, symbol_name, data + ) else: - logger.error(f"Unsupported visualization type: {visualization_type}") + fig = self._plot_graph() + return self._save_visualization( + VisualizationType.DEPENDENCY_GRAPH, symbol_name, fig + ) + + def visualize_blast_radius(self, symbol_name: str, max_depth: int | None = None): + """ + Generate a blast radius visualization for a symbol. + + Args: + symbol_name: Name of the symbol to visualize + max_depth: Maximum depth of the blast radius (overrides config) + + Returns: + Visualization data or path to saved file + """ + self.current_visualization_type = VisualizationType.BLAST_RADIUS + self.current_entity_name = symbol_name + + # Set max depth + current_max_depth = ( + max_depth if max_depth is not None else self.config.max_depth + ) + + # Initialize graph + self._initialize_graph() + + # Find the symbol in the codebase + symbol = None + for sym in self.codebase.symbols: + if hasattr(sym, "name") and sym.name == symbol_name: + symbol = sym + break + + if not symbol: + logger.error(f"Symbol {symbol_name} not found in codebase") return None - def _visualize_code_structure( - self, visualization_type: VisualizationType, **kwargs - ): + # Add root node + self._add_node( + symbol, + name=symbol_name, + color=self.config.color_palette.get("Root"), + is_root=True, + ) + + # Recursively add usages (reverse dependencies) + visited = {symbol} + + def add_usages(sym, depth=0): + if depth >= current_max_depth: + return + + # Skip if no usages attribute + if not hasattr(sym, "usages"): + return + + for usage in sym.usages: + # Skip if no usage symbol + if not hasattr(usage, "usage_symbol"): + continue + + usage_symbol = usage.usage_symbol + + # Skip external modules if configured + if ( + self.config.ignore_external + and hasattr(usage_symbol, "is_external") + and usage_symbol.is_external + ): + continue + + # Add node for usage + self._add_node( + usage_symbol, + name=usage_symbol.name + if hasattr(usage_symbol, "name") + else str(usage_symbol), + color=self.config.color_palette.get( + usage_symbol.__class__.__name__, "#BBBBBB" + ), + file_path=usage_symbol.file.path + if hasattr(usage_symbol, "file") + and hasattr(usage_symbol.file, "path") + else None, + ) + + # Add edge for usage relationship + self._add_edge(sym, usage_symbol, type="used_by") + + # Recursively process usage + if usage_symbol not in visited: + visited.add(usage_symbol) + add_usages(usage_symbol, depth + 1) + + # Start from the root symbol + add_usages(symbol) + + # Generate visualization data + if self.config.output_format == OutputFormat.JSON: + data = self._convert_graph_to_json() + return self._save_visualization( + VisualizationType.BLAST_RADIUS, symbol_name, data + ) + else: + fig = self._plot_graph() + return self._save_visualization( + VisualizationType.BLAST_RADIUS, symbol_name, fig + ) + + def visualize_class_methods(self, class_name: str): """ - Generate a code structure visualization. + Generate a class methods visualization. Args: - visualization_type: Type of visualization to generate - **kwargs: Additional arguments for the specific visualization + class_name: Name of the class to visualize Returns: Visualization data or path to saved file """ - if visualization_type == VisualizationType.CALL_GRAPH: - return self.code_visualizer.visualize_call_graph( - function_name=kwargs.get("entity"), max_depth=kwargs.get("max_depth") + self.current_visualization_type = VisualizationType.CLASS_METHODS + self.current_entity_name = class_name + + # Initialize graph + self._initialize_graph() + + # Find the class in the codebase + class_obj = None + for cls in self.codebase.classes: + if cls.name == class_name: + class_obj = cls + break + + if not class_obj: + logger.error(f"Class {class_name} not found in codebase") + return None + + # Add class node + self._add_node( + class_obj, + name=class_name, + color=self.config.color_palette.get("Class"), + is_root=True, + ) + + # Skip if no methods attribute + if not hasattr(class_obj, "methods"): + logger.error(f"Class {class_name} has no methods attribute") + return None + + # Add method nodes and connections + method_ids = {} + for method in class_obj.methods: + method_name = f"{class_name}.{method.name}" + + # Add method node + method_id = self._add_node( + method, + name=method_name, + color=self.config.color_palette.get("Function"), + file_path=method.file.path + if hasattr(method, "file") and hasattr(method.file, "path") + else None, ) - elif visualization_type == VisualizationType.DEPENDENCY_GRAPH: - return self.code_visualizer.visualize_dependency_graph( - symbol_name=kwargs.get("entity"), max_depth=kwargs.get("max_depth") + + method_ids[method.name] = method_id + + # Add edge from class to method + self._add_edge(class_obj, method, type="contains") + + # Add call relationships between methods + for method in class_obj.methods: + # Skip if no function calls attribute + if not hasattr(method, "function_calls"): + continue + + for call in method.function_calls: + # Get the called function + called_func = call.function_definition + if not called_func: + continue + + # Only add edges between methods of this class + if ( + hasattr(called_func, "is_method") + and called_func.is_method + and hasattr(called_func, "parent_class") + and called_func.parent_class == class_obj + ): + self._add_edge( + method, + called_func, + type="calls", + line=call.line if hasattr(call, "line") else None, + ) + + # Generate visualization data + if self.config.output_format == OutputFormat.JSON: + data = self._convert_graph_to_json() + return self._save_visualization( + VisualizationType.CLASS_METHODS, class_name, data ) - elif visualization_type == VisualizationType.BLAST_RADIUS: - return self.code_visualizer.visualize_blast_radius( - symbol_name=kwargs.get("entity"), max_depth=kwargs.get("max_depth") + else: + fig = self._plot_graph() + return self._save_visualization( + VisualizationType.CLASS_METHODS, class_name, fig ) - elif visualization_type == VisualizationType.CLASS_METHODS: - return self.code_visualizer.visualize_class_methods( - class_name=kwargs.get("entity") + + def visualize_module_dependencies(self, module_path: str): + """ + Generate a module dependencies visualization. + + Args: + module_path: Path to the module to visualize + + Returns: + Visualization data or path to saved file + """ + self.current_visualization_type = VisualizationType.MODULE_DEPENDENCIES + self.current_entity_name = module_path + + # Initialize graph + self._initialize_graph() + + # Get all files in the module + module_files = [] + for file in self.codebase.files: + if hasattr(file, "path") and str(file.path).startswith(module_path): + module_files.append(file) + + if not module_files: + logger.error(f"No files found in module {module_path}") + return None + + # Add file nodes + module_node_ids = {} + for file in module_files: + file_name = str(file.path).split("/")[-1] + file_module = "/".join(str(file.path).split("/")[:-1]) + + # Add file node + file_id = self._add_node( + file, + name=file_name, + module=file_module, + color=self.config.color_palette.get("File"), + file_path=str(file.path), ) - elif visualization_type == VisualizationType.MODULE_DEPENDENCIES: - return self.code_visualizer.visualize_module_dependencies( - module_path=kwargs.get("entity") + + module_node_ids[str(file.path)] = file_id + + # Add import relationships + for file in module_files: + # Skip if no imports attribute + if not hasattr(file, "imports"): + continue + + for imp in file.imports: + imported_file = None + + # Try to get imported file + if hasattr(imp, "resolved_file"): + imported_file = imp.resolved_file + elif hasattr(imp, "resolved_symbol") and hasattr( + imp.resolved_symbol, "file" + ): + imported_file = imp.resolved_symbol.file + + if not imported_file: + continue + + # Skip external modules if configured + if ( + self.config.ignore_external + and hasattr(imported_file, "is_external") + and imported_file.is_external + ): + continue + + # Add node for imported file if not already added + imported_path = ( + str(imported_file.path) if hasattr(imported_file, "path") else "" + ) + + if imported_path not in module_node_ids: + imported_name = imported_path.split("/")[-1] + imported_module = "/".join(imported_path.split("/")[:-1]) + + imported_id = self._add_node( + imported_file, + name=imported_name, + module=imported_module, + color=self.config.color_palette.get( + "External" + if imported_path.startswith(module_path) + else "File" + ), + file_path=imported_path, + ) + + module_node_ids[imported_path] = imported_id + + # Add edge for import relationship + self._add_edge( + file, + imported_file, + type="imports", + import_name=imp.name if hasattr(imp, "name") else "", + ) + + # Generate visualization data + if self.config.output_format == OutputFormat.JSON: + data = self._convert_graph_to_json() + return self._save_visualization( + VisualizationType.MODULE_DEPENDENCIES, module_path, data + ) + else: + fig = self._plot_graph() + return self._save_visualization( + VisualizationType.MODULE_DEPENDENCIES, module_path, fig ) - def _visualize_analysis_results( - self, visualization_type: VisualizationType, **kwargs - ): + def visualize_dead_code(self, path_filter: str | None = None): """ - Generate an analysis results visualization. + Generate a visualization of dead (unused) code in the codebase. Args: - visualization_type: Type of visualization to generate - **kwargs: Additional arguments for the specific visualization + path_filter: Optional path to filter files Returns: Visualization data or path to saved file """ + self.current_visualization_type = VisualizationType.DEAD_CODE + self.current_entity_name = path_filter or "codebase" + + # Initialize graph + self._initialize_graph() + + # Initialize analyzer if needed if not self.analyzer: - logger.error(f"Analyzer required for {visualization_type} visualization") + logger.info("Initializing analyzer for dead code detection") + self.analyzer = CodebaseAnalyzer( + codebase=self.codebase, + repo_path=self.context.base_path + if hasattr(self.context, "base_path") + else None, + ) + + # Perform analysis if not already done + if not hasattr(self.analyzer, "results") or not self.analyzer.results: + logger.info("Running code analysis") + self.analyzer.analyze(AnalysisType.CODEBASE) + + # Extract dead code information from analysis results + if not hasattr(self.analyzer, "results"): + logger.error("Analysis results not available") + return None + + dead_code = {} + if ( + "static_analysis" in self.analyzer.results + and "dead_code" in self.analyzer.results["static_analysis"] + ): + dead_code = self.analyzer.results["static_analysis"]["dead_code"] + + if not dead_code: + logger.warning("No dead code detected in analysis results") return None - if visualization_type == VisualizationType.DEAD_CODE: - return self.analysis_visualizer.visualize_dead_code( - path_filter=kwargs.get("path_filter") + # Create file nodes for containing dead code + file_nodes = {} + + # Process unused functions + if "unused_functions" in dead_code: + for unused_func in dead_code["unused_functions"]: + file_path = unused_func.get("file", "") + + # Skip if path filter is specified and doesn't match + if path_filter and not file_path.startswith(path_filter): + continue + + # Add file node if not already added + if file_path not in file_nodes: + # Find file in codebase + file_obj = None + for file in self.codebase.files: + if hasattr(file, "path") and str(file.path) == file_path: + file_obj = file + break + + if file_obj: + file_name = file_path.split("/")[-1] + self._add_node( + file_obj, + name=file_name, + color=self.config.color_palette.get("File"), + file_path=file_path, + ) + + file_nodes[file_path] = file_obj + + # Add unused function node + func_name = unused_func.get("name", "") + func_line = unused_func.get("line", None) + + # Create a placeholder for the function (we don't have the actual object) + func_obj = { + "name": func_name, + "file_path": file_path, + "line": func_line, + "type": "Function", + } + + self._add_node( + func_obj, + name=func_name, + color=self.config.color_palette.get("Dead"), + file_path=file_path, + line=func_line, + is_dead=True, + ) + + # Add edge from file to function + if file_path in file_nodes: + self._add_edge( + file_nodes[file_path], func_obj, type="contains_dead" + ) + + # Process unused variables + if "unused_variables" in dead_code: + for unused_var in dead_code["unused_variables"]: + file_path = unused_var.get("file", "") + + # Skip if path filter is specified and doesn't match + if path_filter and not file_path.startswith(path_filter): + continue + + # Add file node if not already added + if file_path not in file_nodes: + # Find file in codebase + file_obj = None + for file in self.codebase.files: + if hasattr(file, "path") and str(file.path) == file_path: + file_obj = file + break + + if file_obj: + file_name = file_path.split("/")[-1] + self._add_node( + file_obj, + name=file_name, + color=self.config.color_palette.get("File"), + file_path=file_path, + ) + + file_nodes[file_path] = file_obj + + # Add unused variable node + var_name = unused_var.get("name", "") + var_line = unused_var.get("line", None) + + # Create a placeholder for the variable + var_obj = { + "name": var_name, + "file_path": file_path, + "line": var_line, + "type": "Variable", + } + + self._add_node( + var_obj, + name=var_name, + color=self.config.color_palette.get("Dead"), + file_path=file_path, + line=var_line, + is_dead=True, + ) + + # Add edge from file to variable + if file_path in file_nodes: + self._add_edge(file_nodes[file_path], var_obj, type="contains_dead") + + # Generate visualization data + if self.config.output_format == OutputFormat.JSON: + data = self._convert_graph_to_json() + return self._save_visualization( + VisualizationType.DEAD_CODE, self.current_entity_name, data ) - elif visualization_type == VisualizationType.CYCLOMATIC_COMPLEXITY: - return self.analysis_visualizer.visualize_cyclomatic_complexity( - path_filter=kwargs.get("path_filter") + else: + fig = self._plot_graph() + return self._save_visualization( + VisualizationType.DEAD_CODE, self.current_entity_name, fig ) - elif visualization_type == VisualizationType.ISSUES_HEATMAP: - return self.analysis_visualizer.visualize_issues_heatmap( - severity=kwargs.get("severity"), path_filter=kwargs.get("path_filter") + + def visualize_cyclomatic_complexity(self, path_filter: str | None = None): + """ + Generate a heatmap visualization of cyclomatic complexity. + + Args: + path_filter: Optional path to filter files + + Returns: + Visualization data or path to saved file + """ + self.current_visualization_type = VisualizationType.CYCLOMATIC_COMPLEXITY + self.current_entity_name = path_filter or "codebase" + + # Initialize analyzer if needed + if not self.analyzer: + logger.info("Initializing analyzer for complexity analysis") + self.analyzer = CodebaseAnalyzer( + codebase=self.codebase, + repo_path=self.context.base_path + if hasattr(self.context, "base_path") + else None, ) - elif visualization_type == VisualizationType.PR_COMPARISON: - return self.analysis_visualizer.visualize_pr_comparison() - # Convenience methods for common visualizations - def visualize_call_graph(self, function_name: str, max_depth: int | None = None): - """Convenience method for call graph visualization.""" - return self.visualize( - VisualizationType.CALL_GRAPH, entity=function_name, max_depth=max_depth + # Perform analysis if not already done + if not hasattr(self.analyzer, "results") or not self.analyzer.results: + logger.info("Running code analysis") + self.analyzer.analyze(AnalysisType.CODEBASE) + + # Extract complexity information from analysis results + if not hasattr(self.analyzer, "results"): + logger.error("Analysis results not available") + return None + + complexity_data = {} + if ( + "static_analysis" in self.analyzer.results + and "code_complexity" in self.analyzer.results["static_analysis"] + ): + complexity_data = self.analyzer.results["static_analysis"][ + "code_complexity" + ] + + if not complexity_data: + logger.warning("No complexity data found in analysis results") + return None + + # Extract function complexities + functions = [] + if "function_complexity" in complexity_data: + for func_data in complexity_data["function_complexity"]: + # Skip if path filter is specified and doesn't match + if path_filter and not func_data.get("file", "").startswith( + path_filter + ): + continue + + functions.append({ + "name": func_data.get("name", ""), + "file": func_data.get("file", ""), + "complexity": func_data.get("complexity", 1), + "line": func_data.get("line", None), + }) + + # Sort functions by complexity (descending) + functions.sort(key=lambda x: x.get("complexity", 0), reverse=True) + + # Generate heatmap visualization + plt.figure(figsize=(12, 10)) + + # Extract data for heatmap + func_names = [ + f"{func['name']} ({func['file'].split('/')[-1]})" for func in functions[:30] + ] + complexities = [func.get("complexity", 0) for func in functions[:30]] + + # Create horizontal bar chart + bars = plt.barh(func_names, complexities) + + # Color bars by complexity + norm = plt.Normalize(1, max(10, max(complexities))) + cmap = plt.cm.get_cmap("YlOrRd") + + for i, bar in enumerate(bars): + complexity = complexities[i] + bar.set_color(cmap(norm(complexity))) + + # Add labels and title + plt.xlabel("Cyclomatic Complexity") + plt.title("Top Functions by Cyclomatic Complexity") + plt.grid(axis="x", linestyle="--", alpha=0.6) + + # Add colorbar + plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label="Complexity") + + # Save and return visualization + return self._save_visualization( + VisualizationType.CYCLOMATIC_COMPLEXITY, self.current_entity_name, plt.gcf() ) - def visualize_dependency_graph( - self, symbol_name: str, max_depth: int | None = None + def visualize_issues_heatmap( + self, + severity: IssueSeverity | None = None, + path_filter: str | None = None, ): - """Convenience method for dependency graph visualization.""" - return self.visualize( - VisualizationType.DEPENDENCY_GRAPH, entity=symbol_name, max_depth=max_depth - ) + """ + Generate a heatmap visualization of issues in the codebase. - def visualize_blast_radius(self, symbol_name: str, max_depth: int | None = None): - """Convenience method for blast radius visualization.""" - return self.visualize( - VisualizationType.BLAST_RADIUS, entity=symbol_name, max_depth=max_depth - ) + Args: + severity: Optional severity level to filter issues + path_filter: Optional path to filter files - def visualize_class_methods(self, class_name: str): - """Convenience method for class methods visualization.""" - return self.visualize(VisualizationType.CLASS_METHODS, entity=class_name) + Returns: + Visualization data or path to saved file + """ + self.current_visualization_type = VisualizationType.ISSUES_HEATMAP + self.current_entity_name = f"{severity.value if severity else 'all'}_issues" - def visualize_module_dependencies(self, module_path: str): - """Convenience method for module dependencies visualization.""" - return self.visualize(VisualizationType.MODULE_DEPENDENCIES, entity=module_path) + # Initialize analyzer if needed + if not self.analyzer: + logger.info("Initializing analyzer for issues analysis") + self.analyzer = CodebaseAnalyzer( + codebase=self.codebase, + repo_path=self.context.base_path + if hasattr(self.context, "base_path") + else None, + ) - def visualize_dead_code(self, path_filter: str | None = None): - """Convenience method for dead code visualization.""" - return self.visualize(VisualizationType.DEAD_CODE, path_filter=path_filter) + # Perform analysis if not already done + if not hasattr(self.analyzer, "results") or not self.analyzer.results: + logger.info("Running code analysis") + self.analyzer.analyze(AnalysisType.CODEBASE) - def visualize_cyclomatic_complexity(self, path_filter: str | None = None): - """Convenience method for cyclomatic complexity visualization.""" - return self.visualize( - VisualizationType.CYCLOMATIC_COMPLEXITY, path_filter=path_filter + # Extract issues from analysis results + if ( + not hasattr(self.analyzer, "results") + or "issues" not in self.analyzer.results + ): + logger.error("Issues not available in analysis results") + return None + + issues = self.analyzer.results["issues"] + + # Filter issues by severity if specified + if severity: + issues = [issue for issue in issues if issue.get("severity") == severity] + + # Filter issues by path if specified + if path_filter: + issues = [ + issue + for issue in issues + if issue.get("file", "").startswith(path_filter) + ] + + if not issues: + logger.warning("No issues found matching the criteria") + return None + + # Group issues by file + file_issues = {} + for issue in issues: + file_path = issue.get("file", "") + if file_path not in file_issues: + file_issues[file_path] = [] + + file_issues[file_path].append(issue) + + # Generate heatmap visualization + plt.figure(figsize=(12, 10)) + + # Extract data for heatmap + files = list(file_issues.keys()) + file_names = [file_path.split("/")[-1] for file_path in files] + issue_counts = [len(file_issues[file_path]) for file_path in files] + + # Sort by issue count + sorted_data = sorted( + zip(file_names, issue_counts, files, strict=False), + key=lambda x: x[1], + reverse=True, ) + file_names, issue_counts, files = zip(*sorted_data, strict=False) + + # Create horizontal bar chart + bars = plt.barh(file_names[:20], issue_counts[:20]) + + # Color bars by issue count + norm = plt.Normalize(1, max(5, max(issue_counts[:20]))) + cmap = plt.cm.get_cmap("OrRd") + + for i, bar in enumerate(bars): + count = issue_counts[i] + bar.set_color(cmap(norm(count))) + + # Add labels and title + plt.xlabel("Number of Issues") + severity_text = f" ({severity.value})" if severity else "" + plt.title(f"Files with the Most Issues{severity_text}") + plt.grid(axis="x", linestyle="--", alpha=0.6) - def visualize_issues_heatmap(self, severity=None, path_filter: str | None = None): - """Convenience method for issues heatmap visualization.""" - return self.visualize( - VisualizationType.ISSUES_HEATMAP, severity=severity, path_filter=path_filter + # Add colorbar + plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label="Issue Count") + + # Save and return visualization + return self._save_visualization( + VisualizationType.ISSUES_HEATMAP, self.current_entity_name, plt.gcf() ) def visualize_pr_comparison(self): - """Convenience method for PR comparison visualization.""" - return self.visualize(VisualizationType.PR_COMPARISON) + """ + Generate a visualization comparing base branch with PR. + + Returns: + Visualization data or path to saved file + """ + self.current_visualization_type = VisualizationType.PR_COMPARISON + + # Check if analyzer has PR data + if ( + not self.analyzer + or not self.analyzer.pr_codebase + or not self.analyzer.base_codebase + ): + logger.error("PR comparison requires analyzer with PR data") + return None + + self.current_entity_name = ( + f"pr_{self.analyzer.pr_number}" + if self.analyzer.pr_number + else "pr_comparison" + ) + + # Perform comparison analysis if not already done + if not hasattr(self.analyzer, "results") or not self.analyzer.results: + logger.info("Running PR comparison analysis") + self.analyzer.analyze(AnalysisType.COMPARISON) + + # Extract comparison data from analysis results + if ( + not hasattr(self.analyzer, "results") + or "comparison" not in self.analyzer.results + ): + logger.error("Comparison data not available in analysis results") + return None + + comparison = self.analyzer.results["comparison"] + + # Initialize graph + self._initialize_graph() + + # Process symbol comparison data + if "symbol_comparison" in comparison: + for symbol_data in comparison["symbol_comparison"]: + symbol_name = symbol_data.get("name", "") + in_base = symbol_data.get("in_base", False) + in_pr = symbol_data.get("in_pr", False) + + # Create a placeholder for the symbol + symbol_obj = { + "name": symbol_name, + "in_base": in_base, + "in_pr": in_pr, + "type": "Symbol", + } + + # Determine node color based on presence in base and PR + if in_base and in_pr: + color = "#A5D6A7" # Light green (modified) + elif in_base: + color = "#EF9A9A" # Light red (removed) + else: + color = "#90CAF9" # Light blue (added) + + # Add node for symbol + self._add_node( + symbol_obj, + name=symbol_name, + color=color, + in_base=in_base, + in_pr=in_pr, + ) + + # Process parameter changes if available + if "parameter_changes" in symbol_data: + param_changes = symbol_data["parameter_changes"] + + # Process removed parameters + for param in param_changes.get("removed", []): + param_obj = { + "name": param, + "change_type": "removed", + "type": "Parameter", + } + + self._add_node( + param_obj, + name=param, + color="#EF9A9A", # Light red (removed) + change_type="removed", + ) + + self._add_edge(symbol_obj, param_obj, type="removed_parameter") + + # Process added parameters + for param in param_changes.get("added", []): + param_obj = { + "name": param, + "change_type": "added", + "type": "Parameter", + } + + self._add_node( + param_obj, + name=param, + color="#90CAF9", # Light blue (added) + change_type="added", + ) + + self._add_edge(symbol_obj, param_obj, type="added_parameter") + + # Process return type changes if available + if "return_type_change" in symbol_data: + return_type_change = symbol_data["return_type_change"] + old_type = return_type_change.get("old", "None") + new_type = return_type_change.get("new", "None") + + return_obj = { + "name": f"{old_type} -> {new_type}", + "old_type": old_type, + "new_type": new_type, + "type": "ReturnType", + } + + self._add_node( + return_obj, + name=f"{old_type} -> {new_type}", + color="#FFD54F", # Amber (changed) + old_type=old_type, + new_type=new_type, + ) + + self._add_edge(symbol_obj, return_obj, type="return_type_change") + + # Process call site issues if available + if "call_site_issues" in symbol_data: + for issue in symbol_data["call_site_issues"]: + issue_file = issue.get("file", "") + issue_line = issue.get("line", None) + issue_text = issue.get("issue", "") + + # Create a placeholder for the issue + issue_obj = { + "name": issue_text, + "file": issue_file, + "line": issue_line, + "type": "Issue", + } + + self._add_node( + issue_obj, + name=f"{issue_file.split('/')[-1]}:{issue_line}", + color="#EF5350", # Red (error) + file_path=issue_file, + line=issue_line, + issue_text=issue_text, + ) + + self._add_edge(symbol_obj, issue_obj, type="call_site_issue") + + # Generate visualization data + if self.config.output_format == OutputFormat.JSON: + data = self._convert_graph_to_json() + return self._save_visualization( + VisualizationType.PR_COMPARISON, self.current_entity_name, data + ) + else: + fig = self._plot_graph() + return self._save_visualization( + VisualizationType.PR_COMPARISON, self.current_entity_name, fig + ) # Command-line interface @@ -284,7 +1565,11 @@ def main(): viz_group.add_argument( "--ignore-external", action="store_true", help="Ignore external dependencies" ) - viz_group.add_argument("--severity", help="Filter issues by severity") + viz_group.add_argument( + "--severity", + choices=[s.value for s in IssueSeverity], + help="Filter issues by severity", + ) viz_group.add_argument("--path-filter", help="Filter by file path") # PR options @@ -323,30 +1608,16 @@ def main(): layout_algorithm=args.layout, ) - try: - # Import analyzer only if needed - if ( - args.type - in ["pr_comparison", "dead_code", "cyclomatic_complexity", "issues_heatmap"] - or args.pr_number - ): - from codegen_on_oss.analyzers.codebase_analyzer import CodebaseAnalyzer - - # Create analyzer - analyzer = CodebaseAnalyzer( - repo_url=args.repo_url, - repo_path=args.repo_path, - base_branch=args.base_branch, - pr_number=args.pr_number, - language=args.language, - ) - else: - analyzer = None - except ImportError: - logger.warning( - "CodebaseAnalyzer not available. Some visualizations may not work." + # Create codebase analyzer if needed for PR comparison + analyzer = None + if args.type == VisualizationType.PR_COMPARISON.value or args.pr_number: + analyzer = CodebaseAnalyzer( + repo_url=args.repo_url, + repo_path=args.repo_path, + base_branch=args.base_branch, + pr_number=args.pr_number, + language=args.language, ) - analyzer = None # Create visualizer visualizer = CodebaseVisualizer(analyzer=analyzer, config=config) @@ -355,37 +1626,57 @@ def main(): viz_type = VisualizationType(args.type) result = None - # Process specific requirements for each visualization type - if ( - viz_type - in [ - VisualizationType.CALL_GRAPH, - VisualizationType.DEPENDENCY_GRAPH, - VisualizationType.BLAST_RADIUS, - VisualizationType.CLASS_METHODS, - VisualizationType.MODULE_DEPENDENCIES, - ] - and not args.entity - ): - logger.error(f"Entity name required for {viz_type} visualization") - sys.exit(1) + if viz_type == VisualizationType.CALL_GRAPH: + if not args.entity: + logger.error("Entity name required for call graph visualization") + sys.exit(1) - if ( - viz_type == VisualizationType.PR_COMPARISON - and not args.pr_number - and not (analyzer and hasattr(analyzer, "pr_number")) - ): - logger.error("PR number required for PR comparison visualization") - sys.exit(1) + result = visualizer.visualize_call_graph(args.entity) - # Generate visualization - result = visualizer.visualize( - viz_type, - entity=args.entity, - max_depth=args.max_depth, - severity=args.severity, - path_filter=args.path_filter, - ) + elif viz_type == VisualizationType.DEPENDENCY_GRAPH: + if not args.entity: + logger.error("Entity name required for dependency graph visualization") + sys.exit(1) + + result = visualizer.visualize_dependency_graph(args.entity) + + elif viz_type == VisualizationType.BLAST_RADIUS: + if not args.entity: + logger.error("Entity name required for blast radius visualization") + sys.exit(1) + + result = visualizer.visualize_blast_radius(args.entity) + + elif viz_type == VisualizationType.CLASS_METHODS: + if not args.entity: + logger.error("Class name required for class methods visualization") + sys.exit(1) + + result = visualizer.visualize_class_methods(args.entity) + + elif viz_type == VisualizationType.MODULE_DEPENDENCIES: + if not args.entity: + logger.error("Module path required for module dependencies visualization") + sys.exit(1) + + result = visualizer.visualize_module_dependencies(args.entity) + + elif viz_type == VisualizationType.DEAD_CODE: + result = visualizer.visualize_dead_code(args.path_filter) + + elif viz_type == VisualizationType.CYCLOMATIC_COMPLEXITY: + result = visualizer.visualize_cyclomatic_complexity(args.path_filter) + + elif viz_type == VisualizationType.ISSUES_HEATMAP: + severity = IssueSeverity(args.severity) if args.severity else None + result = visualizer.visualize_issues_heatmap(severity, args.path_filter) + + elif viz_type == VisualizationType.PR_COMPARISON: + if not args.pr_number: + logger.error("PR number required for PR comparison visualization") + sys.exit(1) + + result = visualizer.visualize_pr_comparison() # Output result if result: diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/dependency_graph/__init__.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/dependency_graph/__init__.py new file mode 100644 index 000000000..5b9d135f7 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/dependency_graph/__init__.py @@ -0,0 +1,6 @@ +""" +Dependency Graph Visualization Module + +This module provides tools for visualizing dependency relationships and impact analysis in a codebase. +""" + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/dependency_graph/blast_radius.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/dependency_graph/blast_radius.py new file mode 100644 index 000000000..42b039632 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/dependency_graph/blast_radius.py @@ -0,0 +1,119 @@ +import codegen +import networkx as nx +from codegen import Codebase +from codegen.sdk.core.dataclasses.usage import Usage +from codegen.sdk.core.function import PyFunction +from codegen.sdk.core.symbol import PySymbol + +# Create a directed graph for visualizing relationships between code elements +G = nx.DiGraph() + +# Maximum depth to traverse in the call graph to prevent infinite recursion +MAX_DEPTH = 5 + +# Define colors for different types of nodes in the visualization +COLOR_PALETTE = { + "StartFunction": "#9cdcfe", # Starting function (light blue) + "PyFunction": "#a277ff", # Python functions (purple) + "PyClass": "#ffca85", # Python classes (orange) + "ExternalModule": "#f694ff", # External module imports (pink) + "HTTP_METHOD": "#ffca85", # HTTP method handlers (orange) +} + +# List of common HTTP method names to identify route handlers +HTTP_METHODS = ["get", "put", "patch", "post", "head", "delete"] + + +def generate_edge_meta(usage: Usage) -> dict: + """ + Generate metadata for graph edges based on a usage relationship. + + Args: + usage: A Usage object representing how a symbol is used + + Returns: + dict: Edge metadata including source location and symbol info + """ + return {"name": usage.match.source, "file_path": usage.match.filepath, "start_point": usage.match.start_point, "end_point": usage.match.end_point, "symbol_name": usage.match.__class__.__name__} + + +def is_http_method(symbol: PySymbol) -> bool: + """ + Check if a symbol represents an HTTP method handler. + + Args: + symbol: A Python symbol to check + + Returns: + bool: True if symbol is an HTTP method handler + """ + if isinstance(symbol, PyFunction) and symbol.is_method: + return symbol.name in HTTP_METHODS + return False + + +def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0): + """ + Recursively build a graph visualization showing how a symbol is used. + Shows the "blast radius" - everything that would be affected by changes. + + Args: + symbol: Starting symbol to analyze + depth: Current recursion depth + """ + # Stop recursion if we hit max depth + if depth >= MAX_DEPTH: + return + + # Process each usage of the symbol + for usage in symbol.usages: + usage_symbol = usage.usage_symbol + + # Determine node color based on symbol type + if is_http_method(usage_symbol): + color = COLOR_PALETTE.get("HTTP_METHOD") + else: + color = COLOR_PALETTE.get(usage_symbol.__class__.__name__, "#f694ff") + + # Add node and edge to graph + G.add_node(usage_symbol, color=color) + G.add_edge(symbol, usage_symbol, **generate_edge_meta(usage)) + + # Recurse to process usages of this symbol + create_blast_radius_visualization(usage_symbol, depth + 1) + + +@codegen.function("visualize-function-blast-radius") +def run(codebase: Codebase): + """ + Generate a visualization showing the blast radius of changes to a function. + + This codemod: + 1. Identifies all usages of a target function + 2. Creates a graph showing how the function is used throughout the codebase + 3. Highlights HTTP method handlers and different types of code elements + """ + global G + G = nx.DiGraph() + + # Get the target function to analyze + target_func = codebase.get_function("export_asset") + + # Add starting function to graph with special color + G.add_node(target_func, color=COLOR_PALETTE.get("StartFunction")) + + # Build the visualization starting from target function + create_blast_radius_visualization(target_func) + + print(G) + print("Use codegen.sh to visualize the graph!") + + +if __name__ == "__main__": + print("Initializing codebase...") + codebase = Codebase.from_repo("codegen-oss/posthog", commit="b174f2221ea4ae50e715eb6a7e70e9a2b0760800", language="python") + print(f"Codebase with {len(codebase.files)} files and {len(codebase.functions)} functions.") + print("Creating graph...") + + run(codebase) + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/dependency_graph/dependency_trace.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/dependency_graph/dependency_trace.py new file mode 100644 index 000000000..85448ac4f --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/dependency_graph/dependency_trace.py @@ -0,0 +1,83 @@ +import codegen +import networkx as nx +from codegen import Codebase +from codegen.sdk.core.class_definition import Class +from codegen.sdk.core.import_resolution import Import +from codegen.sdk.core.symbol import Symbol + +G = nx.DiGraph() + +IGNORE_EXTERNAL_MODULE_CALLS = True +IGNORE_CLASS_CALLS = False +MAX_DEPTH = 10 + +COLOR_PALETTE = { + "StartFunction": "#9cdcfe", # Light blue for the starting function + "PyFunction": "#a277ff", # Purple for Python functions + "PyClass": "#ffca85", # Orange for Python classes + "ExternalModule": "#f694ff", # Pink for external module references +} + +# Dictionary to track visited nodes and prevent cycles +visited = {} + + +def create_dependencies_visualization(symbol: Symbol, depth: int = 0): + """Creates a visualization of symbol dependencies in the codebase + + Recursively traverses the dependency tree of a symbol (function, class, etc.) + and creates a directed graph representation. Dependencies can be either direct + symbol references or imports. + + Args: + symbol (Symbol): The starting symbol whose dependencies will be mapped + depth (int): Current depth in the recursive traversal + """ + if depth >= MAX_DEPTH: + return + + for dep in symbol.dependencies: + dep_symbol = None + + if isinstance(dep, Symbol): + dep_symbol = dep + elif isinstance(dep, Import): + dep_symbol = dep.resolved_symbol if dep.resolved_symbol else None + + if dep_symbol: + G.add_node(dep_symbol, color=COLOR_PALETTE.get(dep_symbol.__class__.__name__, "#f694ff")) + G.add_edge(symbol, dep_symbol) + + if not isinstance(dep_symbol, Class): + create_dependencies_visualization(dep_symbol, depth + 1) + + +@codegen.function("visualize-symbol-dependencies") +def run(codebase: Codebase): + """Generate a visualization of symbol dependencies in a codebase. + + This codemod: + 1. Creates a directed graph of symbol dependencies starting from a target function + 2. Tracks relationships between functions, classes, and imports + 3. Generates a visual representation of the dependency hierarchy + """ + global G + G = nx.DiGraph() + + target_func = codebase.get_function("get_query_runner") + G.add_node(target_func, color=COLOR_PALETTE.get("StartFunction")) + + create_dependencies_visualization(target_func) + + print(G) + print("Use codegen.sh to visualize the graph!") + + +if __name__ == "__main__": + print("Initializing codebase...") + codebase = Codebase.from_repo("codegen-oss/posthog", commit="b174f2221ea4ae50e715eb6a7e70e9a2b0760800", language="python") + print(f"Codebase with {len(codebase.files)} files and {len(codebase.functions)} functions.") + print("Creating graph...") + + run(codebase) + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/dependency_graph/viz_dead_code.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/dependency_graph/viz_dead_code.py new file mode 100644 index 000000000..17e72a5a6 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/dependency_graph/viz_dead_code.py @@ -0,0 +1,154 @@ +from abc import ABC + +import networkx as nx + +from codegen.sdk.core.codebase import CodebaseType +from codegen.sdk.core.function import Function +from codegen.sdk.core.import_resolution import Import +from codegen.sdk.core.symbol import Symbol +from codegen.shared.enums.programming_language import ProgrammingLanguage +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill +from tests.shared.skills.skill_test import SkillTestCase, SkillTestCasePyFile + +PyDeadCodeTest = SkillTestCase( + [ + SkillTestCasePyFile( + input=""" +# Live code +def used_function(): + return "I'm used!" + +class UsedClass: + def used_method(self): + return "I'm a used method!" + +# Dead code +def unused_function(): + return "I'm never called!" + +class UnusedClass: + def unused_method(self): + return "I'm never used!" + +# Second-order dead code +def second_order_dead(): + unused_function() + UnusedClass().unused_method() + +# More live code +def another_used_function(): + return used_function() + +# Main execution +def main(): + print(used_function()) + print(UsedClass().used_method()) + print(another_used_function()) + +if __name__ == "__main__": + main() +""", + filepath="example.py", + ), + SkillTestCasePyFile( + input=""" +# This file should be ignored by the DeadCode skill + +from example import used_function, UsedClass + +def test_used_function(): + assert used_function() == "I'm used!" + +def test_used_class(): + assert UsedClass().used_method() == "I'm a used method!" +""", + filepath="test_example.py", + ), + SkillTestCasePyFile( + input=""" +# This file contains a decorated function that should be ignored + +from functools import lru_cache + +@lru_cache +def cached_function(): + return "I'm cached!" + +# This function is dead code but should be ignored due to decoration +@deprecated +def old_function(): + return "I'm old but decorated!" + +# This function is dead code and should be detected +def real_dead_code(): + return "I'm really dead!" +""", + filepath="decorated_functions.py", + ), + ], + graph=True, +) + + +@skill( + eval_skill=False, + prompt="Show me a visualization of the call graph from my_class and filter out test files and include only the methods that have the name post, get, patch, delete", + uid="ec5e98c9-b57f-43f8-8b3c-af1b30bb91e6", +) +class DeadCode(Skill, ABC): + """This skill shows a visualization of the dead code in the codebase. + It iterates through all functions in the codebase, identifying those + that have no usages and are not in test files or decorated. These functions + are considered 'dead code' and are added to a directed graph. The skill + then explores the dependencies of these dead code functions, adding them to + the graph as well. This process helps to identify not only directly unused code + but also code that might only be used by other dead code (second-order dead code). + The resulting visualization provides a clear picture of potentially removable code, + helping developers to clean up and optimize their codebase. + """ + + @staticmethod + @skill_impl(test_cases=[PyDeadCodeTest], language=ProgrammingLanguage.PYTHON) + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def skill_func(codebase: CodebaseType): + # Create a directed graph to visualize dead and second-order dead code + G = nx.DiGraph() + + # First, identify all dead code + dead_code: list[Function] = [] + + # Iterate through all functions in the codebase + for function in codebase.functions: + # Filter down functions + if "test" in function.file.filepath: + continue + + if function.decorators: + continue + + # Check if the function has no usages + if not function.symbol_usages: + # Add the function to the dead code list + dead_code.append(function) + # Add the function to the graph as dead code + G.add_node(function, color="red") + + # # Now, find second-order dead code + for symbol in dead_code: + # Get all usages of the dead code symbol + for dep in symbol.dependencies: + if isinstance(dep, Import): + dep = dep.imported_symbol + if isinstance(dep, Symbol): + if "test" not in dep.name: + G.add_node(dep) + G.add_edge(symbol, dep, color="red") + for usage_symbol in dep.symbol_usages: + if isinstance(usage_symbol, Function): + if "test" not in usage_symbol.name: + G.add_edge(usage_symbol, dep) + + # Visualize the graph to show dead and second-order dead code + codebase.visualize(G) + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/docs/__init__.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/docs/__init__.py new file mode 100644 index 000000000..97a69d1fe --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/docs/__init__.py @@ -0,0 +1,6 @@ +""" +Visualization Documentation Module + +This module contains documentation and examples for using the visualization tools. +""" + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/docs/codebase-visualization.mdx b/codegen-on-oss/codegen_on_oss/analyzers/visualization/docs/codebase-visualization.mdx new file mode 100644 index 000000000..521d6277f --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/docs/codebase-visualization.mdx @@ -0,0 +1,399 @@ +--- +title: "Codebase Visualization" +sidebarTitle: "Visualization" +description: "This guide will show you how to create codebase visualizations using [codegen](/introduction/overview)." +icon: "share-nodes" +iconType: "solid" +--- + + + + + +## Overview + +To demonstrate the visualization capabilities of the codegen we will generate three different visualizations of PostHog's open source [repository](https://github.com/PostHog/posthog). + - [Call Trace Visualization](#call-trace-visualization) + - [Function Dependency Graph](#function-dependency-graph) + - [Blast Radius Visualization](#blast-radius-visualization) + + +## Call Trace Visualization + +Visualizing the call trace of a function is a great way to understand the flow of a function and for debugging. In this tutorial we will create a call trace visualization of the `patch` method of the `SharingConfigurationViewSet` class. View the source code [here](https://github.com/PostHog/posthog/blob/c2986d9ac7502aa107a4afbe31b3633848be6582/posthog/api/sharing.py#L163). + + +### Basic Setup +First, we'll set up our codebase, graph and configure some basic parameters: + +```python +import networkx as nx +from codegen import Codebase + +# Initialize codebase +codebase = Codebase("path/to/posthog/") + +# Create a directed graph for representing call relationships +G = nx.DiGraph() + +# Configuration flags +IGNORE_EXTERNAL_MODULE_CALLS = True # Skip calls to external modules +IGNORE_CLASS_CALLS = False # Include class definition calls +MAX_DEPTH = 10 + +COLOR_PALETTE = { + "StartFunction": "#9cdcfe", # Light blue - Start Function + "PyFunction": "#a277ff", # Soft purple/periwinkle - PyFunction + "PyClass": "#ffca85", # Warm peach/orange - PyClass + "ExternalModule": "#f694ff" # Bright magenta/pink - ExternalModule +} +``` + +### Building the Visualization +We'll create a function that will recursively traverse the call trace of a function and add nodes and edges to the graph: + +```python +def create_downstream_call_trace(src_func: Function, depth: int = 0): + """Creates call graph by recursively traversing function calls + + Args: + src_func (Function): Starting function for call graph + depth (int): Current recursion depth + """ + # Prevent infinite recursion + if MAX_DEPTH <= depth: + return + + # External modules are not functions + if isinstance(src_func, ExternalModule): + return + + # Process each function call + for call in src_func.function_calls: + # Skip self-recursive calls + if call.name == src_func.name: + continue + + # Get called function definition + func = call.function_definition + if not func: + continue + + # Apply configured filters + if isinstance(func, ExternalModule) and IGNORE_EXTERNAL_MODULE_CALLS: + continue + if isinstance(func, Class) and IGNORE_CLASS_CALLS: + continue + + # Generate display name (include class for methods) + if isinstance(func, Class) or isinstance(func, ExternalModule): + func_name = func.name + elif isinstance(func, Function): + func_name = f"{func.parent_class.name}.{func.name}" if func.is_method else func.name + + # Add node and edge with metadata + G.add_node(func, name=func_name, + color=COLOR_PALETTE.get(func.__class__.__name__)) + G.add_edge(src_func, func, **generate_edge_meta(call)) + + # Recurse for regular functions + if isinstance(func, Function): + create_downstream_call_trace(func, depth + 1) +``` + +### Adding Edge Metadata +We can enrich our edges with metadata about the function calls: + +```python +def generate_edge_meta(call: FunctionCall) -> dict: + """Generate metadata for call graph edges + + Args: + call (FunctionCall): Function call information + + Returns: + dict: Edge metadata including name and location + """ + return { + "name": call.name, + "file_path": call.filepath, + "start_point": call.start_point, + "end_point": call.end_point, + "symbol_name": "FunctionCall" + } +``` +### Visualizing the Graph +Finally, we can visualize our call graph starting from a specific function: +```python +# Get target function to analyze +target_class = codebase.get_class('SharingConfigurationViewSet') +target_method = target_class.get_method('patch') + +# Add root node +G.add_node(target_method, + name=f"{target_class.name}.{target_method.name}", + color=COLOR_PALETTE["StartFunction"]) + +# Build the call graph +create_downstream_call_trace(target_method) + +# Render the visualization +codebase.visualize(G) +``` + + +### Take a look + + +View on [codegen.sh](https://www.codegen.sh/codemod/6a34b45d-c8ad-422e-95a8-46d4dc3ce2b0/public/diff) + + +### Common Use Cases +The call graph visualization is particularly useful for: + - Understanding complex codebases + - Planning refactoring efforts + - Identifying tightly coupled components + - Analyzing critical paths + - Documenting system architecture + +## Function Dependency Graph + +Understanding symbol dependencies is crucial for maintaining and refactoring code. This tutorial will show you how to create visual dependency graphs using Codegen and NetworkX. We will be creating a dependency graph of the `get_query_runner` function. View the source code [here](https://github.com/PostHog/posthog/blob/c2986d9ac7502aa107a4afbe31b3633848be6582/posthog/hogql_queries/query_runner.py#L152). + +### Basic Setup + +We'll use the same basic setup as the [Call Trace Visualization](/tutorials/codebase-visualization#call-trace-visualization) tutorial. + + +### Building the Dependency Graph +The core function for building our dependency graph: +```python +def create_dependencies_visualization(symbol: Symbol, depth: int = 0): + """Creates visualization of symbol dependencies + + Args: + symbol (Symbol): Starting symbol to analyze + depth (int): Current recursion depth + """ + # Prevent excessive recursion + if depth >= MAX_DEPTH: + return + + # Process each dependency + for dep in symbol.dependencies: + dep_symbol = None + + # Handle different dependency types + if isinstance(dep, Symbol): + # Direct symbol reference + dep_symbol = dep + elif isinstance(dep, Import): + # Import statement - get resolved symbol + dep_symbol = dep.resolved_symbol if dep.resolved_symbol else None + + if dep_symbol: + # Add node with appropriate styling + G.add_node(dep_symbol, + color=COLOR_PALETTE.get(dep_symbol.__class__.__name__, + "#f694ff")) + + # Add dependency relationship + G.add_edge(symbol, dep_symbol) + + # Recurse unless it's a class (avoid complexity) + if not isinstance(dep_symbol, PyClass): + create_dependencies_visualization(dep_symbol, depth + 1) +``` + +### Visualizing the Graph +Finally, we can visualize our dependency graph starting from a specific symbol: +```python +# Get target symbol +target_func = codebase.get_function("get_query_runner") + +# Add root node +G.add_node(target_func, color=COLOR_PALETTE["StartFunction"]) + +# Generate dependency graph +create_dependencies_visualization(target_func) + +# Render visualization +codebase.visualize(G) +``` + +### Take a look + + +View on [codegen.sh](https://www.codegen.sh/codemod/39a36f0c-9d35-4666-9db7-12ae7c28fc17/public/diff) + + +## Blast Radius visualization + +Understanding the impact of code changes is crucial for safe refactoring. A blast radius visualization shows how changes to one function might affect other parts of the codebase by tracing usage relationships. In this tutorial we will create a blast radius visualization of the `export_asset` function. View the source code [here](https://github.com/PostHog/posthog/blob/c2986d9ac7502aa107a4afbe31b3633848be6582/posthog/tasks/exporter.py#L57). + +### Basic Setup + +We'll use the same basic setup as the [Call Trace Visualization](/tutorials/codebase-visualization#call-trace-visualization) tutorial. + + +### Helper Functions +We'll create some utility functions to help build our visualization: +```python +# List of HTTP methods to highlight +HTTP_METHODS = ["get", "put", "patch", "post", "head", "delete"] + +def generate_edge_meta(usage: Usage) -> dict: + """Generate metadata for graph edges + + Args: + usage (Usage): Usage relationship information + + Returns: + dict: Edge metadata including name and location + """ + return { + "name": usage.match.source, + "file_path": usage.match.filepath, + "start_point": usage.match.start_point, + "end_point": usage.match.end_point, + "symbol_name": usage.match.__class__.__name__ + } + +def is_http_method(symbol: PySymbol) -> bool: + """Check if a symbol is an HTTP endpoint method + + Args: + symbol (PySymbol): Symbol to check + + Returns: + bool: True if symbol is an HTTP method + """ + if isinstance(symbol, PyFunction) and symbol.is_method: + return symbol.name in HTTP_METHODS + return False +``` + +### Building the Blast Radius Visualization +The main function for creating our blast radius visualization: +```python +def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0): + """Create visualization of symbol usage relationships + + Args: + symbol (PySymbol): Starting symbol to analyze + depth (int): Current recursion depth + """ + # Prevent excessive recursion + if depth >= MAX_DEPTH: + return + + # Process each usage of the symbol + for usage in symbol.usages: + usage_symbol = usage.usage_symbol + + # Determine node color based on type + if is_http_method(usage_symbol): + color = COLOR_PALETTE.get("HTTP_METHOD") + else: + color = COLOR_PALETTE.get(usage_symbol.__class__.__name__, "#f694ff") + + # Add node and edge to graph + G.add_node(usage_symbol, color=color) + G.add_edge(symbol, usage_symbol, **generate_edge_meta(usage)) + + # Recursively process usage symbol + create_blast_radius_visualization(usage_symbol, depth + 1) +``` + +### Visualizing the Graph +Finally, we can create our blast radius visualization: +```python +# Get target function to analyze +target_func = codebase.get_function('export_asset') + +# Add root node +G.add_node(target_func, color=COLOR_PALETTE.get("StartFunction")) + +# Build the visualization +create_blast_radius_visualization(target_func) + +# Render graph to show impact flow +# Note: a -> b means changes to a will impact b +codebase.visualize(G) +``` + +### Take a look + + +View on [codegen.sh](https://www.codegen.sh/codemod/d255db6c-9a86-4197-9b78-16c506858a3b/public/diff) + + +## What's Next? + + + + Learn how to use Codegen to create modular codebases. + + + Learn how to use Codegen to delete dead code. + + + Learn how to use Codegen to increase type coverage. + + + Explore the complete API documentation for all Codegen classes and methods. + + + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/structure_graph/__init__.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/structure_graph/__init__.py new file mode 100644 index 000000000..82dfcb765 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/structure_graph/__init__.py @@ -0,0 +1,6 @@ +""" +Structure Graph Visualization Module + +This module provides tools for visualizing code structure, directory trees, and database relationships. +""" + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/structure_graph/graph_viz_dir_tree.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/structure_graph/graph_viz_dir_tree.py new file mode 100644 index 000000000..67fe5e0a7 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/structure_graph/graph_viz_dir_tree.py @@ -0,0 +1,111 @@ +from abc import ABC + +import networkx as nx + +from codegen.sdk.core.codebase import CodebaseType +from codegen.shared.enums.programming_language import ProgrammingLanguage +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill +from tests.shared.skills.skill_test import SkillTestCase, SkillTestCasePyFile + +PyRepoDirTreeTest = SkillTestCase( + [ + SkillTestCasePyFile(input="# Root level file", filepath="README.md"), + SkillTestCasePyFile(input="# Configuration file", filepath="config.yaml"), + SkillTestCasePyFile( + input=""" +def main(): + print("Hello, World!") + +if __name__ == "__main__": + main() +""", + filepath="src/main.py", + ), + SkillTestCasePyFile( + input=""" +class User: + def __init__(self, name): + self.name = name +""", + filepath="src/models/user.py", + ), + SkillTestCasePyFile( + input=""" +from src.models.user import User + +def create_user(name): + return User(name) +""", + filepath="src/services/user_service.py", + ), + SkillTestCasePyFile( + input=""" +import unittest +from src.models.user import User + +class TestUser(unittest.TestCase): + def test_user_creation(self): + user = User("Alice") + self.assertEqual(user.name, "Alice") +""", + filepath="tests/test_user.py", + ), + SkillTestCasePyFile( + input=""" +{ + "name": "my-project", + "version": "1.0.0", + "description": "A sample project" +} +""", + filepath="package.json", + ), + SkillTestCasePyFile( + input=""" +node_modules/ +*.log +.DS_Store +""", + filepath=".gitignore", + ), + ], + graph=True, +) + + +@skill(eval_skill=False, prompt="Show me the directory structure of this codebase", uid="ef9a5a54-d793-4749-992d-63ea3958056b") +class RepoDirTree(Skill, ABC): + """This skill displays the directory or repository tree structure of a codebase. It analyzes the file paths within the codebase and constructs a hierarchical + representation of the directory structure. The skill creates a visual graph where each node represents a directory or file, and edges represent the parent-child + relationships between directories. This visualization helps developers understand the overall organization of the codebase, making it easier to navigate and + manage large projects. Additionally, it can be useful for identifying potential structural issues or inconsistencies in the project layout. + """ + + @staticmethod + @skill_impl(test_cases=[PyRepoDirTreeTest], language=ProgrammingLanguage.PYTHON) + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def skill_func(codebase: CodebaseType): + # Create a directed graph + G = nx.DiGraph() + + # Iterate over all files in the codebase + for file in codebase.files: + # Get the full filepath + filepath = file.filepath + # Split the filepath into parts + parts = filepath.split("/") + + # Add nodes and edges to the graph + for i in range(len(parts)): + # Create a path from the root to the current part + path = "/".join(parts[: i + 1]) + # Add the node for the current directory + G.add_node(path) + # If it's not the root, add an edge from the parent directory to the current directory + if i > 0: + parent_path = "/".join(parts[:i]) + G.add_edge(parent_path, path) + + codebase.visualize(G) + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/structure_graph/graph_viz_foreign_key.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/structure_graph/graph_viz_foreign_key.py new file mode 100644 index 000000000..1f453223b --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/structure_graph/graph_viz_foreign_key.py @@ -0,0 +1,178 @@ +from abc import ABC + +import networkx as nx + +from codegen.sdk.core.codebase import CodebaseType +from codegen.shared.enums.programming_language import ProgrammingLanguage +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill +from tests.shared.skills.skill_test import SkillTestCase, SkillTestCasePyFile + +PyForeignKeyGraphTest = SkillTestCase( + [ + SkillTestCasePyFile( + input=""" +from sqlalchemy import Column, Integer, String, ForeignKey, BigInteger +from app.models.base import BaseModel + +class UserModel(BaseModel): + __tablename__ = 'users' + + id = Column(Integer, primary_key=True) + name = Column(String(100), nullable=False) + email = Column(String(100), unique=True, nullable=False) + +class TaskModel(BaseModel): + __tablename__ = 'tasks' + + id = Column(Integer, primary_key=True) + title = Column(String(200), nullable=False) + description = Column(String(500)) + user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + +class CommentModel(BaseModel): + __tablename__ = 'comments' + + id = Column(Integer, primary_key=True) + content = Column(String(500), nullable=False) + task_id = Column(Integer, ForeignKey("tasks.id", ondelete="CASCADE"), nullable=False) + user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + +class ProjectModel(BaseModel): + __tablename__ = 'projects' + + id = Column(Integer, primary_key=True) + name = Column(String(200), nullable=False) + description = Column(String(500)) + +class TaskProjectModel(BaseModel): + __tablename__ = 'task_projects' + + id = Column(Integer, primary_key=True) + task_id = Column(Integer, ForeignKey("tasks.id", ondelete="CASCADE"), nullable=False) + project_id = Column(Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False) + +class AgentRunModel(BaseModel): + __tablename__ = 'agent_runs' + + id = Column(BigInteger, primary_key=True) + task_id = Column(BigInteger, ForeignKey("tasks.id", ondelete="CASCADE"), nullable=False) + agent_id = Column(BigInteger, ForeignKey("agents.id", ondelete="CASCADE"), nullable=False) + +class AgentModel(BaseModel): + __tablename__ = 'agents' + + id = Column(BigInteger, primary_key=True) + name = Column(String(100), nullable=False) +""", + filepath="app/models/schema.py", + ) + ], + graph=True, +) + + +@skill( + eval_skill=False, + prompt="Help me analyze my data schema. I have a bunch of SQLAlchemy models with foreign keys to each other, all of them are classes like this that inherit BaseModel, like the one in this file.", + uid="2a5d8f4d-5f02-445e-9d00-77bdb9a0d268", +) +class ForeignKeyGraph(Skill, ABC): + """This skill helps analyze a data schema by creating a graph representation of SQLAlchemy models and their foreign key relationships. + + It processes a collection of SQLAlchemy models with foreign keys referencing each other. All of these models are classes that inherit from BaseModel, similar to the one in this file. Foreign keys + are typically defined in the following format: + agent_run_id = Column(BigInteger, ForeignKey("AgentRun.id", ondelete="CASCADE"), nullable=False) + + The skill iterates through all classes in the codebase, identifying those that are subclasses of BaseModel. For each relevant class, it examines the attributes to find ForeignKey definitions. It + then builds a mapping of these relationships. + + Using this mapping, the skill constructs a directed graph where: + - Nodes represent the models (with the 'Model' suffix stripped from their names) + - Edges represent the foreign key relationships between models + + This graph visualization allows for easy analysis of the data schema, showing how different models are interconnected through their foreign key relationships. The resulting graph can be used to + understand data dependencies, optimize queries, or refactor the database schema. + """ + + @staticmethod + @skill_impl(test_cases=[PyForeignKeyGraphTest], language=ProgrammingLanguage.PYTHON) + def skill_func(codebase: CodebaseType): + # Create a mapping dictionary to hold relationships + foreign_key_mapping = {} + + # Iterate through all classes in the codebase + for cls in codebase.classes: + # Check if the class is a subclass of BaseModel and defined in the correct file + if cls.is_subclass_of("BaseModel") and "from app.models.base import BaseModel" in cls.file.content: + # Initialize an empty list for the current class + foreign_key_mapping[cls.name] = [] + + # Iterate through the attributes of the class + for attr in cls.attributes: + # Check if the attribute's source contains a ForeignKey definition + if "ForeignKey" in attr.source: + # Extract the table name from the ForeignKey string + start_index = attr.source.find('("') + 2 + end_index = attr.source.find(".id", start_index) + if end_index != -1: + target_table = attr.source[start_index:end_index] + # Append the target table to the mapping, avoiding duplicates + if target_table not in foreign_key_mapping[cls.name]: + foreign_key_mapping[cls.name].append(target_table) + + # Now foreign_key_mapping contains the desired relationships + # print(foreign_key_mapping) + + # Create a directed graph + G = nx.DiGraph() + + # Iterate through the foreign_key_mapping to add nodes and edges + for model, targets in foreign_key_mapping.items(): + # Add the model node (strip 'Model' suffix) + model_name = model.replace("Model", "") + G.add_node(model_name) + + # Add edges to the target tables + for target in targets: + G.add_node(target) # Ensure the target is also a node + G.add_edge(model_name, target) + + # Now G contains the directed graph of models and their foreign key relationships + # You can visualize or analyze the graph as needed + codebase.visualize(G) + + ############################################################################################################## + # IN DEGREE + ############################################################################################################## + + # Calculate in-degrees for each node + in_degrees = G.in_degree() + + # Create a list of nodes with their in-degree counts + in_degree_list = [(node, degree) for node, degree in in_degrees] + + # Sort the list by in-degree in descending order + sorted_in_degrees = sorted(in_degree_list, key=lambda x: x[1], reverse=True) + + # Print the nodes with their in-degrees + for node, degree in sorted_in_degrees: + print(f"Node: {node}, In-Degree: {degree}") + if degree == 0: + G.nodes[node]["color"] = "red" + + ############################################################################################################## + # FIND MODELS MAPPING TO TASK + ############################################################################################################## + + # Collect models that map to the Task model + models_mapping_to_task = [] + for model, targets in foreign_key_mapping.items(): + if "Task" in targets: + models_mapping_to_task.append(model) + + # Print the models that map to Task + print("Models mapping to 'Task':") + for model in models_mapping_to_task: + print(f"> {model}") +