diff --git a/CLI_Commands.md b/CLI_Commands.md index c176c15..e4fce4a 100644 --- a/CLI_Commands.md +++ b/CLI_Commands.md @@ -3,6 +3,19 @@ Here is the **complete list** of all CLI commands available in `CodeGraphContext`, categorized by workflow scenario. +## Global Flags + +These flags can be used with any command: + +| Flag | Aliases | Description | +| :--- | :--- | :--- | +| **`--visual`** | `--viz`, `-V` | Shows results as an interactive graph visualization in your browser. Works with analyze, find, and query commands. | +| **`--database`** | `-d` | Temporarily override the database backend (`falkordb` or `neo4j`) for any command. | +| **`--version`** | `-v` | Show version and exit. | +| **`--help`** | `-h` | Show help and exit. | + +> **Note:** The visual flag uses uppercase `-V` to avoid conflict with `-v` which is reserved for `--version`. + ## 1. Project Management Use these commands to manage the repositories in your code graph. @@ -29,22 +42,23 @@ Understand the structure, quality, and relationships of your code. | Command | Arguments | Description | | :--- | :--- | :--- | -| **`cgc analyze calls`** | ``
`--file` | Shows **outgoing** calls: what functions does this function call? | -| **`cgc analyze callers`** | ``
`--file` | Shows **incoming** calls: who calls this function? | -| **`cgc analyze chain`** | ` `
`--depth` | Finds the call path between two functions. Default depth is 5. | -| **`cgc analyze deps`** | ``
`--no-external` | Inspects dependencies (imports and importers) for a module. | -| **`cgc analyze tree`** | ``
`--file` | Visualizes the Class Inheritance hierarchy for a given class. | +| **`cgc analyze calls`** | ``
`--file`
`--visual` | Shows **outgoing** calls: what functions does this function call? Use `--visual` or `-V` for graph view. | +| **`cgc analyze callers`** | ``
`--file`
`--visual` | Shows **incoming** calls: who calls this function? Use `--visual` or `-V` for graph view. | +| **`cgc analyze chain`** | ` `
`--depth`
`--visual` | Finds the call path between two functions. Default depth is 5. Use `--visual` or `-V` for graph view. | +| **`cgc analyze deps`** | ``
`--no-external`
`--visual` | Inspects dependencies (imports and importers) for a module. Use `--visual` or `-V` for graph view. | +| **`cgc analyze tree`** | ``
`--file`
`--visual` | Visualizes the Class Inheritance hierarchy for a given class. Use `--visual` or `-V` for graph view. | | **`cgc analyze complexity`**| `[path]`
`--threshold`
`--limit` | Lists functions with high Cyclomatic Complexity. Default threshold: 10. | | **`cgc analyze dead-code`** | `--exclude` | Finds potentially unused functions (0 callers). Use `--exclude` for decorators. | +| **`cgc analyze overrides`** | ``
`--visual` | Finds all implementations of a function across different classes. Use `--visual` or `-V` for graph view. | ## 4. Discovery & Search Find code elements when you don't know the exact structure. | Command | Arguments | Description | | :--- | :--- | :--- | -| **`cgc find name`** | ``
`--type` | Finds code elements (Class, Function) by their **exact** name. | -| **`cgc find pattern`** | ``
`--case-sensitive` | Finds elements using fuzzy substring matching (e.g. "User" finds "UserHelper"). | -| **`cgc find type`** | ``
`--limit` | Lists all nodes of a specific type (e.g. `function`, `class`, `module`). | +| **`cgc find name`** | ``
`--type`
`--visual` | Finds code elements (Class, Function) by their **exact** name. Use `--visual` or `-V` for graph view. | +| **`cgc find pattern`** | ``
`--case-sensitive`
`--visual` | Finds elements using fuzzy substring matching (e.g. "User" finds "UserHelper"). Use `--visual` or `-V` for graph view. | +| **`cgc find type`** | ``
`--limit`
`--visual` | Lists all nodes of a specific type (e.g. `function`, `class`, `module`). Use `--visual` or `-V` for graph view. | ## 5. Configuration & Setup Manage your environment and database connections. @@ -65,7 +79,7 @@ Helper commands for developers and the MCP server. | :--- | :--- | :--- | | **`cgc doctor`** | None | Runs system diagnostics (DB connection, dependencies, permissions). | | **`cgc visualize`** | `[query]` | Generates a link to open the Neo4j Browser.
*(Alias: `cgc v`)* | -| **`cgc query`** | `` | Executes a raw Cypher query directly against the DB. | +| **`cgc query`** | ``
`--visual` | Executes a raw Cypher query directly against the DB. Use `--visual` or `-V` for graph view. | | **`cgc mcp start`** | None | Starts the MCP Server (used by IDEs). | | **`cgc mcp tools`** | None | Lists all available MCP tools supported by the server. | | **`cgc start`** | None | **Deprecated**. Use `cgc mcp start` instead. | @@ -171,3 +185,47 @@ cgc index . # Start indexing - Press `Ctrl+C` in the watch terminal **💡 Tip:** This is perfect for active development sessions where you want your AI assistant to always have the latest code context! + +### Scenario G: Visual Graph Exploration +Use the `--visual` flag (or `-V`) to see results as interactive graphs in your browser. + +1. **Visualize function calls:** + ```bash + cgc analyze calls process_data --visual + # or use the short form + cgc analyze calls process_data -V + ``` + +2. **Visualize call chain between functions:** + ```bash + cgc analyze chain main handle_request --visual + ``` + +3. **Visualize class inheritance:** + ```bash + cgc analyze tree MyBaseClass --visual + ``` + +4. **Visualize module dependencies:** + ```bash + cgc analyze deps mymodule --visual + ``` + +5. **Visualize search results:** + ```bash + cgc find pattern "Controller" --visual + ``` + +6. **Visualize Cypher query results:** + ```bash + cgc query "MATCH (n:Function)-[r:CALLS]->(m:Function) RETURN n,r,m LIMIT 50" --visual + ``` + +7. **Use global flag (applies to any command):** + ```bash + cgc -V analyze callers my_function + ``` + +**💡 Tip:** The visualizations are interactive! Drag to pan, scroll to zoom, and click on nodes to highlight connections. + +**📍 Output:** Visualization HTML files are saved to `~/.codegraphcontext/visualizations/` and automatically opened in your default browser. diff --git a/src/codegraphcontext/cli/cli_helpers.py b/src/codegraphcontext/cli/cli_helpers.py index 2f0c782..7cad42b 100644 --- a/src/codegraphcontext/cli/cli_helpers.py +++ b/src/codegraphcontext/cli/cli_helpers.py @@ -227,6 +227,39 @@ def cypher_helper(query: str): db_manager.close_driver() +def cypher_helper_visual(query: str): + """Executes a read-only Cypher query and visualizes the results.""" + from .visualizer import visualize_cypher_results + + services = _initialize_services() + if not all(services): + return + + db_manager, _, _ = services + + # Replicating safety checks from MCPServer + forbidden_keywords = ['CREATE', 'MERGE', 'DELETE', 'SET', 'REMOVE', 'DROP', 'CALL apoc'] + if any(keyword in query.upper() for keyword in forbidden_keywords): + console.print("[bold red]Error: This command only supports read-only queries.[/bold red]") + db_manager.close_driver() + return + + try: + with db_manager.get_driver().session() as session: + result = session.run(query) + records = [record.data() for record in result] + + if not records: + console.print("[yellow]No results to visualize.[/yellow]") + return # finally block will close driver + + visualize_cypher_results(records, query) + except Exception as e: + console.print(f"[bold red]An error occurred while executing query:[/bold red] {e}") + finally: + db_manager.close_driver() + + import webbrowser def visualize_helper(query: str): diff --git a/src/codegraphcontext/cli/main.py b/src/codegraphcontext/cli/main.py index aa3f1d0..f3e3627 100644 --- a/src/codegraphcontext/cli/main.py +++ b/src/codegraphcontext/cli/main.py @@ -33,6 +33,7 @@ list_repos_helper, delete_helper, cypher_helper, + cypher_helper_visual, visualize_helper, reindex_helper, clean_helper, @@ -47,6 +48,17 @@ logging.getLogger("neo4j").setLevel(logging.WARNING) logging.getLogger("asyncio").setLevel(logging.WARNING) +# Import visualization module +from .visualizer import ( + visualize_call_graph, + visualize_call_chain, + visualize_dependencies, + visualize_inheritance_tree, + visualize_overrides, + visualize_search_results, + check_visual_flag, +) + # Initialize the Typer app and Rich console for formatted output. app = typer.Typer( name="cgc", @@ -721,8 +733,10 @@ def watching(): @find_app.command("name") def find_by_name( + ctx: typer.Context, name: str = typer.Argument(..., help="Exact name to search for"), - type: Optional[str] = typer.Option(None, "--type", "-t", help="Filter by type (function, class, file, module)") + type: Optional[str] = typer.Option(None, "--type", "-t", help="Filter by type (function, class, file, module)"), + visual: bool = typer.Option(False, "--visual", "--viz", "-V", help="Show results as interactive graph visualization") ): """ Find code elements by exact name. @@ -730,6 +744,7 @@ def find_by_name( Examples: cgc find name MyClass cgc find name calculate --type function + cgc find name MyClass --visual """ _load_credentials() services = _initialize_services() @@ -790,6 +805,11 @@ def find_by_name( if not results: console.print(f"[yellow]No code elements found with name '{name}'[/yellow]") return + + # Check if visual mode is enabled + if check_visual_flag(ctx, visual): + visualize_search_results(results, name, search_type="name") + return table = Table(show_header=True, header_style="bold magenta", box=box.ROUNDED) table.add_column("Name", style="cyan") @@ -814,8 +834,10 @@ def find_by_name( @find_app.command("pattern") def find_by_pattern( + ctx: typer.Context, pattern: str = typer.Argument(..., help="Substring pattern to search (fuzzy search fallback)"), - case_sensitive: bool = typer.Option(False, "--case-sensitive", "-c", help="Case-sensitive search") + case_sensitive: bool = typer.Option(False, "--case-sensitive", "-c", help="Case-sensitive search"), + visual: bool = typer.Option(False, "--visual", "--viz", "-V", help="Show results as interactive graph visualization") ): """ Find code elements using substring matching. @@ -823,6 +845,7 @@ def find_by_pattern( Examples: cgc find pattern "Auth" # Finds Auth, Authentication, Authorize... cgc find pattern "process_" # Finds process_data, process_request... + cgc find pattern "Auth" --visual """ _load_credentials() services = _initialize_services() @@ -869,6 +892,11 @@ def find_by_pattern( if not results: console.print(f"[yellow]No matches found for pattern '{pattern}'[/yellow]") return + + # Check if visual mode is enabled + if check_visual_flag(ctx, visual): + visualize_search_results(results, pattern, search_type="pattern") + return if not case_sensitive and any(c in pattern for c in "*?["): console.print("[yellow]Note: Wildcards/Regex are not fully supported in this mode. Performing substring search.[/yellow]") @@ -898,8 +926,10 @@ def find_by_pattern( @find_app.command("type") def find_by_type( + ctx: typer.Context, element_type: str = typer.Argument(..., help="Type to search for (function, class, file, module)"), - limit: int = typer.Option(50, "--limit", "-l", help="Maximum results to return") + limit: int = typer.Option(50, "--limit", "-l", help="Maximum results to return"), + visual: bool = typer.Option(False, "--visual", "--viz", "-V", help="Show results as interactive graph visualization") ): """ Find all elements of a specific type. @@ -907,6 +937,7 @@ def find_by_type( Examples: cgc find type class cgc find type function --limit 100 + cgc find type class --visual """ _load_credentials() services = _initialize_services() @@ -920,6 +951,15 @@ def find_by_type( if not results: console.print(f"[yellow]No elements found of type '{element_type}'[/yellow]") return + + # Add type to results for visualization + for r in results: + r['type'] = element_type.capitalize() + + # Check if visual mode is enabled + if check_visual_flag(ctx, visual): + visualize_search_results(results, element_type, search_type="type") + return table = Table(show_header=True, header_style="bold magenta", box=box.ROUNDED) table.add_column("Name", style="cyan") @@ -1148,8 +1188,10 @@ def find_by_argument_search( @analyze_app.command("calls") def analyze_calls( + ctx: typer.Context, function: str = typer.Argument(..., help="Function name to analyze"), - file: Optional[str] = typer.Option(None, "--file", "-f", help="Specific file path") + file: Optional[str] = typer.Option(None, "--file", "-f", help="Specific file path"), + visual: bool = typer.Option(False, "--visual", "--viz", "-V", help="Show results as interactive graph visualization") ): """ Show what functions this function calls (callees). @@ -1157,6 +1199,7 @@ def analyze_calls( Example: cgc analyze calls process_data cgc analyze calls process_data --file src/main.py + cgc analyze calls process_data --visual """ _load_credentials() services = _initialize_services() @@ -1171,6 +1214,11 @@ def analyze_calls( console.print(f"[yellow]No function calls found for '{function}'[/yellow]") return + # Check if visual mode is enabled + if check_visual_flag(ctx, visual): + visualize_call_graph(results, function, direction="outgoing") + return + table = Table(show_header=True, header_style="bold magenta", box=box.ROUNDED) table.add_column("Called Function", style="cyan") table.add_column("Location", style="dim", overflow="fold") @@ -1195,8 +1243,10 @@ def analyze_calls( @analyze_app.command("callers") def analyze_callers( + ctx: typer.Context, function: str = typer.Argument(..., help="Function name to analyze"), - file: Optional[str] = typer.Option(None, "--file", "-f", help="Specific file path") + file: Optional[str] = typer.Option(None, "--file", "-f", help="Specific file path"), + visual: bool = typer.Option(False, "--visual", "--viz", "-V", help="Show results as interactive graph visualization") ): """ Show what functions call this function. @@ -1204,6 +1254,7 @@ def analyze_callers( Example: cgc analyze callers process_data cgc analyze callers process_data --file src/main.py + cgc analyze callers process_data --visual """ _load_credentials() services = _initialize_services() @@ -1218,6 +1269,11 @@ def analyze_callers( console.print(f"[yellow]No callers found for '{function}'[/yellow]") return + # Check if visual mode is enabled + if check_visual_flag(ctx, visual): + visualize_call_graph(results, function, direction="incoming") + return + table = Table(show_header=True, header_style="bold magenta", box=box.ROUNDED) table.add_column("Caller Function", style="cyan") table.add_column("Location", style="green") @@ -1244,11 +1300,13 @@ def analyze_callers( @analyze_app.command("chain") def analyze_chain( + ctx: typer.Context, from_func: str = typer.Argument(..., help="Starting function"), to_func: str = typer.Argument(..., help="Target function"), max_depth: int = typer.Option(5, "--depth", "-d", help="Maximum call chain depth"), from_file: Optional[str] = typer.Option(None, "--from-file", help="File for starting function"), - to_file: Optional[str] = typer.Option(None, "--to-file", help="File for target function") + to_file: Optional[str] = typer.Option(None, "--to-file", help="File for target function"), + visual: bool = typer.Option(False, "--visual", "--viz", "-V", help="Show results as interactive graph visualization") ): """ Show call chain between two functions. @@ -1256,6 +1314,7 @@ def analyze_chain( Example: cgc analyze chain main process_data --depth 10 cgc analyze chain main process --from-file main.py --to-file utils.py + cgc analyze chain main process_data --visual """ _load_credentials() services = _initialize_services() @@ -1270,6 +1329,11 @@ def analyze_chain( console.print(f"[yellow]No call chain found between '{from_func}' and '{to_func}' within depth {max_depth}[/yellow]") return + # Check if visual mode is enabled + if check_visual_flag(ctx, visual): + visualize_call_chain(results, from_func, to_func) + return + for idx, chain in enumerate(results, 1): console.print(f"\n[bold cyan]Call Chain #{idx} (length: {chain.get('chain_length', 0)}):[/bold cyan]") @@ -1309,8 +1373,10 @@ def analyze_chain( @analyze_app.command("deps") def analyze_dependencies( + ctx: typer.Context, target: str = typer.Argument(..., help="Module name"), - show_external: bool = typer.Option(True, "--external/--no-external", help="Show external dependencies") + show_external: bool = typer.Option(True, "--external/--no-external", help="Show external dependencies"), + visual: bool = typer.Option(False, "--visual", "--viz", "-V", help="Show results as interactive graph visualization") ): """ Show dependencies and imports for a module. @@ -1318,6 +1384,7 @@ def analyze_dependencies( Example: cgc analyze deps numpy cgc analyze deps mymodule --no-external + cgc analyze deps mymodule --visual """ _load_credentials() services = _initialize_services() @@ -1332,6 +1399,11 @@ def analyze_dependencies( console.print(f"[yellow]No dependency information found for '{target}'[/yellow]") return + # Check if visual mode is enabled + if check_visual_flag(ctx, visual): + visualize_dependencies(results, target) + return + # Show who imports this module if results.get('importers'): console.print(f"\n[bold cyan]Files that import '{target}':[/bold cyan]") @@ -1366,8 +1438,10 @@ def analyze_dependencies( @analyze_app.command("tree") def analyze_inheritance_tree( + ctx: typer.Context, class_name: str = typer.Argument(..., help="Class name"), - file: Optional[str] = typer.Option(None, "--file", "-f", help="Specific file path") + file: Optional[str] = typer.Option(None, "--file", "-f", help="Specific file path"), + visual: bool = typer.Option(False, "--visual", "--viz", "-V", help="Show results as interactive graph visualization") ): """ Show inheritance hierarchy for a class. @@ -1375,6 +1449,7 @@ def analyze_inheritance_tree( Example: cgc analyze tree MyClass cgc analyze tree MyClass --file src/models.py + cgc analyze tree MyClass --visual """ _load_credentials() services = _initialize_services() @@ -1385,6 +1460,15 @@ def analyze_inheritance_tree( try: results = code_finder.find_class_hierarchy(class_name, file) + # Check if visual mode is enabled (check for any hierarchy data) + has_hierarchy = results.get('parent_classes') or results.get('child_classes') + if check_visual_flag(ctx, visual): + if has_hierarchy: + visualize_inheritance_tree(results, class_name) + else: + console.print(f"[yellow]No inheritance hierarchy to visualize for '{class_name}'[/yellow]") + return + console.print(f"\n[bold cyan]Class Hierarchy for '{class_name}':[/bold cyan]\n") # Show parent classes @@ -1533,7 +1617,9 @@ def analyze_dead_code( @analyze_app.command("overrides") def analyze_overrides( - function_name: str = typer.Argument(..., help="Function/method name to find implementations of") + ctx: typer.Context, + function_name: str = typer.Argument(..., help="Function/method name to find implementations of"), + visual: bool = typer.Option(False, "--visual", "--viz", "-V", help="Show results as interactive graph visualization") ): """ Find all implementations of a function across different classes. @@ -1543,6 +1629,7 @@ def analyze_overrides( Example: cgc analyze overrides area cgc analyze overrides process + cgc analyze overrides area --visual """ _load_credentials() services = _initialize_services() @@ -1557,6 +1644,11 @@ def analyze_overrides( console.print(f"[yellow]No implementations found for function '{function_name}'[/yellow]") return + # Check if visual mode is enabled + if check_visual_flag(ctx, visual): + visualize_overrides(results, function_name) + return + table = Table(show_header=True, header_style="bold magenta", box=box.ROUNDED) table.add_column("Class", style="cyan") table.add_column("Function", style="green") @@ -1651,16 +1743,26 @@ def analyze_variable_usage( # ============================================================================ @app.command("query") -def query_graph(query: str = typer.Argument(..., help="Cypher query to execute (read-only)")): +def query_graph( + ctx: typer.Context, + query: str = typer.Argument(..., help="Cypher query to execute (read-only)"), + visual: bool = typer.Option(False, "--visual", "--viz", "-V", help="Show results as interactive graph visualization") +): """ Execute a custom Cypher query on the code graph. Examples: cgc query "MATCH (f:Function) RETURN f.name LIMIT 10" cgc query "MATCH (c:Class)-[:CONTAINS]->(m) RETURN c.name, count(m)" + cgc query "MATCH (n)-[r]->(m) RETURN n,r,m LIMIT 50" --visual """ _load_credentials() - cypher_helper(query) + + # Check if visual mode is enabled + if check_visual_flag(ctx, visual): + cypher_helper_visual(query) + else: + cypher_helper(query) # Keep old 'cypher' as alias for backward compatibility @app.command("cypher", hidden=True) @@ -1730,6 +1832,13 @@ def main( "-d", help="[Global] Temporarily override database backend (falkordb or neo4j) for any command" ), + visual: bool = typer.Option( + False, + "--visual", + "--viz", + "-V", + help="[Global] Show results as interactive graph visualization in browser" + ), version_: bool = typer.Option( None, "--version", @@ -1749,9 +1858,16 @@ def main( Main entry point for the cgc CLI application. If no subcommand is provided, it displays a welcome message with instructions. """ + # Initialize context object for sharing state with subcommands + ctx.ensure_object(dict) + if database: os.environ["CGC_RUNTIME_DB_TYPE"] = database + # Store visual flag in context for subcommands to access + if visual: + ctx.obj["visual"] = True + if version_: console.print(f"CodeGraphContext [bold cyan]{get_version()}[/bold cyan]") raise typer.Exit() @@ -1767,6 +1883,8 @@ def main( console.print(" • [cyan]cgc list[/cyan] - List indexed repositories\n") console.print("📊 [bold]Using Neo4j instead of FalkorDB?[/bold]") console.print(" Run [cyan]cgc neo4j setup[/cyan] (or [cyan]cgc n[/cyan]) to configure Neo4j\n") + console.print("📈 [bold]Want visual graph output?[/bold]") + console.print(" Add [cyan]-V[/cyan] or [cyan]--visual[/cyan] to any analyze/find command\n") console.print("👉 Run [cyan]cgc help[/cyan] to see all available commands") console.print("👉 Run [cyan]cgc --version[/cyan] to check the version\n") console.print("👉 Running [green]codegraphcontext[/green] works the same as using [green]cgc[/green]") diff --git a/src/codegraphcontext/cli/visualizer.py b/src/codegraphcontext/cli/visualizer.py new file mode 100644 index 0000000..72dc013 --- /dev/null +++ b/src/codegraphcontext/cli/visualizer.py @@ -0,0 +1,1082 @@ +# src/codegraphcontext/cli/visualizer.py +""" +Visualization module for CodeGraphContext CLI. + +This module generates interactive HTML graph visualizations using vis-network.js +for various CLI command outputs (analyze calls, callers, chain, deps, tree, etc.). + +The visualizations are standalone HTML files that can be opened in any browser. +""" + +import html +import json +import uuid +import webbrowser +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, Literal +from rich.console import Console + +console = Console(stderr=True) + + +def escape_html(text: Any) -> str: + """Safely escape HTML special characters to prevent XSS.""" + if text is None: + return "" + return html.escape(str(text)) + + +def get_visualization_dir() -> Path: + """Get or create the visualization output directory.""" + viz_dir = Path.home() / ".codegraphcontext" / "visualizations" + viz_dir.mkdir(parents=True, exist_ok=True) + return viz_dir + + +def generate_filename(prefix: str = "cgc_viz") -> str: + """Generate a unique filename with timestamp.""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + unique = uuid.uuid4().hex[:8] + return f"{prefix}_{timestamp}_{unique}.html" + + +def _json_for_inline_script(data: Any) -> str: + """Serialize to JSON safe to embed directly inside a from terminating the script. + """ + raw = json.dumps( + data, + ensure_ascii=False, + separators=(",", ":"), + default=str, + ) + # Mitigate XSS via breaking out of script context. + raw = raw.replace(" Dict[str, str]: + """Return color configuration based on node type.""" + colors = { + "Function": {"background": "#4caf50", "border": "#388e3c"}, # Green + "Class": {"background": "#ff9800", "border": "#f57c00"}, # Orange + "Module": {"background": "#9c27b0", "border": "#7b1fa2"}, # Purple + "File": {"background": "#2196f3", "border": "#1976d2"}, # Blue + "Repository": {"background": "#e91e63", "border": "#c2185b"}, # Pink + "Package": {"background": "#607d8b", "border": "#455a64"}, # Grey + "Variable": {"background": "#795548", "border": "#5d4037"}, # Brown + "Caller": {"background": "#00bcd4", "border": "#0097a7"}, # Cyan + "Callee": {"background": "#8bc34a", "border": "#689f38"}, # Light Green + "Target": {"background": "#f44336", "border": "#d32f2f"}, # Red + "Source": {"background": "#3f51b5", "border": "#303f9f"}, # Indigo + "Parent": {"background": "#ff5722", "border": "#e64a19"}, # Deep Orange + "Child": {"background": "#009688", "border": "#00796b"}, # Teal + "Override": {"background": "#673ab7", "border": "#512da8"}, # Deep Purple + "default": {"background": "#97c2fc", "border": "#2b7ce9"}, # Default blue + } + return colors.get(node_type, colors["default"]) + + +def generate_html_template( + nodes: List[Dict], + edges: List[Dict], + title: str, + layout_type: str = "force", + description: str = "" +) -> str: + """ + Generate standalone HTML with vis-network.js visualization. + + Args: + nodes: List of node dictionaries with id, label, group, title, color + edges: List of edge dictionaries with from, to, label, arrows + title: Title for the visualization + layout_type: "force" for force-directed, "hierarchical" for tree layouts + description: Optional description to show in the header + + Returns: + Complete HTML string + """ + # Configure layout options based on type + if layout_type == "hierarchical": + layout_options = """ + layout: { + hierarchical: { + enabled: true, + direction: 'UD', + sortMethod: 'directed', + levelSeparation: 100, + nodeSpacing: 150, + treeSpacing: 200, + blockShifting: true, + edgeMinimization: true, + parentCentralization: true + } + }, + physics: { + enabled: false + } + """ + elif layout_type == "hierarchical_lr": + layout_options = """ + layout: { + hierarchical: { + enabled: true, + direction: 'LR', + sortMethod: 'directed', + levelSeparation: 200, + nodeSpacing: 100, + treeSpacing: 200 + } + }, + physics: { + enabled: false + } + """ + else: # force-directed + layout_options = """ + layout: { + improvedLayout: true + }, + physics: { + enabled: true, + forceAtlas2Based: { + gravitationalConstant: -50, + centralGravity: 0.01, + springLength: 150, + springConstant: 0.08, + damping: 0.4 + }, + maxVelocity: 50, + solver: 'forceAtlas2Based', + timestep: 0.35, + stabilization: { + enabled: true, + iterations: 200, + updateInterval: 25 + } + } + """ + + # Escape user-provided content to prevent XSS + safe_title = escape_html(title) + safe_description = escape_html(description) + + # Escape tooltip HTML (vis-network treats title as HTML) + safe_nodes: List[Dict[str, Any]] = [] + for node in nodes: + node_copy = dict(node) + if "title" in node_copy: + node_copy["title"] = escape_html(node_copy.get("title", "")) + safe_nodes.append(node_copy) + safe_edges: List[Dict[str, Any]] = [dict(edge) for edge in edges] + + html_content = f""" + + + {safe_title} - CodeGraphContext + + + + + +
+
+ + {safe_title} +
+
+
+ Nodes: + {len(nodes)} +
+
+ Edges: + {len(edges)} +
+
+
+ {f'
{safe_description}
' if description else ''} + +
+ +
+
Legend
+
+
+ +
+ Drag to pan • Scroll to zoom • Click node to highlight +
+ + + + +""" + return html_content + + +def visualize_call_graph( + results: List[Dict], + function_name: str, + direction: Literal["outgoing", "incoming"] = "outgoing" +) -> Optional[str]: + """ + Visualize function call relationships (calls or callers). + + Args: + results: List of call results from CodeFinder + function_name: The central function name + direction: "outgoing" for calls, "incoming" for callers + + Returns: + Path to generated HTML file, or None if no results + """ + if not results: + console.print("[yellow]No results to visualize.[/yellow]") + return None + + nodes = [] + edges = [] + seen_nodes = set() + + # Add central function node + central_id = f"central_{function_name}" + central_color = get_node_color("Source" if direction == "outgoing" else "Target") + nodes.append({ + "id": central_id, + "label": function_name, + "group": "Source" if direction == "outgoing" else "Target", + "title": f"{'Caller' if direction == 'outgoing' else 'Called'}: {function_name}", + "color": central_color, + "size": 30, + "font": {"size": 16, "color": "#ffffff"} + }) + seen_nodes.add(central_id) + + for idx, result in enumerate(results): + if direction == "outgoing": + # calls: function_name -> called_function + func_name = result.get("called_function", f"unknown_{idx}") + file_path = result.get("called_file_path", "") + line_num = result.get("called_line_number", "") + is_dep = result.get("called_is_dependency", False) + else: + # callers: caller_function -> function_name + func_name = result.get("caller_function", f"unknown_{idx}") + file_path = result.get("caller_file_path", "") + line_num = result.get("caller_line_number", "") + is_dep = result.get("caller_is_dependency", False) + + node_id = f"node_{func_name}_{idx}" + node_type = "Callee" if direction == "outgoing" else "Caller" + if is_dep: + node_type = "Package" + + if node_id not in seen_nodes: + color = get_node_color(node_type) + nodes.append({ + "id": node_id, + "label": func_name, + "group": node_type, + "title": f"{func_name}\nFile: {file_path}\nLine: {line_num}", + "color": color + }) + seen_nodes.add(node_id) + + if direction == "outgoing": + edges.append({ + "from": central_id, + "to": node_id, + "label": "calls", + "arrows": "to" + }) + else: + edges.append({ + "from": node_id, + "to": central_id, + "label": "calls", + "arrows": "to" + }) + + title = f"{'Outgoing Calls' if direction == 'outgoing' else 'Incoming Callers'}: {function_name}" + description = f"Showing {len(results)} {'called functions' if direction == 'outgoing' else 'caller functions'}" + + html = generate_html_template(nodes, edges, title, layout_type="force", description=description) + return save_and_open_visualization(html, f"cgc_{'calls' if direction == 'outgoing' else 'callers'}") + + +def visualize_call_chain( + results: List[Dict], + from_func: str, + to_func: str +) -> Optional[str]: + """ + Visualize call chain between two functions. + + Args: + results: List of chain results, each containing function_chain + from_func: Starting function name + to_func: Target function name + + Returns: + Path to generated HTML file, or None if no results + """ + if not results: + console.print("[yellow]No call chain found to visualize.[/yellow]") + return None + + nodes = [] + edges = [] + seen_nodes = set() + + for chain_idx, chain in enumerate(results): + functions = chain.get("function_chain", []) + + for idx, func in enumerate(functions): + func_name = func.get("name", f"unknown_{idx}") + file_path = func.get("file_path", "") + line_num = func.get("line_number", "") + + node_id = f"chain{chain_idx}_{func_name}_{idx}" + + # Determine node type based on position + if idx == 0: + node_type = "Source" + elif idx == len(functions) - 1: + node_type = "Target" + else: + node_type = "Function" + + if node_id not in seen_nodes: + color = get_node_color(node_type) + nodes.append({ + "id": node_id, + "label": func_name, + "group": node_type, + "title": f"{func_name}\nFile: {file_path}\nLine: {line_num}", + "color": color, + "level": idx # For hierarchical layout + }) + seen_nodes.add(node_id) + + # Add edge to next function in chain + if idx < len(functions) - 1: + next_func = functions[idx + 1] + next_name = next_func.get("name", f"unknown_{idx+1}") + next_id = f"chain{chain_idx}_{next_name}_{idx+1}" + edges.append({ + "from": node_id, + "to": next_id, + "label": "→", + "arrows": "to" + }) + + title = f"Call Chain: {from_func} → {to_func}" + description = f"Found {len(results)} path(s)" + + html = generate_html_template(nodes, edges, title, layout_type="hierarchical", description=description) + return save_and_open_visualization(html, "cgc_chain") + + +def visualize_dependencies( + results: Dict, + module_name: str +) -> Optional[str]: + """ + Visualize module dependencies (imports and importers). + + Args: + results: Dict with 'importers' and 'imports' lists + module_name: The central module name + + Returns: + Path to generated HTML file, or None if no results + """ + importers = results.get("importers", []) + imports = results.get("imports", []) + + if not importers and not imports: + console.print("[yellow]No dependency information to visualize.[/yellow]") + return None + + nodes = [] + edges = [] + seen_nodes = set() + + # Central module node + central_id = f"central_{module_name}" + color = get_node_color("Module") + nodes.append({ + "id": central_id, + "label": module_name, + "group": "Module", + "title": f"Module: {module_name}", + "color": color, + "size": 30 + }) + seen_nodes.add(central_id) + + # Files that import this module + for idx, imp in enumerate(importers): + file_path = imp.get("importer_file_path", f"file_{idx}") + file_name = Path(file_path).name if file_path else f"file_{idx}" + node_id = f"importer_{idx}" + + if node_id not in seen_nodes: + color = get_node_color("File") + nodes.append({ + "id": node_id, + "label": file_name, + "group": "Importer", + "title": f"File: {file_path}\nLine: {imp.get('import_line_number', '')}", + "color": color + }) + seen_nodes.add(node_id) + + edges.append({ + "from": node_id, + "to": central_id, + "label": "imports", + "arrows": "to" + }) + + # Modules that this module imports + for idx, imp in enumerate(imports): + imported_module = imp.get("imported_module", f"module_{idx}") + alias = imp.get("import_alias", "") + node_id = f"imported_{idx}" + + if node_id not in seen_nodes: + color = get_node_color("Package") + nodes.append({ + "id": node_id, + "label": imported_module + (f" as {alias}" if alias else ""), + "group": "Imported", + "title": f"Module: {imported_module}", + "color": color + }) + seen_nodes.add(node_id) + + edges.append({ + "from": central_id, + "to": node_id, + "label": "imports", + "arrows": "to" + }) + + title = f"Dependencies: {module_name}" + description = f"{len(importers)} importer(s), {len(imports)} import(s)" + + html = generate_html_template(nodes, edges, title, layout_type="force", description=description) + return save_and_open_visualization(html, "cgc_deps") + + +def visualize_inheritance_tree( + results: Dict, + class_name: str +) -> Optional[str]: + """ + Visualize class inheritance hierarchy. + + Args: + results: Dict with 'parent_classes', 'child_classes', and 'methods' + class_name: The central class name + + Returns: + Path to generated HTML file, or None if no results + """ + parents = results.get("parent_classes", []) + children = results.get("child_classes", []) + methods = results.get("methods", []) + + if not parents and not children: + console.print("[yellow]No inheritance hierarchy to visualize.[/yellow]") + return None + + nodes = [] + edges = [] + seen_nodes = set() + + # Central class node + central_id = f"central_{class_name}" + color = get_node_color("Class") + method_list = ", ".join([m.get("method_name", "") for m in methods[:5]]) + if len(methods) > 5: + method_list += f"... (+{len(methods) - 5} more)" + + nodes.append({ + "id": central_id, + "label": class_name, + "group": "Class", + "title": f"Class: {class_name}\nMethods: {method_list or 'None'}", + "color": color, + "size": 30, + "level": 1 # Middle level + }) + seen_nodes.add(central_id) + + # Parent classes (above) + for idx, parent in enumerate(parents): + parent_name = parent.get("parent_class", f"Parent_{idx}") + file_path = parent.get("parent_file_path", "") + node_id = f"parent_{idx}" + + if node_id not in seen_nodes: + color = get_node_color("Parent") + nodes.append({ + "id": node_id, + "label": parent_name, + "group": "Parent", + "title": f"Parent: {parent_name}\nFile: {file_path}", + "color": color, + "level": 0 # Top level + }) + seen_nodes.add(node_id) + + edges.append({ + "from": central_id, + "to": node_id, + "label": "extends", + "arrows": "to" + }) + + # Child classes (below) + for idx, child in enumerate(children): + child_name = child.get("child_class", f"Child_{idx}") + file_path = child.get("child_file_path", "") + node_id = f"child_{idx}" + + if node_id not in seen_nodes: + color = get_node_color("Child") + nodes.append({ + "id": node_id, + "label": child_name, + "group": "Child", + "title": f"Child: {child_name}\nFile: {file_path}", + "color": color, + "level": 2 # Bottom level + }) + seen_nodes.add(node_id) + + edges.append({ + "from": node_id, + "to": central_id, + "label": "extends", + "arrows": "to" + }) + + title = f"Class Hierarchy: {class_name}" + description = f"{len(parents)} parent(s), {len(children)} child(ren), {len(methods)} method(s)" + + html = generate_html_template(nodes, edges, title, layout_type="hierarchical", description=description) + return save_and_open_visualization(html, "cgc_tree") + + +def visualize_overrides( + results: List[Dict], + function_name: str +) -> Optional[str]: + """ + Visualize function/method overrides across classes. + + Args: + results: List of override results with class_name and function info + function_name: The method name being overridden + + Returns: + Path to generated HTML file, or None if no results + """ + if not results: + console.print("[yellow]No overrides to visualize.[/yellow]") + return None + + nodes = [] + edges = [] + seen_nodes = set() + + # Central method name node + central_id = f"method_{function_name}" + color = get_node_color("Function") + nodes.append({ + "id": central_id, + "label": f"Method: {function_name}", + "group": "Method", + "title": f"Method: {function_name}\n{len(results)} implementation(s)", + "color": color, + "size": 30 + }) + seen_nodes.add(central_id) + + # Classes implementing this method + for idx, res in enumerate(results): + class_name = res.get("class_name", f"Class_{idx}") + file_path = res.get("class_file_path", "") + line_num = res.get("function_line_number", "") + node_id = f"class_{idx}" + + if node_id not in seen_nodes: + color = get_node_color("Override") + nodes.append({ + "id": node_id, + "label": class_name, + "group": "Class", + "title": f"Class: {class_name}\nFile: {file_path}\nLine: {line_num}", + "color": color + }) + seen_nodes.add(node_id) + + edges.append({ + "from": node_id, + "to": central_id, + "label": "implements", + "arrows": "to" + }) + + title = f"Overrides: {function_name}" + description = f"{len(results)} implementation(s) found" + + html = generate_html_template(nodes, edges, title, layout_type="force", description=description) + return save_and_open_visualization(html, "cgc_overrides") + + +def visualize_search_results( + results: List[Dict], + search_term: str, + search_type: str = "search" +) -> Optional[str]: + """ + Visualize search/find results as a cluster of nodes. + + Args: + results: List of search results with name, type, file_path, etc. + search_term: The search term used + search_type: Type of search (name, pattern, type) + + Returns: + Path to generated HTML file, or None if no results + """ + if not results: + console.print("[yellow]No search results to visualize.[/yellow]") + return None + + nodes = [] + edges = [] + seen_nodes = set() + + # Central search node + central_id = "search_center" + nodes.append({ + "id": central_id, + "label": f"Search: {search_term}", + "group": "Search", + "title": f"Search term: {search_term}\n{len(results)} result(s)", + "color": {"background": "#ff4081", "border": "#c51162"}, + "size": 35 + }) + seen_nodes.add(central_id) + + # Group results by type + for idx, res in enumerate(results): + name = res.get("name", f"result_{idx}") + node_type = res.get("type", "Unknown") + file_path = res.get("file_path", "") + line_num = res.get("line_number", "") + is_dep = res.get("is_dependency", False) + + node_id = f"result_{idx}" + + if node_id not in seen_nodes: + color = get_node_color(node_type if not is_dep else "Package") + nodes.append({ + "id": node_id, + "label": name, + "group": node_type, + "title": f"{node_type}: {name}\nFile: {file_path}\nLine: {line_num}", + "color": color + }) + seen_nodes.add(node_id) + + edges.append({ + "from": central_id, + "to": node_id, + "label": "matches", + "arrows": "to", + "dashes": True + }) + + title = f"Search Results: {search_term}" + description = f"Found {len(results)} match(es) for '{search_term}'" + + html = generate_html_template(nodes, edges, title, layout_type="force", description=description) + return save_and_open_visualization(html, f"cgc_find_{search_type}") + + +def _safe_json_dumps(obj: Any, indent: int = 2) -> str: + """Safely serialize object to JSON, handling non-serializable types.""" + def default_handler(o): + try: + return str(o) + except Exception: + return "" + + try: + return json.dumps(obj, indent=indent, default=default_handler) + except Exception: + return "{}" + + +def visualize_cypher_results( + records: List[Dict], + query: str +) -> Optional[str]: + """ + Visualize raw Cypher query results. + + Args: + records: List of records returned from Cypher query + query: The original Cypher query + + Returns: + Path to generated HTML file, or None if no results + """ + if not records: + console.print("[yellow]No query results to visualize.[/yellow]") + return None + + nodes = [] + edges = [] + seen_nodes = set() + + for record in records: + for key, value in record.items(): + if isinstance(value, dict): + # Likely a node + node_id = value.get("id", value.get("name", f"node_{len(seen_nodes)}")) + if str(node_id) not in seen_nodes: + labels = value.get("labels", [key]) + label = labels[0] if isinstance(labels, list) and labels else str(labels) + name = value.get("name", str(node_id)) + + color = get_node_color(label) + nodes.append({ + "id": str(node_id), + "label": str(name) if name else str(node_id), + "group": label, + "title": _safe_json_dumps(value), + "color": color + }) + seen_nodes.add(str(node_id)) + elif isinstance(value, list): + # Could be a path or list of nodes + for item in value: + if isinstance(item, dict): + node_id = item.get("id", item.get("name", f"node_{len(seen_nodes)}")) + if str(node_id) not in seen_nodes: + name = item.get("name", str(node_id)) + labels = item.get("labels", ["Node"]) + label = labels[0] if isinstance(labels, list) and labels else "Node" + + color = get_node_color(label) + nodes.append({ + "id": str(node_id), + "label": str(name) if name else str(node_id), + "group": label, + "title": _safe_json_dumps(item), + "color": color + }) + seen_nodes.add(str(node_id)) + + # NOTE: We intentionally do not infer edges when the Cypher query doesn't + # explicitly return relationships. Auto-linking sequential nodes can be + # misleading when the result set contains unrelated nodes. + + title = "Cypher Query Results" + # Truncate query for description + short_query = query[:50] + "..." if len(query) > 50 else query + description = f"Query: {short_query}" + + html = generate_html_template(nodes, edges, title, layout_type="force", description=description) + return save_and_open_visualization(html, "cgc_query") + + +def save_and_open_visualization(html_content: str, prefix: str = "cgc_viz") -> Optional[str]: + """ + Save HTML content to file and open in browser. + + Args: + html_content: The complete HTML string + prefix: Filename prefix + + Returns: + Path to the saved file, or None if saving failed + """ + viz_dir = get_visualization_dir() + filename = generate_filename(prefix) + filepath = viz_dir / filename + + try: + with open(filepath, "w", encoding="utf-8") as f: + f.write(html_content) + except (IOError, OSError) as e: + console.print(f"[red]Error saving visualization: {e}[/red]") + return None + + console.print(f"[green]✓ Visualization saved:[/green] {filepath}") + console.print("[dim]Opening in browser...[/dim]") + + # Open in default browser - use proper file URI format + try: + # Convert to proper file URI (works on Windows and Unix) + file_uri = filepath.as_uri() + webbrowser.open(file_uri) + except Exception as e: + console.print(f"[yellow]Could not open browser automatically: {e}[/yellow]") + console.print(f"[dim]Open this file manually: {filepath}[/dim]") + + return str(filepath) + + +def check_visual_flag(ctx: Any, local_visual: bool = False) -> bool: + """ + Check if visual mode is enabled (either globally or locally). + + Args: + ctx: Typer context object + local_visual: Local --visual flag value + + Returns: + True if visualization should be used + """ + global_visual = False + if ctx and hasattr(ctx, 'obj') and ctx.obj: + global_visual = ctx.obj.get("visual", False) + return local_visual or global_visual diff --git a/tests/test_visualization.py b/tests/test_visualization.py new file mode 100644 index 0000000..1af5ee8 --- /dev/null +++ b/tests/test_visualization.py @@ -0,0 +1,548 @@ +# tests/test_visualization.py +""" +Tests for the visualization module and --visual flag functionality. + +These tests verify that: +- The visualizer module generates correct HTML +- The --visual flag works at both global and command levels +- Different visualization types produce appropriate output +- Edge cases are handled gracefully +""" + +from pathlib import Path +from unittest.mock import patch, MagicMock +from typer.testing import CliRunner + +# Import the visualizer module +from codegraphcontext.cli.visualizer import ( + get_visualization_dir, + generate_filename, + get_node_color, + generate_html_template, + visualize_call_graph, + visualize_call_chain, + visualize_dependencies, + visualize_inheritance_tree, + visualize_overrides, + visualize_search_results, + visualize_cypher_results, + check_visual_flag, + escape_html, + _safe_json_dumps, +) + +# Import the CLI app for integration tests +from codegraphcontext.cli.main import app + + +runner = CliRunner() + + +class TestVisualizerUtilities: + """Tests for utility functions in the visualizer module.""" + + def test_get_visualization_dir_creates_directory(self): + """Test that get_visualization_dir creates the directory if it doesn't exist.""" + viz_dir = get_visualization_dir() + assert viz_dir.exists() + assert viz_dir.is_dir() + assert viz_dir == Path.home() / ".codegraphcontext" / "visualizations" + + def test_generate_filename_format(self): + """Test that generate_filename produces correctly formatted filenames.""" + filename = generate_filename("test_prefix") + assert filename.startswith("test_prefix_") + assert filename.endswith(".html") + # Should have timestamp format: prefix_YYYYMMDD_HHMMSS.html + parts = filename.replace(".html", "").split("_") + assert len(parts) >= 3 + + def test_get_node_color_known_types(self): + """Test that get_node_color returns correct colors for known types.""" + function_color = get_node_color("Function") + assert "background" in function_color + assert "border" in function_color + assert function_color["background"] == "#4caf50" # Green + + class_color = get_node_color("Class") + assert class_color["background"] == "#ff9800" # Orange + + def test_get_node_color_unknown_type(self): + """Test that get_node_color returns default color for unknown types.""" + unknown_color = get_node_color("UnknownType") + assert "background" in unknown_color + assert unknown_color["background"] == "#97c2fc" # Default blue + + def test_escape_html_basic(self): + """Test escape_html escapes special characters.""" + assert escape_html("" in html + + def test_html_has_proper_encoding(self): + """Test that HTML has proper encoding meta tag.""" + nodes = [{"id": "1", "label": "test", "group": "Test"}] + edges = [] + + html = generate_html_template(nodes, edges, "Encoding Test") + + assert 'charset="utf-8"' in html or "charset='utf-8'" in html + + def test_html_handles_special_characters(self): + """Test that HTML handles special characters in node labels.""" + nodes = [ + {"id": "1", "label": 'func', "group": "Function"}, + {"id": "2", "label": "process & handle", "group": "Function"}, + {"id": "3", "label": '"quoted"', "group": "Function"}, + ] + edges = [] + + html = generate_html_template(nodes, edges, "Special Chars Test") + + # Should not raise an error and should produce valid HTML + assert "" in html + # JSON encoding should handle special chars + assert "nodesData" in html + + def test_html_xss_protection_in_nodesdata_script_breakout(self): + """Test that in node labels cannot break out of inline script.""" + malicious = "" + nodes = [{"id": "1", "label": malicious, "group": "Function"}] + html = generate_html_template(nodes, [], "XSS Nodes Test") + # Exact breakout payload should never appear in generated HTML. + assert malicious not in html + # Ensure we actually applied the inline-script JSON escaping. + assert "<\\/script>" in html + + def test_html_xss_protection_in_title(self): + """Test that title is properly escaped to prevent XSS attacks.""" + nodes = [{"id": "1", "label": "test", "group": "Test"}] + edges = [] + + # Attempt XSS via title + malicious_title = '' + html = generate_html_template(nodes, edges, malicious_title) + + # The script tag should be escaped, not raw + assert '' not in html + assert '<script>' in html or 'script>' in html + + def test_html_xss_protection_in_description(self): + """Test that description is properly escaped to prevent XSS attacks.""" + nodes = [{"id": "1", "label": "test", "group": "Test"}] + edges = [] + + # Attempt XSS via description + malicious_desc = '' + html = generate_html_template(nodes, edges, "Safe Title", description=malicious_desc) + + # The malicious tag should be escaped + assert '' not in html + assert '<img' in html or 'src=x' not in html \ No newline at end of file