diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 4e500b424..bcba375d2 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -15,10 +15,15 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions-cool/check-user-permission@v2
+ if: github.triggering_actor != 'codegen-sh[bot]'
with:
require: write
username: ${{ github.triggering_actor }}
error-if-missing: true
+ # Skip permission check for codegen-sh[bot]
+ - name: Skip permission check for bot
+ if: github.triggering_actor == 'codegen-sh[bot]'
+ run: echo "Skipping permission check for codegen-sh[bot]"
unit-tests:
needs: access-check
diff --git a/codegen-on-oss/codegen_on_oss/analysis/README.md b/codegen-on-oss/codegen_on_oss/analysis/README.md
new file mode 100644
index 000000000..096e5421b
--- /dev/null
+++ b/codegen-on-oss/codegen_on_oss/analysis/README.md
@@ -0,0 +1,183 @@
+# Code Analysis Module with Error Context
+
+This module provides robust and dynamic code analysis capabilities with a focus on error detection and contextual error information.
+
+## Overview
+
+The code analysis module consists of several components:
+
+1. **CodeAnalyzer**: The main class that integrates all analysis components and provides a unified interface.
+2. **ErrorContextAnalyzer**: A specialized class for detecting and analyzing errors in code.
+3. **CodeError**: A class representing an error in code with detailed context information.
+4. **API Endpoints**: FastAPI endpoints for accessing the analysis functionality.
+
+## Features
+
+### Code Structure Analysis
+
+- Analyze codebase structure and dependencies
+- Generate dependency graphs for files and symbols
+- Analyze import relationships and detect circular imports
+- Get detailed information about files, functions, classes, and symbols
+
+### Error Detection and Analysis
+
+- Detect syntax errors, type errors, parameter errors, and more
+- Analyze function parameters and return statements for errors
+- Detect undefined variables and unused imports
+- Find circular dependencies between symbols
+- Provide detailed context information for errors
+
+### API Endpoints
+
+- `/analyze_repo`: Analyze a repository and return various metrics
+- `/analyze_symbol`: Analyze a symbol and return detailed information
+- `/analyze_file`: Analyze a file and return detailed information
+- `/analyze_function`: Analyze a function and return detailed information
+- `/analyze_errors`: Analyze errors in a repository, file, or function
+
+## Error Types
+
+The module can detect the following types of errors:
+
+- **Syntax Errors**: Invalid syntax in code
+- **Type Errors**: Type mismatches in expressions
+- **Parameter Errors**: Incorrect function parameters
+- **Call Errors**: Incorrect function calls
+- **Undefined Variables**: Variables used without being defined
+- **Unused Imports**: Imports that are not used in the code
+- **Circular Imports**: Circular dependencies between files
+- **Circular Dependencies**: Circular dependencies between symbols
+- **Name Errors**: References to undefined names
+- **Import Errors**: Problems with import statements
+- **Attribute Errors**: References to undefined attributes
+
+## Error Severity Levels
+
+The module assigns severity levels to each error:
+
+- **Critical**: Errors that will definitely cause the code to crash or fail
+- **High**: Errors that are likely to cause problems in most execution paths
+- **Medium**: Errors that may cause problems in some execution paths
+- **Low**: Minor issues that are unlikely to cause problems but should be fixed
+- **Info**: Informational messages about potential improvements
+
+## Usage
+
+### Using the CodeAnalyzer
+
+```python
+from codegen import Codebase
+from codegen_on_oss.analysis.analysis import CodeAnalyzer
+
+# Create a codebase from a repository
+codebase = Codebase.from_repo("owner/repo")
+
+# Create an analyzer
+analyzer = CodeAnalyzer(codebase)
+
+# Analyze errors in the codebase
+errors = analyzer.analyze_errors()
+
+# Get detailed error context for a function
+function_errors = analyzer.get_function_error_context("function_name")
+
+# Get detailed error context for a file
+file_errors = analyzer.get_file_error_context("path/to/file.py")
+```
+
+### Using the API
+
+```bash
+# Analyze a repository
+curl -X POST "http://localhost:8000/analyze_repo" \
+ -H "Content-Type: application/json" \
+ -d '{"repo_url": "owner/repo"}'
+
+# Analyze errors in a function
+curl -X POST "http://localhost:8000/analyze_function" \
+ -H "Content-Type: application/json" \
+ -d '{"repo_url": "owner/repo", "function_name": "function_name"}'
+
+# Analyze errors in a file
+curl -X POST "http://localhost:8000/analyze_file" \
+ -H "Content-Type: application/json" \
+ -d '{"repo_url": "owner/repo", "file_path": "path/to/file.py"}'
+```
+
+## Error Context Example
+
+Here's an example of the error context information provided for a function:
+
+```json
+{
+ "function_name": "calculate_total",
+ "file_path": "app/utils.py",
+ "errors": [
+ {
+ "error_type": "parameter_error",
+ "message": "Function 'calculate_discount' called with 1 arguments but expects 2",
+ "line_number": 15,
+ "severity": "high",
+ "context_lines": {
+ "13": "def calculate_total(items):",
+ "14": " total = sum(item.price for item in items)",
+ "15": " discount = calculate_discount(total)",
+ "16": " return total - discount",
+ "17": ""
+ },
+ "suggested_fix": "Update call to provide 2 arguments: calculate_discount(total, discount_percent)"
+ }
+ ],
+ "callers": [
+ {"name": "process_order"}
+ ],
+ "callees": [
+ {"name": "calculate_discount"}
+ ],
+ "parameters": [
+ {
+ "name": "items",
+ "type": "List[Item]",
+ "default": null
+ }
+ ],
+ "return_info": {
+ "type": "float",
+ "statements": ["total - discount"]
+ }
+}
+```
+
+## Implementation Details
+
+### ErrorContextAnalyzer
+
+The `ErrorContextAnalyzer` class is responsible for detecting and analyzing errors in code. It uses various techniques to detect errors, including:
+
+- **AST Analysis**: Parsing the code into an abstract syntax tree to detect syntax errors and undefined variables
+- **Graph Analysis**: Building dependency graphs to detect circular imports and dependencies
+- **Pattern Matching**: Using regular expressions to detect potential type errors and other issues
+- **Static Analysis**: Analyzing function parameters, return statements, and variable usage
+
+### CodeError
+
+The `CodeError` class represents an error in code with detailed context information. It includes:
+
+- **Error Type**: The type of error (syntax, type, parameter, etc.)
+- **Message**: A descriptive message explaining the error
+- **Location**: The file path and line number where the error occurs
+- **Severity**: The severity of the error (critical, high, medium, low, info)
+- **Context Lines**: The lines of code surrounding the error
+- **Suggested Fix**: A suggested fix for the error
+
+## Running the API Server
+
+To run the API server locally:
+
+```bash
+cd codegen-on-oss
+python -m codegen_on_oss.analysis.analysis
+```
+
+The server will be available at `http://localhost:8000`.
diff --git a/codegen-on-oss/codegen_on_oss/analysis/analysis.py b/codegen-on-oss/codegen_on_oss/analysis/analysis.py
index 9e956ec06..98ce44030 100644
--- a/codegen-on-oss/codegen_on_oss/analysis/analysis.py
+++ b/codegen-on-oss/codegen_on_oss/analysis/analysis.py
@@ -1,37 +1,100 @@
-from fastapi import FastAPI
-from pydantic import BaseModel
-from typing import Dict, List, Tuple, Any
+"""
+Unified Analysis Module for Codegen-on-OSS
+
+This module serves as a central hub for all code analysis functionality, integrating
+various specialized analysis components into a cohesive system.
+"""
+
+import contextlib
+import math
+import os
+import re
+import subprocess
+import tempfile
+from datetime import UTC, datetime, timedelta
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import networkx as nx
+import requests
+import uvicorn
from codegen import Codebase
+from codegen.sdk.core.class_definition import Class
+from codegen.sdk.core.directory import Directory
+from codegen.sdk.core.expressions.binary_expression import BinaryExpression
+from codegen.sdk.core.expressions.comparison_expression import ComparisonExpression
+from codegen.sdk.core.expressions.unary_expression import UnaryExpression
+from codegen.sdk.core.external_module import ExternalModule
+from codegen.sdk.core.file import SourceFile
+from codegen.sdk.core.function import Function
+from codegen.sdk.core.import_resolution import Import
from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement
from codegen.sdk.core.statements.if_block_statement import IfBlockStatement
from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement
from codegen.sdk.core.statements.while_statement import WhileStatement
-from codegen.sdk.core.expressions.binary_expression import BinaryExpression
-from codegen.sdk.core.expressions.unary_expression import UnaryExpression
-from codegen.sdk.core.expressions.comparison_expression import ComparisonExpression
-import math
-import re
-import requests
-from datetime import datetime, timedelta
-import subprocess
-import os
-import tempfile
+from codegen.sdk.core.symbol import Symbol
+from codegen.sdk.enums import EdgeType, SymbolType
+from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
-import modal
-
-image = (
- modal.Image.debian_slim()
- .apt_install("git")
- .pip_install(
- "codegen", "fastapi", "uvicorn", "gitpython", "requests", "pydantic", "datetime"
- )
-)
+from pydantic import BaseModel
-app = modal.App(name="analytics-app", image=image)
+# Import from other analysis modules
+from codegen_on_oss.analysis.analysis_import import (
+ create_graph_from_codebase,
+ convert_all_calls_to_kwargs,
+ find_import_cycles,
+ find_problematic_import_loops,
+)
+from codegen_on_oss.analysis.codebase_analysis import (
+ get_class_summary,
+ get_codebase_summary,
+ get_file_summary,
+ get_function_summary,
+ get_symbol_summary,
+)
+from codegen_on_oss.analysis.codebase_context import CodebaseContext
+from codegen_on_oss.analysis.codegen_sdk_codebase import (
+ get_codegen_sdk_subdirectories,
+ get_codegen_sdk_codebase,
+)
+from codegen_on_oss.analysis.current_code_codebase import (
+ get_graphsitter_repo_path,
+ get_codegen_codebase_base_path,
+ get_current_code_codebase,
+ import_all_codegen_sdk_modules,
+ DocumentedObjects,
+ get_documented_objects,
+)
+from codegen_on_oss.analysis.document_functions import (
+ hop_through_imports,
+ get_extended_context,
+ run as document_functions_run,
+)
+from codegen_on_oss.analysis.error_context import CodeError, ErrorContextAnalyzer
+from codegen_on_oss.analysis.mdx_docs_generation import (
+ render_mdx_page_for_class,
+ render_mdx_page_title,
+ render_mdx_inheritence_section,
+ render_mdx_attributes_section,
+ render_mdx_methods_section,
+ render_mdx_for_attribute,
+ format_parameter_for_mdx,
+ format_parameters_for_mdx,
+ format_return_for_mdx,
+ render_mdx_for_method,
+ get_mdx_route_for_class,
+ format_type_string,
+ resolve_type_string,
+ format_builtin_type_string,
+ span_type_string_by_pipe,
+ parse_link,
+)
+from codegen_on_oss.analysis.module_dependencies import run as module_dependencies_run
+from codegen_on_oss.analysis.symbolattr import print_symbol_attribution
-fastapi_app = FastAPI()
+# Create FastAPI app
+app = FastAPI()
-fastapi_app.add_middleware(
+app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
@@ -39,393 +102,1111 @@
allow_headers=["*"],
)
+class CodeAnalyzer:
+ """
+ Central class for code analysis that integrates all analysis components.
-def get_monthly_commits(repo_path: str) -> Dict[str, int]:
+ This class serves as the main entry point for all code analysis functionality,
+ providing a unified interface to access various analysis capabilities.
"""
- Get the number of commits per month for the last 12 months.
+
+ def __init__(self, codebase: Codebase):
+ """
+ Initialize the CodeAnalyzer with a codebase.
+
+ Args:
+ codebase: The Codebase object to analyze
+ """
+ self.codebase = codebase
+ self._context = None
+ self._initialized = False
+ self._error_analyzer = None
+
+ def initialize(self) -> None:
+ """
+ Initialize the analyzer by setting up the context and other necessary components.
+ This is called automatically when needed but can be called explicitly for eager initialization.
+ """
+ if self._initialized:
+ return
+
+ # Initialize context if not already done
+ if self._context is None:
+ self._context = self._create_context()
+
+ self._initialized = True
+
+ def _create_context(self) -> CodebaseContext:
+ """
+ Create a CodebaseContext instance for the current codebase.
+
+ Returns:
+ A new CodebaseContext instance
+ """
+ # If the codebase already has a context, use it
+ if hasattr(self.codebase, "ctx") and self.codebase.ctx is not None:
+ return self.codebase.ctx
+
+ # Otherwise, create a new context from the codebase's configuration
+ from codegen.configs.models.codebase import CodebaseConfig
+ from codegen.sdk.codebase.config import ProjectConfig
+
+ # Create a project config from the codebase
+ project_config = ProjectConfig(
+ repo_operator=self.codebase.repo_operator,
+ programming_language=self.codebase.programming_language,
+ base_path=self.codebase.base_path,
+ )
+
+ # Create and return a new context
+ return CodebaseContext([project_config], config=CodebaseConfig())
+
+ @property
+ def context(self) -> CodebaseContext:
+ """
+ Get the CodebaseContext for the current codebase.
+
+ Returns:
+ A CodebaseContext object for the codebase
+ """
+ if not self._initialized:
+ self.initialize()
+
+ return self._context
+
+ @property
+ def error_analyzer(self) -> ErrorContextAnalyzer:
+ """
+ Get the ErrorContextAnalyzer for the current codebase.
+
+ Returns:
+ An ErrorContextAnalyzer object for the codebase
+ """
+ if self._error_analyzer is None:
+ self._error_analyzer = ErrorContextAnalyzer(self.codebase)
+
+ return self._error_analyzer
+
+ def get_codebase_summary(self) -> str:
+ """
+ Get a comprehensive summary of the codebase.
+
+ Returns:
+ A string containing summary information about the codebase
+ """
+ return get_codebase_summary(self.codebase)
+
+ def get_file_summary(self, file_path: str) -> str:
+ """
+ Get a summary of a specific file.
+
+ Args:
+ file_path: Path to the file to analyze
+
+ Returns:
+ A string containing summary information about the file
+ """
+ file = self.codebase.get_file(file_path)
+ if file is None:
+ return f"File not found: {file_path}"
+ return get_file_summary(file)
+
+ def get_class_summary(self, class_name: str) -> str:
+ """
+ Get a summary of a specific class.
+
+ Args:
+ class_name: Name of the class to analyze
+
+ Returns:
+ A string containing summary information about the class
+ """
+ for cls in self.codebase.classes:
+ if cls.name == class_name:
+ return get_class_summary(cls)
+ return f"Class not found: {class_name}"
+
+ def get_function_summary(self, function_name: str) -> str:
+ """
+ Get a summary of a specific function.
+
+ Args:
+ function_name: Name of the function to analyze
+
+ Returns:
+ A string containing summary information about the function
+ """
+ for func in self.codebase.functions:
+ if func.name == function_name:
+ return get_function_summary(func)
+ return f"Function not found: {function_name}"
+
+ def get_symbol_summary(self, symbol_name: str) -> str:
+ """
+ Get a summary of a specific symbol.
+
+ Args:
+ symbol_name: Name of the symbol to analyze
+
+ Returns:
+ A string containing summary information about the symbol
+ """
+ for symbol in self.codebase.symbols:
+ if symbol.name == symbol_name:
+ return get_symbol_summary(symbol)
+ return f"Symbol not found: {symbol_name}"
+
+ def find_symbol_by_name(self, symbol_name: str) -> Symbol | None:
+ """
+ Find a symbol by its name.
+
+ Args:
+ symbol_name: Name of the symbol to find
+
+ Returns:
+ The Symbol object if found, None otherwise
+ """
+ for symbol in self.codebase.symbols:
+ if symbol.name == symbol_name:
+ return symbol
+ return None
+
+ def find_file_by_path(self, file_path: str) -> SourceFile | None:
+ """
+ Find a file by its path.
+
+ Args:
+ file_path: Path to the file to find
+
+ Returns:
+ The SourceFile object if found, None otherwise
+ """
+ return self.codebase.get_file(file_path)
+
+ def find_class_by_name(self, class_name: str) -> Class | None:
+ """
+ Find a class by its name.
+
+ Args:
+ class_name: Name of the class to find
+
+ Returns:
+ The Class object if found, None otherwise
+ """
+ for cls in self.codebase.classes:
+ if cls.name == class_name:
+ return cls
+ return None
+
+ def find_function_by_name(self, function_name: str) -> Function | None:
+ """
+ Find a function by its name.
+
+ Args:
+ function_name: Name of the function to find
+
+ Returns:
+ The Function object if found, None otherwise
+ """
+ for func in self.codebase.functions:
+ if func.name == function_name:
+ return func
+ return None
+
+ def document_functions(self) -> None:
+ """
+ Generate documentation for functions in the codebase.
+ """
+ document_functions_run(self.codebase)
+
+ def analyze_imports(self) -> dict[str, Any]:
+ """
+ Analyze import relationships in the codebase.
+
+ Returns:
+ A dictionary containing import analysis results
+ """
+ graph = create_graph_from_codebase(self.codebase)
+ cycles = find_import_cycles(graph)
+ problematic_loops = find_problematic_import_loops(graph, cycles)
+
+ return {
+ "import_graph": graph,
+ "cycles": cycles,
+ "problematic_loops": problematic_loops,
+ }
+
+ def get_dependency_graph(self) -> nx.DiGraph:
+ """
+ Get a dependency graph for the codebase files.
+
+ Returns:
+ A directed graph representing file dependencies
+ """
+ G = nx.DiGraph()
+
+ # Add nodes for all files
+ for file in self.codebase.files:
+ G.add_node(file.name, type="file")
+
+ # Add edges for imports
+ for file in self.codebase.files:
+ for imp in file.imports:
+ if imp.imported_symbol and hasattr(imp.imported_symbol, "file"):
+ imported_file = imp.imported_symbol.file
+ if imported_file and imported_file.name != file.name:
+ G.add_edge(file.name, imported_file.name)
+
+ return G
+
+ def get_symbol_attribution(self, symbol_name: str) -> str:
+ """
+ Get attribution information for a symbol.
+
+ Args:
+ symbol_name: Name of the symbol to analyze
+
+ Returns:
+ A string containing attribution information
+ """
+ symbol = self.find_symbol_by_name(symbol_name)
+ if symbol is None:
+ return f"Symbol not found: {symbol_name}"
+
+ return print_symbol_attribution(symbol)
+
+ def get_context_for_symbol(self, symbol_name: str) -> dict[str, Any]:
+ """
+ Get context information for a symbol.
+
+ Args:
+ symbol_name: Name of the symbol to analyze
+
+ Returns:
+ A dictionary containing context information
+ """
+ symbol = self.find_symbol_by_name(symbol_name)
+ if symbol is None:
+ return {"error": f"Symbol not found: {symbol_name}"}
+
+ # Use the context to get more information about the symbol
+ ctx = self.context
+
+ # Get symbol node ID in the context graph
+ node_id = None
+ for n_id, node in enumerate(ctx.nodes):
+ if isinstance(node, Symbol) and node.name == symbol_name:
+ node_id = n_id
+ break
+
+ if node_id is None:
+ return {"error": f"Symbol not found in context: {symbol_name}"}
+
+ # Get predecessors (symbols that use this symbol)
+ predecessors = []
+ for pred in ctx.predecessors(node_id):
+ if isinstance(pred, Symbol):
+ predecessors.append({
+ "name": pred.name,
+ "type": pred.symbol_type.name
+ if hasattr(pred, "symbol_type")
+ else "Unknown",
+ })
+
+ # Get successors (symbols used by this symbol)
+ successors = []
+ for succ in ctx.successors(node_id):
+ if isinstance(succ, Symbol):
+ successors.append({
+ "name": succ.name,
+ "type": succ.symbol_type.name
+ if hasattr(succ, "symbol_type")
+ else "Unknown",
+ })
+
+ return {
+ "symbol": {
+ "name": symbol.name,
+ "type": symbol.symbol_type.name
+ if hasattr(symbol, "symbol_type")
+ else "Unknown",
+ "file": symbol.file.name if hasattr(symbol, "file") else "Unknown",
+ },
+ "predecessors": predecessors,
+ "successors": successors,
+ }
+
+ def get_file_dependencies(self, file_path: str) -> dict[str, Any]:
+ """
+ Get dependency information for a file using CodebaseContext.
+
+ Args:
+ file_path: Path to the file to analyze
+
+ Returns:
+ A dictionary containing dependency information
+ """
+ file = self.find_file_by_path(file_path)
+ if file is None:
+ return {"error": f"File not found: {file_path}"}
+
+ # Use the context to get more information about the file
+ ctx = self.context
+
+ # Get file node ID in the context graph
+ node_id = None
+ for n_id, node in enumerate(ctx.nodes):
+ if isinstance(node, SourceFile) and node.name == file.name:
+ node_id = n_id
+ break
+
+ if node_id is None:
+ return {"error": f"File not found in context: {file_path}"}
+
+ # Get files that import this file
+ importers = []
+ for pred in ctx.predecessors(node_id, edge_type=EdgeType.IMPORT):
+ if isinstance(pred, SourceFile):
+ importers.append(pred.name)
+
+ imported = []
+ for succ in ctx.successors(node_id, edge_type=EdgeType.IMPORT):
+ if isinstance(succ, SourceFile):
+ imported.append(succ.name)
+
+ return {"file": file.name, "importers": importers, "imported": imported}
+
+ def analyze_codebase_structure(self) -> dict[str, Any]:
+ """
+ Analyze the overall structure of the codebase using CodebaseContext.
+
+ Returns:
+ A dictionary containing structural analysis results
+ """
+ ctx = self.context
+
+ # Count nodes by type
+ node_types: dict[str, int] = {}
+ for node in ctx.nodes:
+ node_type = type(node).__name__
+ node_types[node_type] = node_types.get(node_type, 0) + 1
+
+ edge_types: dict[str, int] = {}
+ for _, _, edge in ctx.edges:
+ edge_type = edge.type.name
+ edge_types[edge_type] = edge_types.get(edge_type, 0) + 1
+
+ directories = {}
+ for path, directory in ctx.directories.items():
+ directories[str(path)] = {
+ "files": len([
+ item for item in directory.items if isinstance(item, SourceFile)
+ ]),
+ "subdirectories": len([
+ item for item in directory.items if isinstance(item, Directory)
+ ]),
+ }
+
+ return {
+ "node_types": node_types,
+ "edge_types": edge_types,
+ "directories": directories,
+ }
+
+ def get_symbol_dependencies(self, symbol_name: str) -> dict[str, list[str]]:
+ """
+ Get direct dependencies of a symbol.
+
+ Args:
+ symbol_name: Name of the symbol to analyze
+
+ Returns:
+ A dictionary mapping dependency types to lists of symbol names
+ """
+ symbol = self.find_symbol_by_name(symbol_name)
+ if symbol is None:
+ return {"error": [f"Symbol not found: {symbol_name}"]}
+
+ dependencies: dict[str, list[str]] = {
+ "imports": [],
+ "functions": [],
+ "classes": [],
+ "variables": [],
+ }
+
+ # Process dependencies based on symbol type
+ if hasattr(symbol, "dependencies"):
+ for dep in symbol.dependencies:
+ if isinstance(dep, Import):
+ if dep.imported_symbol:
+ dependencies["imports"].append(dep.imported_symbol.name)
+ elif isinstance(dep, Symbol):
+ if dep.symbol_type == SymbolType.Function:
+ dependencies["functions"].append(dep.name)
+ elif dep.symbol_type == SymbolType.Class:
+ dependencies["classes"].append(dep.name)
+ elif dep.symbol_type == SymbolType.GlobalVar:
+ dependencies["variables"].append(dep.name)
+
+ return dependencies
+
+ def analyze_errors(self) -> dict[str, list[dict[str, Any]]]:
+ """
+ Analyze the codebase for errors.
+
+ Returns:
+ A dictionary mapping file paths to lists of errors
+ """
+ return self.error_analyzer.analyze_codebase()
+
+ def get_function_error_context(self, function_name: str) -> dict[str, Any]:
+ """
+ Get detailed error context for a specific function.
+
+ Args:
+ function_name: The name of the function to analyze
+
+ Returns:
+ A dictionary with detailed error context
+ """
+ return self.error_analyzer.get_function_error_context(function_name)
+
+ def get_file_error_context(self, file_path: str) -> dict[str, Any]:
+ """
+ Get detailed error context for a specific file.
+
+ Args:
+ file_path: The path of the file to analyze
+
+ Returns:
+ A dictionary with detailed error context
+ """
+ return self.error_analyzer.get_file_error_context(file_path)
+
+ def get_error_context(self, error: CodeError) -> dict[str, Any]:
+ """
+ Get detailed context information for an error.
+
+ Args:
+ error: The error to get context for
+
+ Returns:
+ A dictionary with detailed context information
+ """
+ return self.error_analyzer.get_error_context(error)
+
+ def convert_args_to_kwargs(self) -> None:
+ """
+ Convert all function call arguments to keyword arguments.
+ """
+ convert_all_calls_to_kwargs(self.codebase)
+
+ def visualize_module_dependencies(self) -> None:
+ """
+ Visualize module dependencies in the codebase.
+ """
+ module_dependencies_run(self.codebase)
+
+ def generate_mdx_documentation(self, class_name: str) -> str:
+ """
+ Generate MDX documentation for a class.
+
+ Args:
+ class_name: Name of the class to document
+
+ Returns:
+ MDX documentation as a string
+ """
+ for cls in self.codebase.classes:
+ if cls.name == class_name:
+ return render_mdx_page_for_class(cls)
+ return f"Class not found: {class_name}"
+
+ def print_symbol_attribution(self) -> None:
+ """
+ Print attribution information for symbols in the codebase.
+ """
+ print_symbol_attribution(self.codebase)
+
+ def get_extended_symbol_context(self, symbol_name: str, degree: int = 2) -> Dict[str, List[str]]:
+ """
+ Get extended context (dependencies and usages) for a symbol.
+
+ Args:
+ symbol_name: Name of the symbol to analyze
+ degree: How many levels deep to collect dependencies and usages
+
+ Returns:
+ A dictionary containing dependencies and usages
+ """
+ symbol = self.find_symbol_by_name(symbol_name)
+ if symbol:
+ dependencies, usages = get_extended_context(symbol, degree)
+ return {
+ "dependencies": [dep.name for dep in dependencies],
+ "usages": [usage.name for usage in usages]
+ }
+ return {"dependencies": [], "usages": []}
+
+ def get_file_imports(self, file_path: str) -> List[str]:
+ """
+ Get all imports in a file.
+
+ Args:
+ file_path: Path to the file to analyze
+
+ Returns:
+ A list of import statements
+ """
+ file = self.find_file_by_path(file_path)
+ if file and hasattr(file, "imports"):
+ return [imp.source for imp in file.imports]
+ return []
+
+ def get_file_exports(self, file_path: str) -> List[str]:
+ """
+ Get all exports from a file.
+
+ Args:
+ file_path: Path to the file to analyze
+
+ Returns:
+ A list of exported symbol names
+ """
+ file = self.find_file_by_path(file_path)
+ if not file:
+ return []
+
+ exports = []
+ for symbol in self.codebase.symbols:
+ if hasattr(symbol, "file") and symbol.file == file:
+ exports.append(symbol.name)
+
+ return exports
+
+ def analyze_complexity(self, file_path: str = None) -> Dict[str, Any]:
+ """
+ Analyze code complexity metrics for the codebase or a specific file.
+
+ Args:
+ file_path: Optional path to a specific file to analyze
+
+ Returns:
+ A dictionary containing complexity metrics
+ """
+ files_to_analyze = []
+ if file_path:
+ file = self.find_file_by_path(file_path)
+ if file:
+ files_to_analyze = [file]
+ else:
+ return {"error": f"File not found: {file_path}"}
+ else:
+ files_to_analyze = self.codebase.files
+
+ # Calculate complexity metrics
+ results = {
+ "cyclomatic_complexity": {
+ "total": 0,
+ "average": 0,
+ "max": 0,
+ "max_file": "",
+ "max_function": "",
+ "by_file": {}
+ },
+ "halstead_complexity": {
+ "total": 0,
+ "average": 0,
+ "max": 0,
+ "max_file": "",
+ "by_file": {}
+ },
+ "maintainability_index": {
+ "total": 0,
+ "average": 0,
+ "min": 100,
+ "min_file": "",
+ "by_file": {}
+ },
+ "line_metrics": {
+ "total_loc": 0,
+ "total_lloc": 0,
+ "total_sloc": 0,
+ "total_comments": 0,
+ "comment_ratio": 0,
+ "by_file": {}
+ }
+ }
+
+ # Process each file
+ for file in files_to_analyze:
+ # Skip non-Python files
+ if not file.name.endswith(".py"):
+ continue
+
+ file_path = file.name
+ file_content = file.content
+
+ # Calculate cyclomatic complexity
+ cc_total = 0
+ cc_max = 0
+ cc_max_function = ""
+
+ # Count decision points (if, for, while, etc.)
+ for func in file.functions:
+ func_cc = 1 # Base complexity
+
+ # Count control structures
+ for node in func.ast_node.body:
+ if isinstance(node, (ast.If, ast.For, ast.While, ast.Try)):
+ func_cc += 1
+
+ # Count logical operators in conditions
+ if isinstance(node, ast.If) and isinstance(node.test, ast.BoolOp):
+ func_cc += len(node.test.values) - 1
+
+ cc_total += func_cc
+ if func_cc > cc_max:
+ cc_max = func_cc
+ cc_max_function = func.name
+
+ # Update cyclomatic complexity metrics
+ results["cyclomatic_complexity"]["by_file"][file_path] = {
+ "total": cc_total,
+ "average": cc_total / len(file.functions) if file.functions else 0,
+ "max": cc_max,
+ "max_function": cc_max_function
+ }
+
+ results["cyclomatic_complexity"]["total"] += cc_total
+ if cc_max > results["cyclomatic_complexity"]["max"]:
+ results["cyclomatic_complexity"]["max"] = cc_max
+ results["cyclomatic_complexity"]["max_file"] = file_path
+ results["cyclomatic_complexity"]["max_function"] = cc_max_function
+
+ # Calculate line metrics
+ loc = len(file_content.splitlines())
+ lloc = sum(1 for line in file_content.splitlines() if line.strip() and not line.strip().startswith("#"))
+ sloc = sum(1 for line in file_content.splitlines() if line.strip())
+ comments = sum(1 for line in file_content.splitlines() if line.strip().startswith("#"))
+
+ results["line_metrics"]["by_file"][file_path] = {
+ "loc": loc,
+ "lloc": lloc,
+ "sloc": sloc,
+ "comments": comments,
+ "comment_ratio": comments / loc if loc else 0
+ }
+
+ results["line_metrics"]["total_loc"] += loc
+ results["line_metrics"]["total_lloc"] += lloc
+ results["line_metrics"]["total_sloc"] += sloc
+ results["line_metrics"]["total_comments"] += comments
+
+ # Simple Halstead complexity approximation
+ operators = len(re.findall(r'[\+\-\*/=<>!&|^~]', file_content))
+ operands = len(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', file_content))
+
+ n1 = len(set(re.findall(r'[\+\-\*/=<>!&|^~]', file_content)))
+ n2 = len(set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', file_content)))
+
+ N = operators + operands
+ n = n1 + n2
+
+ # Calculate Halstead metrics
+ if n1 > 0 and n2 > 0:
+ volume = N * math.log2(n)
+ difficulty = (n1 / 2) * (operands / n2)
+ effort = volume * difficulty
+ else:
+ volume = 0
+ difficulty = 0
+ effort = 0
+
+ results["halstead_complexity"]["by_file"][file_path] = {
+ "volume": volume,
+ "difficulty": difficulty,
+ "effort": effort
+ }
+
+ results["halstead_complexity"]["total"] += effort
+ if effort > results["halstead_complexity"]["max"]:
+ results["halstead_complexity"]["max"] = effort
+ results["halstead_complexity"]["max_file"] = file_path
+
+ # Calculate maintainability index
+ if lloc > 0:
+ mi = 171 - 5.2 * math.log(volume) - 0.23 * cc_total - 16.2 * math.log(lloc)
+ mi = max(0, min(100, mi))
+ else:
+ mi = 100
+
+ results["maintainability_index"]["by_file"][file_path] = mi
+ results["maintainability_index"]["total"] += mi
+
+ if mi < results["maintainability_index"]["min"]:
+ results["maintainability_index"]["min"] = mi
+ results["maintainability_index"]["min_file"] = file_path
+
+ # Calculate averages
+ num_files = len(results["cyclomatic_complexity"]["by_file"])
+ if num_files > 0:
+ results["cyclomatic_complexity"]["average"] = results["cyclomatic_complexity"]["total"] / num_files
+ results["halstead_complexity"]["average"] = results["halstead_complexity"]["total"] / num_files
+ results["maintainability_index"]["average"] = results["maintainability_index"]["total"] / num_files
+
+ total_loc = results["line_metrics"]["total_loc"]
+ if total_loc > 0:
+ results["line_metrics"]["comment_ratio"] = results["line_metrics"]["total_comments"] / total_loc
+
+ return results
+
+ def find_central_files(self) -> List[Dict[str, Any]]:
+ """
+ Find the most central files in the codebase based on dependency analysis.
+
+ Returns:
+ A list of dictionaries containing file information and centrality metrics
+ """
+ G = self.get_dependency_graph()
+
+ # Calculate centrality metrics
+ degree_centrality = nx.degree_centrality(G)
+ betweenness_centrality = nx.betweenness_centrality(G)
+ closeness_centrality = nx.closeness_centrality(G)
+
+ # Combine metrics
+ centrality = {}
+ for node in G.nodes():
+ centrality[node] = {
+ "file": node,
+ "degree": degree_centrality.get(node, 0),
+ "betweenness": betweenness_centrality.get(node, 0),
+ "closeness": closeness_centrality.get(node, 0),
+ "combined": (
+ degree_centrality.get(node, 0) +
+ betweenness_centrality.get(node, 0) +
+ closeness_centrality.get(node, 0)
+ ) / 3
+ }
+
+ # Sort by combined centrality
+ sorted_centrality = sorted(
+ centrality.values(),
+ key=lambda x: x["combined"],
+ reverse=True
+ )
+
+ return sorted_centrality[:10] # Return top 10 most central files
+
+
+# Request models for API endpoints
+class RepoRequest(BaseModel):
+ """Request model for repository analysis."""
+
+ repo_url: str
+
+
+class SymbolRequest(BaseModel):
+ """Request model for symbol analysis."""
+
+ repo_url: str
+ symbol_name: str
+
+
+class FileRequest(BaseModel):
+ """Request model for file analysis."""
+
+ repo_url: str
+ file_path: str
+
+
+class FunctionRequest(BaseModel):
+ """Request model for function analysis."""
+
+ repo_url: str
+ function_name: str
+
+
+class ErrorRequest(BaseModel):
+ """Request model for error analysis."""
+
+ repo_url: str
+ file_path: str | None = None
+ function_name: str | None = None
+
+
+class ComplexityRequest(BaseModel):
+ """Request model for complexity analysis."""
+
+ repo_url: str
+ file_path: str | None = None
+
+
+class DocumentationRequest(BaseModel):
+ """Request model for documentation generation."""
+
+ repo_url: str
+ class_name: str | None = None
+
+
+# API endpoints
+@app.post("/analyze_repo")
+async def analyze_repo(request: RepoRequest) -> dict[str, Any]:
+ """
+ Analyze a repository and return various metrics.
Args:
- repo_path: Path to the git repository
+ request: The repository request containing the repo URL
Returns:
- Dictionary with month-year as key and number of commits as value
+ A dictionary of analysis results
+ """
+ repo_url = request.repo_url
+
+ try:
+ codebase = Codebase.from_repo(repo_url)
+ analyzer = CodeAnalyzer(codebase)
+
+ # Get import analysis
+ import_analysis = analyzer.analyze_imports()
+
+ # Get structure analysis
+ structure_analysis = analyzer.analyze_codebase_structure()
+
+ # Get error analysis
+ error_analysis = analyzer.analyze_errors()
+
+ # Combine all results
+ results = {
+ "repo_url": repo_url,
+ "num_files": len(codebase.files),
+ "num_functions": len(codebase.functions),
+ "num_classes": len(codebase.classes),
+ "import_analysis": import_analysis,
+ "structure_analysis": structure_analysis,
+ "error_analysis": error_analysis
+ }
+
+ return results
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"Error analyzing repository: {str(e)}") from e
+
+
+@app.post("/analyze_symbol")
+async def analyze_symbol(request: SymbolRequest) -> dict[str, Any]:
"""
- end_date = datetime.now()
- start_date = end_date - timedelta(days=365)
+ Analyze a symbol and return detailed information.
- date_format = "%Y-%m-%d"
- since_date = start_date.strftime(date_format)
- until_date = end_date.strftime(date_format)
- repo_path = "https://github.com/" + repo_path
+ Args:
+ request: The symbol request containing the repo URL and symbol name
+
+ Returns:
+ A dictionary of analysis results
+ """
+ repo_url = request.repo_url
+ symbol_name = request.symbol_name
try:
- original_dir = os.getcwd()
-
- with tempfile.TemporaryDirectory() as temp_dir:
- subprocess.run(["git", "clone", repo_path, temp_dir], check=True)
- os.chdir(temp_dir)
-
- cmd = [
- "git",
- "log",
- f"--since={since_date}",
- f"--until={until_date}",
- "--format=%aI",
- ]
-
- result = subprocess.run(cmd, capture_output=True, text=True, check=True)
- commit_dates = result.stdout.strip().split("\n")
-
- monthly_counts = {}
- current_date = start_date
- while current_date <= end_date:
- month_key = current_date.strftime("%Y-%m")
- monthly_counts[month_key] = 0
- current_date = (
- current_date.replace(day=1) + timedelta(days=32)
- ).replace(day=1)
-
- for date_str in commit_dates:
- if date_str: # Skip empty lines
- commit_date = datetime.fromisoformat(date_str.strip())
- month_key = commit_date.strftime("%Y-%m")
- if month_key in monthly_counts:
- monthly_counts[month_key] += 1
-
- os.chdir(original_dir)
- return dict(sorted(monthly_counts.items()))
-
- except subprocess.CalledProcessError as e:
- print(f"Error executing git command: {e}")
- return {}
+ codebase = Codebase.from_repo(repo_url)
+ analyzer = CodeAnalyzer(codebase)
+
+ # Get symbol context
+ symbol_context = analyzer.get_context_for_symbol(symbol_name)
+
+ # Get symbol dependencies
+ dependencies = analyzer.get_symbol_dependencies(symbol_name)
+
+ # Get symbol attribution
+ attribution = analyzer.get_symbol_attribution(symbol_name)
+
+ return {
+ "symbol_name": symbol_name,
+ "context": symbol_context,
+ "dependencies": dependencies,
+ "attribution": attribution,
+ }
except Exception as e:
- print(f"Error processing git commits: {e}")
- return {}
- finally:
- try:
- os.chdir(original_dir)
- except:
- pass
-
-
-def calculate_cyclomatic_complexity(function):
- def analyze_statement(statement):
- complexity = 0
-
- if isinstance(statement, IfBlockStatement):
- complexity += 1
- if hasattr(statement, "elif_statements"):
- complexity += len(statement.elif_statements)
-
- elif isinstance(statement, (ForLoopStatement, WhileStatement)):
- complexity += 1
-
- elif isinstance(statement, TryCatchStatement):
- complexity += len(getattr(statement, "except_blocks", []))
-
- if hasattr(statement, "condition") and isinstance(statement.condition, str):
- complexity += statement.condition.count(
- " and "
- ) + statement.condition.count(" or ")
-
- if hasattr(statement, "nested_code_blocks"):
- for block in statement.nested_code_blocks:
- complexity += analyze_block(block)
-
- return complexity
-
- def analyze_block(block):
- if not block or not hasattr(block, "statements"):
- return 0
- return sum(analyze_statement(stmt) for stmt in block.statements)
-
- return (
- 1 + analyze_block(function.code_block) if hasattr(function, "code_block") else 1
- )
-
-
-def cc_rank(complexity):
- if complexity < 0:
- raise ValueError("Complexity must be a non-negative value")
-
- ranks = [
- (1, 5, "A"),
- (6, 10, "B"),
- (11, 20, "C"),
- (21, 30, "D"),
- (31, 40, "E"),
- (41, float("inf"), "F"),
- ]
- for low, high, rank in ranks:
- if low <= complexity <= high:
- return rank
- return "F"
-
-
-def calculate_doi(cls):
- """Calculate the depth of inheritance for a given class."""
- return len(cls.superclasses)
-
-
-def get_operators_and_operands(function):
- operators = []
- operands = []
-
- for statement in function.code_block.statements:
- for call in statement.function_calls:
- operators.append(call.name)
- for arg in call.args:
- operands.append(arg.source)
-
- if hasattr(statement, "expressions"):
- for expr in statement.expressions:
- if isinstance(expr, BinaryExpression):
- operators.extend([op.source for op in expr.operators])
- operands.extend([elem.source for elem in expr.elements])
- elif isinstance(expr, UnaryExpression):
- operators.append(expr.ts_node.type)
- operands.append(expr.argument.source)
- elif isinstance(expr, ComparisonExpression):
- operators.extend([op.source for op in expr.operators])
- operands.extend([elem.source for elem in expr.elements])
-
- if hasattr(statement, "expression"):
- expr = statement.expression
- if isinstance(expr, BinaryExpression):
- operators.extend([op.source for op in expr.operators])
- operands.extend([elem.source for elem in expr.elements])
- elif isinstance(expr, UnaryExpression):
- operators.append(expr.ts_node.type)
- operands.append(expr.argument.source)
- elif isinstance(expr, ComparisonExpression):
- operators.extend([op.source for op in expr.operators])
- operands.extend([elem.source for elem in expr.elements])
-
- return operators, operands
-
-
-def calculate_halstead_volume(operators, operands):
- n1 = len(set(operators))
- n2 = len(set(operands))
-
- N1 = len(operators)
- N2 = len(operands)
-
- N = N1 + N2
- n = n1 + n2
-
- if n > 0:
- volume = N * math.log2(n)
- return volume, N1, N2, n1, n2
- return 0, N1, N2, n1, n2
-
-
-def count_lines(source: str):
- """Count different types of lines in source code."""
- if not source.strip():
- return 0, 0, 0, 0
-
- lines = [line.strip() for line in source.splitlines()]
- loc = len(lines)
- sloc = len([line for line in lines if line])
-
- in_multiline = False
- comments = 0
- code_lines = []
-
- i = 0
- while i < len(lines):
- line = lines[i]
- code_part = line
- if not in_multiline and "#" in line:
- comment_start = line.find("#")
- if not re.search(r'["\'].*#.*["\']', line[:comment_start]):
- code_part = line[:comment_start].strip()
- if line[comment_start:].strip():
- comments += 1
-
- if ('"""' in line or "'''" in line) and not (
- line.count('"""') % 2 == 0 or line.count("'''") % 2 == 0
- ):
- if in_multiline:
- in_multiline = False
- comments += 1
- else:
- in_multiline = True
- comments += 1
- if line.strip().startswith('"""') or line.strip().startswith("'''"):
- code_part = ""
- elif in_multiline:
- comments += 1
- code_part = ""
- elif line.strip().startswith("#"):
- comments += 1
- code_part = ""
+ raise HTTPException(
+ status_code=500, detail=f"Error analyzing symbol: {e!s}"
+ ) from e
+
+
+@app.post("/analyze_file")
+async def analyze_file(request: FileRequest) -> dict[str, Any]:
+ """
+ Analyze a file and return detailed information.
+
+ Args:
+ request: The file request containing the repo URL and file path
+
+ Returns:
+ A dictionary of analysis results
+ """
+ repo_url = request.repo_url
+ file_path = request.file_path
- if code_part.strip():
- code_lines.append(code_part)
+ try:
+ codebase = Codebase.from_repo(repo_url)
+ analyzer = CodeAnalyzer(codebase)
- i += 1
+ # Get file summary
+ file_summary = analyzer.get_file_summary(file_path)
- lloc = 0
- continued_line = False
- for line in code_lines:
- if continued_line:
- if not any(line.rstrip().endswith(c) for c in ("\\", ",", "{", "[", "(")):
- continued_line = False
- continue
+ # Get file dependencies
+ file_dependencies = analyzer.get_file_dependencies(file_path)
- lloc += len([stmt for stmt in line.split(";") if stmt.strip()])
+ # Get file error context
+ file_error_context = analyzer.get_file_error_context(file_path)
- if any(line.rstrip().endswith(c) for c in ("\\", ",", "{", "[", "(")):
- continued_line = True
+ return {
+ "file_path": file_path,
+ "summary": file_summary,
+ "dependencies": file_dependencies,
+ "error_context": file_error_context,
+ }
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail=f"Error analyzing file: {e!s}"
+ ) from e
- return loc, lloc, sloc, comments
+@app.post("/analyze_function")
+async def analyze_function(request: FunctionRequest) -> dict[str, Any]:
+ """
+ Analyze a function and return detailed information.
-def calculate_maintainability_index(
- halstead_volume: float, cyclomatic_complexity: float, loc: int
-) -> int:
- """Calculate the normalized maintainability index for a given function."""
- if loc <= 0:
- return 100
+ Args:
+ request: The function request containing the repo URL and function name
+
+ Returns:
+ A dictionary of analysis results
+ """
+ repo_url = request.repo_url
+ function_name = request.function_name
try:
- raw_mi = (
- 171
- - 5.2 * math.log(max(1, halstead_volume))
- - 0.23 * cyclomatic_complexity
- - 16.2 * math.log(max(1, loc))
- )
- normalized_mi = max(0, min(100, raw_mi * 100 / 171))
- return int(normalized_mi)
- except (ValueError, TypeError):
- return 0
+ codebase = Codebase.from_repo(repo_url)
+ analyzer = CodeAnalyzer(codebase)
+ # Get function summary
+ function_summary = analyzer.get_function_summary(function_name)
-def get_maintainability_rank(mi_score: float) -> str:
- """Convert maintainability index score to a letter grade."""
- if mi_score >= 85:
- return "A"
- elif mi_score >= 65:
- return "B"
- elif mi_score >= 45:
- return "C"
- elif mi_score >= 25:
- return "D"
- else:
- return "F"
+ # Get function error context
+ function_error_context = analyzer.get_function_error_context(function_name)
+ return {
+ "function_name": function_name,
+ "summary": function_summary,
+ "error_context": function_error_context,
+ }
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail=f"Error analyzing function: {e!s}"
+ ) from e
-def get_github_repo_description(repo_url):
- api_url = f"https://api.github.com/repos/{repo_url}"
- response = requests.get(api_url)
+@app.post("/analyze_errors")
+async def analyze_errors(request: ErrorRequest) -> dict[str, Any]:
+ """
+ Analyze errors in a repository, file, or function.
- if response.status_code == 200:
- repo_data = response.json()
- return repo_data.get("description", "No description available")
- else:
- return ""
+ Args:
+ request: The error request containing the repo URL and optional file path or function name
+ Returns:
+ A dictionary of error analysis results
+ """
+ repo_url = request.repo_url
+ file_path = request.file_path
+ function_name = request.function_name
-class RepoRequest(BaseModel):
- repo_url: str
+ try:
+ codebase = Codebase.from_repo(repo_url)
+ analyzer = CodeAnalyzer(codebase)
+
+ if function_name:
+ # Analyze errors in a specific function
+ return analyzer.get_function_error_context(function_name)
+ elif file_path:
+ # Analyze errors in a specific file
+ return analyzer.get_file_error_context(file_path)
+ else:
+ # Analyze errors in the entire codebase
+ return {"error_analysis": analyzer.analyze_errors()}
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail=f"Error analyzing errors: {e!s}"
+ ) from e
-@fastapi_app.post("/analyze_repo")
-async def analyze_repo(request: RepoRequest) -> Dict[str, Any]:
- """Analyze a repository and return comprehensive metrics."""
+@app.post("/analyze_complexity")
+async def analyze_complexity(request: ComplexityRequest) -> dict[str, Any]:
+ """
+ Analyze code complexity metrics for a repository or specific file.
+
+ Args:
+ request: The complexity request containing the repo URL and optional file path
+
+ Returns:
+ A dictionary of complexity analysis results
+ """
repo_url = request.repo_url
- codebase = Codebase.from_repo(repo_url)
-
- num_files = len(codebase.files(extensions="*"))
- num_functions = len(codebase.functions)
- num_classes = len(codebase.classes)
-
- total_loc = total_lloc = total_sloc = total_comments = 0
- total_complexity = 0
- total_volume = 0
- total_mi = 0
- total_doi = 0
-
- monthly_commits = get_monthly_commits(repo_url)
- print(monthly_commits)
-
- for file in codebase.files:
- loc, lloc, sloc, comments = count_lines(file.source)
- total_loc += loc
- total_lloc += lloc
- total_sloc += sloc
- total_comments += comments
-
- callables = codebase.functions + [m for c in codebase.classes for m in c.methods]
-
- num_callables = 0
- for func in callables:
- if not hasattr(func, "code_block"):
- continue
-
- complexity = calculate_cyclomatic_complexity(func)
- operators, operands = get_operators_and_operands(func)
- volume, _, _, _, _ = calculate_halstead_volume(operators, operands)
- loc = len(func.code_block.source.splitlines())
- mi_score = calculate_maintainability_index(volume, complexity, loc)
-
- total_complexity += complexity
- total_volume += volume
- total_mi += mi_score
- num_callables += 1
-
- for cls in codebase.classes:
- doi = calculate_doi(cls)
- total_doi += doi
-
- desc = get_github_repo_description(repo_url)
-
- results = {
- "repo_url": repo_url,
- "line_metrics": {
- "total": {
- "loc": total_loc,
- "lloc": total_lloc,
- "sloc": total_sloc,
- "comments": total_comments,
- "comment_density": (total_comments / total_loc * 100)
- if total_loc > 0
- else 0,
- },
- },
- "cyclomatic_complexity": {
- "average": total_complexity if num_callables > 0 else 0,
- },
- "depth_of_inheritance": {
- "average": total_doi / len(codebase.classes) if codebase.classes else 0,
- },
- "halstead_metrics": {
- "total_volume": int(total_volume),
- "average_volume": int(total_volume / num_callables)
- if num_callables > 0
- else 0,
- },
- "maintainability_index": {
- "average": int(total_mi / num_callables) if num_callables > 0 else 0,
- },
- "description": desc,
- "num_files": num_files,
- "num_functions": num_functions,
- "num_classes": num_classes,
- "monthly_commits": monthly_commits,
- }
-
- return results
-
-
-@app.function(image=image)
-@modal.asgi_app()
-def fastapi_modal_app():
- return fastapi_app
+ file_path = request.file_path
+
+ try:
+ codebase = Codebase.from_repo(repo_url)
+ analyzer = CodeAnalyzer(codebase)
+
+ # Analyze complexity
+ complexity_results = analyzer.analyze_complexity(file_path)
+
+ return {
+ "repo_url": repo_url,
+ "file_path": file_path,
+ "complexity_analysis": complexity_results
+ }
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail=f"Error analyzing complexity: {e!s}"
+ ) from e
+
+
+@app.post("/generate_documentation")
+async def generate_documentation(request: DocumentationRequest) -> dict[str, Any]:
+ """
+ Generate documentation for a class or the entire codebase.
+
+ Args:
+ request: The documentation request containing the repo URL and optional class name
+
+ Returns:
+ A dictionary containing the generated documentation
+ """
+ repo_url = request.repo_url
+ class_name = request.class_name
+
+ try:
+ codebase = Codebase.from_repo(repo_url)
+ analyzer = CodeAnalyzer(codebase)
+
+ if class_name:
+ # Generate documentation for a specific class
+ mdx_doc = analyzer.generate_mdx_documentation(class_name)
+ return {
+ "class_name": class_name,
+ "documentation": mdx_doc
+ }
+ else:
+ # Generate documentation for all functions
+ analyzer.document_functions()
+ return {"message": "Documentation generated for all functions"}
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail=f"Error generating documentation: {e!s}"
+ ) from e
if __name__ == "__main__":
- app.deploy("analytics-app")
+ # Run the FastAPI app locally with uvicorn
+ uvicorn.run(app, host="127.0.0.1", port=8000)
diff --git a/codegen-on-oss/codegen_on_oss/analysis/codebase_context.py b/codegen-on-oss/codegen_on_oss/analysis/codebase_context.py
index 5c0fd47dd..c092356b7 100644
--- a/codegen-on-oss/codegen_on_oss/analysis/codebase_context.py
+++ b/codegen-on-oss/codegen_on_oss/analysis/codebase_context.py
@@ -121,7 +121,6 @@ class CodebaseContext:
dependency_manager: DependencyManager | None
language_engine: LanguageEngine | None
_computing = False
- _graph: PyDiGraph[Importable, Edge]
filepath_idx: dict[str, NodeId]
_ext_module_idx: dict[str, NodeId]
flags: Flags
@@ -143,8 +142,6 @@ def __init__(
from codegen.sdk.core.parser import Parser
self.progress = progress or StubProgress()
- self.__graph = PyDiGraph()
- self.__graph_ready = False
self.filepath_idx = {}
self._ext_module_idx = {}
self.generation = 0
diff --git a/codegen-on-oss/codegen_on_oss/analysis/enhanced_type_detection.py b/codegen-on-oss/codegen_on_oss/analysis/enhanced_type_detection.py
new file mode 100644
index 000000000..839b329b3
--- /dev/null
+++ b/codegen-on-oss/codegen_on_oss/analysis/enhanced_type_detection.py
@@ -0,0 +1,521 @@
+"""
+Enhanced Type Detection Module for Error Context Analysis
+
+This module provides more robust type detection and analysis capabilities
+for the ErrorContextAnalyzer. It uses AST analysis and type inference
+to detect potential type errors in code.
+"""
+
+import ast
+import inspect
+from typing import Any, Dict, List, Optional, Set, Tuple, Union
+
+from codegen.sdk.core.function import Function
+from codegen_on_oss.analysis.error_context import CodeError, ErrorType, ErrorSeverity
+
+
+class TypeAnalyzer:
+ """
+ Analyzes code for type-related errors using AST analysis and type inference.
+ """
+
+ def __init__(self):
+ """Initialize the TypeAnalyzer."""
+ # Map of known Python types
+ self.python_types = {
+ 'str': str,
+ 'int': int,
+ 'float': float,
+ 'bool': bool,
+ 'list': list,
+ 'dict': dict,
+ 'tuple': tuple,
+ 'set': set,
+ 'None': type(None),
+ }
+
+ # Map of compatible binary operations
+ self.compatible_ops = {
+ ast.Add: {
+ str: [str],
+ int: [int, float],
+ float: [int, float],
+ list: [list],
+ tuple: [tuple],
+ },
+ ast.Sub: {
+ int: [int, float],
+ float: [int, float],
+ set: [set],
+ },
+ ast.Mult: {
+ int: [int, float, str, list, tuple],
+ float: [int, float],
+ str: [int],
+ list: [int],
+ tuple: [int],
+ },
+ ast.Div: {
+ int: [int, float],
+ float: [int, float],
+ },
+ # Add more operations as needed
+ }
+
+ def analyze_function(self, function: Function) -> List[CodeError]:
+ """
+ Analyze a function for type-related errors.
+
+ Args:
+ function: The function to analyze
+
+ Returns:
+ A list of type-related errors
+ """
+ errors = []
+
+ if not hasattr(function, "code_block") or not hasattr(function.code_block, "source"):
+ return errors
+
+ try:
+ # Parse the AST
+ tree = ast.parse(function.code_block.source)
+
+ # Track variable types based on assignments and annotations
+ variable_types = self._collect_variable_types(tree, function)
+
+ # Check for type mismatches
+ errors.extend(self._check_type_mismatches(tree, variable_types, function))
+
+ # Check for parameter type mismatches
+ errors.extend(self._check_parameter_types(tree, variable_types, function))
+
+ # Check for return type mismatches
+ errors.extend(self._check_return_types(tree, variable_types, function))
+
+ return errors
+ except SyntaxError:
+ # If we can't parse the AST, return no errors
+ return errors
+
+ def _collect_variable_types(self, tree: ast.AST, function: Function) -> Dict[str, Any]:
+ """
+ Collect variable types from assignments and annotations.
+
+ Args:
+ tree: The AST to analyze
+ function: The function being analyzed
+
+ Returns:
+ A dictionary mapping variable names to their types
+ """
+ variable_types = {}
+
+ # Add function parameters with type annotations
+ if hasattr(function, "parameters"):
+ for param in function.parameters:
+ if hasattr(param, "type_annotation") and param.type_annotation:
+ variable_types[param.name] = self._parse_type_annotation(param.type_annotation)
+
+ # First pass: collect type information from the AST
+ for node in ast.walk(tree):
+ if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
+ # Handle type annotations
+ variable_types[node.target.id] = self._get_type_from_annotation(node.annotation)
+ elif isinstance(node, ast.Assign):
+ # Infer types from assignments where possible
+ for target in node.targets:
+ if isinstance(target, ast.Name):
+ inferred_type = self._infer_type_from_value(node.value)
+ if inferred_type:
+ variable_types[target.id] = inferred_type
+
+ return variable_types
+
+ def _check_type_mismatches(self, tree: ast.AST, variable_types: Dict[str, Any], function: Function) -> List[CodeError]:
+ """
+ Check for type mismatches in binary operations.
+
+ Args:
+ tree: The AST to analyze
+ variable_types: Dictionary mapping variable names to their types
+ function: The function being analyzed
+
+ Returns:
+ A list of type-related errors
+ """
+ errors = []
+
+ for node in ast.walk(tree):
+ if isinstance(node, ast.BinOp):
+ # Check binary operations for type mismatches
+ left_type = self._get_expression_type(node.left, variable_types)
+ right_type = self._get_expression_type(node.right, variable_types)
+
+ if left_type and right_type and not self._are_types_compatible(left_type, right_type, node.op):
+ # Found potential type error
+ line_number = node.lineno
+ errors.append(CodeError(
+ error_type=ErrorType.TYPE_ERROR,
+ message=f"Potential type mismatch: {self._type_name(left_type)} {type(node.op).__name__} {self._type_name(right_type)}",
+ file_path=function.file.name if hasattr(function, "file") else None,
+ line_number=line_number,
+ severity=ErrorSeverity.HIGH,
+ symbol_name=function.name,
+ context_lines=self._get_context_lines(function, line_number),
+ suggested_fix=f"Ensure operands are of compatible types for {type(node.op).__name__} operation"
+ ))
+
+ return errors
+
+ def _check_parameter_types(self, tree: ast.AST, variable_types: Dict[str, Any], function: Function) -> List[CodeError]:
+ """
+ Check for parameter type mismatches in function calls.
+
+ Args:
+ tree: The AST to analyze
+ variable_types: Dictionary mapping variable names to their types
+ function: The function being analyzed
+
+ Returns:
+ A list of parameter-related errors
+ """
+ errors = []
+
+ # Get function calls
+ for node in ast.walk(tree):
+ if isinstance(node, ast.Call):
+ # Check if we're calling a function we know about
+ if isinstance(node.func, ast.Name) and node.func.id in variable_types:
+ # This is a simplification - in a real implementation, we would need to
+ # track function signatures and parameter types
+ pass
+
+ return errors
+
+ def _check_return_types(self, tree: ast.AST, variable_types: Dict[str, Any], function: Function) -> List[CodeError]:
+ """
+ Check for return type mismatches.
+
+ Args:
+ tree: The AST to analyze
+ variable_types: Dictionary mapping variable names to their types
+ function: The function being analyzed
+
+ Returns:
+ A list of return-related errors
+ """
+ errors = []
+
+ # Get the declared return type
+ declared_return_type = None
+ if hasattr(function, "return_type") and function.return_type:
+ declared_return_type = self._parse_type_annotation(function.return_type)
+
+ if not declared_return_type:
+ return errors
+
+ # Check return statements
+ for node in ast.walk(tree):
+ if isinstance(node, ast.Return) and node.value:
+ returned_type = self._get_expression_type(node.value, variable_types)
+
+ if returned_type and not self._is_return_type_compatible(returned_type, declared_return_type):
+ line_number = node.lineno
+ errors.append(CodeError(
+ error_type=ErrorType.TYPE_ERROR,
+ message=f"Return type mismatch: returning {self._type_name(returned_type)} but function declares {self._type_name(declared_return_type)}",
+ file_path=function.file.name if hasattr(function, "file") else None,
+ line_number=line_number,
+ severity=ErrorSeverity.HIGH,
+ symbol_name=function.name,
+ context_lines=self._get_context_lines(function, line_number),
+ suggested_fix=f"Ensure the return value matches the declared return type {self._type_name(declared_return_type)}"
+ ))
+
+ return errors
+
+ def _get_expression_type(self, node: ast.AST, variable_types: Dict[str, Any]) -> Optional[Any]:
+ """
+ Get the type of an expression.
+
+ Args:
+ node: The AST node representing the expression
+ variable_types: Dictionary mapping variable names to their types
+
+ Returns:
+ The type of the expression, or None if it cannot be determined
+ """
+ if isinstance(node, ast.Name):
+ # Variable reference
+ return variable_types.get(node.id)
+ elif isinstance(node, ast.Constant):
+ # Literal value
+ return type(node.value)
+ elif isinstance(node, ast.List):
+ # List literal
+ return list
+ elif isinstance(node, ast.Dict):
+ # Dict literal
+ return dict
+ elif isinstance(node, ast.Tuple):
+ # Tuple literal
+ return tuple
+ elif isinstance(node, ast.Set):
+ # Set literal
+ return set
+ elif isinstance(node, ast.BinOp):
+ # Binary operation
+ left_type = self._get_expression_type(node.left, variable_types)
+ right_type = self._get_expression_type(node.right, variable_types)
+
+ # Determine result type based on operation and operand types
+ # This is a simplification - in a real implementation, we would need more sophisticated type inference
+ if isinstance(node.op, ast.Add):
+ if left_type == str or right_type == str:
+ return str
+ elif left_type in (int, float) and right_type in (int, float):
+ return float if float in (left_type, right_type) else int
+ elif left_type == list and right_type == list:
+ return list
+ elif left_type == tuple and right_type == tuple:
+ return tuple
+
+ # Add more operation type inference as needed
+
+ # For other expression types, we can't determine the type
+ return None
+
+ def _are_types_compatible(self, left_type: Any, right_type: Any, op: ast.operator) -> bool:
+ """
+ Check if two types are compatible for a binary operation.
+
+ Args:
+ left_type: The type of the left operand
+ right_type: The type of the right operand
+ op: The binary operation
+
+ Returns:
+ True if the types are compatible, False otherwise
+ """
+ op_type = type(op)
+
+ if op_type in self.compatible_ops and left_type in self.compatible_ops[op_type]:
+ return right_type in self.compatible_ops[op_type][left_type]
+
+ return False
+
+ def _is_return_type_compatible(self, actual_type: Any, declared_type: Any) -> bool:
+ """
+ Check if a return type is compatible with the declared return type.
+
+ Args:
+ actual_type: The actual return type
+ declared_type: The declared return type
+
+ Returns:
+ True if the types are compatible, False otherwise
+ """
+ # This is a simplification - in a real implementation, we would need more sophisticated type compatibility checking
+ if actual_type == declared_type:
+ return True
+
+ # Handle numeric types
+ if declared_type == float and actual_type == int:
+ return True
+
+ # Handle None
+ if declared_type == type(None) and actual_type == type(None):
+ return True
+
+ # Handle Union types (simplified)
+ if isinstance(declared_type, tuple):
+ return actual_type in declared_type
+
+ return False
+
+ def _get_type_from_annotation(self, annotation: ast.AST) -> Optional[Any]:
+ """
+ Get a type from an annotation AST node.
+
+ Args:
+ annotation: The AST node representing the annotation
+
+ Returns:
+ The type, or None if it cannot be determined
+ """
+ if isinstance(annotation, ast.Name):
+ # Simple type name
+ return self.python_types.get(annotation.id)
+ elif isinstance(annotation, ast.Subscript):
+ # Generic type (e.g., List[int])
+ if isinstance(annotation.value, ast.Name):
+ if annotation.value.id == 'List':
+ return list
+ elif annotation.value.id == 'Dict':
+ return dict
+ elif annotation.value.id == 'Tuple':
+ return tuple
+ elif annotation.value.id == 'Set':
+ return set
+ elif annotation.value.id == 'Optional':
+ # For Optional[T], we return the inner type
+ return self._get_type_from_annotation(annotation.slice)
+ elif annotation.value.id == 'Union':
+ # For Union[T1, T2, ...], we return a tuple of types
+ if isinstance(annotation.slice, ast.Tuple):
+ types = [self._get_type_from_annotation(elt) for elt in annotation.slice.elts]
+ return tuple(t for t in types if t is not None)
+
+ return None
+
+ def _parse_type_annotation(self, type_annotation: str) -> Optional[Any]:
+ """
+ Parse a type annotation string.
+
+ Args:
+ type_annotation: The type annotation string
+
+ Returns:
+ The type, or None if it cannot be parsed
+ """
+ # This is a simplification - in a real implementation, we would need more sophisticated parsing
+ if type_annotation == 'str':
+ return str
+ elif type_annotation == 'int':
+ return int
+ elif type_annotation == 'float':
+ return float
+ elif type_annotation == 'bool':
+ return bool
+ elif type_annotation == 'list' or type_annotation.startswith('List['):
+ return list
+ elif type_annotation == 'dict' or type_annotation.startswith('Dict['):
+ return dict
+ elif type_annotation == 'tuple' or type_annotation.startswith('Tuple['):
+ return tuple
+ elif type_annotation == 'set' or type_annotation.startswith('Set['):
+ return set
+ elif type_annotation == 'None':
+ return type(None)
+ elif type_annotation.startswith('Optional['):
+ # Extract the inner type
+ inner_type = type_annotation[9:-1]
+ return self._parse_type_annotation(inner_type)
+ elif type_annotation.startswith('Union['):
+ # Extract the union types
+ union_types = type_annotation[6:-1].split(', ')
+ types = [self._parse_type_annotation(t) for t in union_types]
+ return tuple(t for t in types if t is not None)
+
+ return None
+
+ def _infer_type_from_value(self, node: ast.AST) -> Optional[Any]:
+ """
+ Infer the type of a value.
+
+ Args:
+ node: The AST node representing the value
+
+ Returns:
+ The inferred type, or None if it cannot be determined
+ """
+ if isinstance(node, ast.Constant):
+ return type(node.value)
+ elif isinstance(node, ast.List):
+ return list
+ elif isinstance(node, ast.Dict):
+ return dict
+ elif isinstance(node, ast.Tuple):
+ return tuple
+ elif isinstance(node, ast.Set):
+ return set
+ elif isinstance(node, ast.Call):
+ if isinstance(node.func, ast.Name):
+ # Function call - try to determine the return type
+ if node.func.id in self.python_types:
+ return self.python_types[node.func.id]
+
+ return None
+
+ def _get_context_lines(self, function: Function, line_number: int, context_size: int = 2) -> Dict[int, str]:
+ """
+ Get context lines around a specific line in a function.
+
+ Args:
+ function: The function containing the line
+ line_number: The line number to get context for
+ context_size: Number of lines before and after to include
+
+ Returns:
+ Dictionary mapping line numbers to line content
+ """
+ if not hasattr(function, "code_block") or not hasattr(function.code_block, "source"):
+ return {}
+
+ lines = function.code_block.source.splitlines()
+
+ # Adjust line_number to be relative to the function's code block
+ if hasattr(function, "line_number"):
+ relative_line = line_number - function.line_number
+ else:
+ relative_line = line_number
+
+ start_line = max(0, relative_line - context_size - 1)
+ end_line = min(len(lines), relative_line + context_size)
+
+ # Map the relative line numbers back to absolute line numbers
+ if hasattr(function, "line_number"):
+ return {i + function.line_number: lines[i] for i in range(start_line, end_line)}
+ else:
+ return {i + 1: lines[i] for i in range(start_line, end_line)}
+
+ def _type_name(self, type_obj: Any) -> str:
+ """
+ Get a human-readable name for a type.
+
+ Args:
+ type_obj: The type object
+
+ Returns:
+ A string representation of the type
+ """
+ if type_obj == str:
+ return "str"
+ elif type_obj == int:
+ return "int"
+ elif type_obj == float:
+ return "float"
+ elif type_obj == bool:
+ return "bool"
+ elif type_obj == list:
+ return "list"
+ elif type_obj == dict:
+ return "dict"
+ elif type_obj == tuple:
+ return "tuple"
+ elif type_obj == set:
+ return "set"
+ elif type_obj == type(None):
+ return "None"
+ elif isinstance(type_obj, tuple):
+ # Union type
+ return f"Union[{', '.join(self._type_name(t) for t in type_obj)}]"
+
+ return str(type_obj)
+
+
+# Example usage
+def analyze_function_types(function: Function) -> List[CodeError]:
+ """
+ Analyze a function for type-related errors.
+
+ Args:
+ function: The function to analyze
+
+ Returns:
+ A list of type-related errors
+ """
+ analyzer = TypeAnalyzer()
+ return analyzer.analyze_function(function)
+
diff --git a/codegen-on-oss/codegen_on_oss/analysis/error_context.py b/codegen-on-oss/codegen_on_oss/analysis/error_context.py
new file mode 100644
index 000000000..d74414287
--- /dev/null
+++ b/codegen-on-oss/codegen_on_oss/analysis/error_context.py
@@ -0,0 +1,818 @@
+"""
+Error Context Module for Codegen-on-OSS
+
+This module provides robust and dynamic error context analysis for code files and functions.
+It helps identify and contextualize errors in code, providing detailed information about
+the error location, type, and potential fixes.
+"""
+
+import ast
+import inspect
+import re
+import tokenize
+from io import StringIO
+from typing import Any, Dict, List, Optional, Set, Tuple, Union
+
+import networkx as nx
+from codegen.sdk.core.class_definition import Class
+from codegen.sdk.core.codebase import Codebase
+from codegen.sdk.core.file import SourceFile
+from codegen.sdk.core.function import Function
+from codegen.sdk.core.import_resolution import Import
+from codegen.sdk.core.symbol import Symbol
+from codegen.sdk.enums import EdgeType, SymbolType
+
+# Error types
+class ErrorType:
+ SYNTAX_ERROR = "syntax_error"
+ TYPE_ERROR = "type_error"
+ NAME_ERROR = "name_error"
+ IMPORT_ERROR = "import_error"
+ ATTRIBUTE_ERROR = "attribute_error"
+ PARAMETER_ERROR = "parameter_error"
+ CALL_ERROR = "call_error"
+ UNDEFINED_VARIABLE = "undefined_variable"
+ UNUSED_IMPORT = "unused_import"
+ UNUSED_VARIABLE = "unused_variable"
+ CIRCULAR_IMPORT = "circular_import"
+ CIRCULAR_DEPENDENCY = "circular_dependency"
+
+
+class ErrorSeverity:
+ CRITICAL = "critical"
+ HIGH = "high"
+ MEDIUM = "medium"
+ LOW = "low"
+ INFO = "info"
+
+
+class CodeError:
+ """Represents an error in code with context."""
+
+ def __init__(
+ self,
+ error_type: str,
+ message: str,
+ file_path: Optional[str] = None,
+ line_number: Optional[int] = None,
+ column: Optional[int] = None,
+ severity: str = ErrorSeverity.MEDIUM,
+ symbol_name: Optional[str] = None,
+ context_lines: Optional[Dict[int, str]] = None,
+ suggested_fix: Optional[str] = None,
+ ):
+ self.error_type = error_type
+ self.message = message
+ self.file_path = file_path
+ self.line_number = line_number
+ self.column = column
+ self.severity = severity
+ self.symbol_name = symbol_name
+ self.context_lines = context_lines or {}
+ self.suggested_fix = suggested_fix
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert the error to a dictionary representation."""
+ return {
+ "error_type": self.error_type,
+ "message": self.message,
+ "file_path": self.file_path,
+ "line_number": self.line_number,
+ "column": self.column,
+ "severity": self.severity,
+ "symbol_name": self.symbol_name,
+ "context_lines": self.context_lines,
+ "suggested_fix": self.suggested_fix,
+ }
+
+ def __str__(self) -> str:
+ """String representation of the error."""
+ location = f"{self.file_path}:{self.line_number}" if self.file_path and self.line_number else "Unknown location"
+ return f"{self.error_type.upper()} ({self.severity}): {self.message} at {location}"
+
+
+class ErrorContextAnalyzer:
+ """
+ Analyzes code for errors and provides rich context information.
+
+ This class is responsible for detecting various types of errors in code
+ and providing detailed context information to help understand and fix them.
+ """
+
+ def __init__(self, codebase: Codebase):
+ """
+ Initialize the ErrorContextAnalyzer with a codebase.
+
+ Args:
+ codebase: The Codebase object to analyze
+ """
+ self.codebase = codebase
+ self._call_graph = None
+ self._dependency_graph = None
+ self._import_graph = None
+
+ def get_context_lines(self, file_path: str, line_number: int, context_size: int = 3) -> Dict[int, str]:
+ """
+ Get context lines around a specific line in a file.
+
+ Args:
+ file_path: Path to the file
+ line_number: The line number to get context for
+ context_size: Number of lines before and after to include
+
+ Returns:
+ Dictionary mapping line numbers to line content
+ """
+ file = self.codebase.get_file(file_path)
+ if not file or not hasattr(file, "source"):
+ return {}
+
+ lines = file.source.splitlines()
+ start_line = max(0, line_number - context_size - 1)
+ end_line = min(len(lines), line_number + context_size)
+
+ return {i + 1: lines[i] for i in range(start_line, end_line)}
+
+ def build_call_graph(self) -> nx.DiGraph:
+ """
+ Build a call graph for the codebase.
+
+ Returns:
+ A directed graph representing function calls
+ """
+ if self._call_graph is not None:
+ return self._call_graph
+
+ G = nx.DiGraph()
+
+ # Add nodes for all functions
+ for func in self.codebase.functions:
+ G.add_node(func.name, type="function", function=func)
+
+ # Add edges for function calls
+ for func in self.codebase.functions:
+ if not hasattr(func, "function_calls"):
+ continue
+
+ for call in func.function_calls:
+ if call.name in G:
+ G.add_edge(func.name, call.name, type="call")
+
+ self._call_graph = G
+ return G
+
+ def build_dependency_graph(self) -> nx.DiGraph:
+ """
+ Build a dependency graph for the codebase.
+
+ Returns:
+ A directed graph representing symbol dependencies
+ """
+ if self._dependency_graph is not None:
+ return self._dependency_graph
+
+ G = nx.DiGraph()
+
+ # Add nodes for all symbols
+ for symbol in self.codebase.symbols:
+ G.add_node(symbol.name, type="symbol", symbol=symbol)
+
+ # Add edges for dependencies
+ for symbol in self.codebase.symbols:
+ if not hasattr(symbol, "dependencies"):
+ continue
+
+ for dep in symbol.dependencies:
+ if isinstance(dep, Symbol):
+ G.add_edge(symbol.name, dep.name, type="dependency")
+
+ self._dependency_graph = G
+ return G
+
+ def build_import_graph(self) -> nx.DiGraph:
+ """
+ Build an import graph for the codebase.
+
+ Returns:
+ A directed graph representing file imports
+ """
+ if self._import_graph is not None:
+ return self._import_graph
+
+ G = nx.DiGraph()
+
+ # Add nodes for all files
+ for file in self.codebase.files:
+ G.add_node(file.name, type="file", file=file)
+
+ # Add edges for imports
+ for file in self.codebase.files:
+ for imp in file.imports:
+ if imp.imported_symbol and hasattr(imp.imported_symbol, "file"):
+ imported_file = imp.imported_symbol.file
+ if imported_file and imported_file.name != file.name:
+ G.add_edge(file.name, imported_file.name, type="import")
+
+ self._import_graph = G
+ return G
+
+ def find_circular_imports(self) -> List[List[str]]:
+ """
+ Find circular imports in the codebase.
+
+ Returns:
+ A list of cycles, where each cycle is a list of file names
+ """
+ import_graph = self.build_import_graph()
+ return list(nx.simple_cycles(import_graph))
+
+ def find_circular_dependencies(self) -> List[List[str]]:
+ """
+ Find circular dependencies between symbols.
+
+ Returns:
+ A list of cycles, where each cycle is a list of symbol names
+ """
+ dependency_graph = self.build_dependency_graph()
+ return list(nx.simple_cycles(dependency_graph))
+
+ def analyze_function_parameters(self, function: Function) -> List[CodeError]:
+ """
+ Analyze function parameters for errors.
+
+ Args:
+ function: The function to analyze
+
+ Returns:
+ A list of parameter-related errors
+ """
+ errors = []
+
+ if not hasattr(function, "parameters") or not hasattr(function, "function_calls"):
+ return errors
+
+ # Check for parameter type mismatches
+ for param in function.parameters:
+ if not hasattr(param, "type_annotation") or not param.type_annotation:
+ continue
+
+ # Check if parameter is used with correct type
+ # This is a simplified check and would need more sophisticated type inference in practice
+ param_name = param.name
+ param_type = param.type_annotation
+
+ # Look for usage of this parameter in the function body
+ if hasattr(function, "code_block") and hasattr(function.code_block, "source"):
+ source = function.code_block.source
+
+ # Simple pattern matching for potential type errors
+ # This is a simplified approach and would need more sophisticated analysis in practice
+ if re.search(rf"\b{param_name}\s*\+\s*\d+\b", source) and "str" in param_type:
+ line_number = self._find_line_number(function.code_block.source, rf"\b{param_name}\s*\+\s*\d+\b")
+ errors.append(CodeError(
+ error_type=ErrorType.TYPE_ERROR,
+ message=f"Potential type error: adding integer to string parameter '{param_name}'",
+ file_path=function.file.name if hasattr(function, "file") else None,
+ line_number=line_number,
+ severity=ErrorSeverity.HIGH,
+ symbol_name=function.name,
+ context_lines=self.get_context_lines(function.file.name, line_number) if hasattr(function, "file") else None,
+ suggested_fix=f"Ensure '{param_name}' is converted to int before addition or use string concatenation"
+ ))
+
+ # Check for call parameter mismatches
+ call_graph = self.build_call_graph()
+ for call in function.function_calls:
+ called_func_name = call.name
+
+ # Find the called function
+ called_func = None
+ for func in self.codebase.functions:
+ if func.name == called_func_name:
+ called_func = func
+ break
+
+ if not called_func or not hasattr(called_func, "parameters"):
+ continue
+
+ # Check if number of arguments matches
+ if hasattr(call, "args") and len(call.args) != len(called_func.parameters):
+ # Find the line number of the call
+ line_number = self._find_line_number(function.code_block.source, rf"\b{called_func_name}\s*\(")
+
+ errors.append(CodeError(
+ error_type=ErrorType.PARAMETER_ERROR,
+ message=f"Function '{called_func_name}' called with {len(call.args)} arguments but expects {len(called_func.parameters)}",
+ file_path=function.file.name if hasattr(function, "file") else None,
+ line_number=line_number,
+ severity=ErrorSeverity.HIGH,
+ symbol_name=function.name,
+ context_lines=self.get_context_lines(function.file.name, line_number) if hasattr(function, "file") else None,
+ suggested_fix=f"Update call to provide {len(called_func.parameters)} arguments"
+ ))
+
+ return errors
+
+ def analyze_function_returns(self, function: Function) -> List[CodeError]:
+ """
+ Analyze function return statements for errors.
+
+ Args:
+ function: The function to analyze
+
+ Returns:
+ A list of return-related errors
+ """
+ errors = []
+
+ if not hasattr(function, "return_type") or not function.return_type:
+ return errors
+
+ if not hasattr(function, "return_statements") or not function.return_statements:
+ # Function has return type but no return statements
+ errors.append(CodeError(
+ error_type=ErrorType.TYPE_ERROR,
+ message=f"Function '{function.name}' has return type '{function.return_type}' but no return statements",
+ file_path=function.file.name if hasattr(function, "file") else None,
+ line_number=function.line_number if hasattr(function, "line_number") else None,
+ severity=ErrorSeverity.MEDIUM,
+ symbol_name=function.name,
+ context_lines=self.get_context_lines(function.file.name, function.line_number) if hasattr(function, "file") and hasattr(function, "line_number") else None,
+ suggested_fix=f"Add return statement or change return type to 'None'"
+ ))
+ return errors
+
+ # Check if return statements match the declared return type
+ return_type = function.return_type
+ for ret_stmt in function.return_statements:
+ # This is a simplified check and would need more sophisticated type inference in practice
+ if hasattr(ret_stmt, "expression") and hasattr(ret_stmt.expression, "source"):
+ expr_source = ret_stmt.expression.source
+
+ # Simple pattern matching for potential type errors
+ if "int" in return_type and re.search(r"\".*\"", expr_source):
+ line_number = self._find_line_number(function.code_block.source, rf"return\s+{re.escape(expr_source)}")
+ errors.append(CodeError(
+ error_type=ErrorType.TYPE_ERROR,
+ message=f"Function '{function.name}' returns string but declares return type '{return_type}'",
+ file_path=function.file.name if hasattr(function, "file") else None,
+ line_number=line_number,
+ severity=ErrorSeverity.HIGH,
+ symbol_name=function.name,
+ context_lines=self.get_context_lines(function.file.name, line_number) if hasattr(function, "file") else None,
+ suggested_fix=f"Convert return value to {return_type} or update return type annotation"
+ ))
+
+ return errors
+
+ def analyze_unused_imports(self, file: SourceFile) -> List[CodeError]:
+ """
+ Find unused imports in a file.
+
+ Args:
+ file: The file to analyze
+
+ Returns:
+ A list of unused import errors
+ """
+ errors = []
+
+ if not hasattr(file, "imports") or not hasattr(file, "symbols"):
+ return errors
+
+ # Get all imported symbols
+ imported_symbols = set()
+ for imp in file.imports:
+ if hasattr(imp, "imported_symbol") and imp.imported_symbol:
+ imported_symbols.add(imp.imported_symbol.name)
+
+ # Get all used symbols
+ used_symbols = set()
+ for symbol in file.symbols:
+ if hasattr(symbol, "dependencies"):
+ for dep in symbol.dependencies:
+ if isinstance(dep, Symbol):
+ used_symbols.add(dep.name)
+
+ # Find unused imports
+ unused_imports = imported_symbols - used_symbols
+ for unused in unused_imports:
+ # Find the import statement
+ for imp in file.imports:
+ if hasattr(imp, "imported_symbol") and imp.imported_symbol and imp.imported_symbol.name == unused:
+ errors.append(CodeError(
+ error_type=ErrorType.UNUSED_IMPORT,
+ message=f"Unused import: '{unused}'",
+ file_path=file.name,
+ line_number=imp.line_number if hasattr(imp, "line_number") else None,
+ severity=ErrorSeverity.LOW,
+ context_lines=self.get_context_lines(file.name, imp.line_number) if hasattr(imp, "line_number") else None,
+ suggested_fix=f"Remove unused import of '{unused}'"
+ ))
+
+ return errors
+
+ def analyze_undefined_variables(self, function: Function) -> List[CodeError]:
+ """
+ Find undefined variables in a function.
+
+ Args:
+ function: The function to analyze
+
+ Returns:
+ A list of undefined variable errors
+ """
+ errors = []
+
+ if not hasattr(function, "code_block") or not hasattr(function.code_block, "source"):
+ return errors
+
+ # Get parameter names
+ param_names = set()
+ if hasattr(function, "parameters"):
+ for param in function.parameters:
+ param_names.add(param.name)
+
+ # Parse the function body to find variable definitions and usages
+ try:
+ tree = ast.parse(function.code_block.source)
+
+ # Find all variable assignments
+ assigned_vars = set()
+ for node in ast.walk(tree):
+ if isinstance(node, ast.Assign):
+ for target in node.targets:
+ if isinstance(target, ast.Name):
+ assigned_vars.add(target.id)
+ elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
+ assigned_vars.add(node.target.id)
+
+ # Find all variable usages
+ for node in ast.walk(tree):
+ if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load):
+ var_name = node.id
+ if var_name not in assigned_vars and var_name not in param_names and not var_name.startswith("__"):
+ # This is a potential undefined variable
+ # Find the line number in the source code
+ line_number = node.lineno
+
+ errors.append(CodeError(
+ error_type=ErrorType.UNDEFINED_VARIABLE,
+ message=f"Potentially undefined variable: '{var_name}'",
+ file_path=function.file.name if hasattr(function, "file") else None,
+ line_number=line_number,
+ severity=ErrorSeverity.HIGH,
+ symbol_name=function.name,
+ context_lines=self.get_context_lines(function.file.name, line_number) if hasattr(function, "file") else None,
+ suggested_fix=f"Define '{var_name}' before use or check for typos"
+ ))
+ except SyntaxError:
+ # If there's a syntax error, we can't analyze the function body
+ pass
+
+ return errors
+
+ def analyze_function(self, function: Function) -> List[CodeError]:
+ """
+ Analyze a function for errors.
+
+ Args:
+ function: The function to analyze
+
+ Returns:
+ A list of errors found in the function
+ """
+ errors = []
+
+ # Analyze parameters
+ errors.extend(self.analyze_function_parameters(function))
+
+ # Analyze return statements
+ errors.extend(self.analyze_function_returns(function))
+
+ # Analyze undefined variables
+ errors.extend(self.analyze_undefined_variables(function))
+
+ return errors
+
+ def analyze_file(self, file: SourceFile) -> List[CodeError]:
+ """
+ Analyze a file for errors.
+
+ Args:
+ file: The file to analyze
+
+ Returns:
+ A list of errors found in the file
+ """
+ errors = []
+
+ # Analyze unused imports
+ errors.extend(self.analyze_unused_imports(file))
+
+ # Analyze syntax errors
+ if hasattr(file, "source"):
+ try:
+ ast.parse(file.source)
+ except SyntaxError as e:
+ errors.append(CodeError(
+ error_type=ErrorType.SYNTAX_ERROR,
+ message=f"Syntax error: {str(e)}",
+ file_path=file.name,
+ line_number=e.lineno,
+ column=e.offset,
+ severity=ErrorSeverity.CRITICAL,
+ context_lines=self.get_context_lines(file.name, e.lineno),
+ suggested_fix="Fix the syntax error"
+ ))
+
+ # Analyze functions in the file
+ for func in file.functions:
+ errors.extend(self.analyze_function(func))
+
+ return errors
+
+ def analyze_codebase(self) -> Dict[str, List[Dict[str, Any]]]:
+ """
+ Analyze the entire codebase for errors.
+
+ Returns:
+ A dictionary mapping file paths to lists of errors
+ """
+ results = {}
+
+ # Analyze each file
+ for file in self.codebase.files:
+ file_errors = self.analyze_file(file)
+ if file_errors:
+ results[file.name] = [error.to_dict() for error in file_errors]
+
+ # Find circular imports
+ circular_imports = self.find_circular_imports()
+ for cycle in circular_imports:
+ for file_name in cycle:
+ if file_name not in results:
+ results[file_name] = []
+
+ results[file_name].append(CodeError(
+ error_type=ErrorType.CIRCULAR_IMPORT,
+ message=f"Circular import detected: {' -> '.join(cycle)}",
+ file_path=file_name,
+ severity=ErrorSeverity.HIGH,
+ suggested_fix="Refactor imports to break the circular dependency"
+ ).to_dict())
+
+ # Find circular dependencies
+ circular_deps = self.find_circular_dependencies()
+ for cycle in circular_deps:
+ for symbol_name in cycle:
+ # Find the file containing this symbol
+ symbol_file = None
+ for symbol in self.codebase.symbols:
+ if symbol.name == symbol_name and hasattr(symbol, "file"):
+ symbol_file = symbol.file.name
+ break
+
+ if not symbol_file:
+ continue
+
+ if symbol_file not in results:
+ results[symbol_file] = []
+
+ results[symbol_file].append(CodeError(
+ error_type=ErrorType.CIRCULAR_DEPENDENCY,
+ message=f"Circular dependency detected: {' -> '.join(cycle)}",
+ file_path=symbol_file,
+ symbol_name=symbol_name,
+ severity=ErrorSeverity.MEDIUM,
+ suggested_fix="Refactor code to break the circular dependency"
+ ).to_dict())
+
+ return results
+
+ def get_error_context(self, error: CodeError) -> Dict[str, Any]:
+ """
+ Get detailed context information for an error.
+
+ Args:
+ error: The error to get context for
+
+ Returns:
+ A dictionary with detailed context information
+ """
+ context = error.to_dict()
+
+ # Add additional context based on error type
+ if error.error_type == ErrorType.PARAMETER_ERROR and error.symbol_name:
+ # Get information about the function
+ func = None
+ for function in self.codebase.functions:
+ if function.name == error.symbol_name:
+ func = function
+ break
+
+ if func and hasattr(func, "parameters"):
+ context["function_info"] = {
+ "name": func.name,
+ "parameters": [{"name": p.name, "type": p.type_annotation if hasattr(p, "type_annotation") else None} for p in func.parameters],
+ "return_type": func.return_type if hasattr(func, "return_type") else None
+ }
+
+ elif error.error_type == ErrorType.CIRCULAR_IMPORT:
+ # Add information about the import cycle
+ import_graph = self.build_import_graph()
+ if error.file_path in import_graph:
+ context["import_info"] = {
+ "imports": [n for n in import_graph.successors(error.file_path)],
+ "imported_by": [n for n in import_graph.predecessors(error.file_path)]
+ }
+
+ elif error.error_type == ErrorType.UNDEFINED_VARIABLE and error.symbol_name:
+ # Get information about the function
+ func = None
+ for function in self.codebase.functions:
+ if function.name == error.symbol_name:
+ func = function
+ break
+
+ if func and hasattr(func, "parameters"):
+ context["function_info"] = {
+ "name": func.name,
+ "parameters": [p.name for p in func.parameters],
+ "local_variables": self._extract_local_variables(func)
+ }
+
+ return context
+
+ def _extract_local_variables(self, function: Function) -> List[str]:
+ """
+ Extract local variables defined in a function.
+
+ Args:
+ function: The function to analyze
+
+ Returns:
+ A list of local variable names
+ """
+ if not hasattr(function, "code_block") or not hasattr(function.code_block, "source"):
+ return []
+
+ local_vars = []
+ try:
+ tree = ast.parse(function.code_block.source)
+
+ # Find all variable assignments
+ for node in ast.walk(tree):
+ if isinstance(node, ast.Assign):
+ for target in node.targets:
+ if isinstance(target, ast.Name):
+ local_vars.append(target.id)
+ elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
+ local_vars.append(node.target.id)
+ except SyntaxError:
+ pass
+
+ return local_vars
+
+ def _find_line_number(self, source: str, pattern: str) -> Optional[int]:
+ """
+ Find the line number where a pattern appears in source code.
+
+ Args:
+ source: The source code to search
+ pattern: The regex pattern to search for
+
+ Returns:
+ The line number (1-based) or None if not found
+ """
+ lines = source.splitlines()
+ for i, line in enumerate(lines):
+ if re.search(pattern, line):
+ return i + 1
+ return None
+
+ def get_function_error_context(self, function_name: str) -> Dict[str, Any]:
+ """
+ Get detailed error context for a specific function.
+
+ Args:
+ function_name: The name of the function to analyze
+
+ Returns:
+ A dictionary with detailed error context
+ """
+ # Find the function
+ function = None
+ for func in self.codebase.functions:
+ if func.name == function_name:
+ function = func
+ break
+
+ if not function:
+ return {"error": f"Function not found: {function_name}"}
+
+ # Analyze the function
+ errors = self.analyze_function(function)
+
+ # Get call graph information
+ call_graph = self.build_call_graph()
+ callers = []
+ callees = []
+
+ if function_name in call_graph:
+ callers = [{"name": caller} for caller in call_graph.predecessors(function_name)]
+ callees = [{"name": callee} for callee in call_graph.successors(function_name)]
+
+ # Get parameter information
+ parameters = []
+ if hasattr(function, "parameters"):
+ for param in function.parameters:
+ param_info = {
+ "name": param.name,
+ "type": param.type_annotation if hasattr(param, "type_annotation") else None,
+ "default": param.default_value if hasattr(param, "default_value") else None
+ }
+ parameters.append(param_info)
+
+ # Get return information
+ return_info = {
+ "type": function.return_type if hasattr(function, "return_type") else None,
+ "statements": []
+ }
+
+ if hasattr(function, "return_statements"):
+ for ret_stmt in function.return_statements:
+ if hasattr(ret_stmt, "expression") and hasattr(ret_stmt.expression, "source"):
+ return_info["statements"].append(ret_stmt.expression.source)
+
+ # Combine all information
+ result = {
+ "function_name": function_name,
+ "file_path": function.file.name if hasattr(function, "file") else None,
+ "errors": [error.to_dict() for error in errors],
+ "callers": callers,
+ "callees": callees,
+ "parameters": parameters,
+ "return_info": return_info,
+ "source": function.source if hasattr(function, "source") else None
+ }
+
+ return result
+
+ def get_file_error_context(self, file_path: str) -> Dict[str, Any]:
+ """
+ Get detailed error context for a specific file.
+
+ Args:
+ file_path: The path of the file to analyze
+
+ Returns:
+ A dictionary with detailed error context
+ """
+ # Find the file
+ file = self.codebase.get_file(file_path)
+ if not file:
+ return {"error": f"File not found: {file_path}"}
+
+ # Analyze the file
+ errors = self.analyze_file(file)
+
+ # Get import graph information
+ import_graph = self.build_import_graph()
+ importers = []
+ imported = []
+
+ if file.name in import_graph:
+ importers = [{"name": importer} for importer in import_graph.predecessors(file.name)]
+ imported = [{"name": imp} for imp in import_graph.successors(file.name)]
+
+ # Get function information
+ functions = []
+ for func in file.functions:
+ func_errors = [error for error in errors if error.symbol_name == func.name]
+ functions.append({
+ "name": func.name,
+ "line_number": func.line_number if hasattr(func, "line_number") else None,
+ "errors": [error.to_dict() for error in func_errors]
+ })
+
+ # Get class information
+ classes = []
+ for cls in file.classes:
+ classes.append({
+ "name": cls.name,
+ "line_number": cls.line_number if hasattr(cls, "line_number") else None
+ })
+
+ # Combine all information
+ result = {
+ "file_path": file_path,
+ "errors": [error.to_dict() for error in errors],
+ "importers": importers,
+ "imported": imported,
+ "functions": functions,
+ "classes": classes
+ }
+
+ return result
+
diff --git a/codegen-on-oss/codegen_on_oss/analysis/error_example.py b/codegen-on-oss/codegen_on_oss/analysis/error_example.py
new file mode 100644
index 000000000..240b264ce
--- /dev/null
+++ b/codegen-on-oss/codegen_on_oss/analysis/error_example.py
@@ -0,0 +1,157 @@
+"""
+Example script demonstrating the use of the error context analysis functionality.
+
+This script shows how to use the CodeAnalyzer class to detect and analyze errors
+in a codebase, providing detailed contextual information about the errors.
+"""
+
+from codegen import Codebase
+from codegen_on_oss.analysis.analysis import CodeAnalyzer
+from codegen_on_oss.analysis.error_context import ErrorType, ErrorSeverity
+
+
+def main():
+ """
+ Main function demonstrating the use of the error context analysis functionality.
+ """
+ print("Analyzing a sample repository for errors...")
+
+ # Load a codebase
+ repo_name = "fastapi/fastapi"
+ codebase = Codebase.from_repo(repo_name)
+
+ print(f"Loaded codebase: {repo_name}")
+ print(f"Files: {len(codebase.files)}")
+ print(f"Functions: {len(codebase.functions)}")
+ print(f"Classes: {len(codebase.classes)}")
+
+ # Create analyzer instance
+ analyzer = CodeAnalyzer(codebase)
+
+ # Analyze errors in the entire codebase
+ print("\n=== Codebase Error Analysis ===")
+ error_analysis = analyzer.analyze_errors()
+
+ # Count errors by type
+ error_counts = {}
+ total_errors = 0
+
+ for file_path, errors in error_analysis.items():
+ for error in errors:
+ error_type = error["error_type"]
+ error_counts[error_type] = error_counts.get(error_type, 0) + 1
+ total_errors += 1
+
+ print(f"Found {total_errors} errors across {len(error_analysis)} files")
+
+ if error_counts:
+ print("\nError types:")
+ for error_type, count in error_counts.items():
+ print(f"- {error_type}: {count}")
+
+ # Find files with the most errors
+ files_with_errors = [(file_path, len(errors)) for file_path, errors in error_analysis.items()]
+ files_with_errors.sort(key=lambda x: x[1], reverse=True)
+
+ if files_with_errors:
+ print("\nTop files with errors:")
+ for file_path, count in files_with_errors[:5]: # Show top 5
+ print(f"- {file_path}: {count} errors")
+
+ # Analyze a specific file
+ if files_with_errors:
+ file_to_analyze = files_with_errors[0][0]
+ print(f"\n=== Detailed Error Analysis for {file_to_analyze} ===")
+ file_error_context = analyzer.get_file_error_context(file_to_analyze)
+
+ print(f"File: {file_error_context['file_path']}")
+ print(f"Errors: {len(file_error_context['errors'])}")
+
+ if file_error_context['errors']:
+ print("\nDetailed errors:")
+ for i, error in enumerate(file_error_context['errors'][:3], 1): # Show top 3
+ print(f"\nError {i}:")
+ print(f"- Type: {error['error_type']}")
+ print(f"- Message: {error['message']}")
+ print(f"- Severity: {error['severity']}")
+ if error['line_number']:
+ print(f"- Line: {error['line_number']}")
+ if error['suggested_fix']:
+ print(f"- Suggested fix: {error['suggested_fix']}")
+
+ if error['context_lines']:
+ print("- Context:")
+ for line_num, line in error['context_lines'].items():
+ prefix = ">" if line_num == error['line_number'] else " "
+ print(f" {prefix} {line_num}: {line}")
+
+ # Show functions in the file
+ if file_error_context['functions']:
+ print("\nFunctions in this file:")
+ for func in file_error_context['functions']:
+ error_count = len(func['errors'])
+ error_suffix = f" ({error_count} errors)" if error_count > 0 else ""
+ print(f"- {func['name']}{error_suffix}")
+
+ # Analyze a specific function with errors
+ function_to_analyze = None
+ for file_path, errors in error_analysis.items():
+ for error in errors:
+ if error['symbol_name']:
+ function_to_analyze = error['symbol_name']
+ break
+ if function_to_analyze:
+ break
+
+ if function_to_analyze:
+ print(f"\n=== Detailed Error Analysis for function {function_to_analyze} ===")
+ function_error_context = analyzer.get_function_error_context(function_to_analyze)
+
+ print(f"Function: {function_error_context['function_name']}")
+ print(f"File: {function_error_context['file_path']}")
+ print(f"Errors: {len(function_error_context['errors'])}")
+
+ if function_error_context['parameters']:
+ print("\nParameters:")
+ for param in function_error_context['parameters']:
+ default = f" = {param['default']}" if param['default'] is not None else ""
+ type_annotation = f": {param['type']}" if param['type'] else ""
+ print(f"- {param['name']}{type_annotation}{default}")
+
+ if function_error_context['return_info']['type']:
+ print(f"\nReturn type: {function_error_context['return_info']['type']}")
+
+ if function_error_context['callers']:
+ print("\nCalled by:")
+ for caller in function_error_context['callers']:
+ print(f"- {caller['name']}")
+
+ if function_error_context['callees']:
+ print("\nCalls:")
+ for callee in function_error_context['callees']:
+ print(f"- {callee['name']}")
+
+ if function_error_context['errors']:
+ print("\nDetailed errors:")
+ for i, error in enumerate(function_error_context['errors'], 1):
+ print(f"\nError {i}:")
+ print(f"- Type: {error['error_type']}")
+ print(f"- Message: {error['message']}")
+ print(f"- Severity: {error['severity']}")
+ if error['line_number']:
+ print(f"- Line: {error['line_number']}")
+ if error['suggested_fix']:
+ print(f"- Suggested fix: {error['suggested_fix']}")
+
+ if error['context_lines']:
+ print("- Context:")
+ for line_num, line in error['context_lines'].items():
+ prefix = ">" if line_num == error['line_number'] else " "
+ print(f" {prefix} {line_num}: {line}")
+
+ print("\nAnalysis complete!")
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/codegen-on-oss/codegen_on_oss/analysis/example.py b/codegen-on-oss/codegen_on_oss/analysis/example.py
new file mode 100644
index 000000000..34dd1710a
--- /dev/null
+++ b/codegen-on-oss/codegen_on_oss/analysis/example.py
@@ -0,0 +1,103 @@
+"""
+Example script demonstrating the use of the unified analysis module.
+
+This script shows how to use the CodeAnalyzer and CodeMetrics classes
+to perform comprehensive code analysis on a repository.
+"""
+
+from codegen import Codebase
+from codegen_on_oss.analysis.analysis import CodeAnalyzer
+from codegen_on_oss.metrics import CodeMetrics
+
+
+def main():
+ """
+ Main function demonstrating the use of the analysis module.
+ """
+ print("Analyzing a sample repository...")
+
+ # Load a codebase
+ repo_name = "fastapi/fastapi"
+ codebase = Codebase.from_repo(repo_name)
+
+ print(f"Loaded codebase: {repo_name}")
+ print(f"Files: {len(codebase.files)}")
+ print(f"Functions: {len(codebase.functions)}")
+ print(f"Classes: {len(codebase.classes)}")
+
+ # Create analyzer instance
+ analyzer = CodeAnalyzer(codebase)
+
+ # Get codebase summary
+ print("\n=== Codebase Summary ===")
+ print(analyzer.get_codebase_summary())
+
+ # Analyze complexity
+ print("\n=== Complexity Analysis ===")
+ complexity_results = analyzer.analyze_complexity()
+ print(f"Average cyclomatic complexity: {complexity_results['cyclomatic_complexity']['average']:.2f}")
+ print(f"Complexity rank: {complexity_results['cyclomatic_complexity']['rank']}")
+
+ # Find complex functions
+ complex_functions = [
+ f for f in complexity_results['cyclomatic_complexity']['functions']
+ if f['complexity'] > 10
+ ][:5] # Show top 5
+
+ if complex_functions:
+ print("\nTop complex functions:")
+ for func in complex_functions:
+ print(f"- {func['name']}: Complexity {func['complexity']} (Rank {func['rank']})")
+
+ # Analyze imports
+ print("\n=== Import Analysis ===")
+ import_analysis = analyzer.analyze_imports()
+ print(f"Found {len(import_analysis['import_cycles'])} import cycles")
+
+ # Create metrics instance
+ metrics = CodeMetrics(codebase)
+
+ # Get code quality summary
+ print("\n=== Code Quality Summary ===")
+ quality_summary = metrics.get_code_quality_summary()
+
+ print("Overall metrics:")
+ for metric, value in quality_summary["overall_metrics"].items():
+ if isinstance(value, float):
+ print(f"- {metric}: {value:.2f}")
+ else:
+ print(f"- {metric}: {value}")
+
+ print("\nProblem areas:")
+ for area, count in quality_summary["problem_areas"].items():
+ print(f"- {area}: {count}")
+
+ # Find bug-prone functions
+ print("\n=== Bug-Prone Functions ===")
+ bug_prone = metrics.find_bug_prone_functions()[:5] # Show top 5
+
+ if bug_prone:
+ print("Top bug-prone functions:")
+ for func in bug_prone:
+ print(f"- {func['name']}: Estimated bugs {func['bugs_delivered']:.2f}")
+
+ # Analyze dependencies
+ print("\n=== Dependency Analysis ===")
+ dependencies = metrics.analyze_dependencies()
+
+ print(f"Dependency graph: {dependencies['dependency_graph']['nodes']} nodes, "
+ f"{dependencies['dependency_graph']['edges']} edges")
+ print(f"Dependency density: {dependencies['dependency_graph']['density']:.4f}")
+ print(f"Number of cycles: {dependencies['cycles']}")
+
+ if dependencies['most_central_files']:
+ print("\nMost central files:")
+ for file, score in dependencies['most_central_files'][:5]: # Show top 5
+ print(f"- {file}: Centrality {score:.4f}")
+
+ print("\nAnalysis complete!")
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/codegen-on-oss/codegen_on_oss/analysis/examples/__init__.py b/codegen-on-oss/codegen_on_oss/analysis/examples/__init__.py
new file mode 100644
index 000000000..32ca6bc47
--- /dev/null
+++ b/codegen-on-oss/codegen_on_oss/analysis/examples/__init__.py
@@ -0,0 +1,4 @@
+"""
+Examples for using the code analysis module.
+"""
+
diff --git a/codegen-on-oss/codegen_on_oss/analysis/examples/analyze_errors.py b/codegen-on-oss/codegen_on_oss/analysis/examples/analyze_errors.py
new file mode 100644
index 000000000..b3c1e0149
--- /dev/null
+++ b/codegen-on-oss/codegen_on_oss/analysis/examples/analyze_errors.py
@@ -0,0 +1,207 @@
+#!/usr/bin/env python3
+"""
+Example script demonstrating how to use the error context analysis functionality.
+
+This script analyzes a repository for errors and prints detailed error context information.
+"""
+
+import argparse
+import json
+import sys
+from typing import Dict, Any
+
+from codegen import Codebase
+from codegen_on_oss.analysis.analysis import CodeAnalyzer
+
+
+def print_error(error: Dict[str, Any]) -> None:
+ """Print a formatted error message."""
+ print(f"ERROR: {error['error_type']} ({error['severity']})")
+ print(f" Message: {error['message']}")
+
+ if error.get('file_path'):
+ print(f" File: {error['file_path']}")
+
+ if error.get('line_number'):
+ print(f" Line: {error['line_number']}")
+
+ if error.get('context_lines'):
+ print(" Context:")
+ for line_num, line in error['context_lines'].items():
+ prefix = ">" if str(line_num) == str(error.get('line_number')) else " "
+ print(f" {prefix} {line_num}: {line}")
+
+ if error.get('suggested_fix'):
+ print(f" Suggested Fix: {error['suggested_fix']}")
+
+ print()
+
+
+def analyze_repo(repo_url: str) -> None:
+ """Analyze a repository for errors."""
+ print(f"Analyzing repository: {repo_url}")
+
+ try:
+ # Create a codebase from the repository
+ codebase = Codebase.from_repo(repo_url)
+
+ # Create an analyzer
+ analyzer = CodeAnalyzer(codebase)
+
+ # Analyze errors in the codebase
+ errors = analyzer.analyze_errors()
+
+ # Print summary
+ total_errors = sum(len(file_errors) for file_errors in errors.values())
+ print(f"\nFound {total_errors} errors in {len(errors)} files\n")
+
+ # Print errors by file
+ for file_path, file_errors in errors.items():
+ print(f"File: {file_path}")
+ print(f" {len(file_errors)} errors found")
+
+ # Print the first 3 errors for each file
+ for i, error in enumerate(file_errors[:3]):
+ print(f" Error {i+1}:")
+ print_error(error)
+
+ if len(file_errors) > 3:
+ print(f" ... and {len(file_errors) - 3} more errors\n")
+
+ print()
+
+ except Exception as e:
+ print(f"Error analyzing repository: {e}", file=sys.stderr)
+ sys.exit(1)
+
+
+def analyze_file(repo_url: str, file_path: str) -> None:
+ """Analyze a specific file for errors."""
+ print(f"Analyzing file: {file_path} in repository: {repo_url}")
+
+ try:
+ # Create a codebase from the repository
+ codebase = Codebase.from_repo(repo_url)
+
+ # Create an analyzer
+ analyzer = CodeAnalyzer(codebase)
+
+ # Get file error context
+ file_error_context = analyzer.get_file_error_context(file_path)
+
+ # Print errors
+ if 'errors' in file_error_context:
+ errors = file_error_context['errors']
+ print(f"\nFound {len(errors)} errors\n")
+
+ for i, error in enumerate(errors):
+ print(f"Error {i+1}:")
+ print_error(error)
+ else:
+ print("\nNo errors found or file not found")
+
+ except Exception as e:
+ print(f"Error analyzing file: {e}", file=sys.stderr)
+ sys.exit(1)
+
+
+def analyze_function(repo_url: str, function_name: str) -> None:
+ """Analyze a specific function for errors."""
+ print(f"Analyzing function: {function_name} in repository: {repo_url}")
+
+ try:
+ # Create a codebase from the repository
+ codebase = Codebase.from_repo(repo_url)
+
+ # Create an analyzer
+ analyzer = CodeAnalyzer(codebase)
+
+ # Get function error context
+ function_error_context = analyzer.get_function_error_context(function_name)
+
+ # Print function information
+ if 'function_name' in function_error_context:
+ print(f"\nFunction: {function_error_context['function_name']}")
+
+ if 'file_path' in function_error_context:
+ print(f"File: {function_error_context['file_path']}")
+
+ # Print parameters
+ if 'parameters' in function_error_context:
+ params = function_error_context['parameters']
+ print(f"\nParameters ({len(params)}):")
+ for param in params:
+ param_type = f": {param['type']}" if param.get('type') else ""
+ default = f" = {param['default']}" if param.get('default') else ""
+ print(f" {param['name']}{param_type}{default}")
+
+ # Print return information
+ if 'return_info' in function_error_context:
+ return_info = function_error_context['return_info']
+ print(f"\nReturn Type: {return_info.get('type', 'Unknown')}")
+ if return_info.get('statements'):
+ print(f"Return Statements ({len(return_info['statements'])}):")
+ for stmt in return_info['statements']:
+ print(f" return {stmt}")
+
+ # Print callers and callees
+ if 'callers' in function_error_context:
+ callers = function_error_context['callers']
+ print(f"\nCallers ({len(callers)}):")
+ for caller in callers:
+ print(f" {caller['name']}")
+
+ if 'callees' in function_error_context:
+ callees = function_error_context['callees']
+ print(f"\nCallees ({len(callees)}):")
+ for callee in callees:
+ print(f" {callee['name']}")
+
+ # Print errors
+ if 'errors' in function_error_context:
+ errors = function_error_context['errors']
+ print(f"\nErrors ({len(errors)}):")
+ for i, error in enumerate(errors):
+ print(f"Error {i+1}:")
+ print_error(error)
+ else:
+ print("\nNo errors found")
+ else:
+ print("\nFunction not found")
+
+ except Exception as e:
+ print(f"Error analyzing function: {e}", file=sys.stderr)
+ sys.exit(1)
+
+
+def main() -> None:
+ """Main entry point."""
+ parser = argparse.ArgumentParser(description="Analyze code for errors")
+ parser.add_argument("repo_url", help="Repository URL (owner/repo)")
+
+ subparsers = parser.add_subparsers(dest="command", help="Analysis command")
+
+ # Repository analysis command
+ repo_parser = subparsers.add_parser("repo", help="Analyze entire repository")
+
+ # File analysis command
+ file_parser = subparsers.add_parser("file", help="Analyze a specific file")
+ file_parser.add_argument("file_path", help="Path to the file to analyze")
+
+ # Function analysis command
+ function_parser = subparsers.add_parser("function", help="Analyze a specific function")
+ function_parser.add_argument("function_name", help="Name of the function to analyze")
+
+ args = parser.parse_args()
+
+ if args.command == "file":
+ analyze_file(args.repo_url, args.file_path)
+ elif args.command == "function":
+ analyze_function(args.repo_url, args.function_name)
+ else:
+ analyze_repo(args.repo_url)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/codegen-on-oss/codegen_on_oss/analysis/mdx_docs_generation.py b/codegen-on-oss/codegen_on_oss/analysis/mdx_docs_generation.py
index 648a3b68e..9e4543bea 100644
--- a/codegen-on-oss/codegen_on_oss/analysis/mdx_docs_generation.py
+++ b/codegen-on-oss/codegen_on_oss/analysis/mdx_docs_generation.py
@@ -110,10 +110,10 @@ def format_parameters_for_mdx(parameters: list[ParameterDoc]) -> str:
def format_return_for_mdx(return_type: list[str], return_description: str) -> str:
description = sanitize_html_for_mdx(return_description) if return_description else ""
- return_type = resolve_type_string(return_type[0])
+ return_type_str = resolve_type_string(return_type[0])
return f"""
-
+
"""
@@ -154,8 +154,8 @@ def get_mdx_route_for_class(cls_doc: ClassDoc) -> str:
def format_type_string(type_string: str) -> str:
- type_string = type_string.split("|")
- return " | ".join([type_str.strip() for type_str in type_string])
+ type_strings = type_string.split("|")
+ return " | ".join([type_str.strip() for type_str in type_strings])
def resolve_type_string(type_string: str) -> str:
diff --git a/codegen-on-oss/codegen_on_oss/analysis/tests/__init__.py b/codegen-on-oss/codegen_on_oss/analysis/tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/codegen-on-oss/codegen_on_oss/analysis/tests/test_error_context.py b/codegen-on-oss/codegen_on_oss/analysis/tests/test_error_context.py
new file mode 100644
index 000000000..33a6e3b01
--- /dev/null
+++ b/codegen-on-oss/codegen_on_oss/analysis/tests/test_error_context.py
@@ -0,0 +1,245 @@
+"""
+Tests for the error context analysis functionality.
+
+This module contains unit tests for the ErrorContextAnalyzer and related classes.
+"""
+
+import ast
+import unittest
+from unittest.mock import MagicMock, patch
+
+from codegen.sdk.core.codebase import Codebase
+from codegen.sdk.core.file import SourceFile
+from codegen.sdk.core.function import Function
+from codegen.sdk.core.symbol import Symbol
+from codegen_on_oss.analysis.error_context import (
+ CodeError,
+ ErrorContextAnalyzer,
+ ErrorSeverity,
+ ErrorType
+)
+
+
+class TestCodeError(unittest.TestCase):
+ """Tests for the CodeError class."""
+
+ def test_code_error_initialization(self):
+ """Test that a CodeError can be initialized with all parameters."""
+ error = CodeError(
+ error_type=ErrorType.SYNTAX_ERROR,
+ message="Invalid syntax",
+ file_path="test.py",
+ line_number=10,
+ column=5,
+ severity=ErrorSeverity.CRITICAL,
+ symbol_name="test_function",
+ context_lines={9: "def test_function():", 10: " print('Hello world'"},
+ suggested_fix="Fix the syntax error"
+ )
+
+ self.assertEqual(error.error_type, ErrorType.SYNTAX_ERROR)
+ self.assertEqual(error.message, "Invalid syntax")
+ self.assertEqual(error.file_path, "test.py")
+ self.assertEqual(error.line_number, 10)
+ self.assertEqual(error.column, 5)
+ self.assertEqual(error.severity, ErrorSeverity.CRITICAL)
+ self.assertEqual(error.symbol_name, "test_function")
+ self.assertEqual(error.context_lines, {9: "def test_function():", 10: " print('Hello world'"})
+ self.assertEqual(error.suggested_fix, "Fix the syntax error")
+
+ def test_code_error_to_dict(self):
+ """Test that a CodeError can be converted to a dictionary."""
+ error = CodeError(
+ error_type=ErrorType.SYNTAX_ERROR,
+ message="Invalid syntax",
+ file_path="test.py",
+ line_number=10,
+ severity=ErrorSeverity.CRITICAL
+ )
+
+ error_dict = error.to_dict()
+
+ self.assertEqual(error_dict["error_type"], ErrorType.SYNTAX_ERROR)
+ self.assertEqual(error_dict["message"], "Invalid syntax")
+ self.assertEqual(error_dict["file_path"], "test.py")
+ self.assertEqual(error_dict["line_number"], 10)
+ self.assertEqual(error_dict["severity"], ErrorSeverity.CRITICAL)
+
+ def test_code_error_str(self):
+ """Test the string representation of a CodeError."""
+ error = CodeError(
+ error_type=ErrorType.SYNTAX_ERROR,
+ message="Invalid syntax",
+ file_path="test.py",
+ line_number=10,
+ severity=ErrorSeverity.CRITICAL
+ )
+
+ error_str = str(error)
+
+ self.assertIn(ErrorType.SYNTAX_ERROR.upper(), error_str)
+ self.assertIn("Invalid syntax", error_str)
+ self.assertIn("test.py:10", error_str)
+ self.assertIn(ErrorSeverity.CRITICAL, error_str)
+
+
+class TestErrorContextAnalyzer(unittest.TestCase):
+ """Tests for the ErrorContextAnalyzer class."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ # Create a mock codebase
+ self.codebase = MagicMock(spec=Codebase)
+
+ # Create a mock file
+ self.file = MagicMock(spec=SourceFile)
+ self.file.name = "test.py"
+ self.file.source = "def test_function():\n x = 'hello' + 5\n return x"
+
+ # Create a mock function
+ self.function = MagicMock(spec=Function)
+ self.function.name = "test_function"
+ self.function.file = self.file
+ self.function.line_number = 1
+ self.function.code_block = MagicMock()
+ self.function.code_block.source = "def test_function():\n x = 'hello' + 5\n return x"
+
+ # Set up the codebase with the file and function
+ self.codebase.files = [self.file]
+ self.codebase.functions = [self.function]
+ self.codebase.get_file.return_value = self.file
+
+ # Create the analyzer
+ self.analyzer = ErrorContextAnalyzer(self.codebase)
+
+ def test_get_context_lines(self):
+ """Test getting context lines around a specific line."""
+ context_lines = self.analyzer.get_context_lines("test.py", 2, context_size=1)
+
+ self.assertEqual(context_lines, {
+ 1: "def test_function():",
+ 2: " x = 'hello' + 5",
+ 3: " return x"
+ })
+
+ def test_analyze_function(self):
+ """Test analyzing a function for errors."""
+ errors = self.analyzer.analyze_function(self.function)
+
+ # We should find at least one error (type error)
+ self.assertGreaterEqual(len(errors), 1)
+
+ # Check that we found a type error
+ type_errors = [e for e in errors if e.error_type == ErrorType.TYPE_ERROR]
+ self.assertGreaterEqual(len(type_errors), 1)
+
+ # Check the error details
+ error = type_errors[0]
+ self.assertEqual(error.file_path, "test.py")
+ self.assertEqual(error.symbol_name, "test_function")
+ self.assertEqual(error.severity, ErrorSeverity.HIGH)
+ self.assertIn("'hello' + 5", str(error.context_lines))
+
+ def test_analyze_file(self):
+ """Test analyzing a file for errors."""
+ errors = self.analyzer.analyze_file(self.file)
+
+ # We should find at least one error (type error)
+ self.assertGreaterEqual(len(errors), 1)
+
+ # Check that we found a type error
+ type_errors = [e for e in errors if e.error_type == ErrorType.TYPE_ERROR]
+ self.assertGreaterEqual(len(type_errors), 1)
+
+ def test_analyze_codebase(self):
+ """Test analyzing the entire codebase for errors."""
+ error_dict = self.analyzer.analyze_codebase()
+
+ # We should have errors for our test file
+ self.assertIn("test.py", error_dict)
+ self.assertGreaterEqual(len(error_dict["test.py"]), 1)
+
+ def test_find_circular_imports(self):
+ """Test finding circular imports."""
+ # Mock the build_import_graph method to return a graph with a cycle
+ import networkx as nx
+ G = nx.DiGraph()
+ G.add_edge("a.py", "b.py")
+ G.add_edge("b.py", "c.py")
+ G.add_edge("c.py", "a.py")
+
+ with patch.object(self.analyzer, 'build_import_graph', return_value=G):
+ cycles = self.analyzer.find_circular_imports()
+
+ # We should find one cycle
+ self.assertEqual(len(cycles), 1)
+
+ # The cycle should contain a.py, b.py, and c.py
+ cycle = cycles[0]
+ self.assertIn("a.py", cycle)
+ self.assertIn("b.py", cycle)
+ self.assertIn("c.py", cycle)
+
+ def test_get_function_error_context(self):
+ """Test getting detailed error context for a function."""
+ # Mock the analyze_function method to return a specific error
+ error = CodeError(
+ error_type=ErrorType.TYPE_ERROR,
+ message="Cannot add string and integer",
+ file_path="test.py",
+ line_number=2,
+ severity=ErrorSeverity.HIGH,
+ symbol_name="test_function",
+ context_lines={1: "def test_function():", 2: " x = 'hello' + 5", 3: " return x"},
+ suggested_fix="Convert the integer to a string: 'hello' + str(5)"
+ )
+
+ with patch.object(self.analyzer, 'analyze_function', return_value=[error]):
+ context = self.analyzer.get_function_error_context("test_function")
+
+ # Check the context
+ self.assertEqual(context["function_name"], "test_function")
+ self.assertEqual(context["file_path"], "test.py")
+ self.assertEqual(len(context["errors"]), 1)
+
+ # Check the error details
+ error_dict = context["errors"][0]
+ self.assertEqual(error_dict["error_type"], ErrorType.TYPE_ERROR)
+ self.assertEqual(error_dict["message"], "Cannot add string and integer")
+ self.assertEqual(error_dict["line_number"], 2)
+ self.assertEqual(error_dict["severity"], ErrorSeverity.HIGH)
+ self.assertEqual(error_dict["suggested_fix"], "Convert the integer to a string: 'hello' + str(5)")
+
+ def test_get_file_error_context(self):
+ """Test getting detailed error context for a file."""
+ # Mock the analyze_file method to return a specific error
+ error = CodeError(
+ error_type=ErrorType.TYPE_ERROR,
+ message="Cannot add string and integer",
+ file_path="test.py",
+ line_number=2,
+ severity=ErrorSeverity.HIGH,
+ symbol_name="test_function",
+ context_lines={1: "def test_function():", 2: " x = 'hello' + 5", 3: " return x"},
+ suggested_fix="Convert the integer to a string: 'hello' + str(5)"
+ )
+
+ with patch.object(self.analyzer, 'analyze_file', return_value=[error]):
+ context = self.analyzer.get_file_error_context("test.py")
+
+ # Check the context
+ self.assertEqual(context["file_path"], "test.py")
+ self.assertEqual(len(context["errors"]), 1)
+
+ # Check the error details
+ error_dict = context["errors"][0]
+ self.assertEqual(error_dict["error_type"], ErrorType.TYPE_ERROR)
+ self.assertEqual(error_dict["message"], "Cannot add string and integer")
+ self.assertEqual(error_dict["line_number"], 2)
+ self.assertEqual(error_dict["severity"], ErrorSeverity.HIGH)
+ self.assertEqual(error_dict["suggested_fix"], "Convert the integer to a string: 'hello' + str(5)")
+
+
+if __name__ == '__main__':
+ unittest.main()
+
diff --git a/codegen-on-oss/codegen_on_oss/metrics.py b/codegen-on-oss/codegen_on_oss/metrics.py
index d77b4e686..d81d5b20b 100644
--- a/codegen-on-oss/codegen_on_oss/metrics.py
+++ b/codegen-on-oss/codegen_on_oss/metrics.py
@@ -1,15 +1,36 @@
+"""
+Metrics module for Codegen-on-OSS
+
+This module provides tools for measuring and recording performance metrics
+and code quality metrics for codebases.
+"""
+
import json
import os
import time
+import math
from collections.abc import Generator
from contextlib import contextmanager
from importlib.metadata import version
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import psutil
+import networkx as nx
+from codegen import Codebase
from codegen_on_oss.errors import ParseRunError
from codegen_on_oss.outputs.base import BaseOutput
+from codegen_on_oss.analysis.analysis import (
+ CodeAnalyzer,
+ calculate_cyclomatic_complexity,
+ calculate_halstead_volume,
+ calculate_maintainability_index,
+ count_lines,
+ get_operators_and_operands,
+ cc_rank,
+ get_maintainability_rank,
+ calculate_doi
+)
if TYPE_CHECKING:
# Logger only available in type checking context.
@@ -19,6 +40,478 @@
codegen_version = str(version("codegen"))
+class CodeMetrics:
+ """
+ A class to calculate and provide code quality metrics for a codebase.
+ Integrates with the analysis module for comprehensive code analysis.
+ """
+
+ # Constants for threshold values
+ COMPLEXITY_THRESHOLD = 10
+ MAINTAINABILITY_THRESHOLD = 65
+ INHERITANCE_DEPTH_THRESHOLD = 3
+ VOLUME_THRESHOLD = 1000
+ EFFORT_THRESHOLD = 50000
+ BUG_THRESHOLD = 0.5
+
+ def __init__(self, codebase: Codebase):
+ """
+ Initialize the CodeMetrics class with a codebase.
+
+ Args:
+ codebase: The Codebase object to analyze
+ """
+ self.codebase = codebase
+ self.analyzer = CodeAnalyzer(codebase)
+ self._complexity_metrics = None
+ self._line_metrics = None
+ self._maintainability_metrics = None
+ self._inheritance_metrics = None
+ self._halstead_metrics = None
+
+ def calculate_all_metrics(self) -> Dict[str, Any]:
+ """
+ Calculate all available metrics for the codebase.
+
+ Returns:
+ A dictionary containing all metrics categories
+ """
+ return {
+ "complexity": self.complexity_metrics,
+ "lines": self.line_metrics,
+ "maintainability": self.maintainability_metrics,
+ "inheritance": self.inheritance_metrics,
+ "halstead": self.halstead_metrics,
+ }
+
+ @property
+ def complexity_metrics(self) -> Dict[str, Any]:
+ """
+ Calculate cyclomatic complexity metrics for the codebase.
+
+ Returns:
+ A dictionary containing complexity metrics including average,
+ rank, and per-function complexity scores
+ """
+ if self._complexity_metrics is not None:
+ return self._complexity_metrics
+
+ callables = self.codebase.functions + [
+ m for c in self.codebase.classes for m in c.methods
+ ]
+
+ complexities = []
+ for func in callables:
+ if not hasattr(func, "code_block"):
+ continue
+
+ complexity = calculate_cyclomatic_complexity(func)
+ complexities.append({
+ "name": func.name,
+ "complexity": complexity,
+ "rank": cc_rank(complexity)
+ })
+
+ avg_complexity = (
+ sum(item["complexity"] for item in complexities) / len(complexities)
+ if complexities else 0
+ )
+
+ self._complexity_metrics = {
+ "average": avg_complexity,
+ "rank": cc_rank(avg_complexity),
+ "functions": complexities
+ }
+
+ return self._complexity_metrics
+
+ @property
+ def line_metrics(self) -> Dict[str, Any]:
+ """
+ Calculate line-based metrics for the codebase.
+
+ Returns:
+ A dictionary containing line metrics including total counts
+ and per-file metrics for LOC, LLOC, SLOC, and comments
+ """
+ if self._line_metrics is not None:
+ return self._line_metrics
+
+ total_loc = total_lloc = total_sloc = total_comments = 0
+ file_metrics = []
+
+ for file in self.codebase.files:
+ loc, lloc, sloc, comments = count_lines(file.source)
+ comment_density = (comments / loc * 100) if loc > 0 else 0
+
+ file_metrics.append({
+ "file": file.path,
+ "loc": loc,
+ "lloc": lloc,
+ "sloc": sloc,
+ "comments": comments,
+ "comment_density": comment_density
+ })
+
+ total_loc += loc
+ total_lloc += lloc
+ total_sloc += sloc
+ total_comments += comments
+
+ total_comment_density = (
+ total_comments / total_loc * 100 if total_loc > 0 else 0
+ )
+
+ self._line_metrics = {
+ "total": {
+ "loc": total_loc,
+ "lloc": total_lloc,
+ "sloc": total_sloc,
+ "comments": total_comments,
+ "comment_density": total_comment_density
+ },
+ "files": file_metrics
+ }
+
+ return self._line_metrics
+
+ @property
+ def maintainability_metrics(self) -> Dict[str, Any]:
+ """
+ Calculate maintainability index metrics for the codebase.
+
+ Returns:
+ A dictionary containing maintainability metrics including average,
+ rank, and per-function maintainability scores
+ """
+ if self._maintainability_metrics is not None:
+ return self._maintainability_metrics
+
+ callables = self.codebase.functions + [
+ m for c in self.codebase.classes for m in c.methods
+ ]
+
+ mi_scores = []
+ for func in callables:
+ if not hasattr(func, "code_block"):
+ continue
+
+ complexity = calculate_cyclomatic_complexity(func)
+ operators, operands = get_operators_and_operands(func)
+ volume, _, _, _, _ = calculate_halstead_volume(operators, operands)
+ loc = len(func.code_block.source.splitlines())
+ mi_score = calculate_maintainability_index(volume, complexity, loc)
+
+ mi_scores.append({
+ "name": func.name,
+ "mi_score": mi_score,
+ "rank": get_maintainability_rank(mi_score)
+ })
+
+ avg_mi = (
+ sum(item["mi_score"] for item in mi_scores) / len(mi_scores)
+ if mi_scores else 0
+ )
+
+ self._maintainability_metrics = {
+ "average": avg_mi,
+ "rank": get_maintainability_rank(avg_mi),
+ "functions": mi_scores
+ }
+
+ return self._maintainability_metrics
+
+ @property
+ def inheritance_metrics(self) -> Dict[str, Any]:
+ """
+ Calculate inheritance metrics for the codebase.
+
+ Returns:
+ A dictionary containing inheritance metrics including average
+ depth of inheritance and per-class inheritance depth
+ """
+ if self._inheritance_metrics is not None:
+ return self._inheritance_metrics
+
+ class_metrics = []
+ for cls in self.codebase.classes:
+ doi = calculate_doi(cls)
+ class_metrics.append({
+ "name": cls.name,
+ "doi": doi
+ })
+
+ avg_doi = (
+ sum(item["doi"] for item in class_metrics) / len(class_metrics)
+ if class_metrics else 0
+ )
+
+ self._inheritance_metrics = {
+ "average": avg_doi,
+ "classes": class_metrics
+ }
+
+ return self._inheritance_metrics
+
+ @property
+ def halstead_metrics(self) -> Dict[str, Any]:
+ """
+ Calculate Halstead complexity metrics for the codebase.
+
+ Returns:
+ A dictionary containing Halstead metrics including volume,
+ difficulty, effort, and other Halstead measures
+ """
+ if self._halstead_metrics is not None:
+ return self._halstead_metrics
+
+ callables = self.codebase.functions + [
+ m for c in self.codebase.classes for m in c.methods
+ ]
+
+ halstead_metrics = []
+ for func in callables:
+ if not hasattr(func, "code_block"):
+ continue
+
+ operators, operands = get_operators_and_operands(func)
+ volume, n1, n2, n_operators, n_operands = calculate_halstead_volume(
+ operators, operands
+ )
+
+ # Calculate additional Halstead metrics
+ n = n_operators + n_operands
+ N = n1 + n2
+
+ difficulty = (
+ (n_operators / 2) * (n2 / n_operands) if n_operands > 0 else 0
+ )
+ effort = difficulty * volume if volume > 0 else 0
+ time_required = effort / 18 if effort > 0 else 0 # Seconds
+ bugs_delivered = volume / 3000 if volume > 0 else 0
+
+ halstead_metrics.append({
+ "name": func.name,
+ "volume": volume,
+ "difficulty": difficulty,
+ "effort": effort,
+ "time_required": time_required, # in seconds
+ "bugs_delivered": bugs_delivered
+ })
+
+ avg_volume = (
+ sum(item["volume"] for item in halstead_metrics) / len(halstead_metrics)
+ if halstead_metrics else 0
+ )
+ avg_difficulty = (
+ sum(item["difficulty"] for item in halstead_metrics) / len(halstead_metrics)
+ if halstead_metrics else 0
+ )
+ avg_effort = (
+ sum(item["effort"] for item in halstead_metrics) / len(halstead_metrics)
+ if halstead_metrics else 0
+ )
+
+ self._halstead_metrics = {
+ "average": {
+ "volume": avg_volume,
+ "difficulty": avg_difficulty,
+ "effort": avg_effort
+ },
+ "functions": halstead_metrics
+ }
+
+ return self._halstead_metrics
+
+ def find_complex_functions(self, threshold: int = COMPLEXITY_THRESHOLD) -> List[Dict[str, Any]]:
+ """
+ Find functions with cyclomatic complexity above the threshold.
+
+ Args:
+ threshold: The complexity threshold (default: 10)
+
+ Returns:
+ A list of functions with complexity above the threshold
+ """
+ metrics = self.complexity_metrics
+ return [
+ func for func in metrics["functions"]
+ if func["complexity"] > threshold
+ ]
+
+ def find_low_maintainability_functions(
+ self, threshold: int = MAINTAINABILITY_THRESHOLD
+ ) -> List[Dict[str, Any]]:
+ """
+ Find functions with maintainability index below the threshold.
+
+ Args:
+ threshold: The maintainability threshold (default: 65)
+
+ Returns:
+ A list of functions with maintainability below the threshold
+ """
+ metrics = self.maintainability_metrics
+ return [
+ func for func in metrics["functions"]
+ if func["mi_score"] < threshold
+ ]
+
+ def find_deep_inheritance_classes(
+ self, threshold: int = INHERITANCE_DEPTH_THRESHOLD
+ ) -> List[Dict[str, Any]]:
+ """
+ Find classes with depth of inheritance above the threshold.
+
+ Args:
+ threshold: The inheritance depth threshold (default: 3)
+
+ Returns:
+ A list of classes with inheritance depth above the threshold
+ """
+ metrics = self.inheritance_metrics
+ return [cls for cls in metrics["classes"] if cls["doi"] > threshold]
+
+ def find_high_volume_functions(self, threshold: int = VOLUME_THRESHOLD) -> List[Dict[str, Any]]:
+ """
+ Find functions with Halstead volume above the threshold.
+
+ Args:
+ threshold: The volume threshold (default: 1000)
+
+ Returns:
+ A list of functions with volume above the threshold
+ """
+ metrics = self.halstead_metrics
+ return [
+ func for func in metrics["functions"]
+ if func["volume"] > threshold
+ ]
+
+ def find_high_effort_functions(self, threshold: int = EFFORT_THRESHOLD) -> List[Dict[str, Any]]:
+ """
+ Find functions with high Halstead effort (difficult to maintain).
+
+ Args:
+ threshold: The effort threshold (default: 50000)
+
+ Returns:
+ A list of functions with effort above the threshold
+ """
+ metrics = self.halstead_metrics
+ return [
+ func for func in metrics["functions"]
+ if func["effort"] > threshold
+ ]
+
+ def find_bug_prone_functions(self, threshold: float = BUG_THRESHOLD) -> List[Dict[str, Any]]:
+ """
+ Find functions with high estimated bug delivery.
+
+ Args:
+ threshold: The bugs delivered threshold (default: 0.5)
+
+ Returns:
+ A list of functions likely to contain bugs
+ """
+ metrics = self.halstead_metrics
+ return [
+ func for func in metrics["functions"]
+ if func["bugs_delivered"] > threshold
+ ]
+
+ def get_code_quality_summary(self) -> Dict[str, Any]:
+ """
+ Generate a comprehensive code quality summary.
+
+ Returns:
+ A dictionary with overall code quality metrics and problem areas
+ """
+ return {
+ "overall_metrics": {
+ "complexity": self.complexity_metrics["average"],
+ "complexity_rank": self.complexity_metrics["rank"],
+ "maintainability": self.maintainability_metrics["average"],
+ "maintainability_rank": self.maintainability_metrics["rank"],
+ "lines_of_code": self.line_metrics["total"]["loc"],
+ "comment_density": self.line_metrics["total"]["comment_density"],
+ "inheritance_depth": self.inheritance_metrics["average"],
+ "halstead_volume": self.halstead_metrics["average"]["volume"],
+ "halstead_difficulty": self.halstead_metrics["average"]["difficulty"],
+ },
+ "problem_areas": {
+ "complex_functions": len(self.find_complex_functions()),
+ "low_maintainability": len(self.find_low_maintainability_functions()),
+ "deep_inheritance": len(self.find_deep_inheritance_classes()),
+ "high_volume": len(self.find_high_volume_functions()),
+ "high_effort": len(self.find_high_effort_functions()),
+ "bug_prone": len(self.find_bug_prone_functions()),
+ },
+ "import_analysis": self.analyzer.analyze_imports()
+ }
+
+ def analyze_codebase_structure(self) -> Dict[str, Any]:
+ """
+ Analyze the structure of the codebase.
+
+ Returns:
+ A dictionary with codebase structure information
+ """
+ return {
+ "summary": self.analyzer.get_codebase_summary(),
+ "files": len(self.codebase.files),
+ "functions": len(self.codebase.functions),
+ "classes": len(self.codebase.classes),
+ "imports": len(self.codebase.imports),
+ "symbols": len(self.codebase.symbols)
+ }
+
+ def generate_documentation(self) -> None:
+ """
+ Generate documentation for the codebase.
+ """
+ self.analyzer.document_functions()
+
+ def analyze_dependencies(self) -> Dict[str, Any]:
+ """
+ Analyze dependencies in the codebase.
+
+ Returns:
+ A dictionary with dependency analysis results
+ """
+ # Create a dependency graph
+ G = nx.DiGraph()
+
+ # Add nodes for all files
+ for file in self.codebase.files:
+ G.add_node(file.path)
+
+ # Add edges for imports
+ for imp in self.codebase.imports:
+ if imp.from_file and imp.to_file:
+ G.add_edge(imp.from_file.filepath, imp.to_file.filepath)
+
+ # Find cycles
+ cycles = list(nx.simple_cycles(G))
+
+ # Calculate centrality metrics
+ centrality = nx.degree_centrality(G)
+
+ return {
+ "dependency_graph": {
+ "nodes": len(G.nodes),
+ "edges": len(G.edges),
+ "density": nx.density(G)
+ },
+ "cycles": len(cycles),
+ "most_central_files": sorted(
+ [(file, score) for file, score in centrality.items()],
+ key=lambda x: x[1],
+ reverse=True
+ )[:10]
+ }
+
+
class MetricsProfiler:
"""
A helper to record performance metrics across multiple profiles and write them to a CSV.
@@ -42,7 +535,7 @@ def __init__(self, output: BaseOutput):
@contextmanager
def start_profiler(
self, name: str, revision: str, language: str | None, logger: "Logger"
- ) -> Generator["MetricsProfile", None, None]:
+ ) -> Generator[Any, None, None]:
"""
Starts a new profiling session for a given profile name.
Returns a MetricsProfile instance that you can use to mark measurements.
@@ -81,9 +574,9 @@ def fields(cls) -> list[str]:
class MetricsProfile:
"""
Context-managed profile that records measurements at each call to `measure()`.
- It tracks the wall-clock duration, CPU time, and memory usage (with delta) at the time of the call.
- Upon exiting the context, it also writes all collected metrics, including the total time,
- to a CSV file.
+ It tracks the wall-clock duration, CPU time, and memory usage (with delta)
+ at the time of the call. Upon exiting the context, it also writes all collected
+ metrics, including the total time, to a CSV file.
"""
if TYPE_CHECKING:
@@ -131,7 +624,9 @@ def measure(self, action_name: str):
"""
current_time = time.perf_counter()
current_cpu = float(time.process_time())
- current_mem = int(psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024))
+ current_mem = int(
+ psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
+ )
# Calculate time deltas.
delta_time = current_time - self.last_measure_time
@@ -168,7 +663,9 @@ def finish(self, error: str | None = None):
"""
finish_time = time.perf_counter()
finish_cpu = float(time.process_time())
- finish_mem = int(psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024))
+ finish_mem = int(
+ psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
+ )
total_duration = finish_time - self.start_time
@@ -196,3 +693,4 @@ def write_output(self, measurement: dict[str, Any]):
"""
self.logger.info(json.dumps(measurement, indent=4))
self.output.write_output(measurement)
+