diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4e500b424..bcba375d2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,10 +15,15 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions-cool/check-user-permission@v2 + if: github.triggering_actor != 'codegen-sh[bot]' with: require: write username: ${{ github.triggering_actor }} error-if-missing: true + # Skip permission check for codegen-sh[bot] + - name: Skip permission check for bot + if: github.triggering_actor == 'codegen-sh[bot]' + run: echo "Skipping permission check for codegen-sh[bot]" unit-tests: needs: access-check diff --git a/codegen-on-oss/codegen_on_oss/analysis/README.md b/codegen-on-oss/codegen_on_oss/analysis/README.md new file mode 100644 index 000000000..096e5421b --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analysis/README.md @@ -0,0 +1,183 @@ +# Code Analysis Module with Error Context + +This module provides robust and dynamic code analysis capabilities with a focus on error detection and contextual error information. + +## Overview + +The code analysis module consists of several components: + +1. **CodeAnalyzer**: The main class that integrates all analysis components and provides a unified interface. +2. **ErrorContextAnalyzer**: A specialized class for detecting and analyzing errors in code. +3. **CodeError**: A class representing an error in code with detailed context information. +4. **API Endpoints**: FastAPI endpoints for accessing the analysis functionality. + +## Features + +### Code Structure Analysis + +- Analyze codebase structure and dependencies +- Generate dependency graphs for files and symbols +- Analyze import relationships and detect circular imports +- Get detailed information about files, functions, classes, and symbols + +### Error Detection and Analysis + +- Detect syntax errors, type errors, parameter errors, and more +- Analyze function parameters and return statements for errors +- Detect undefined variables and unused imports +- Find circular dependencies between symbols +- Provide detailed context information for errors + +### API Endpoints + +- `/analyze_repo`: Analyze a repository and return various metrics +- `/analyze_symbol`: Analyze a symbol and return detailed information +- `/analyze_file`: Analyze a file and return detailed information +- `/analyze_function`: Analyze a function and return detailed information +- `/analyze_errors`: Analyze errors in a repository, file, or function + +## Error Types + +The module can detect the following types of errors: + +- **Syntax Errors**: Invalid syntax in code +- **Type Errors**: Type mismatches in expressions +- **Parameter Errors**: Incorrect function parameters +- **Call Errors**: Incorrect function calls +- **Undefined Variables**: Variables used without being defined +- **Unused Imports**: Imports that are not used in the code +- **Circular Imports**: Circular dependencies between files +- **Circular Dependencies**: Circular dependencies between symbols +- **Name Errors**: References to undefined names +- **Import Errors**: Problems with import statements +- **Attribute Errors**: References to undefined attributes + +## Error Severity Levels + +The module assigns severity levels to each error: + +- **Critical**: Errors that will definitely cause the code to crash or fail +- **High**: Errors that are likely to cause problems in most execution paths +- **Medium**: Errors that may cause problems in some execution paths +- **Low**: Minor issues that are unlikely to cause problems but should be fixed +- **Info**: Informational messages about potential improvements + +## Usage + +### Using the CodeAnalyzer + +```python +from codegen import Codebase +from codegen_on_oss.analysis.analysis import CodeAnalyzer + +# Create a codebase from a repository +codebase = Codebase.from_repo("owner/repo") + +# Create an analyzer +analyzer = CodeAnalyzer(codebase) + +# Analyze errors in the codebase +errors = analyzer.analyze_errors() + +# Get detailed error context for a function +function_errors = analyzer.get_function_error_context("function_name") + +# Get detailed error context for a file +file_errors = analyzer.get_file_error_context("path/to/file.py") +``` + +### Using the API + +```bash +# Analyze a repository +curl -X POST "http://localhost:8000/analyze_repo" \ + -H "Content-Type: application/json" \ + -d '{"repo_url": "owner/repo"}' + +# Analyze errors in a function +curl -X POST "http://localhost:8000/analyze_function" \ + -H "Content-Type: application/json" \ + -d '{"repo_url": "owner/repo", "function_name": "function_name"}' + +# Analyze errors in a file +curl -X POST "http://localhost:8000/analyze_file" \ + -H "Content-Type: application/json" \ + -d '{"repo_url": "owner/repo", "file_path": "path/to/file.py"}' +``` + +## Error Context Example + +Here's an example of the error context information provided for a function: + +```json +{ + "function_name": "calculate_total", + "file_path": "app/utils.py", + "errors": [ + { + "error_type": "parameter_error", + "message": "Function 'calculate_discount' called with 1 arguments but expects 2", + "line_number": 15, + "severity": "high", + "context_lines": { + "13": "def calculate_total(items):", + "14": " total = sum(item.price for item in items)", + "15": " discount = calculate_discount(total)", + "16": " return total - discount", + "17": "" + }, + "suggested_fix": "Update call to provide 2 arguments: calculate_discount(total, discount_percent)" + } + ], + "callers": [ + {"name": "process_order"} + ], + "callees": [ + {"name": "calculate_discount"} + ], + "parameters": [ + { + "name": "items", + "type": "List[Item]", + "default": null + } + ], + "return_info": { + "type": "float", + "statements": ["total - discount"] + } +} +``` + +## Implementation Details + +### ErrorContextAnalyzer + +The `ErrorContextAnalyzer` class is responsible for detecting and analyzing errors in code. It uses various techniques to detect errors, including: + +- **AST Analysis**: Parsing the code into an abstract syntax tree to detect syntax errors and undefined variables +- **Graph Analysis**: Building dependency graphs to detect circular imports and dependencies +- **Pattern Matching**: Using regular expressions to detect potential type errors and other issues +- **Static Analysis**: Analyzing function parameters, return statements, and variable usage + +### CodeError + +The `CodeError` class represents an error in code with detailed context information. It includes: + +- **Error Type**: The type of error (syntax, type, parameter, etc.) +- **Message**: A descriptive message explaining the error +- **Location**: The file path and line number where the error occurs +- **Severity**: The severity of the error (critical, high, medium, low, info) +- **Context Lines**: The lines of code surrounding the error +- **Suggested Fix**: A suggested fix for the error + +## Running the API Server + +To run the API server locally: + +```bash +cd codegen-on-oss +python -m codegen_on_oss.analysis.analysis +``` + +The server will be available at `http://localhost:8000`. diff --git a/codegen-on-oss/codegen_on_oss/analysis/analysis.py b/codegen-on-oss/codegen_on_oss/analysis/analysis.py index 9e956ec06..98ce44030 100644 --- a/codegen-on-oss/codegen_on_oss/analysis/analysis.py +++ b/codegen-on-oss/codegen_on_oss/analysis/analysis.py @@ -1,37 +1,100 @@ -from fastapi import FastAPI -from pydantic import BaseModel -from typing import Dict, List, Tuple, Any +""" +Unified Analysis Module for Codegen-on-OSS + +This module serves as a central hub for all code analysis functionality, integrating +various specialized analysis components into a cohesive system. +""" + +import contextlib +import math +import os +import re +import subprocess +import tempfile +from datetime import UTC, datetime, timedelta +from typing import Any, Dict, List, Optional, Tuple, Union + +import networkx as nx +import requests +import uvicorn from codegen import Codebase +from codegen.sdk.core.class_definition import Class +from codegen.sdk.core.directory import Directory +from codegen.sdk.core.expressions.binary_expression import BinaryExpression +from codegen.sdk.core.expressions.comparison_expression import ComparisonExpression +from codegen.sdk.core.expressions.unary_expression import UnaryExpression +from codegen.sdk.core.external_module import ExternalModule +from codegen.sdk.core.file import SourceFile +from codegen.sdk.core.function import Function +from codegen.sdk.core.import_resolution import Import from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement from codegen.sdk.core.statements.if_block_statement import IfBlockStatement from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement from codegen.sdk.core.statements.while_statement import WhileStatement -from codegen.sdk.core.expressions.binary_expression import BinaryExpression -from codegen.sdk.core.expressions.unary_expression import UnaryExpression -from codegen.sdk.core.expressions.comparison_expression import ComparisonExpression -import math -import re -import requests -from datetime import datetime, timedelta -import subprocess -import os -import tempfile +from codegen.sdk.core.symbol import Symbol +from codegen.sdk.enums import EdgeType, SymbolType +from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware -import modal - -image = ( - modal.Image.debian_slim() - .apt_install("git") - .pip_install( - "codegen", "fastapi", "uvicorn", "gitpython", "requests", "pydantic", "datetime" - ) -) +from pydantic import BaseModel -app = modal.App(name="analytics-app", image=image) +# Import from other analysis modules +from codegen_on_oss.analysis.analysis_import import ( + create_graph_from_codebase, + convert_all_calls_to_kwargs, + find_import_cycles, + find_problematic_import_loops, +) +from codegen_on_oss.analysis.codebase_analysis import ( + get_class_summary, + get_codebase_summary, + get_file_summary, + get_function_summary, + get_symbol_summary, +) +from codegen_on_oss.analysis.codebase_context import CodebaseContext +from codegen_on_oss.analysis.codegen_sdk_codebase import ( + get_codegen_sdk_subdirectories, + get_codegen_sdk_codebase, +) +from codegen_on_oss.analysis.current_code_codebase import ( + get_graphsitter_repo_path, + get_codegen_codebase_base_path, + get_current_code_codebase, + import_all_codegen_sdk_modules, + DocumentedObjects, + get_documented_objects, +) +from codegen_on_oss.analysis.document_functions import ( + hop_through_imports, + get_extended_context, + run as document_functions_run, +) +from codegen_on_oss.analysis.error_context import CodeError, ErrorContextAnalyzer +from codegen_on_oss.analysis.mdx_docs_generation import ( + render_mdx_page_for_class, + render_mdx_page_title, + render_mdx_inheritence_section, + render_mdx_attributes_section, + render_mdx_methods_section, + render_mdx_for_attribute, + format_parameter_for_mdx, + format_parameters_for_mdx, + format_return_for_mdx, + render_mdx_for_method, + get_mdx_route_for_class, + format_type_string, + resolve_type_string, + format_builtin_type_string, + span_type_string_by_pipe, + parse_link, +) +from codegen_on_oss.analysis.module_dependencies import run as module_dependencies_run +from codegen_on_oss.analysis.symbolattr import print_symbol_attribution -fastapi_app = FastAPI() +# Create FastAPI app +app = FastAPI() -fastapi_app.add_middleware( +app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, @@ -39,393 +102,1111 @@ allow_headers=["*"], ) +class CodeAnalyzer: + """ + Central class for code analysis that integrates all analysis components. -def get_monthly_commits(repo_path: str) -> Dict[str, int]: + This class serves as the main entry point for all code analysis functionality, + providing a unified interface to access various analysis capabilities. """ - Get the number of commits per month for the last 12 months. + + def __init__(self, codebase: Codebase): + """ + Initialize the CodeAnalyzer with a codebase. + + Args: + codebase: The Codebase object to analyze + """ + self.codebase = codebase + self._context = None + self._initialized = False + self._error_analyzer = None + + def initialize(self) -> None: + """ + Initialize the analyzer by setting up the context and other necessary components. + This is called automatically when needed but can be called explicitly for eager initialization. + """ + if self._initialized: + return + + # Initialize context if not already done + if self._context is None: + self._context = self._create_context() + + self._initialized = True + + def _create_context(self) -> CodebaseContext: + """ + Create a CodebaseContext instance for the current codebase. + + Returns: + A new CodebaseContext instance + """ + # If the codebase already has a context, use it + if hasattr(self.codebase, "ctx") and self.codebase.ctx is not None: + return self.codebase.ctx + + # Otherwise, create a new context from the codebase's configuration + from codegen.configs.models.codebase import CodebaseConfig + from codegen.sdk.codebase.config import ProjectConfig + + # Create a project config from the codebase + project_config = ProjectConfig( + repo_operator=self.codebase.repo_operator, + programming_language=self.codebase.programming_language, + base_path=self.codebase.base_path, + ) + + # Create and return a new context + return CodebaseContext([project_config], config=CodebaseConfig()) + + @property + def context(self) -> CodebaseContext: + """ + Get the CodebaseContext for the current codebase. + + Returns: + A CodebaseContext object for the codebase + """ + if not self._initialized: + self.initialize() + + return self._context + + @property + def error_analyzer(self) -> ErrorContextAnalyzer: + """ + Get the ErrorContextAnalyzer for the current codebase. + + Returns: + An ErrorContextAnalyzer object for the codebase + """ + if self._error_analyzer is None: + self._error_analyzer = ErrorContextAnalyzer(self.codebase) + + return self._error_analyzer + + def get_codebase_summary(self) -> str: + """ + Get a comprehensive summary of the codebase. + + Returns: + A string containing summary information about the codebase + """ + return get_codebase_summary(self.codebase) + + def get_file_summary(self, file_path: str) -> str: + """ + Get a summary of a specific file. + + Args: + file_path: Path to the file to analyze + + Returns: + A string containing summary information about the file + """ + file = self.codebase.get_file(file_path) + if file is None: + return f"File not found: {file_path}" + return get_file_summary(file) + + def get_class_summary(self, class_name: str) -> str: + """ + Get a summary of a specific class. + + Args: + class_name: Name of the class to analyze + + Returns: + A string containing summary information about the class + """ + for cls in self.codebase.classes: + if cls.name == class_name: + return get_class_summary(cls) + return f"Class not found: {class_name}" + + def get_function_summary(self, function_name: str) -> str: + """ + Get a summary of a specific function. + + Args: + function_name: Name of the function to analyze + + Returns: + A string containing summary information about the function + """ + for func in self.codebase.functions: + if func.name == function_name: + return get_function_summary(func) + return f"Function not found: {function_name}" + + def get_symbol_summary(self, symbol_name: str) -> str: + """ + Get a summary of a specific symbol. + + Args: + symbol_name: Name of the symbol to analyze + + Returns: + A string containing summary information about the symbol + """ + for symbol in self.codebase.symbols: + if symbol.name == symbol_name: + return get_symbol_summary(symbol) + return f"Symbol not found: {symbol_name}" + + def find_symbol_by_name(self, symbol_name: str) -> Symbol | None: + """ + Find a symbol by its name. + + Args: + symbol_name: Name of the symbol to find + + Returns: + The Symbol object if found, None otherwise + """ + for symbol in self.codebase.symbols: + if symbol.name == symbol_name: + return symbol + return None + + def find_file_by_path(self, file_path: str) -> SourceFile | None: + """ + Find a file by its path. + + Args: + file_path: Path to the file to find + + Returns: + The SourceFile object if found, None otherwise + """ + return self.codebase.get_file(file_path) + + def find_class_by_name(self, class_name: str) -> Class | None: + """ + Find a class by its name. + + Args: + class_name: Name of the class to find + + Returns: + The Class object if found, None otherwise + """ + for cls in self.codebase.classes: + if cls.name == class_name: + return cls + return None + + def find_function_by_name(self, function_name: str) -> Function | None: + """ + Find a function by its name. + + Args: + function_name: Name of the function to find + + Returns: + The Function object if found, None otherwise + """ + for func in self.codebase.functions: + if func.name == function_name: + return func + return None + + def document_functions(self) -> None: + """ + Generate documentation for functions in the codebase. + """ + document_functions_run(self.codebase) + + def analyze_imports(self) -> dict[str, Any]: + """ + Analyze import relationships in the codebase. + + Returns: + A dictionary containing import analysis results + """ + graph = create_graph_from_codebase(self.codebase) + cycles = find_import_cycles(graph) + problematic_loops = find_problematic_import_loops(graph, cycles) + + return { + "import_graph": graph, + "cycles": cycles, + "problematic_loops": problematic_loops, + } + + def get_dependency_graph(self) -> nx.DiGraph: + """ + Get a dependency graph for the codebase files. + + Returns: + A directed graph representing file dependencies + """ + G = nx.DiGraph() + + # Add nodes for all files + for file in self.codebase.files: + G.add_node(file.name, type="file") + + # Add edges for imports + for file in self.codebase.files: + for imp in file.imports: + if imp.imported_symbol and hasattr(imp.imported_symbol, "file"): + imported_file = imp.imported_symbol.file + if imported_file and imported_file.name != file.name: + G.add_edge(file.name, imported_file.name) + + return G + + def get_symbol_attribution(self, symbol_name: str) -> str: + """ + Get attribution information for a symbol. + + Args: + symbol_name: Name of the symbol to analyze + + Returns: + A string containing attribution information + """ + symbol = self.find_symbol_by_name(symbol_name) + if symbol is None: + return f"Symbol not found: {symbol_name}" + + return print_symbol_attribution(symbol) + + def get_context_for_symbol(self, symbol_name: str) -> dict[str, Any]: + """ + Get context information for a symbol. + + Args: + symbol_name: Name of the symbol to analyze + + Returns: + A dictionary containing context information + """ + symbol = self.find_symbol_by_name(symbol_name) + if symbol is None: + return {"error": f"Symbol not found: {symbol_name}"} + + # Use the context to get more information about the symbol + ctx = self.context + + # Get symbol node ID in the context graph + node_id = None + for n_id, node in enumerate(ctx.nodes): + if isinstance(node, Symbol) and node.name == symbol_name: + node_id = n_id + break + + if node_id is None: + return {"error": f"Symbol not found in context: {symbol_name}"} + + # Get predecessors (symbols that use this symbol) + predecessors = [] + for pred in ctx.predecessors(node_id): + if isinstance(pred, Symbol): + predecessors.append({ + "name": pred.name, + "type": pred.symbol_type.name + if hasattr(pred, "symbol_type") + else "Unknown", + }) + + # Get successors (symbols used by this symbol) + successors = [] + for succ in ctx.successors(node_id): + if isinstance(succ, Symbol): + successors.append({ + "name": succ.name, + "type": succ.symbol_type.name + if hasattr(succ, "symbol_type") + else "Unknown", + }) + + return { + "symbol": { + "name": symbol.name, + "type": symbol.symbol_type.name + if hasattr(symbol, "symbol_type") + else "Unknown", + "file": symbol.file.name if hasattr(symbol, "file") else "Unknown", + }, + "predecessors": predecessors, + "successors": successors, + } + + def get_file_dependencies(self, file_path: str) -> dict[str, Any]: + """ + Get dependency information for a file using CodebaseContext. + + Args: + file_path: Path to the file to analyze + + Returns: + A dictionary containing dependency information + """ + file = self.find_file_by_path(file_path) + if file is None: + return {"error": f"File not found: {file_path}"} + + # Use the context to get more information about the file + ctx = self.context + + # Get file node ID in the context graph + node_id = None + for n_id, node in enumerate(ctx.nodes): + if isinstance(node, SourceFile) and node.name == file.name: + node_id = n_id + break + + if node_id is None: + return {"error": f"File not found in context: {file_path}"} + + # Get files that import this file + importers = [] + for pred in ctx.predecessors(node_id, edge_type=EdgeType.IMPORT): + if isinstance(pred, SourceFile): + importers.append(pred.name) + + imported = [] + for succ in ctx.successors(node_id, edge_type=EdgeType.IMPORT): + if isinstance(succ, SourceFile): + imported.append(succ.name) + + return {"file": file.name, "importers": importers, "imported": imported} + + def analyze_codebase_structure(self) -> dict[str, Any]: + """ + Analyze the overall structure of the codebase using CodebaseContext. + + Returns: + A dictionary containing structural analysis results + """ + ctx = self.context + + # Count nodes by type + node_types: dict[str, int] = {} + for node in ctx.nodes: + node_type = type(node).__name__ + node_types[node_type] = node_types.get(node_type, 0) + 1 + + edge_types: dict[str, int] = {} + for _, _, edge in ctx.edges: + edge_type = edge.type.name + edge_types[edge_type] = edge_types.get(edge_type, 0) + 1 + + directories = {} + for path, directory in ctx.directories.items(): + directories[str(path)] = { + "files": len([ + item for item in directory.items if isinstance(item, SourceFile) + ]), + "subdirectories": len([ + item for item in directory.items if isinstance(item, Directory) + ]), + } + + return { + "node_types": node_types, + "edge_types": edge_types, + "directories": directories, + } + + def get_symbol_dependencies(self, symbol_name: str) -> dict[str, list[str]]: + """ + Get direct dependencies of a symbol. + + Args: + symbol_name: Name of the symbol to analyze + + Returns: + A dictionary mapping dependency types to lists of symbol names + """ + symbol = self.find_symbol_by_name(symbol_name) + if symbol is None: + return {"error": [f"Symbol not found: {symbol_name}"]} + + dependencies: dict[str, list[str]] = { + "imports": [], + "functions": [], + "classes": [], + "variables": [], + } + + # Process dependencies based on symbol type + if hasattr(symbol, "dependencies"): + for dep in symbol.dependencies: + if isinstance(dep, Import): + if dep.imported_symbol: + dependencies["imports"].append(dep.imported_symbol.name) + elif isinstance(dep, Symbol): + if dep.symbol_type == SymbolType.Function: + dependencies["functions"].append(dep.name) + elif dep.symbol_type == SymbolType.Class: + dependencies["classes"].append(dep.name) + elif dep.symbol_type == SymbolType.GlobalVar: + dependencies["variables"].append(dep.name) + + return dependencies + + def analyze_errors(self) -> dict[str, list[dict[str, Any]]]: + """ + Analyze the codebase for errors. + + Returns: + A dictionary mapping file paths to lists of errors + """ + return self.error_analyzer.analyze_codebase() + + def get_function_error_context(self, function_name: str) -> dict[str, Any]: + """ + Get detailed error context for a specific function. + + Args: + function_name: The name of the function to analyze + + Returns: + A dictionary with detailed error context + """ + return self.error_analyzer.get_function_error_context(function_name) + + def get_file_error_context(self, file_path: str) -> dict[str, Any]: + """ + Get detailed error context for a specific file. + + Args: + file_path: The path of the file to analyze + + Returns: + A dictionary with detailed error context + """ + return self.error_analyzer.get_file_error_context(file_path) + + def get_error_context(self, error: CodeError) -> dict[str, Any]: + """ + Get detailed context information for an error. + + Args: + error: The error to get context for + + Returns: + A dictionary with detailed context information + """ + return self.error_analyzer.get_error_context(error) + + def convert_args_to_kwargs(self) -> None: + """ + Convert all function call arguments to keyword arguments. + """ + convert_all_calls_to_kwargs(self.codebase) + + def visualize_module_dependencies(self) -> None: + """ + Visualize module dependencies in the codebase. + """ + module_dependencies_run(self.codebase) + + def generate_mdx_documentation(self, class_name: str) -> str: + """ + Generate MDX documentation for a class. + + Args: + class_name: Name of the class to document + + Returns: + MDX documentation as a string + """ + for cls in self.codebase.classes: + if cls.name == class_name: + return render_mdx_page_for_class(cls) + return f"Class not found: {class_name}" + + def print_symbol_attribution(self) -> None: + """ + Print attribution information for symbols in the codebase. + """ + print_symbol_attribution(self.codebase) + + def get_extended_symbol_context(self, symbol_name: str, degree: int = 2) -> Dict[str, List[str]]: + """ + Get extended context (dependencies and usages) for a symbol. + + Args: + symbol_name: Name of the symbol to analyze + degree: How many levels deep to collect dependencies and usages + + Returns: + A dictionary containing dependencies and usages + """ + symbol = self.find_symbol_by_name(symbol_name) + if symbol: + dependencies, usages = get_extended_context(symbol, degree) + return { + "dependencies": [dep.name for dep in dependencies], + "usages": [usage.name for usage in usages] + } + return {"dependencies": [], "usages": []} + + def get_file_imports(self, file_path: str) -> List[str]: + """ + Get all imports in a file. + + Args: + file_path: Path to the file to analyze + + Returns: + A list of import statements + """ + file = self.find_file_by_path(file_path) + if file and hasattr(file, "imports"): + return [imp.source for imp in file.imports] + return [] + + def get_file_exports(self, file_path: str) -> List[str]: + """ + Get all exports from a file. + + Args: + file_path: Path to the file to analyze + + Returns: + A list of exported symbol names + """ + file = self.find_file_by_path(file_path) + if not file: + return [] + + exports = [] + for symbol in self.codebase.symbols: + if hasattr(symbol, "file") and symbol.file == file: + exports.append(symbol.name) + + return exports + + def analyze_complexity(self, file_path: str = None) -> Dict[str, Any]: + """ + Analyze code complexity metrics for the codebase or a specific file. + + Args: + file_path: Optional path to a specific file to analyze + + Returns: + A dictionary containing complexity metrics + """ + files_to_analyze = [] + if file_path: + file = self.find_file_by_path(file_path) + if file: + files_to_analyze = [file] + else: + return {"error": f"File not found: {file_path}"} + else: + files_to_analyze = self.codebase.files + + # Calculate complexity metrics + results = { + "cyclomatic_complexity": { + "total": 0, + "average": 0, + "max": 0, + "max_file": "", + "max_function": "", + "by_file": {} + }, + "halstead_complexity": { + "total": 0, + "average": 0, + "max": 0, + "max_file": "", + "by_file": {} + }, + "maintainability_index": { + "total": 0, + "average": 0, + "min": 100, + "min_file": "", + "by_file": {} + }, + "line_metrics": { + "total_loc": 0, + "total_lloc": 0, + "total_sloc": 0, + "total_comments": 0, + "comment_ratio": 0, + "by_file": {} + } + } + + # Process each file + for file in files_to_analyze: + # Skip non-Python files + if not file.name.endswith(".py"): + continue + + file_path = file.name + file_content = file.content + + # Calculate cyclomatic complexity + cc_total = 0 + cc_max = 0 + cc_max_function = "" + + # Count decision points (if, for, while, etc.) + for func in file.functions: + func_cc = 1 # Base complexity + + # Count control structures + for node in func.ast_node.body: + if isinstance(node, (ast.If, ast.For, ast.While, ast.Try)): + func_cc += 1 + + # Count logical operators in conditions + if isinstance(node, ast.If) and isinstance(node.test, ast.BoolOp): + func_cc += len(node.test.values) - 1 + + cc_total += func_cc + if func_cc > cc_max: + cc_max = func_cc + cc_max_function = func.name + + # Update cyclomatic complexity metrics + results["cyclomatic_complexity"]["by_file"][file_path] = { + "total": cc_total, + "average": cc_total / len(file.functions) if file.functions else 0, + "max": cc_max, + "max_function": cc_max_function + } + + results["cyclomatic_complexity"]["total"] += cc_total + if cc_max > results["cyclomatic_complexity"]["max"]: + results["cyclomatic_complexity"]["max"] = cc_max + results["cyclomatic_complexity"]["max_file"] = file_path + results["cyclomatic_complexity"]["max_function"] = cc_max_function + + # Calculate line metrics + loc = len(file_content.splitlines()) + lloc = sum(1 for line in file_content.splitlines() if line.strip() and not line.strip().startswith("#")) + sloc = sum(1 for line in file_content.splitlines() if line.strip()) + comments = sum(1 for line in file_content.splitlines() if line.strip().startswith("#")) + + results["line_metrics"]["by_file"][file_path] = { + "loc": loc, + "lloc": lloc, + "sloc": sloc, + "comments": comments, + "comment_ratio": comments / loc if loc else 0 + } + + results["line_metrics"]["total_loc"] += loc + results["line_metrics"]["total_lloc"] += lloc + results["line_metrics"]["total_sloc"] += sloc + results["line_metrics"]["total_comments"] += comments + + # Simple Halstead complexity approximation + operators = len(re.findall(r'[\+\-\*/=<>!&|^~]', file_content)) + operands = len(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', file_content)) + + n1 = len(set(re.findall(r'[\+\-\*/=<>!&|^~]', file_content))) + n2 = len(set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', file_content))) + + N = operators + operands + n = n1 + n2 + + # Calculate Halstead metrics + if n1 > 0 and n2 > 0: + volume = N * math.log2(n) + difficulty = (n1 / 2) * (operands / n2) + effort = volume * difficulty + else: + volume = 0 + difficulty = 0 + effort = 0 + + results["halstead_complexity"]["by_file"][file_path] = { + "volume": volume, + "difficulty": difficulty, + "effort": effort + } + + results["halstead_complexity"]["total"] += effort + if effort > results["halstead_complexity"]["max"]: + results["halstead_complexity"]["max"] = effort + results["halstead_complexity"]["max_file"] = file_path + + # Calculate maintainability index + if lloc > 0: + mi = 171 - 5.2 * math.log(volume) - 0.23 * cc_total - 16.2 * math.log(lloc) + mi = max(0, min(100, mi)) + else: + mi = 100 + + results["maintainability_index"]["by_file"][file_path] = mi + results["maintainability_index"]["total"] += mi + + if mi < results["maintainability_index"]["min"]: + results["maintainability_index"]["min"] = mi + results["maintainability_index"]["min_file"] = file_path + + # Calculate averages + num_files = len(results["cyclomatic_complexity"]["by_file"]) + if num_files > 0: + results["cyclomatic_complexity"]["average"] = results["cyclomatic_complexity"]["total"] / num_files + results["halstead_complexity"]["average"] = results["halstead_complexity"]["total"] / num_files + results["maintainability_index"]["average"] = results["maintainability_index"]["total"] / num_files + + total_loc = results["line_metrics"]["total_loc"] + if total_loc > 0: + results["line_metrics"]["comment_ratio"] = results["line_metrics"]["total_comments"] / total_loc + + return results + + def find_central_files(self) -> List[Dict[str, Any]]: + """ + Find the most central files in the codebase based on dependency analysis. + + Returns: + A list of dictionaries containing file information and centrality metrics + """ + G = self.get_dependency_graph() + + # Calculate centrality metrics + degree_centrality = nx.degree_centrality(G) + betweenness_centrality = nx.betweenness_centrality(G) + closeness_centrality = nx.closeness_centrality(G) + + # Combine metrics + centrality = {} + for node in G.nodes(): + centrality[node] = { + "file": node, + "degree": degree_centrality.get(node, 0), + "betweenness": betweenness_centrality.get(node, 0), + "closeness": closeness_centrality.get(node, 0), + "combined": ( + degree_centrality.get(node, 0) + + betweenness_centrality.get(node, 0) + + closeness_centrality.get(node, 0) + ) / 3 + } + + # Sort by combined centrality + sorted_centrality = sorted( + centrality.values(), + key=lambda x: x["combined"], + reverse=True + ) + + return sorted_centrality[:10] # Return top 10 most central files + + +# Request models for API endpoints +class RepoRequest(BaseModel): + """Request model for repository analysis.""" + + repo_url: str + + +class SymbolRequest(BaseModel): + """Request model for symbol analysis.""" + + repo_url: str + symbol_name: str + + +class FileRequest(BaseModel): + """Request model for file analysis.""" + + repo_url: str + file_path: str + + +class FunctionRequest(BaseModel): + """Request model for function analysis.""" + + repo_url: str + function_name: str + + +class ErrorRequest(BaseModel): + """Request model for error analysis.""" + + repo_url: str + file_path: str | None = None + function_name: str | None = None + + +class ComplexityRequest(BaseModel): + """Request model for complexity analysis.""" + + repo_url: str + file_path: str | None = None + + +class DocumentationRequest(BaseModel): + """Request model for documentation generation.""" + + repo_url: str + class_name: str | None = None + + +# API endpoints +@app.post("/analyze_repo") +async def analyze_repo(request: RepoRequest) -> dict[str, Any]: + """ + Analyze a repository and return various metrics. Args: - repo_path: Path to the git repository + request: The repository request containing the repo URL Returns: - Dictionary with month-year as key and number of commits as value + A dictionary of analysis results + """ + repo_url = request.repo_url + + try: + codebase = Codebase.from_repo(repo_url) + analyzer = CodeAnalyzer(codebase) + + # Get import analysis + import_analysis = analyzer.analyze_imports() + + # Get structure analysis + structure_analysis = analyzer.analyze_codebase_structure() + + # Get error analysis + error_analysis = analyzer.analyze_errors() + + # Combine all results + results = { + "repo_url": repo_url, + "num_files": len(codebase.files), + "num_functions": len(codebase.functions), + "num_classes": len(codebase.classes), + "import_analysis": import_analysis, + "structure_analysis": structure_analysis, + "error_analysis": error_analysis + } + + return results + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error analyzing repository: {str(e)}") from e + + +@app.post("/analyze_symbol") +async def analyze_symbol(request: SymbolRequest) -> dict[str, Any]: """ - end_date = datetime.now() - start_date = end_date - timedelta(days=365) + Analyze a symbol and return detailed information. - date_format = "%Y-%m-%d" - since_date = start_date.strftime(date_format) - until_date = end_date.strftime(date_format) - repo_path = "https://github.com/" + repo_path + Args: + request: The symbol request containing the repo URL and symbol name + + Returns: + A dictionary of analysis results + """ + repo_url = request.repo_url + symbol_name = request.symbol_name try: - original_dir = os.getcwd() - - with tempfile.TemporaryDirectory() as temp_dir: - subprocess.run(["git", "clone", repo_path, temp_dir], check=True) - os.chdir(temp_dir) - - cmd = [ - "git", - "log", - f"--since={since_date}", - f"--until={until_date}", - "--format=%aI", - ] - - result = subprocess.run(cmd, capture_output=True, text=True, check=True) - commit_dates = result.stdout.strip().split("\n") - - monthly_counts = {} - current_date = start_date - while current_date <= end_date: - month_key = current_date.strftime("%Y-%m") - monthly_counts[month_key] = 0 - current_date = ( - current_date.replace(day=1) + timedelta(days=32) - ).replace(day=1) - - for date_str in commit_dates: - if date_str: # Skip empty lines - commit_date = datetime.fromisoformat(date_str.strip()) - month_key = commit_date.strftime("%Y-%m") - if month_key in monthly_counts: - monthly_counts[month_key] += 1 - - os.chdir(original_dir) - return dict(sorted(monthly_counts.items())) - - except subprocess.CalledProcessError as e: - print(f"Error executing git command: {e}") - return {} + codebase = Codebase.from_repo(repo_url) + analyzer = CodeAnalyzer(codebase) + + # Get symbol context + symbol_context = analyzer.get_context_for_symbol(symbol_name) + + # Get symbol dependencies + dependencies = analyzer.get_symbol_dependencies(symbol_name) + + # Get symbol attribution + attribution = analyzer.get_symbol_attribution(symbol_name) + + return { + "symbol_name": symbol_name, + "context": symbol_context, + "dependencies": dependencies, + "attribution": attribution, + } except Exception as e: - print(f"Error processing git commits: {e}") - return {} - finally: - try: - os.chdir(original_dir) - except: - pass - - -def calculate_cyclomatic_complexity(function): - def analyze_statement(statement): - complexity = 0 - - if isinstance(statement, IfBlockStatement): - complexity += 1 - if hasattr(statement, "elif_statements"): - complexity += len(statement.elif_statements) - - elif isinstance(statement, (ForLoopStatement, WhileStatement)): - complexity += 1 - - elif isinstance(statement, TryCatchStatement): - complexity += len(getattr(statement, "except_blocks", [])) - - if hasattr(statement, "condition") and isinstance(statement.condition, str): - complexity += statement.condition.count( - " and " - ) + statement.condition.count(" or ") - - if hasattr(statement, "nested_code_blocks"): - for block in statement.nested_code_blocks: - complexity += analyze_block(block) - - return complexity - - def analyze_block(block): - if not block or not hasattr(block, "statements"): - return 0 - return sum(analyze_statement(stmt) for stmt in block.statements) - - return ( - 1 + analyze_block(function.code_block) if hasattr(function, "code_block") else 1 - ) - - -def cc_rank(complexity): - if complexity < 0: - raise ValueError("Complexity must be a non-negative value") - - ranks = [ - (1, 5, "A"), - (6, 10, "B"), - (11, 20, "C"), - (21, 30, "D"), - (31, 40, "E"), - (41, float("inf"), "F"), - ] - for low, high, rank in ranks: - if low <= complexity <= high: - return rank - return "F" - - -def calculate_doi(cls): - """Calculate the depth of inheritance for a given class.""" - return len(cls.superclasses) - - -def get_operators_and_operands(function): - operators = [] - operands = [] - - for statement in function.code_block.statements: - for call in statement.function_calls: - operators.append(call.name) - for arg in call.args: - operands.append(arg.source) - - if hasattr(statement, "expressions"): - for expr in statement.expressions: - if isinstance(expr, BinaryExpression): - operators.extend([op.source for op in expr.operators]) - operands.extend([elem.source for elem in expr.elements]) - elif isinstance(expr, UnaryExpression): - operators.append(expr.ts_node.type) - operands.append(expr.argument.source) - elif isinstance(expr, ComparisonExpression): - operators.extend([op.source for op in expr.operators]) - operands.extend([elem.source for elem in expr.elements]) - - if hasattr(statement, "expression"): - expr = statement.expression - if isinstance(expr, BinaryExpression): - operators.extend([op.source for op in expr.operators]) - operands.extend([elem.source for elem in expr.elements]) - elif isinstance(expr, UnaryExpression): - operators.append(expr.ts_node.type) - operands.append(expr.argument.source) - elif isinstance(expr, ComparisonExpression): - operators.extend([op.source for op in expr.operators]) - operands.extend([elem.source for elem in expr.elements]) - - return operators, operands - - -def calculate_halstead_volume(operators, operands): - n1 = len(set(operators)) - n2 = len(set(operands)) - - N1 = len(operators) - N2 = len(operands) - - N = N1 + N2 - n = n1 + n2 - - if n > 0: - volume = N * math.log2(n) - return volume, N1, N2, n1, n2 - return 0, N1, N2, n1, n2 - - -def count_lines(source: str): - """Count different types of lines in source code.""" - if not source.strip(): - return 0, 0, 0, 0 - - lines = [line.strip() for line in source.splitlines()] - loc = len(lines) - sloc = len([line for line in lines if line]) - - in_multiline = False - comments = 0 - code_lines = [] - - i = 0 - while i < len(lines): - line = lines[i] - code_part = line - if not in_multiline and "#" in line: - comment_start = line.find("#") - if not re.search(r'["\'].*#.*["\']', line[:comment_start]): - code_part = line[:comment_start].strip() - if line[comment_start:].strip(): - comments += 1 - - if ('"""' in line or "'''" in line) and not ( - line.count('"""') % 2 == 0 or line.count("'''") % 2 == 0 - ): - if in_multiline: - in_multiline = False - comments += 1 - else: - in_multiline = True - comments += 1 - if line.strip().startswith('"""') or line.strip().startswith("'''"): - code_part = "" - elif in_multiline: - comments += 1 - code_part = "" - elif line.strip().startswith("#"): - comments += 1 - code_part = "" + raise HTTPException( + status_code=500, detail=f"Error analyzing symbol: {e!s}" + ) from e + + +@app.post("/analyze_file") +async def analyze_file(request: FileRequest) -> dict[str, Any]: + """ + Analyze a file and return detailed information. + + Args: + request: The file request containing the repo URL and file path + + Returns: + A dictionary of analysis results + """ + repo_url = request.repo_url + file_path = request.file_path - if code_part.strip(): - code_lines.append(code_part) + try: + codebase = Codebase.from_repo(repo_url) + analyzer = CodeAnalyzer(codebase) - i += 1 + # Get file summary + file_summary = analyzer.get_file_summary(file_path) - lloc = 0 - continued_line = False - for line in code_lines: - if continued_line: - if not any(line.rstrip().endswith(c) for c in ("\\", ",", "{", "[", "(")): - continued_line = False - continue + # Get file dependencies + file_dependencies = analyzer.get_file_dependencies(file_path) - lloc += len([stmt for stmt in line.split(";") if stmt.strip()]) + # Get file error context + file_error_context = analyzer.get_file_error_context(file_path) - if any(line.rstrip().endswith(c) for c in ("\\", ",", "{", "[", "(")): - continued_line = True + return { + "file_path": file_path, + "summary": file_summary, + "dependencies": file_dependencies, + "error_context": file_error_context, + } + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error analyzing file: {e!s}" + ) from e - return loc, lloc, sloc, comments +@app.post("/analyze_function") +async def analyze_function(request: FunctionRequest) -> dict[str, Any]: + """ + Analyze a function and return detailed information. -def calculate_maintainability_index( - halstead_volume: float, cyclomatic_complexity: float, loc: int -) -> int: - """Calculate the normalized maintainability index for a given function.""" - if loc <= 0: - return 100 + Args: + request: The function request containing the repo URL and function name + + Returns: + A dictionary of analysis results + """ + repo_url = request.repo_url + function_name = request.function_name try: - raw_mi = ( - 171 - - 5.2 * math.log(max(1, halstead_volume)) - - 0.23 * cyclomatic_complexity - - 16.2 * math.log(max(1, loc)) - ) - normalized_mi = max(0, min(100, raw_mi * 100 / 171)) - return int(normalized_mi) - except (ValueError, TypeError): - return 0 + codebase = Codebase.from_repo(repo_url) + analyzer = CodeAnalyzer(codebase) + # Get function summary + function_summary = analyzer.get_function_summary(function_name) -def get_maintainability_rank(mi_score: float) -> str: - """Convert maintainability index score to a letter grade.""" - if mi_score >= 85: - return "A" - elif mi_score >= 65: - return "B" - elif mi_score >= 45: - return "C" - elif mi_score >= 25: - return "D" - else: - return "F" + # Get function error context + function_error_context = analyzer.get_function_error_context(function_name) + return { + "function_name": function_name, + "summary": function_summary, + "error_context": function_error_context, + } + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error analyzing function: {e!s}" + ) from e -def get_github_repo_description(repo_url): - api_url = f"https://api.github.com/repos/{repo_url}" - response = requests.get(api_url) +@app.post("/analyze_errors") +async def analyze_errors(request: ErrorRequest) -> dict[str, Any]: + """ + Analyze errors in a repository, file, or function. - if response.status_code == 200: - repo_data = response.json() - return repo_data.get("description", "No description available") - else: - return "" + Args: + request: The error request containing the repo URL and optional file path or function name + Returns: + A dictionary of error analysis results + """ + repo_url = request.repo_url + file_path = request.file_path + function_name = request.function_name -class RepoRequest(BaseModel): - repo_url: str + try: + codebase = Codebase.from_repo(repo_url) + analyzer = CodeAnalyzer(codebase) + + if function_name: + # Analyze errors in a specific function + return analyzer.get_function_error_context(function_name) + elif file_path: + # Analyze errors in a specific file + return analyzer.get_file_error_context(file_path) + else: + # Analyze errors in the entire codebase + return {"error_analysis": analyzer.analyze_errors()} + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error analyzing errors: {e!s}" + ) from e -@fastapi_app.post("/analyze_repo") -async def analyze_repo(request: RepoRequest) -> Dict[str, Any]: - """Analyze a repository and return comprehensive metrics.""" +@app.post("/analyze_complexity") +async def analyze_complexity(request: ComplexityRequest) -> dict[str, Any]: + """ + Analyze code complexity metrics for a repository or specific file. + + Args: + request: The complexity request containing the repo URL and optional file path + + Returns: + A dictionary of complexity analysis results + """ repo_url = request.repo_url - codebase = Codebase.from_repo(repo_url) - - num_files = len(codebase.files(extensions="*")) - num_functions = len(codebase.functions) - num_classes = len(codebase.classes) - - total_loc = total_lloc = total_sloc = total_comments = 0 - total_complexity = 0 - total_volume = 0 - total_mi = 0 - total_doi = 0 - - monthly_commits = get_monthly_commits(repo_url) - print(monthly_commits) - - for file in codebase.files: - loc, lloc, sloc, comments = count_lines(file.source) - total_loc += loc - total_lloc += lloc - total_sloc += sloc - total_comments += comments - - callables = codebase.functions + [m for c in codebase.classes for m in c.methods] - - num_callables = 0 - for func in callables: - if not hasattr(func, "code_block"): - continue - - complexity = calculate_cyclomatic_complexity(func) - operators, operands = get_operators_and_operands(func) - volume, _, _, _, _ = calculate_halstead_volume(operators, operands) - loc = len(func.code_block.source.splitlines()) - mi_score = calculate_maintainability_index(volume, complexity, loc) - - total_complexity += complexity - total_volume += volume - total_mi += mi_score - num_callables += 1 - - for cls in codebase.classes: - doi = calculate_doi(cls) - total_doi += doi - - desc = get_github_repo_description(repo_url) - - results = { - "repo_url": repo_url, - "line_metrics": { - "total": { - "loc": total_loc, - "lloc": total_lloc, - "sloc": total_sloc, - "comments": total_comments, - "comment_density": (total_comments / total_loc * 100) - if total_loc > 0 - else 0, - }, - }, - "cyclomatic_complexity": { - "average": total_complexity if num_callables > 0 else 0, - }, - "depth_of_inheritance": { - "average": total_doi / len(codebase.classes) if codebase.classes else 0, - }, - "halstead_metrics": { - "total_volume": int(total_volume), - "average_volume": int(total_volume / num_callables) - if num_callables > 0 - else 0, - }, - "maintainability_index": { - "average": int(total_mi / num_callables) if num_callables > 0 else 0, - }, - "description": desc, - "num_files": num_files, - "num_functions": num_functions, - "num_classes": num_classes, - "monthly_commits": monthly_commits, - } - - return results - - -@app.function(image=image) -@modal.asgi_app() -def fastapi_modal_app(): - return fastapi_app + file_path = request.file_path + + try: + codebase = Codebase.from_repo(repo_url) + analyzer = CodeAnalyzer(codebase) + + # Analyze complexity + complexity_results = analyzer.analyze_complexity(file_path) + + return { + "repo_url": repo_url, + "file_path": file_path, + "complexity_analysis": complexity_results + } + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error analyzing complexity: {e!s}" + ) from e + + +@app.post("/generate_documentation") +async def generate_documentation(request: DocumentationRequest) -> dict[str, Any]: + """ + Generate documentation for a class or the entire codebase. + + Args: + request: The documentation request containing the repo URL and optional class name + + Returns: + A dictionary containing the generated documentation + """ + repo_url = request.repo_url + class_name = request.class_name + + try: + codebase = Codebase.from_repo(repo_url) + analyzer = CodeAnalyzer(codebase) + + if class_name: + # Generate documentation for a specific class + mdx_doc = analyzer.generate_mdx_documentation(class_name) + return { + "class_name": class_name, + "documentation": mdx_doc + } + else: + # Generate documentation for all functions + analyzer.document_functions() + return {"message": "Documentation generated for all functions"} + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error generating documentation: {e!s}" + ) from e if __name__ == "__main__": - app.deploy("analytics-app") + # Run the FastAPI app locally with uvicorn + uvicorn.run(app, host="127.0.0.1", port=8000) diff --git a/codegen-on-oss/codegen_on_oss/analysis/codebase_context.py b/codegen-on-oss/codegen_on_oss/analysis/codebase_context.py index 5c0fd47dd..c092356b7 100644 --- a/codegen-on-oss/codegen_on_oss/analysis/codebase_context.py +++ b/codegen-on-oss/codegen_on_oss/analysis/codebase_context.py @@ -121,7 +121,6 @@ class CodebaseContext: dependency_manager: DependencyManager | None language_engine: LanguageEngine | None _computing = False - _graph: PyDiGraph[Importable, Edge] filepath_idx: dict[str, NodeId] _ext_module_idx: dict[str, NodeId] flags: Flags @@ -143,8 +142,6 @@ def __init__( from codegen.sdk.core.parser import Parser self.progress = progress or StubProgress() - self.__graph = PyDiGraph() - self.__graph_ready = False self.filepath_idx = {} self._ext_module_idx = {} self.generation = 0 diff --git a/codegen-on-oss/codegen_on_oss/analysis/enhanced_type_detection.py b/codegen-on-oss/codegen_on_oss/analysis/enhanced_type_detection.py new file mode 100644 index 000000000..839b329b3 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analysis/enhanced_type_detection.py @@ -0,0 +1,521 @@ +""" +Enhanced Type Detection Module for Error Context Analysis + +This module provides more robust type detection and analysis capabilities +for the ErrorContextAnalyzer. It uses AST analysis and type inference +to detect potential type errors in code. +""" + +import ast +import inspect +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from codegen.sdk.core.function import Function +from codegen_on_oss.analysis.error_context import CodeError, ErrorType, ErrorSeverity + + +class TypeAnalyzer: + """ + Analyzes code for type-related errors using AST analysis and type inference. + """ + + def __init__(self): + """Initialize the TypeAnalyzer.""" + # Map of known Python types + self.python_types = { + 'str': str, + 'int': int, + 'float': float, + 'bool': bool, + 'list': list, + 'dict': dict, + 'tuple': tuple, + 'set': set, + 'None': type(None), + } + + # Map of compatible binary operations + self.compatible_ops = { + ast.Add: { + str: [str], + int: [int, float], + float: [int, float], + list: [list], + tuple: [tuple], + }, + ast.Sub: { + int: [int, float], + float: [int, float], + set: [set], + }, + ast.Mult: { + int: [int, float, str, list, tuple], + float: [int, float], + str: [int], + list: [int], + tuple: [int], + }, + ast.Div: { + int: [int, float], + float: [int, float], + }, + # Add more operations as needed + } + + def analyze_function(self, function: Function) -> List[CodeError]: + """ + Analyze a function for type-related errors. + + Args: + function: The function to analyze + + Returns: + A list of type-related errors + """ + errors = [] + + if not hasattr(function, "code_block") or not hasattr(function.code_block, "source"): + return errors + + try: + # Parse the AST + tree = ast.parse(function.code_block.source) + + # Track variable types based on assignments and annotations + variable_types = self._collect_variable_types(tree, function) + + # Check for type mismatches + errors.extend(self._check_type_mismatches(tree, variable_types, function)) + + # Check for parameter type mismatches + errors.extend(self._check_parameter_types(tree, variable_types, function)) + + # Check for return type mismatches + errors.extend(self._check_return_types(tree, variable_types, function)) + + return errors + except SyntaxError: + # If we can't parse the AST, return no errors + return errors + + def _collect_variable_types(self, tree: ast.AST, function: Function) -> Dict[str, Any]: + """ + Collect variable types from assignments and annotations. + + Args: + tree: The AST to analyze + function: The function being analyzed + + Returns: + A dictionary mapping variable names to their types + """ + variable_types = {} + + # Add function parameters with type annotations + if hasattr(function, "parameters"): + for param in function.parameters: + if hasattr(param, "type_annotation") and param.type_annotation: + variable_types[param.name] = self._parse_type_annotation(param.type_annotation) + + # First pass: collect type information from the AST + for node in ast.walk(tree): + if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): + # Handle type annotations + variable_types[node.target.id] = self._get_type_from_annotation(node.annotation) + elif isinstance(node, ast.Assign): + # Infer types from assignments where possible + for target in node.targets: + if isinstance(target, ast.Name): + inferred_type = self._infer_type_from_value(node.value) + if inferred_type: + variable_types[target.id] = inferred_type + + return variable_types + + def _check_type_mismatches(self, tree: ast.AST, variable_types: Dict[str, Any], function: Function) -> List[CodeError]: + """ + Check for type mismatches in binary operations. + + Args: + tree: The AST to analyze + variable_types: Dictionary mapping variable names to their types + function: The function being analyzed + + Returns: + A list of type-related errors + """ + errors = [] + + for node in ast.walk(tree): + if isinstance(node, ast.BinOp): + # Check binary operations for type mismatches + left_type = self._get_expression_type(node.left, variable_types) + right_type = self._get_expression_type(node.right, variable_types) + + if left_type and right_type and not self._are_types_compatible(left_type, right_type, node.op): + # Found potential type error + line_number = node.lineno + errors.append(CodeError( + error_type=ErrorType.TYPE_ERROR, + message=f"Potential type mismatch: {self._type_name(left_type)} {type(node.op).__name__} {self._type_name(right_type)}", + file_path=function.file.name if hasattr(function, "file") else None, + line_number=line_number, + severity=ErrorSeverity.HIGH, + symbol_name=function.name, + context_lines=self._get_context_lines(function, line_number), + suggested_fix=f"Ensure operands are of compatible types for {type(node.op).__name__} operation" + )) + + return errors + + def _check_parameter_types(self, tree: ast.AST, variable_types: Dict[str, Any], function: Function) -> List[CodeError]: + """ + Check for parameter type mismatches in function calls. + + Args: + tree: The AST to analyze + variable_types: Dictionary mapping variable names to their types + function: The function being analyzed + + Returns: + A list of parameter-related errors + """ + errors = [] + + # Get function calls + for node in ast.walk(tree): + if isinstance(node, ast.Call): + # Check if we're calling a function we know about + if isinstance(node.func, ast.Name) and node.func.id in variable_types: + # This is a simplification - in a real implementation, we would need to + # track function signatures and parameter types + pass + + return errors + + def _check_return_types(self, tree: ast.AST, variable_types: Dict[str, Any], function: Function) -> List[CodeError]: + """ + Check for return type mismatches. + + Args: + tree: The AST to analyze + variable_types: Dictionary mapping variable names to their types + function: The function being analyzed + + Returns: + A list of return-related errors + """ + errors = [] + + # Get the declared return type + declared_return_type = None + if hasattr(function, "return_type") and function.return_type: + declared_return_type = self._parse_type_annotation(function.return_type) + + if not declared_return_type: + return errors + + # Check return statements + for node in ast.walk(tree): + if isinstance(node, ast.Return) and node.value: + returned_type = self._get_expression_type(node.value, variable_types) + + if returned_type and not self._is_return_type_compatible(returned_type, declared_return_type): + line_number = node.lineno + errors.append(CodeError( + error_type=ErrorType.TYPE_ERROR, + message=f"Return type mismatch: returning {self._type_name(returned_type)} but function declares {self._type_name(declared_return_type)}", + file_path=function.file.name if hasattr(function, "file") else None, + line_number=line_number, + severity=ErrorSeverity.HIGH, + symbol_name=function.name, + context_lines=self._get_context_lines(function, line_number), + suggested_fix=f"Ensure the return value matches the declared return type {self._type_name(declared_return_type)}" + )) + + return errors + + def _get_expression_type(self, node: ast.AST, variable_types: Dict[str, Any]) -> Optional[Any]: + """ + Get the type of an expression. + + Args: + node: The AST node representing the expression + variable_types: Dictionary mapping variable names to their types + + Returns: + The type of the expression, or None if it cannot be determined + """ + if isinstance(node, ast.Name): + # Variable reference + return variable_types.get(node.id) + elif isinstance(node, ast.Constant): + # Literal value + return type(node.value) + elif isinstance(node, ast.List): + # List literal + return list + elif isinstance(node, ast.Dict): + # Dict literal + return dict + elif isinstance(node, ast.Tuple): + # Tuple literal + return tuple + elif isinstance(node, ast.Set): + # Set literal + return set + elif isinstance(node, ast.BinOp): + # Binary operation + left_type = self._get_expression_type(node.left, variable_types) + right_type = self._get_expression_type(node.right, variable_types) + + # Determine result type based on operation and operand types + # This is a simplification - in a real implementation, we would need more sophisticated type inference + if isinstance(node.op, ast.Add): + if left_type == str or right_type == str: + return str + elif left_type in (int, float) and right_type in (int, float): + return float if float in (left_type, right_type) else int + elif left_type == list and right_type == list: + return list + elif left_type == tuple and right_type == tuple: + return tuple + + # Add more operation type inference as needed + + # For other expression types, we can't determine the type + return None + + def _are_types_compatible(self, left_type: Any, right_type: Any, op: ast.operator) -> bool: + """ + Check if two types are compatible for a binary operation. + + Args: + left_type: The type of the left operand + right_type: The type of the right operand + op: The binary operation + + Returns: + True if the types are compatible, False otherwise + """ + op_type = type(op) + + if op_type in self.compatible_ops and left_type in self.compatible_ops[op_type]: + return right_type in self.compatible_ops[op_type][left_type] + + return False + + def _is_return_type_compatible(self, actual_type: Any, declared_type: Any) -> bool: + """ + Check if a return type is compatible with the declared return type. + + Args: + actual_type: The actual return type + declared_type: The declared return type + + Returns: + True if the types are compatible, False otherwise + """ + # This is a simplification - in a real implementation, we would need more sophisticated type compatibility checking + if actual_type == declared_type: + return True + + # Handle numeric types + if declared_type == float and actual_type == int: + return True + + # Handle None + if declared_type == type(None) and actual_type == type(None): + return True + + # Handle Union types (simplified) + if isinstance(declared_type, tuple): + return actual_type in declared_type + + return False + + def _get_type_from_annotation(self, annotation: ast.AST) -> Optional[Any]: + """ + Get a type from an annotation AST node. + + Args: + annotation: The AST node representing the annotation + + Returns: + The type, or None if it cannot be determined + """ + if isinstance(annotation, ast.Name): + # Simple type name + return self.python_types.get(annotation.id) + elif isinstance(annotation, ast.Subscript): + # Generic type (e.g., List[int]) + if isinstance(annotation.value, ast.Name): + if annotation.value.id == 'List': + return list + elif annotation.value.id == 'Dict': + return dict + elif annotation.value.id == 'Tuple': + return tuple + elif annotation.value.id == 'Set': + return set + elif annotation.value.id == 'Optional': + # For Optional[T], we return the inner type + return self._get_type_from_annotation(annotation.slice) + elif annotation.value.id == 'Union': + # For Union[T1, T2, ...], we return a tuple of types + if isinstance(annotation.slice, ast.Tuple): + types = [self._get_type_from_annotation(elt) for elt in annotation.slice.elts] + return tuple(t for t in types if t is not None) + + return None + + def _parse_type_annotation(self, type_annotation: str) -> Optional[Any]: + """ + Parse a type annotation string. + + Args: + type_annotation: The type annotation string + + Returns: + The type, or None if it cannot be parsed + """ + # This is a simplification - in a real implementation, we would need more sophisticated parsing + if type_annotation == 'str': + return str + elif type_annotation == 'int': + return int + elif type_annotation == 'float': + return float + elif type_annotation == 'bool': + return bool + elif type_annotation == 'list' or type_annotation.startswith('List['): + return list + elif type_annotation == 'dict' or type_annotation.startswith('Dict['): + return dict + elif type_annotation == 'tuple' or type_annotation.startswith('Tuple['): + return tuple + elif type_annotation == 'set' or type_annotation.startswith('Set['): + return set + elif type_annotation == 'None': + return type(None) + elif type_annotation.startswith('Optional['): + # Extract the inner type + inner_type = type_annotation[9:-1] + return self._parse_type_annotation(inner_type) + elif type_annotation.startswith('Union['): + # Extract the union types + union_types = type_annotation[6:-1].split(', ') + types = [self._parse_type_annotation(t) for t in union_types] + return tuple(t for t in types if t is not None) + + return None + + def _infer_type_from_value(self, node: ast.AST) -> Optional[Any]: + """ + Infer the type of a value. + + Args: + node: The AST node representing the value + + Returns: + The inferred type, or None if it cannot be determined + """ + if isinstance(node, ast.Constant): + return type(node.value) + elif isinstance(node, ast.List): + return list + elif isinstance(node, ast.Dict): + return dict + elif isinstance(node, ast.Tuple): + return tuple + elif isinstance(node, ast.Set): + return set + elif isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + # Function call - try to determine the return type + if node.func.id in self.python_types: + return self.python_types[node.func.id] + + return None + + def _get_context_lines(self, function: Function, line_number: int, context_size: int = 2) -> Dict[int, str]: + """ + Get context lines around a specific line in a function. + + Args: + function: The function containing the line + line_number: The line number to get context for + context_size: Number of lines before and after to include + + Returns: + Dictionary mapping line numbers to line content + """ + if not hasattr(function, "code_block") or not hasattr(function.code_block, "source"): + return {} + + lines = function.code_block.source.splitlines() + + # Adjust line_number to be relative to the function's code block + if hasattr(function, "line_number"): + relative_line = line_number - function.line_number + else: + relative_line = line_number + + start_line = max(0, relative_line - context_size - 1) + end_line = min(len(lines), relative_line + context_size) + + # Map the relative line numbers back to absolute line numbers + if hasattr(function, "line_number"): + return {i + function.line_number: lines[i] for i in range(start_line, end_line)} + else: + return {i + 1: lines[i] for i in range(start_line, end_line)} + + def _type_name(self, type_obj: Any) -> str: + """ + Get a human-readable name for a type. + + Args: + type_obj: The type object + + Returns: + A string representation of the type + """ + if type_obj == str: + return "str" + elif type_obj == int: + return "int" + elif type_obj == float: + return "float" + elif type_obj == bool: + return "bool" + elif type_obj == list: + return "list" + elif type_obj == dict: + return "dict" + elif type_obj == tuple: + return "tuple" + elif type_obj == set: + return "set" + elif type_obj == type(None): + return "None" + elif isinstance(type_obj, tuple): + # Union type + return f"Union[{', '.join(self._type_name(t) for t in type_obj)}]" + + return str(type_obj) + + +# Example usage +def analyze_function_types(function: Function) -> List[CodeError]: + """ + Analyze a function for type-related errors. + + Args: + function: The function to analyze + + Returns: + A list of type-related errors + """ + analyzer = TypeAnalyzer() + return analyzer.analyze_function(function) + diff --git a/codegen-on-oss/codegen_on_oss/analysis/error_context.py b/codegen-on-oss/codegen_on_oss/analysis/error_context.py new file mode 100644 index 000000000..d74414287 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analysis/error_context.py @@ -0,0 +1,818 @@ +""" +Error Context Module for Codegen-on-OSS + +This module provides robust and dynamic error context analysis for code files and functions. +It helps identify and contextualize errors in code, providing detailed information about +the error location, type, and potential fixes. +""" + +import ast +import inspect +import re +import tokenize +from io import StringIO +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import networkx as nx +from codegen.sdk.core.class_definition import Class +from codegen.sdk.core.codebase import Codebase +from codegen.sdk.core.file import SourceFile +from codegen.sdk.core.function import Function +from codegen.sdk.core.import_resolution import Import +from codegen.sdk.core.symbol import Symbol +from codegen.sdk.enums import EdgeType, SymbolType + +# Error types +class ErrorType: + SYNTAX_ERROR = "syntax_error" + TYPE_ERROR = "type_error" + NAME_ERROR = "name_error" + IMPORT_ERROR = "import_error" + ATTRIBUTE_ERROR = "attribute_error" + PARAMETER_ERROR = "parameter_error" + CALL_ERROR = "call_error" + UNDEFINED_VARIABLE = "undefined_variable" + UNUSED_IMPORT = "unused_import" + UNUSED_VARIABLE = "unused_variable" + CIRCULAR_IMPORT = "circular_import" + CIRCULAR_DEPENDENCY = "circular_dependency" + + +class ErrorSeverity: + CRITICAL = "critical" + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + INFO = "info" + + +class CodeError: + """Represents an error in code with context.""" + + def __init__( + self, + error_type: str, + message: str, + file_path: Optional[str] = None, + line_number: Optional[int] = None, + column: Optional[int] = None, + severity: str = ErrorSeverity.MEDIUM, + symbol_name: Optional[str] = None, + context_lines: Optional[Dict[int, str]] = None, + suggested_fix: Optional[str] = None, + ): + self.error_type = error_type + self.message = message + self.file_path = file_path + self.line_number = line_number + self.column = column + self.severity = severity + self.symbol_name = symbol_name + self.context_lines = context_lines or {} + self.suggested_fix = suggested_fix + + def to_dict(self) -> Dict[str, Any]: + """Convert the error to a dictionary representation.""" + return { + "error_type": self.error_type, + "message": self.message, + "file_path": self.file_path, + "line_number": self.line_number, + "column": self.column, + "severity": self.severity, + "symbol_name": self.symbol_name, + "context_lines": self.context_lines, + "suggested_fix": self.suggested_fix, + } + + def __str__(self) -> str: + """String representation of the error.""" + location = f"{self.file_path}:{self.line_number}" if self.file_path and self.line_number else "Unknown location" + return f"{self.error_type.upper()} ({self.severity}): {self.message} at {location}" + + +class ErrorContextAnalyzer: + """ + Analyzes code for errors and provides rich context information. + + This class is responsible for detecting various types of errors in code + and providing detailed context information to help understand and fix them. + """ + + def __init__(self, codebase: Codebase): + """ + Initialize the ErrorContextAnalyzer with a codebase. + + Args: + codebase: The Codebase object to analyze + """ + self.codebase = codebase + self._call_graph = None + self._dependency_graph = None + self._import_graph = None + + def get_context_lines(self, file_path: str, line_number: int, context_size: int = 3) -> Dict[int, str]: + """ + Get context lines around a specific line in a file. + + Args: + file_path: Path to the file + line_number: The line number to get context for + context_size: Number of lines before and after to include + + Returns: + Dictionary mapping line numbers to line content + """ + file = self.codebase.get_file(file_path) + if not file or not hasattr(file, "source"): + return {} + + lines = file.source.splitlines() + start_line = max(0, line_number - context_size - 1) + end_line = min(len(lines), line_number + context_size) + + return {i + 1: lines[i] for i in range(start_line, end_line)} + + def build_call_graph(self) -> nx.DiGraph: + """ + Build a call graph for the codebase. + + Returns: + A directed graph representing function calls + """ + if self._call_graph is not None: + return self._call_graph + + G = nx.DiGraph() + + # Add nodes for all functions + for func in self.codebase.functions: + G.add_node(func.name, type="function", function=func) + + # Add edges for function calls + for func in self.codebase.functions: + if not hasattr(func, "function_calls"): + continue + + for call in func.function_calls: + if call.name in G: + G.add_edge(func.name, call.name, type="call") + + self._call_graph = G + return G + + def build_dependency_graph(self) -> nx.DiGraph: + """ + Build a dependency graph for the codebase. + + Returns: + A directed graph representing symbol dependencies + """ + if self._dependency_graph is not None: + return self._dependency_graph + + G = nx.DiGraph() + + # Add nodes for all symbols + for symbol in self.codebase.symbols: + G.add_node(symbol.name, type="symbol", symbol=symbol) + + # Add edges for dependencies + for symbol in self.codebase.symbols: + if not hasattr(symbol, "dependencies"): + continue + + for dep in symbol.dependencies: + if isinstance(dep, Symbol): + G.add_edge(symbol.name, dep.name, type="dependency") + + self._dependency_graph = G + return G + + def build_import_graph(self) -> nx.DiGraph: + """ + Build an import graph for the codebase. + + Returns: + A directed graph representing file imports + """ + if self._import_graph is not None: + return self._import_graph + + G = nx.DiGraph() + + # Add nodes for all files + for file in self.codebase.files: + G.add_node(file.name, type="file", file=file) + + # Add edges for imports + for file in self.codebase.files: + for imp in file.imports: + if imp.imported_symbol and hasattr(imp.imported_symbol, "file"): + imported_file = imp.imported_symbol.file + if imported_file and imported_file.name != file.name: + G.add_edge(file.name, imported_file.name, type="import") + + self._import_graph = G + return G + + def find_circular_imports(self) -> List[List[str]]: + """ + Find circular imports in the codebase. + + Returns: + A list of cycles, where each cycle is a list of file names + """ + import_graph = self.build_import_graph() + return list(nx.simple_cycles(import_graph)) + + def find_circular_dependencies(self) -> List[List[str]]: + """ + Find circular dependencies between symbols. + + Returns: + A list of cycles, where each cycle is a list of symbol names + """ + dependency_graph = self.build_dependency_graph() + return list(nx.simple_cycles(dependency_graph)) + + def analyze_function_parameters(self, function: Function) -> List[CodeError]: + """ + Analyze function parameters for errors. + + Args: + function: The function to analyze + + Returns: + A list of parameter-related errors + """ + errors = [] + + if not hasattr(function, "parameters") or not hasattr(function, "function_calls"): + return errors + + # Check for parameter type mismatches + for param in function.parameters: + if not hasattr(param, "type_annotation") or not param.type_annotation: + continue + + # Check if parameter is used with correct type + # This is a simplified check and would need more sophisticated type inference in practice + param_name = param.name + param_type = param.type_annotation + + # Look for usage of this parameter in the function body + if hasattr(function, "code_block") and hasattr(function.code_block, "source"): + source = function.code_block.source + + # Simple pattern matching for potential type errors + # This is a simplified approach and would need more sophisticated analysis in practice + if re.search(rf"\b{param_name}\s*\+\s*\d+\b", source) and "str" in param_type: + line_number = self._find_line_number(function.code_block.source, rf"\b{param_name}\s*\+\s*\d+\b") + errors.append(CodeError( + error_type=ErrorType.TYPE_ERROR, + message=f"Potential type error: adding integer to string parameter '{param_name}'", + file_path=function.file.name if hasattr(function, "file") else None, + line_number=line_number, + severity=ErrorSeverity.HIGH, + symbol_name=function.name, + context_lines=self.get_context_lines(function.file.name, line_number) if hasattr(function, "file") else None, + suggested_fix=f"Ensure '{param_name}' is converted to int before addition or use string concatenation" + )) + + # Check for call parameter mismatches + call_graph = self.build_call_graph() + for call in function.function_calls: + called_func_name = call.name + + # Find the called function + called_func = None + for func in self.codebase.functions: + if func.name == called_func_name: + called_func = func + break + + if not called_func or not hasattr(called_func, "parameters"): + continue + + # Check if number of arguments matches + if hasattr(call, "args") and len(call.args) != len(called_func.parameters): + # Find the line number of the call + line_number = self._find_line_number(function.code_block.source, rf"\b{called_func_name}\s*\(") + + errors.append(CodeError( + error_type=ErrorType.PARAMETER_ERROR, + message=f"Function '{called_func_name}' called with {len(call.args)} arguments but expects {len(called_func.parameters)}", + file_path=function.file.name if hasattr(function, "file") else None, + line_number=line_number, + severity=ErrorSeverity.HIGH, + symbol_name=function.name, + context_lines=self.get_context_lines(function.file.name, line_number) if hasattr(function, "file") else None, + suggested_fix=f"Update call to provide {len(called_func.parameters)} arguments" + )) + + return errors + + def analyze_function_returns(self, function: Function) -> List[CodeError]: + """ + Analyze function return statements for errors. + + Args: + function: The function to analyze + + Returns: + A list of return-related errors + """ + errors = [] + + if not hasattr(function, "return_type") or not function.return_type: + return errors + + if not hasattr(function, "return_statements") or not function.return_statements: + # Function has return type but no return statements + errors.append(CodeError( + error_type=ErrorType.TYPE_ERROR, + message=f"Function '{function.name}' has return type '{function.return_type}' but no return statements", + file_path=function.file.name if hasattr(function, "file") else None, + line_number=function.line_number if hasattr(function, "line_number") else None, + severity=ErrorSeverity.MEDIUM, + symbol_name=function.name, + context_lines=self.get_context_lines(function.file.name, function.line_number) if hasattr(function, "file") and hasattr(function, "line_number") else None, + suggested_fix=f"Add return statement or change return type to 'None'" + )) + return errors + + # Check if return statements match the declared return type + return_type = function.return_type + for ret_stmt in function.return_statements: + # This is a simplified check and would need more sophisticated type inference in practice + if hasattr(ret_stmt, "expression") and hasattr(ret_stmt.expression, "source"): + expr_source = ret_stmt.expression.source + + # Simple pattern matching for potential type errors + if "int" in return_type and re.search(r"\".*\"", expr_source): + line_number = self._find_line_number(function.code_block.source, rf"return\s+{re.escape(expr_source)}") + errors.append(CodeError( + error_type=ErrorType.TYPE_ERROR, + message=f"Function '{function.name}' returns string but declares return type '{return_type}'", + file_path=function.file.name if hasattr(function, "file") else None, + line_number=line_number, + severity=ErrorSeverity.HIGH, + symbol_name=function.name, + context_lines=self.get_context_lines(function.file.name, line_number) if hasattr(function, "file") else None, + suggested_fix=f"Convert return value to {return_type} or update return type annotation" + )) + + return errors + + def analyze_unused_imports(self, file: SourceFile) -> List[CodeError]: + """ + Find unused imports in a file. + + Args: + file: The file to analyze + + Returns: + A list of unused import errors + """ + errors = [] + + if not hasattr(file, "imports") or not hasattr(file, "symbols"): + return errors + + # Get all imported symbols + imported_symbols = set() + for imp in file.imports: + if hasattr(imp, "imported_symbol") and imp.imported_symbol: + imported_symbols.add(imp.imported_symbol.name) + + # Get all used symbols + used_symbols = set() + for symbol in file.symbols: + if hasattr(symbol, "dependencies"): + for dep in symbol.dependencies: + if isinstance(dep, Symbol): + used_symbols.add(dep.name) + + # Find unused imports + unused_imports = imported_symbols - used_symbols + for unused in unused_imports: + # Find the import statement + for imp in file.imports: + if hasattr(imp, "imported_symbol") and imp.imported_symbol and imp.imported_symbol.name == unused: + errors.append(CodeError( + error_type=ErrorType.UNUSED_IMPORT, + message=f"Unused import: '{unused}'", + file_path=file.name, + line_number=imp.line_number if hasattr(imp, "line_number") else None, + severity=ErrorSeverity.LOW, + context_lines=self.get_context_lines(file.name, imp.line_number) if hasattr(imp, "line_number") else None, + suggested_fix=f"Remove unused import of '{unused}'" + )) + + return errors + + def analyze_undefined_variables(self, function: Function) -> List[CodeError]: + """ + Find undefined variables in a function. + + Args: + function: The function to analyze + + Returns: + A list of undefined variable errors + """ + errors = [] + + if not hasattr(function, "code_block") or not hasattr(function.code_block, "source"): + return errors + + # Get parameter names + param_names = set() + if hasattr(function, "parameters"): + for param in function.parameters: + param_names.add(param.name) + + # Parse the function body to find variable definitions and usages + try: + tree = ast.parse(function.code_block.source) + + # Find all variable assignments + assigned_vars = set() + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name): + assigned_vars.add(target.id) + elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): + assigned_vars.add(node.target.id) + + # Find all variable usages + for node in ast.walk(tree): + if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load): + var_name = node.id + if var_name not in assigned_vars and var_name not in param_names and not var_name.startswith("__"): + # This is a potential undefined variable + # Find the line number in the source code + line_number = node.lineno + + errors.append(CodeError( + error_type=ErrorType.UNDEFINED_VARIABLE, + message=f"Potentially undefined variable: '{var_name}'", + file_path=function.file.name if hasattr(function, "file") else None, + line_number=line_number, + severity=ErrorSeverity.HIGH, + symbol_name=function.name, + context_lines=self.get_context_lines(function.file.name, line_number) if hasattr(function, "file") else None, + suggested_fix=f"Define '{var_name}' before use or check for typos" + )) + except SyntaxError: + # If there's a syntax error, we can't analyze the function body + pass + + return errors + + def analyze_function(self, function: Function) -> List[CodeError]: + """ + Analyze a function for errors. + + Args: + function: The function to analyze + + Returns: + A list of errors found in the function + """ + errors = [] + + # Analyze parameters + errors.extend(self.analyze_function_parameters(function)) + + # Analyze return statements + errors.extend(self.analyze_function_returns(function)) + + # Analyze undefined variables + errors.extend(self.analyze_undefined_variables(function)) + + return errors + + def analyze_file(self, file: SourceFile) -> List[CodeError]: + """ + Analyze a file for errors. + + Args: + file: The file to analyze + + Returns: + A list of errors found in the file + """ + errors = [] + + # Analyze unused imports + errors.extend(self.analyze_unused_imports(file)) + + # Analyze syntax errors + if hasattr(file, "source"): + try: + ast.parse(file.source) + except SyntaxError as e: + errors.append(CodeError( + error_type=ErrorType.SYNTAX_ERROR, + message=f"Syntax error: {str(e)}", + file_path=file.name, + line_number=e.lineno, + column=e.offset, + severity=ErrorSeverity.CRITICAL, + context_lines=self.get_context_lines(file.name, e.lineno), + suggested_fix="Fix the syntax error" + )) + + # Analyze functions in the file + for func in file.functions: + errors.extend(self.analyze_function(func)) + + return errors + + def analyze_codebase(self) -> Dict[str, List[Dict[str, Any]]]: + """ + Analyze the entire codebase for errors. + + Returns: + A dictionary mapping file paths to lists of errors + """ + results = {} + + # Analyze each file + for file in self.codebase.files: + file_errors = self.analyze_file(file) + if file_errors: + results[file.name] = [error.to_dict() for error in file_errors] + + # Find circular imports + circular_imports = self.find_circular_imports() + for cycle in circular_imports: + for file_name in cycle: + if file_name not in results: + results[file_name] = [] + + results[file_name].append(CodeError( + error_type=ErrorType.CIRCULAR_IMPORT, + message=f"Circular import detected: {' -> '.join(cycle)}", + file_path=file_name, + severity=ErrorSeverity.HIGH, + suggested_fix="Refactor imports to break the circular dependency" + ).to_dict()) + + # Find circular dependencies + circular_deps = self.find_circular_dependencies() + for cycle in circular_deps: + for symbol_name in cycle: + # Find the file containing this symbol + symbol_file = None + for symbol in self.codebase.symbols: + if symbol.name == symbol_name and hasattr(symbol, "file"): + symbol_file = symbol.file.name + break + + if not symbol_file: + continue + + if symbol_file not in results: + results[symbol_file] = [] + + results[symbol_file].append(CodeError( + error_type=ErrorType.CIRCULAR_DEPENDENCY, + message=f"Circular dependency detected: {' -> '.join(cycle)}", + file_path=symbol_file, + symbol_name=symbol_name, + severity=ErrorSeverity.MEDIUM, + suggested_fix="Refactor code to break the circular dependency" + ).to_dict()) + + return results + + def get_error_context(self, error: CodeError) -> Dict[str, Any]: + """ + Get detailed context information for an error. + + Args: + error: The error to get context for + + Returns: + A dictionary with detailed context information + """ + context = error.to_dict() + + # Add additional context based on error type + if error.error_type == ErrorType.PARAMETER_ERROR and error.symbol_name: + # Get information about the function + func = None + for function in self.codebase.functions: + if function.name == error.symbol_name: + func = function + break + + if func and hasattr(func, "parameters"): + context["function_info"] = { + "name": func.name, + "parameters": [{"name": p.name, "type": p.type_annotation if hasattr(p, "type_annotation") else None} for p in func.parameters], + "return_type": func.return_type if hasattr(func, "return_type") else None + } + + elif error.error_type == ErrorType.CIRCULAR_IMPORT: + # Add information about the import cycle + import_graph = self.build_import_graph() + if error.file_path in import_graph: + context["import_info"] = { + "imports": [n for n in import_graph.successors(error.file_path)], + "imported_by": [n for n in import_graph.predecessors(error.file_path)] + } + + elif error.error_type == ErrorType.UNDEFINED_VARIABLE and error.symbol_name: + # Get information about the function + func = None + for function in self.codebase.functions: + if function.name == error.symbol_name: + func = function + break + + if func and hasattr(func, "parameters"): + context["function_info"] = { + "name": func.name, + "parameters": [p.name for p in func.parameters], + "local_variables": self._extract_local_variables(func) + } + + return context + + def _extract_local_variables(self, function: Function) -> List[str]: + """ + Extract local variables defined in a function. + + Args: + function: The function to analyze + + Returns: + A list of local variable names + """ + if not hasattr(function, "code_block") or not hasattr(function.code_block, "source"): + return [] + + local_vars = [] + try: + tree = ast.parse(function.code_block.source) + + # Find all variable assignments + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name): + local_vars.append(target.id) + elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): + local_vars.append(node.target.id) + except SyntaxError: + pass + + return local_vars + + def _find_line_number(self, source: str, pattern: str) -> Optional[int]: + """ + Find the line number where a pattern appears in source code. + + Args: + source: The source code to search + pattern: The regex pattern to search for + + Returns: + The line number (1-based) or None if not found + """ + lines = source.splitlines() + for i, line in enumerate(lines): + if re.search(pattern, line): + return i + 1 + return None + + def get_function_error_context(self, function_name: str) -> Dict[str, Any]: + """ + Get detailed error context for a specific function. + + Args: + function_name: The name of the function to analyze + + Returns: + A dictionary with detailed error context + """ + # Find the function + function = None + for func in self.codebase.functions: + if func.name == function_name: + function = func + break + + if not function: + return {"error": f"Function not found: {function_name}"} + + # Analyze the function + errors = self.analyze_function(function) + + # Get call graph information + call_graph = self.build_call_graph() + callers = [] + callees = [] + + if function_name in call_graph: + callers = [{"name": caller} for caller in call_graph.predecessors(function_name)] + callees = [{"name": callee} for callee in call_graph.successors(function_name)] + + # Get parameter information + parameters = [] + if hasattr(function, "parameters"): + for param in function.parameters: + param_info = { + "name": param.name, + "type": param.type_annotation if hasattr(param, "type_annotation") else None, + "default": param.default_value if hasattr(param, "default_value") else None + } + parameters.append(param_info) + + # Get return information + return_info = { + "type": function.return_type if hasattr(function, "return_type") else None, + "statements": [] + } + + if hasattr(function, "return_statements"): + for ret_stmt in function.return_statements: + if hasattr(ret_stmt, "expression") and hasattr(ret_stmt.expression, "source"): + return_info["statements"].append(ret_stmt.expression.source) + + # Combine all information + result = { + "function_name": function_name, + "file_path": function.file.name if hasattr(function, "file") else None, + "errors": [error.to_dict() for error in errors], + "callers": callers, + "callees": callees, + "parameters": parameters, + "return_info": return_info, + "source": function.source if hasattr(function, "source") else None + } + + return result + + def get_file_error_context(self, file_path: str) -> Dict[str, Any]: + """ + Get detailed error context for a specific file. + + Args: + file_path: The path of the file to analyze + + Returns: + A dictionary with detailed error context + """ + # Find the file + file = self.codebase.get_file(file_path) + if not file: + return {"error": f"File not found: {file_path}"} + + # Analyze the file + errors = self.analyze_file(file) + + # Get import graph information + import_graph = self.build_import_graph() + importers = [] + imported = [] + + if file.name in import_graph: + importers = [{"name": importer} for importer in import_graph.predecessors(file.name)] + imported = [{"name": imp} for imp in import_graph.successors(file.name)] + + # Get function information + functions = [] + for func in file.functions: + func_errors = [error for error in errors if error.symbol_name == func.name] + functions.append({ + "name": func.name, + "line_number": func.line_number if hasattr(func, "line_number") else None, + "errors": [error.to_dict() for error in func_errors] + }) + + # Get class information + classes = [] + for cls in file.classes: + classes.append({ + "name": cls.name, + "line_number": cls.line_number if hasattr(cls, "line_number") else None + }) + + # Combine all information + result = { + "file_path": file_path, + "errors": [error.to_dict() for error in errors], + "importers": importers, + "imported": imported, + "functions": functions, + "classes": classes + } + + return result + diff --git a/codegen-on-oss/codegen_on_oss/analysis/error_example.py b/codegen-on-oss/codegen_on_oss/analysis/error_example.py new file mode 100644 index 000000000..240b264ce --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analysis/error_example.py @@ -0,0 +1,157 @@ +""" +Example script demonstrating the use of the error context analysis functionality. + +This script shows how to use the CodeAnalyzer class to detect and analyze errors +in a codebase, providing detailed contextual information about the errors. +""" + +from codegen import Codebase +from codegen_on_oss.analysis.analysis import CodeAnalyzer +from codegen_on_oss.analysis.error_context import ErrorType, ErrorSeverity + + +def main(): + """ + Main function demonstrating the use of the error context analysis functionality. + """ + print("Analyzing a sample repository for errors...") + + # Load a codebase + repo_name = "fastapi/fastapi" + codebase = Codebase.from_repo(repo_name) + + print(f"Loaded codebase: {repo_name}") + print(f"Files: {len(codebase.files)}") + print(f"Functions: {len(codebase.functions)}") + print(f"Classes: {len(codebase.classes)}") + + # Create analyzer instance + analyzer = CodeAnalyzer(codebase) + + # Analyze errors in the entire codebase + print("\n=== Codebase Error Analysis ===") + error_analysis = analyzer.analyze_errors() + + # Count errors by type + error_counts = {} + total_errors = 0 + + for file_path, errors in error_analysis.items(): + for error in errors: + error_type = error["error_type"] + error_counts[error_type] = error_counts.get(error_type, 0) + 1 + total_errors += 1 + + print(f"Found {total_errors} errors across {len(error_analysis)} files") + + if error_counts: + print("\nError types:") + for error_type, count in error_counts.items(): + print(f"- {error_type}: {count}") + + # Find files with the most errors + files_with_errors = [(file_path, len(errors)) for file_path, errors in error_analysis.items()] + files_with_errors.sort(key=lambda x: x[1], reverse=True) + + if files_with_errors: + print("\nTop files with errors:") + for file_path, count in files_with_errors[:5]: # Show top 5 + print(f"- {file_path}: {count} errors") + + # Analyze a specific file + if files_with_errors: + file_to_analyze = files_with_errors[0][0] + print(f"\n=== Detailed Error Analysis for {file_to_analyze} ===") + file_error_context = analyzer.get_file_error_context(file_to_analyze) + + print(f"File: {file_error_context['file_path']}") + print(f"Errors: {len(file_error_context['errors'])}") + + if file_error_context['errors']: + print("\nDetailed errors:") + for i, error in enumerate(file_error_context['errors'][:3], 1): # Show top 3 + print(f"\nError {i}:") + print(f"- Type: {error['error_type']}") + print(f"- Message: {error['message']}") + print(f"- Severity: {error['severity']}") + if error['line_number']: + print(f"- Line: {error['line_number']}") + if error['suggested_fix']: + print(f"- Suggested fix: {error['suggested_fix']}") + + if error['context_lines']: + print("- Context:") + for line_num, line in error['context_lines'].items(): + prefix = ">" if line_num == error['line_number'] else " " + print(f" {prefix} {line_num}: {line}") + + # Show functions in the file + if file_error_context['functions']: + print("\nFunctions in this file:") + for func in file_error_context['functions']: + error_count = len(func['errors']) + error_suffix = f" ({error_count} errors)" if error_count > 0 else "" + print(f"- {func['name']}{error_suffix}") + + # Analyze a specific function with errors + function_to_analyze = None + for file_path, errors in error_analysis.items(): + for error in errors: + if error['symbol_name']: + function_to_analyze = error['symbol_name'] + break + if function_to_analyze: + break + + if function_to_analyze: + print(f"\n=== Detailed Error Analysis for function {function_to_analyze} ===") + function_error_context = analyzer.get_function_error_context(function_to_analyze) + + print(f"Function: {function_error_context['function_name']}") + print(f"File: {function_error_context['file_path']}") + print(f"Errors: {len(function_error_context['errors'])}") + + if function_error_context['parameters']: + print("\nParameters:") + for param in function_error_context['parameters']: + default = f" = {param['default']}" if param['default'] is not None else "" + type_annotation = f": {param['type']}" if param['type'] else "" + print(f"- {param['name']}{type_annotation}{default}") + + if function_error_context['return_info']['type']: + print(f"\nReturn type: {function_error_context['return_info']['type']}") + + if function_error_context['callers']: + print("\nCalled by:") + for caller in function_error_context['callers']: + print(f"- {caller['name']}") + + if function_error_context['callees']: + print("\nCalls:") + for callee in function_error_context['callees']: + print(f"- {callee['name']}") + + if function_error_context['errors']: + print("\nDetailed errors:") + for i, error in enumerate(function_error_context['errors'], 1): + print(f"\nError {i}:") + print(f"- Type: {error['error_type']}") + print(f"- Message: {error['message']}") + print(f"- Severity: {error['severity']}") + if error['line_number']: + print(f"- Line: {error['line_number']}") + if error['suggested_fix']: + print(f"- Suggested fix: {error['suggested_fix']}") + + if error['context_lines']: + print("- Context:") + for line_num, line in error['context_lines'].items(): + prefix = ">" if line_num == error['line_number'] else " " + print(f" {prefix} {line_num}: {line}") + + print("\nAnalysis complete!") + + +if __name__ == "__main__": + main() + diff --git a/codegen-on-oss/codegen_on_oss/analysis/example.py b/codegen-on-oss/codegen_on_oss/analysis/example.py new file mode 100644 index 000000000..34dd1710a --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analysis/example.py @@ -0,0 +1,103 @@ +""" +Example script demonstrating the use of the unified analysis module. + +This script shows how to use the CodeAnalyzer and CodeMetrics classes +to perform comprehensive code analysis on a repository. +""" + +from codegen import Codebase +from codegen_on_oss.analysis.analysis import CodeAnalyzer +from codegen_on_oss.metrics import CodeMetrics + + +def main(): + """ + Main function demonstrating the use of the analysis module. + """ + print("Analyzing a sample repository...") + + # Load a codebase + repo_name = "fastapi/fastapi" + codebase = Codebase.from_repo(repo_name) + + print(f"Loaded codebase: {repo_name}") + print(f"Files: {len(codebase.files)}") + print(f"Functions: {len(codebase.functions)}") + print(f"Classes: {len(codebase.classes)}") + + # Create analyzer instance + analyzer = CodeAnalyzer(codebase) + + # Get codebase summary + print("\n=== Codebase Summary ===") + print(analyzer.get_codebase_summary()) + + # Analyze complexity + print("\n=== Complexity Analysis ===") + complexity_results = analyzer.analyze_complexity() + print(f"Average cyclomatic complexity: {complexity_results['cyclomatic_complexity']['average']:.2f}") + print(f"Complexity rank: {complexity_results['cyclomatic_complexity']['rank']}") + + # Find complex functions + complex_functions = [ + f for f in complexity_results['cyclomatic_complexity']['functions'] + if f['complexity'] > 10 + ][:5] # Show top 5 + + if complex_functions: + print("\nTop complex functions:") + for func in complex_functions: + print(f"- {func['name']}: Complexity {func['complexity']} (Rank {func['rank']})") + + # Analyze imports + print("\n=== Import Analysis ===") + import_analysis = analyzer.analyze_imports() + print(f"Found {len(import_analysis['import_cycles'])} import cycles") + + # Create metrics instance + metrics = CodeMetrics(codebase) + + # Get code quality summary + print("\n=== Code Quality Summary ===") + quality_summary = metrics.get_code_quality_summary() + + print("Overall metrics:") + for metric, value in quality_summary["overall_metrics"].items(): + if isinstance(value, float): + print(f"- {metric}: {value:.2f}") + else: + print(f"- {metric}: {value}") + + print("\nProblem areas:") + for area, count in quality_summary["problem_areas"].items(): + print(f"- {area}: {count}") + + # Find bug-prone functions + print("\n=== Bug-Prone Functions ===") + bug_prone = metrics.find_bug_prone_functions()[:5] # Show top 5 + + if bug_prone: + print("Top bug-prone functions:") + for func in bug_prone: + print(f"- {func['name']}: Estimated bugs {func['bugs_delivered']:.2f}") + + # Analyze dependencies + print("\n=== Dependency Analysis ===") + dependencies = metrics.analyze_dependencies() + + print(f"Dependency graph: {dependencies['dependency_graph']['nodes']} nodes, " + f"{dependencies['dependency_graph']['edges']} edges") + print(f"Dependency density: {dependencies['dependency_graph']['density']:.4f}") + print(f"Number of cycles: {dependencies['cycles']}") + + if dependencies['most_central_files']: + print("\nMost central files:") + for file, score in dependencies['most_central_files'][:5]: # Show top 5 + print(f"- {file}: Centrality {score:.4f}") + + print("\nAnalysis complete!") + + +if __name__ == "__main__": + main() + diff --git a/codegen-on-oss/codegen_on_oss/analysis/examples/__init__.py b/codegen-on-oss/codegen_on_oss/analysis/examples/__init__.py new file mode 100644 index 000000000..32ca6bc47 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analysis/examples/__init__.py @@ -0,0 +1,4 @@ +""" +Examples for using the code analysis module. +""" + diff --git a/codegen-on-oss/codegen_on_oss/analysis/examples/analyze_errors.py b/codegen-on-oss/codegen_on_oss/analysis/examples/analyze_errors.py new file mode 100644 index 000000000..b3c1e0149 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analysis/examples/analyze_errors.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +""" +Example script demonstrating how to use the error context analysis functionality. + +This script analyzes a repository for errors and prints detailed error context information. +""" + +import argparse +import json +import sys +from typing import Dict, Any + +from codegen import Codebase +from codegen_on_oss.analysis.analysis import CodeAnalyzer + + +def print_error(error: Dict[str, Any]) -> None: + """Print a formatted error message.""" + print(f"ERROR: {error['error_type']} ({error['severity']})") + print(f" Message: {error['message']}") + + if error.get('file_path'): + print(f" File: {error['file_path']}") + + if error.get('line_number'): + print(f" Line: {error['line_number']}") + + if error.get('context_lines'): + print(" Context:") + for line_num, line in error['context_lines'].items(): + prefix = ">" if str(line_num) == str(error.get('line_number')) else " " + print(f" {prefix} {line_num}: {line}") + + if error.get('suggested_fix'): + print(f" Suggested Fix: {error['suggested_fix']}") + + print() + + +def analyze_repo(repo_url: str) -> None: + """Analyze a repository for errors.""" + print(f"Analyzing repository: {repo_url}") + + try: + # Create a codebase from the repository + codebase = Codebase.from_repo(repo_url) + + # Create an analyzer + analyzer = CodeAnalyzer(codebase) + + # Analyze errors in the codebase + errors = analyzer.analyze_errors() + + # Print summary + total_errors = sum(len(file_errors) for file_errors in errors.values()) + print(f"\nFound {total_errors} errors in {len(errors)} files\n") + + # Print errors by file + for file_path, file_errors in errors.items(): + print(f"File: {file_path}") + print(f" {len(file_errors)} errors found") + + # Print the first 3 errors for each file + for i, error in enumerate(file_errors[:3]): + print(f" Error {i+1}:") + print_error(error) + + if len(file_errors) > 3: + print(f" ... and {len(file_errors) - 3} more errors\n") + + print() + + except Exception as e: + print(f"Error analyzing repository: {e}", file=sys.stderr) + sys.exit(1) + + +def analyze_file(repo_url: str, file_path: str) -> None: + """Analyze a specific file for errors.""" + print(f"Analyzing file: {file_path} in repository: {repo_url}") + + try: + # Create a codebase from the repository + codebase = Codebase.from_repo(repo_url) + + # Create an analyzer + analyzer = CodeAnalyzer(codebase) + + # Get file error context + file_error_context = analyzer.get_file_error_context(file_path) + + # Print errors + if 'errors' in file_error_context: + errors = file_error_context['errors'] + print(f"\nFound {len(errors)} errors\n") + + for i, error in enumerate(errors): + print(f"Error {i+1}:") + print_error(error) + else: + print("\nNo errors found or file not found") + + except Exception as e: + print(f"Error analyzing file: {e}", file=sys.stderr) + sys.exit(1) + + +def analyze_function(repo_url: str, function_name: str) -> None: + """Analyze a specific function for errors.""" + print(f"Analyzing function: {function_name} in repository: {repo_url}") + + try: + # Create a codebase from the repository + codebase = Codebase.from_repo(repo_url) + + # Create an analyzer + analyzer = CodeAnalyzer(codebase) + + # Get function error context + function_error_context = analyzer.get_function_error_context(function_name) + + # Print function information + if 'function_name' in function_error_context: + print(f"\nFunction: {function_error_context['function_name']}") + + if 'file_path' in function_error_context: + print(f"File: {function_error_context['file_path']}") + + # Print parameters + if 'parameters' in function_error_context: + params = function_error_context['parameters'] + print(f"\nParameters ({len(params)}):") + for param in params: + param_type = f": {param['type']}" if param.get('type') else "" + default = f" = {param['default']}" if param.get('default') else "" + print(f" {param['name']}{param_type}{default}") + + # Print return information + if 'return_info' in function_error_context: + return_info = function_error_context['return_info'] + print(f"\nReturn Type: {return_info.get('type', 'Unknown')}") + if return_info.get('statements'): + print(f"Return Statements ({len(return_info['statements'])}):") + for stmt in return_info['statements']: + print(f" return {stmt}") + + # Print callers and callees + if 'callers' in function_error_context: + callers = function_error_context['callers'] + print(f"\nCallers ({len(callers)}):") + for caller in callers: + print(f" {caller['name']}") + + if 'callees' in function_error_context: + callees = function_error_context['callees'] + print(f"\nCallees ({len(callees)}):") + for callee in callees: + print(f" {callee['name']}") + + # Print errors + if 'errors' in function_error_context: + errors = function_error_context['errors'] + print(f"\nErrors ({len(errors)}):") + for i, error in enumerate(errors): + print(f"Error {i+1}:") + print_error(error) + else: + print("\nNo errors found") + else: + print("\nFunction not found") + + except Exception as e: + print(f"Error analyzing function: {e}", file=sys.stderr) + sys.exit(1) + + +def main() -> None: + """Main entry point.""" + parser = argparse.ArgumentParser(description="Analyze code for errors") + parser.add_argument("repo_url", help="Repository URL (owner/repo)") + + subparsers = parser.add_subparsers(dest="command", help="Analysis command") + + # Repository analysis command + repo_parser = subparsers.add_parser("repo", help="Analyze entire repository") + + # File analysis command + file_parser = subparsers.add_parser("file", help="Analyze a specific file") + file_parser.add_argument("file_path", help="Path to the file to analyze") + + # Function analysis command + function_parser = subparsers.add_parser("function", help="Analyze a specific function") + function_parser.add_argument("function_name", help="Name of the function to analyze") + + args = parser.parse_args() + + if args.command == "file": + analyze_file(args.repo_url, args.file_path) + elif args.command == "function": + analyze_function(args.repo_url, args.function_name) + else: + analyze_repo(args.repo_url) + + +if __name__ == "__main__": + main() + diff --git a/codegen-on-oss/codegen_on_oss/analysis/mdx_docs_generation.py b/codegen-on-oss/codegen_on_oss/analysis/mdx_docs_generation.py index 648a3b68e..9e4543bea 100644 --- a/codegen-on-oss/codegen_on_oss/analysis/mdx_docs_generation.py +++ b/codegen-on-oss/codegen_on_oss/analysis/mdx_docs_generation.py @@ -110,10 +110,10 @@ def format_parameters_for_mdx(parameters: list[ParameterDoc]) -> str: def format_return_for_mdx(return_type: list[str], return_description: str) -> str: description = sanitize_html_for_mdx(return_description) if return_description else "" - return_type = resolve_type_string(return_type[0]) + return_type_str = resolve_type_string(return_type[0]) return f""" - + """ @@ -154,8 +154,8 @@ def get_mdx_route_for_class(cls_doc: ClassDoc) -> str: def format_type_string(type_string: str) -> str: - type_string = type_string.split("|") - return " | ".join([type_str.strip() for type_str in type_string]) + type_strings = type_string.split("|") + return " | ".join([type_str.strip() for type_str in type_strings]) def resolve_type_string(type_string: str) -> str: diff --git a/codegen-on-oss/codegen_on_oss/analysis/tests/__init__.py b/codegen-on-oss/codegen_on_oss/analysis/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/codegen-on-oss/codegen_on_oss/analysis/tests/test_error_context.py b/codegen-on-oss/codegen_on_oss/analysis/tests/test_error_context.py new file mode 100644 index 000000000..33a6e3b01 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analysis/tests/test_error_context.py @@ -0,0 +1,245 @@ +""" +Tests for the error context analysis functionality. + +This module contains unit tests for the ErrorContextAnalyzer and related classes. +""" + +import ast +import unittest +from unittest.mock import MagicMock, patch + +from codegen.sdk.core.codebase import Codebase +from codegen.sdk.core.file import SourceFile +from codegen.sdk.core.function import Function +from codegen.sdk.core.symbol import Symbol +from codegen_on_oss.analysis.error_context import ( + CodeError, + ErrorContextAnalyzer, + ErrorSeverity, + ErrorType +) + + +class TestCodeError(unittest.TestCase): + """Tests for the CodeError class.""" + + def test_code_error_initialization(self): + """Test that a CodeError can be initialized with all parameters.""" + error = CodeError( + error_type=ErrorType.SYNTAX_ERROR, + message="Invalid syntax", + file_path="test.py", + line_number=10, + column=5, + severity=ErrorSeverity.CRITICAL, + symbol_name="test_function", + context_lines={9: "def test_function():", 10: " print('Hello world'"}, + suggested_fix="Fix the syntax error" + ) + + self.assertEqual(error.error_type, ErrorType.SYNTAX_ERROR) + self.assertEqual(error.message, "Invalid syntax") + self.assertEqual(error.file_path, "test.py") + self.assertEqual(error.line_number, 10) + self.assertEqual(error.column, 5) + self.assertEqual(error.severity, ErrorSeverity.CRITICAL) + self.assertEqual(error.symbol_name, "test_function") + self.assertEqual(error.context_lines, {9: "def test_function():", 10: " print('Hello world'"}) + self.assertEqual(error.suggested_fix, "Fix the syntax error") + + def test_code_error_to_dict(self): + """Test that a CodeError can be converted to a dictionary.""" + error = CodeError( + error_type=ErrorType.SYNTAX_ERROR, + message="Invalid syntax", + file_path="test.py", + line_number=10, + severity=ErrorSeverity.CRITICAL + ) + + error_dict = error.to_dict() + + self.assertEqual(error_dict["error_type"], ErrorType.SYNTAX_ERROR) + self.assertEqual(error_dict["message"], "Invalid syntax") + self.assertEqual(error_dict["file_path"], "test.py") + self.assertEqual(error_dict["line_number"], 10) + self.assertEqual(error_dict["severity"], ErrorSeverity.CRITICAL) + + def test_code_error_str(self): + """Test the string representation of a CodeError.""" + error = CodeError( + error_type=ErrorType.SYNTAX_ERROR, + message="Invalid syntax", + file_path="test.py", + line_number=10, + severity=ErrorSeverity.CRITICAL + ) + + error_str = str(error) + + self.assertIn(ErrorType.SYNTAX_ERROR.upper(), error_str) + self.assertIn("Invalid syntax", error_str) + self.assertIn("test.py:10", error_str) + self.assertIn(ErrorSeverity.CRITICAL, error_str) + + +class TestErrorContextAnalyzer(unittest.TestCase): + """Tests for the ErrorContextAnalyzer class.""" + + def setUp(self): + """Set up test fixtures.""" + # Create a mock codebase + self.codebase = MagicMock(spec=Codebase) + + # Create a mock file + self.file = MagicMock(spec=SourceFile) + self.file.name = "test.py" + self.file.source = "def test_function():\n x = 'hello' + 5\n return x" + + # Create a mock function + self.function = MagicMock(spec=Function) + self.function.name = "test_function" + self.function.file = self.file + self.function.line_number = 1 + self.function.code_block = MagicMock() + self.function.code_block.source = "def test_function():\n x = 'hello' + 5\n return x" + + # Set up the codebase with the file and function + self.codebase.files = [self.file] + self.codebase.functions = [self.function] + self.codebase.get_file.return_value = self.file + + # Create the analyzer + self.analyzer = ErrorContextAnalyzer(self.codebase) + + def test_get_context_lines(self): + """Test getting context lines around a specific line.""" + context_lines = self.analyzer.get_context_lines("test.py", 2, context_size=1) + + self.assertEqual(context_lines, { + 1: "def test_function():", + 2: " x = 'hello' + 5", + 3: " return x" + }) + + def test_analyze_function(self): + """Test analyzing a function for errors.""" + errors = self.analyzer.analyze_function(self.function) + + # We should find at least one error (type error) + self.assertGreaterEqual(len(errors), 1) + + # Check that we found a type error + type_errors = [e for e in errors if e.error_type == ErrorType.TYPE_ERROR] + self.assertGreaterEqual(len(type_errors), 1) + + # Check the error details + error = type_errors[0] + self.assertEqual(error.file_path, "test.py") + self.assertEqual(error.symbol_name, "test_function") + self.assertEqual(error.severity, ErrorSeverity.HIGH) + self.assertIn("'hello' + 5", str(error.context_lines)) + + def test_analyze_file(self): + """Test analyzing a file for errors.""" + errors = self.analyzer.analyze_file(self.file) + + # We should find at least one error (type error) + self.assertGreaterEqual(len(errors), 1) + + # Check that we found a type error + type_errors = [e for e in errors if e.error_type == ErrorType.TYPE_ERROR] + self.assertGreaterEqual(len(type_errors), 1) + + def test_analyze_codebase(self): + """Test analyzing the entire codebase for errors.""" + error_dict = self.analyzer.analyze_codebase() + + # We should have errors for our test file + self.assertIn("test.py", error_dict) + self.assertGreaterEqual(len(error_dict["test.py"]), 1) + + def test_find_circular_imports(self): + """Test finding circular imports.""" + # Mock the build_import_graph method to return a graph with a cycle + import networkx as nx + G = nx.DiGraph() + G.add_edge("a.py", "b.py") + G.add_edge("b.py", "c.py") + G.add_edge("c.py", "a.py") + + with patch.object(self.analyzer, 'build_import_graph', return_value=G): + cycles = self.analyzer.find_circular_imports() + + # We should find one cycle + self.assertEqual(len(cycles), 1) + + # The cycle should contain a.py, b.py, and c.py + cycle = cycles[0] + self.assertIn("a.py", cycle) + self.assertIn("b.py", cycle) + self.assertIn("c.py", cycle) + + def test_get_function_error_context(self): + """Test getting detailed error context for a function.""" + # Mock the analyze_function method to return a specific error + error = CodeError( + error_type=ErrorType.TYPE_ERROR, + message="Cannot add string and integer", + file_path="test.py", + line_number=2, + severity=ErrorSeverity.HIGH, + symbol_name="test_function", + context_lines={1: "def test_function():", 2: " x = 'hello' + 5", 3: " return x"}, + suggested_fix="Convert the integer to a string: 'hello' + str(5)" + ) + + with patch.object(self.analyzer, 'analyze_function', return_value=[error]): + context = self.analyzer.get_function_error_context("test_function") + + # Check the context + self.assertEqual(context["function_name"], "test_function") + self.assertEqual(context["file_path"], "test.py") + self.assertEqual(len(context["errors"]), 1) + + # Check the error details + error_dict = context["errors"][0] + self.assertEqual(error_dict["error_type"], ErrorType.TYPE_ERROR) + self.assertEqual(error_dict["message"], "Cannot add string and integer") + self.assertEqual(error_dict["line_number"], 2) + self.assertEqual(error_dict["severity"], ErrorSeverity.HIGH) + self.assertEqual(error_dict["suggested_fix"], "Convert the integer to a string: 'hello' + str(5)") + + def test_get_file_error_context(self): + """Test getting detailed error context for a file.""" + # Mock the analyze_file method to return a specific error + error = CodeError( + error_type=ErrorType.TYPE_ERROR, + message="Cannot add string and integer", + file_path="test.py", + line_number=2, + severity=ErrorSeverity.HIGH, + symbol_name="test_function", + context_lines={1: "def test_function():", 2: " x = 'hello' + 5", 3: " return x"}, + suggested_fix="Convert the integer to a string: 'hello' + str(5)" + ) + + with patch.object(self.analyzer, 'analyze_file', return_value=[error]): + context = self.analyzer.get_file_error_context("test.py") + + # Check the context + self.assertEqual(context["file_path"], "test.py") + self.assertEqual(len(context["errors"]), 1) + + # Check the error details + error_dict = context["errors"][0] + self.assertEqual(error_dict["error_type"], ErrorType.TYPE_ERROR) + self.assertEqual(error_dict["message"], "Cannot add string and integer") + self.assertEqual(error_dict["line_number"], 2) + self.assertEqual(error_dict["severity"], ErrorSeverity.HIGH) + self.assertEqual(error_dict["suggested_fix"], "Convert the integer to a string: 'hello' + str(5)") + + +if __name__ == '__main__': + unittest.main() + diff --git a/codegen-on-oss/codegen_on_oss/metrics.py b/codegen-on-oss/codegen_on_oss/metrics.py index d77b4e686..d81d5b20b 100644 --- a/codegen-on-oss/codegen_on_oss/metrics.py +++ b/codegen-on-oss/codegen_on_oss/metrics.py @@ -1,15 +1,36 @@ +""" +Metrics module for Codegen-on-OSS + +This module provides tools for measuring and recording performance metrics +and code quality metrics for codebases. +""" + import json import os import time +import math from collections.abc import Generator from contextlib import contextmanager from importlib.metadata import version -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import psutil +import networkx as nx +from codegen import Codebase from codegen_on_oss.errors import ParseRunError from codegen_on_oss.outputs.base import BaseOutput +from codegen_on_oss.analysis.analysis import ( + CodeAnalyzer, + calculate_cyclomatic_complexity, + calculate_halstead_volume, + calculate_maintainability_index, + count_lines, + get_operators_and_operands, + cc_rank, + get_maintainability_rank, + calculate_doi +) if TYPE_CHECKING: # Logger only available in type checking context. @@ -19,6 +40,478 @@ codegen_version = str(version("codegen")) +class CodeMetrics: + """ + A class to calculate and provide code quality metrics for a codebase. + Integrates with the analysis module for comprehensive code analysis. + """ + + # Constants for threshold values + COMPLEXITY_THRESHOLD = 10 + MAINTAINABILITY_THRESHOLD = 65 + INHERITANCE_DEPTH_THRESHOLD = 3 + VOLUME_THRESHOLD = 1000 + EFFORT_THRESHOLD = 50000 + BUG_THRESHOLD = 0.5 + + def __init__(self, codebase: Codebase): + """ + Initialize the CodeMetrics class with a codebase. + + Args: + codebase: The Codebase object to analyze + """ + self.codebase = codebase + self.analyzer = CodeAnalyzer(codebase) + self._complexity_metrics = None + self._line_metrics = None + self._maintainability_metrics = None + self._inheritance_metrics = None + self._halstead_metrics = None + + def calculate_all_metrics(self) -> Dict[str, Any]: + """ + Calculate all available metrics for the codebase. + + Returns: + A dictionary containing all metrics categories + """ + return { + "complexity": self.complexity_metrics, + "lines": self.line_metrics, + "maintainability": self.maintainability_metrics, + "inheritance": self.inheritance_metrics, + "halstead": self.halstead_metrics, + } + + @property + def complexity_metrics(self) -> Dict[str, Any]: + """ + Calculate cyclomatic complexity metrics for the codebase. + + Returns: + A dictionary containing complexity metrics including average, + rank, and per-function complexity scores + """ + if self._complexity_metrics is not None: + return self._complexity_metrics + + callables = self.codebase.functions + [ + m for c in self.codebase.classes for m in c.methods + ] + + complexities = [] + for func in callables: + if not hasattr(func, "code_block"): + continue + + complexity = calculate_cyclomatic_complexity(func) + complexities.append({ + "name": func.name, + "complexity": complexity, + "rank": cc_rank(complexity) + }) + + avg_complexity = ( + sum(item["complexity"] for item in complexities) / len(complexities) + if complexities else 0 + ) + + self._complexity_metrics = { + "average": avg_complexity, + "rank": cc_rank(avg_complexity), + "functions": complexities + } + + return self._complexity_metrics + + @property + def line_metrics(self) -> Dict[str, Any]: + """ + Calculate line-based metrics for the codebase. + + Returns: + A dictionary containing line metrics including total counts + and per-file metrics for LOC, LLOC, SLOC, and comments + """ + if self._line_metrics is not None: + return self._line_metrics + + total_loc = total_lloc = total_sloc = total_comments = 0 + file_metrics = [] + + for file in self.codebase.files: + loc, lloc, sloc, comments = count_lines(file.source) + comment_density = (comments / loc * 100) if loc > 0 else 0 + + file_metrics.append({ + "file": file.path, + "loc": loc, + "lloc": lloc, + "sloc": sloc, + "comments": comments, + "comment_density": comment_density + }) + + total_loc += loc + total_lloc += lloc + total_sloc += sloc + total_comments += comments + + total_comment_density = ( + total_comments / total_loc * 100 if total_loc > 0 else 0 + ) + + self._line_metrics = { + "total": { + "loc": total_loc, + "lloc": total_lloc, + "sloc": total_sloc, + "comments": total_comments, + "comment_density": total_comment_density + }, + "files": file_metrics + } + + return self._line_metrics + + @property + def maintainability_metrics(self) -> Dict[str, Any]: + """ + Calculate maintainability index metrics for the codebase. + + Returns: + A dictionary containing maintainability metrics including average, + rank, and per-function maintainability scores + """ + if self._maintainability_metrics is not None: + return self._maintainability_metrics + + callables = self.codebase.functions + [ + m for c in self.codebase.classes for m in c.methods + ] + + mi_scores = [] + for func in callables: + if not hasattr(func, "code_block"): + continue + + complexity = calculate_cyclomatic_complexity(func) + operators, operands = get_operators_and_operands(func) + volume, _, _, _, _ = calculate_halstead_volume(operators, operands) + loc = len(func.code_block.source.splitlines()) + mi_score = calculate_maintainability_index(volume, complexity, loc) + + mi_scores.append({ + "name": func.name, + "mi_score": mi_score, + "rank": get_maintainability_rank(mi_score) + }) + + avg_mi = ( + sum(item["mi_score"] for item in mi_scores) / len(mi_scores) + if mi_scores else 0 + ) + + self._maintainability_metrics = { + "average": avg_mi, + "rank": get_maintainability_rank(avg_mi), + "functions": mi_scores + } + + return self._maintainability_metrics + + @property + def inheritance_metrics(self) -> Dict[str, Any]: + """ + Calculate inheritance metrics for the codebase. + + Returns: + A dictionary containing inheritance metrics including average + depth of inheritance and per-class inheritance depth + """ + if self._inheritance_metrics is not None: + return self._inheritance_metrics + + class_metrics = [] + for cls in self.codebase.classes: + doi = calculate_doi(cls) + class_metrics.append({ + "name": cls.name, + "doi": doi + }) + + avg_doi = ( + sum(item["doi"] for item in class_metrics) / len(class_metrics) + if class_metrics else 0 + ) + + self._inheritance_metrics = { + "average": avg_doi, + "classes": class_metrics + } + + return self._inheritance_metrics + + @property + def halstead_metrics(self) -> Dict[str, Any]: + """ + Calculate Halstead complexity metrics for the codebase. + + Returns: + A dictionary containing Halstead metrics including volume, + difficulty, effort, and other Halstead measures + """ + if self._halstead_metrics is not None: + return self._halstead_metrics + + callables = self.codebase.functions + [ + m for c in self.codebase.classes for m in c.methods + ] + + halstead_metrics = [] + for func in callables: + if not hasattr(func, "code_block"): + continue + + operators, operands = get_operators_and_operands(func) + volume, n1, n2, n_operators, n_operands = calculate_halstead_volume( + operators, operands + ) + + # Calculate additional Halstead metrics + n = n_operators + n_operands + N = n1 + n2 + + difficulty = ( + (n_operators / 2) * (n2 / n_operands) if n_operands > 0 else 0 + ) + effort = difficulty * volume if volume > 0 else 0 + time_required = effort / 18 if effort > 0 else 0 # Seconds + bugs_delivered = volume / 3000 if volume > 0 else 0 + + halstead_metrics.append({ + "name": func.name, + "volume": volume, + "difficulty": difficulty, + "effort": effort, + "time_required": time_required, # in seconds + "bugs_delivered": bugs_delivered + }) + + avg_volume = ( + sum(item["volume"] for item in halstead_metrics) / len(halstead_metrics) + if halstead_metrics else 0 + ) + avg_difficulty = ( + sum(item["difficulty"] for item in halstead_metrics) / len(halstead_metrics) + if halstead_metrics else 0 + ) + avg_effort = ( + sum(item["effort"] for item in halstead_metrics) / len(halstead_metrics) + if halstead_metrics else 0 + ) + + self._halstead_metrics = { + "average": { + "volume": avg_volume, + "difficulty": avg_difficulty, + "effort": avg_effort + }, + "functions": halstead_metrics + } + + return self._halstead_metrics + + def find_complex_functions(self, threshold: int = COMPLEXITY_THRESHOLD) -> List[Dict[str, Any]]: + """ + Find functions with cyclomatic complexity above the threshold. + + Args: + threshold: The complexity threshold (default: 10) + + Returns: + A list of functions with complexity above the threshold + """ + metrics = self.complexity_metrics + return [ + func for func in metrics["functions"] + if func["complexity"] > threshold + ] + + def find_low_maintainability_functions( + self, threshold: int = MAINTAINABILITY_THRESHOLD + ) -> List[Dict[str, Any]]: + """ + Find functions with maintainability index below the threshold. + + Args: + threshold: The maintainability threshold (default: 65) + + Returns: + A list of functions with maintainability below the threshold + """ + metrics = self.maintainability_metrics + return [ + func for func in metrics["functions"] + if func["mi_score"] < threshold + ] + + def find_deep_inheritance_classes( + self, threshold: int = INHERITANCE_DEPTH_THRESHOLD + ) -> List[Dict[str, Any]]: + """ + Find classes with depth of inheritance above the threshold. + + Args: + threshold: The inheritance depth threshold (default: 3) + + Returns: + A list of classes with inheritance depth above the threshold + """ + metrics = self.inheritance_metrics + return [cls for cls in metrics["classes"] if cls["doi"] > threshold] + + def find_high_volume_functions(self, threshold: int = VOLUME_THRESHOLD) -> List[Dict[str, Any]]: + """ + Find functions with Halstead volume above the threshold. + + Args: + threshold: The volume threshold (default: 1000) + + Returns: + A list of functions with volume above the threshold + """ + metrics = self.halstead_metrics + return [ + func for func in metrics["functions"] + if func["volume"] > threshold + ] + + def find_high_effort_functions(self, threshold: int = EFFORT_THRESHOLD) -> List[Dict[str, Any]]: + """ + Find functions with high Halstead effort (difficult to maintain). + + Args: + threshold: The effort threshold (default: 50000) + + Returns: + A list of functions with effort above the threshold + """ + metrics = self.halstead_metrics + return [ + func for func in metrics["functions"] + if func["effort"] > threshold + ] + + def find_bug_prone_functions(self, threshold: float = BUG_THRESHOLD) -> List[Dict[str, Any]]: + """ + Find functions with high estimated bug delivery. + + Args: + threshold: The bugs delivered threshold (default: 0.5) + + Returns: + A list of functions likely to contain bugs + """ + metrics = self.halstead_metrics + return [ + func for func in metrics["functions"] + if func["bugs_delivered"] > threshold + ] + + def get_code_quality_summary(self) -> Dict[str, Any]: + """ + Generate a comprehensive code quality summary. + + Returns: + A dictionary with overall code quality metrics and problem areas + """ + return { + "overall_metrics": { + "complexity": self.complexity_metrics["average"], + "complexity_rank": self.complexity_metrics["rank"], + "maintainability": self.maintainability_metrics["average"], + "maintainability_rank": self.maintainability_metrics["rank"], + "lines_of_code": self.line_metrics["total"]["loc"], + "comment_density": self.line_metrics["total"]["comment_density"], + "inheritance_depth": self.inheritance_metrics["average"], + "halstead_volume": self.halstead_metrics["average"]["volume"], + "halstead_difficulty": self.halstead_metrics["average"]["difficulty"], + }, + "problem_areas": { + "complex_functions": len(self.find_complex_functions()), + "low_maintainability": len(self.find_low_maintainability_functions()), + "deep_inheritance": len(self.find_deep_inheritance_classes()), + "high_volume": len(self.find_high_volume_functions()), + "high_effort": len(self.find_high_effort_functions()), + "bug_prone": len(self.find_bug_prone_functions()), + }, + "import_analysis": self.analyzer.analyze_imports() + } + + def analyze_codebase_structure(self) -> Dict[str, Any]: + """ + Analyze the structure of the codebase. + + Returns: + A dictionary with codebase structure information + """ + return { + "summary": self.analyzer.get_codebase_summary(), + "files": len(self.codebase.files), + "functions": len(self.codebase.functions), + "classes": len(self.codebase.classes), + "imports": len(self.codebase.imports), + "symbols": len(self.codebase.symbols) + } + + def generate_documentation(self) -> None: + """ + Generate documentation for the codebase. + """ + self.analyzer.document_functions() + + def analyze_dependencies(self) -> Dict[str, Any]: + """ + Analyze dependencies in the codebase. + + Returns: + A dictionary with dependency analysis results + """ + # Create a dependency graph + G = nx.DiGraph() + + # Add nodes for all files + for file in self.codebase.files: + G.add_node(file.path) + + # Add edges for imports + for imp in self.codebase.imports: + if imp.from_file and imp.to_file: + G.add_edge(imp.from_file.filepath, imp.to_file.filepath) + + # Find cycles + cycles = list(nx.simple_cycles(G)) + + # Calculate centrality metrics + centrality = nx.degree_centrality(G) + + return { + "dependency_graph": { + "nodes": len(G.nodes), + "edges": len(G.edges), + "density": nx.density(G) + }, + "cycles": len(cycles), + "most_central_files": sorted( + [(file, score) for file, score in centrality.items()], + key=lambda x: x[1], + reverse=True + )[:10] + } + + class MetricsProfiler: """ A helper to record performance metrics across multiple profiles and write them to a CSV. @@ -42,7 +535,7 @@ def __init__(self, output: BaseOutput): @contextmanager def start_profiler( self, name: str, revision: str, language: str | None, logger: "Logger" - ) -> Generator["MetricsProfile", None, None]: + ) -> Generator[Any, None, None]: """ Starts a new profiling session for a given profile name. Returns a MetricsProfile instance that you can use to mark measurements. @@ -81,9 +574,9 @@ def fields(cls) -> list[str]: class MetricsProfile: """ Context-managed profile that records measurements at each call to `measure()`. - It tracks the wall-clock duration, CPU time, and memory usage (with delta) at the time of the call. - Upon exiting the context, it also writes all collected metrics, including the total time, - to a CSV file. + It tracks the wall-clock duration, CPU time, and memory usage (with delta) + at the time of the call. Upon exiting the context, it also writes all collected + metrics, including the total time, to a CSV file. """ if TYPE_CHECKING: @@ -131,7 +624,9 @@ def measure(self, action_name: str): """ current_time = time.perf_counter() current_cpu = float(time.process_time()) - current_mem = int(psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)) + current_mem = int( + psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024) + ) # Calculate time deltas. delta_time = current_time - self.last_measure_time @@ -168,7 +663,9 @@ def finish(self, error: str | None = None): """ finish_time = time.perf_counter() finish_cpu = float(time.process_time()) - finish_mem = int(psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)) + finish_mem = int( + psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024) + ) total_duration = finish_time - self.start_time @@ -196,3 +693,4 @@ def write_output(self, measurement: dict[str, Any]): """ self.logger.info(json.dumps(measurement, indent=4)) self.output.write_output(measurement) +