diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4e500b424..bcba375d2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,10 +15,15 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions-cool/check-user-permission@v2 + if: github.triggering_actor != 'codegen-sh[bot]' with: require: write username: ${{ github.triggering_actor }} error-if-missing: true + # Skip permission check for codegen-sh[bot] + - name: Skip permission check for bot + if: github.triggering_actor == 'codegen-sh[bot]' + run: echo "Skipping permission check for codegen-sh[bot]" unit-tests: needs: access-check diff --git a/codegen-on-oss/codegen_on_oss/analysis/README.md b/codegen-on-oss/codegen_on_oss/analysis/README.md new file mode 100644 index 000000000..72c78e9b4 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analysis/README.md @@ -0,0 +1,172 @@ +# Enhanced Code Analysis Module + +This module provides comprehensive code analysis capabilities for Python codebases, focusing on detailed error detection, function call analysis, and type validation. + +## Features + +### Error Detection + +The error detection system identifies various issues in your code: + +- **Parameter Validation**: Detects unused parameters, parameter count mismatches, and missing required parameters +- **Call Validation**: Validates function call-in and call-out points, detects circular dependencies +- **Return Validation**: Checks for inconsistent return types and values +- **Code Quality**: Identifies unreachable code, overly complex functions, and potential exceptions + +### Function Call Analysis + +The function call analysis provides insights into how functions interact: + +- **Call Graph**: Builds a graph of function calls to visualize dependencies +- **Parameter Usage**: Analyzes how parameters are used within functions +- **Call Statistics**: Identifies most called functions, entry points, and leaf functions +- **Call Chains**: Finds paths between functions and calculates call depths + +### Type Validation + +The type validation system checks for type-related issues: + +- **Type Annotations**: Validates type annotations and identifies missing annotations +- **Type Compatibility**: Checks for type mismatches and inconsistencies +- **Type Inference**: Infers types for variables and expressions where possible + +## Usage + +### Using the CodeAnalyzer + +```python +from codegen import Codebase +from codegen_on_oss.analysis.analysis import CodeAnalyzer + +# Create a codebase from a repository +codebase = Codebase.from_repo("owner/repo") + +# Create an analyzer +analyzer = CodeAnalyzer(codebase) + +# Get comprehensive analysis +results = analyzer.analyze_all() + +# Access specific analysis components +error_analysis = analyzer.analyze_errors() +function_call_analysis = analyzer.analyze_function_calls() +type_analysis = analyzer.analyze_types() +complexity_analysis = analyzer.analyze_complexity() +import_analysis = analyzer.analyze_imports() + +# Get detailed information about specific elements +function = analyzer.find_function_by_name("my_function") +call_graph = analyzer.get_function_call_graph() +callers = call_graph.get_callers("my_function") +callees = call_graph.get_callees("my_function") +``` + +### Using the API + +The module provides a FastAPI-based API for analyzing codebases: + +- `POST /analyze_repo`: Analyze an entire repository +- `POST /analyze_file`: Analyze a specific file +- `POST /analyze_function`: Analyze a specific function +- `POST /analyze_errors`: Get detailed error analysis with optional filtering + +Example request to analyze a repository: + +```json +{ + "repo_url": "owner/repo", + "branch": "main" +} +``` + +Example request to analyze a specific function: + +```json +{ + "repo_url": "owner/repo", + "function_name": "my_function" +} +``` + +## Error Categories + +The error detection system identifies the following categories of errors: + +- `PARAMETER_TYPE_MISMATCH`: Parameter type doesn't match expected type +- `PARAMETER_COUNT_MISMATCH`: Wrong number of parameters in function call +- `UNUSED_PARAMETER`: Parameter is declared but never used +- `UNDEFINED_PARAMETER`: Parameter is used but not declared +- `MISSING_REQUIRED_PARAMETER`: Required parameter is missing in function call +- `RETURN_TYPE_MISMATCH`: Return value type doesn't match declared return type +- `UNDEFINED_VARIABLE`: Variable is used but not defined +- `UNUSED_IMPORT`: Import is never used +- `UNUSED_VARIABLE`: Variable is defined but never used +- `POTENTIAL_EXCEPTION`: Function might throw an exception without proper handling +- `CALL_POINT_ERROR`: Error in function call-in or call-out point +- `CIRCULAR_DEPENDENCY`: Circular dependency between functions +- `INCONSISTENT_RETURN`: Inconsistent return statements in function +- `UNREACHABLE_CODE`: Code that will never be executed +- `COMPLEX_FUNCTION`: Function with high cyclomatic complexity + +## Extending the Analysis + +You can extend the analysis capabilities by: + +1. Creating new detector classes that inherit from `ErrorDetector` +2. Implementing custom analysis logic in the `detect_errors` method +3. Adding the new detector to the `CodeAnalysisError` class + +Example: + +```python +from codegen_on_oss.analysis.error_detection import ErrorDetector, ErrorCategory, ErrorSeverity, CodeError + +class MyCustomDetector(ErrorDetector): + def detect_errors(self) -> List[CodeError]: + self.clear_errors() + + # Implement custom detection logic + for function in self.codebase.functions: + # Check for issues + if some_condition: + self.errors.append(CodeError( + category=ErrorCategory.COMPLEX_FUNCTION, + severity=ErrorSeverity.WARNING, + message="Custom error message", + file_path=function.filepath, + function_name=function.name + )) + + return self.errors +``` + +## Running the Server + +To run the analysis API server: + +```bash +python -m codegen_on_oss.analysis.server --host 0.0.0.0 --port 8000 +``` + +Then you can access the API documentation at http://localhost:8000/docs + +## Example Script + +An example script is provided to demonstrate the usage of the analysis module: + +```bash +python -m codegen_on_oss.analysis.example owner/repo main +``` + +This will analyze the specified repository and print the results. + +## Future Enhancements + +Planned enhancements for the analysis module: + +- Integration with external linters and type checkers +- Machine learning-based error detection +- Interactive visualization of analysis results +- Performance optimization for large codebases +- Support for more programming languages + diff --git a/codegen-on-oss/codegen_on_oss/analysis/analysis.py b/codegen-on-oss/codegen_on_oss/analysis/analysis.py index 9e956ec06..d121605f8 100644 --- a/codegen-on-oss/codegen_on_oss/analysis/analysis.py +++ b/codegen-on-oss/codegen_on_oss/analysis/analysis.py @@ -1,431 +1,688 @@ -from fastapi import FastAPI -from pydantic import BaseModel -from typing import Dict, List, Tuple, Any -from codegen import Codebase -from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement -from codegen.sdk.core.statements.if_block_statement import IfBlockStatement -from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement -from codegen.sdk.core.statements.while_statement import WhileStatement -from codegen.sdk.core.expressions.binary_expression import BinaryExpression -from codegen.sdk.core.expressions.unary_expression import UnaryExpression -from codegen.sdk.core.expressions.comparison_expression import ComparisonExpression -import math -import re -import requests -from datetime import datetime, timedelta -import subprocess +""" +Unified Analysis Module for Codegen-on-OSS + +This module serves as a central hub for all code analysis functionality, integrating +various specialized analysis components into a cohesive system for comprehensive +code analysis, error detection, and validation. +""" + +import json import os +import subprocess import tempfile -from fastapi.middleware.cors import CORSMiddleware -import modal +from datetime import UTC, datetime, timedelta +from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union +from urllib.parse import urlparse -image = ( - modal.Image.debian_slim() - .apt_install("git") - .pip_install( - "codegen", "fastapi", "uvicorn", "gitpython", "requests", "pydantic", "datetime" - ) -) - -app = modal.App(name="analytics-app", image=image) +import networkx as nx +import requests +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel -fastapi_app = FastAPI() +from codegen import Codebase +from codegen.sdk.core.class_definition import Class +from codegen.sdk.core.external_module import ExternalModule +from codegen.sdk.core.file import SourceFile +from codegen.sdk.core.function import Function +from codegen.sdk.core.import_resolution import Import +from codegen.sdk.core.symbol import Symbol +from codegen.sdk.enums import EdgeType, SymbolType + +# Import from other analysis modules +from codegen_on_oss.analysis.codebase_context import CodebaseContext +from codegen_on_oss.analysis.codebase_analysis import ( + get_codebase_summary, + get_file_summary, + get_class_summary, + get_function_summary, + get_symbol_summary +) +from codegen_on_oss.analysis.error_detection import ( + CodeAnalysisError, + ErrorCategory, + ErrorSeverity, + CodeError +) +from codegen_on_oss.analysis.function_call_analysis import ( + FunctionCallAnalysis, + FunctionCallGraph, + ParameterUsageAnalysis +) +from codegen_on_oss.analysis.type_validation import ( + TypeValidation, + TypeValidationError, + TypeAnnotationValidator, + TypeCompatibilityChecker, + TypeInference +) +from codegen_on_oss.analysis.analysis_import import ( + create_graph_from_codebase, + find_import_cycles, + find_problematic_import_loops +) -fastapi_app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], +# Create FastAPI app +app = FastAPI( + title="Code Analysis API", + description="API for comprehensive code analysis, error detection, and validation", + version="1.0.0" ) -def get_monthly_commits(repo_path: str) -> Dict[str, int]: +class CodeAnalyzer: """ - Get the number of commits per month for the last 12 months. - - Args: - repo_path: Path to the git repository - - Returns: - Dictionary with month-year as key and number of commits as value + Central class for code analysis that integrates all analysis components. + + This class serves as the main entry point for all code analysis functionality, + providing a unified interface to access various analysis capabilities including + error detection, function call analysis, and type validation. """ - end_date = datetime.now() - start_date = end_date - timedelta(days=365) - - date_format = "%Y-%m-%d" - since_date = start_date.strftime(date_format) - until_date = end_date.strftime(date_format) - repo_path = "https://github.com/" + repo_path - - try: - original_dir = os.getcwd() - - with tempfile.TemporaryDirectory() as temp_dir: - subprocess.run(["git", "clone", repo_path, temp_dir], check=True) - os.chdir(temp_dir) - - cmd = [ - "git", - "log", - f"--since={since_date}", - f"--until={until_date}", - "--format=%aI", - ] - - result = subprocess.run(cmd, capture_output=True, text=True, check=True) - commit_dates = result.stdout.strip().split("\n") - - monthly_counts = {} - current_date = start_date - while current_date <= end_date: - month_key = current_date.strftime("%Y-%m") - monthly_counts[month_key] = 0 - current_date = ( - current_date.replace(day=1) + timedelta(days=32) - ).replace(day=1) - - for date_str in commit_dates: - if date_str: # Skip empty lines - commit_date = datetime.fromisoformat(date_str.strip()) - month_key = commit_date.strftime("%Y-%m") - if month_key in monthly_counts: - monthly_counts[month_key] += 1 - - os.chdir(original_dir) - return dict(sorted(monthly_counts.items())) - - except subprocess.CalledProcessError as e: - print(f"Error executing git command: {e}") - return {} - except Exception as e: - print(f"Error processing git commits: {e}") - return {} - finally: - try: - os.chdir(original_dir) - except: - pass - - -def calculate_cyclomatic_complexity(function): - def analyze_statement(statement): - complexity = 0 - - if isinstance(statement, IfBlockStatement): - complexity += 1 - if hasattr(statement, "elif_statements"): - complexity += len(statement.elif_statements) - - elif isinstance(statement, (ForLoopStatement, WhileStatement)): - complexity += 1 - - elif isinstance(statement, TryCatchStatement): - complexity += len(getattr(statement, "except_blocks", [])) - - if hasattr(statement, "condition") and isinstance(statement.condition, str): - complexity += statement.condition.count( - " and " - ) + statement.condition.count(" or ") - - if hasattr(statement, "nested_code_blocks"): - for block in statement.nested_code_blocks: - complexity += analyze_block(block) - - return complexity - - def analyze_block(block): - if not block or not hasattr(block, "statements"): - return 0 - return sum(analyze_statement(stmt) for stmt in block.statements) - - return ( - 1 + analyze_block(function.code_block) if hasattr(function, "code_block") else 1 - ) - - -def cc_rank(complexity): - if complexity < 0: - raise ValueError("Complexity must be a non-negative value") - - ranks = [ - (1, 5, "A"), - (6, 10, "B"), - (11, 20, "C"), - (21, 30, "D"), - (31, 40, "E"), - (41, float("inf"), "F"), - ] - for low, high, rank in ranks: - if low <= complexity <= high: - return rank - return "F" - - -def calculate_doi(cls): - """Calculate the depth of inheritance for a given class.""" - return len(cls.superclasses) - - -def get_operators_and_operands(function): - operators = [] - operands = [] - - for statement in function.code_block.statements: - for call in statement.function_calls: - operators.append(call.name) - for arg in call.args: - operands.append(arg.source) - - if hasattr(statement, "expressions"): - for expr in statement.expressions: - if isinstance(expr, BinaryExpression): - operators.extend([op.source for op in expr.operators]) - operands.extend([elem.source for elem in expr.elements]) - elif isinstance(expr, UnaryExpression): - operators.append(expr.ts_node.type) - operands.append(expr.argument.source) - elif isinstance(expr, ComparisonExpression): - operators.extend([op.source for op in expr.operators]) - operands.extend([elem.source for elem in expr.elements]) - - if hasattr(statement, "expression"): - expr = statement.expression - if isinstance(expr, BinaryExpression): - operators.extend([op.source for op in expr.operators]) - operands.extend([elem.source for elem in expr.elements]) - elif isinstance(expr, UnaryExpression): - operators.append(expr.ts_node.type) - operands.append(expr.argument.source) - elif isinstance(expr, ComparisonExpression): - operators.extend([op.source for op in expr.operators]) - operands.extend([elem.source for elem in expr.elements]) - - return operators, operands - - -def calculate_halstead_volume(operators, operands): - n1 = len(set(operators)) - n2 = len(set(operands)) - - N1 = len(operators) - N2 = len(operands) - - N = N1 + N2 - n = n1 + n2 - - if n > 0: - volume = N * math.log2(n) - return volume, N1, N2, n1, n2 - return 0, N1, N2, n1, n2 - - -def count_lines(source: str): - """Count different types of lines in source code.""" - if not source.strip(): - return 0, 0, 0, 0 - - lines = [line.strip() for line in source.splitlines()] - loc = len(lines) - sloc = len([line for line in lines if line]) - - in_multiline = False - comments = 0 - code_lines = [] - - i = 0 - while i < len(lines): - line = lines[i] - code_part = line - if not in_multiline and "#" in line: - comment_start = line.find("#") - if not re.search(r'["\'].*#.*["\']', line[:comment_start]): - code_part = line[:comment_start].strip() - if line[comment_start:].strip(): - comments += 1 - - if ('"""' in line or "'''" in line) and not ( - line.count('"""') % 2 == 0 or line.count("'''") % 2 == 0 - ): - if in_multiline: - in_multiline = False - comments += 1 - else: - in_multiline = True - comments += 1 - if line.strip().startswith('"""') or line.strip().startswith("'''"): - code_part = "" - elif in_multiline: - comments += 1 - code_part = "" - elif line.strip().startswith("#"): - comments += 1 - code_part = "" - - if code_part.strip(): - code_lines.append(code_part) - - i += 1 - - lloc = 0 - continued_line = False - for line in code_lines: - if continued_line: - if not any(line.rstrip().endswith(c) for c in ("\\", ",", "{", "[", "(")): - continued_line = False - continue - - lloc += len([stmt for stmt in line.split(";") if stmt.strip()]) - - if any(line.rstrip().endswith(c) for c in ("\\", ",", "{", "[", "(")): - continued_line = True - - return loc, lloc, sloc, comments - - -def calculate_maintainability_index( - halstead_volume: float, cyclomatic_complexity: float, loc: int -) -> int: - """Calculate the normalized maintainability index for a given function.""" - if loc <= 0: - return 100 - - try: - raw_mi = ( - 171 - - 5.2 * math.log(max(1, halstead_volume)) - - 0.23 * cyclomatic_complexity - - 16.2 * math.log(max(1, loc)) + + def __init__(self, codebase: Codebase): + """ + Initialize the CodeAnalyzer with a codebase. + + Args: + codebase: The Codebase object to analyze + """ + self.codebase = codebase + self._context = None + self._initialized = False + + # Initialize analysis components + self._error_analyzer = None + self._function_call_analyzer = None + self._type_validator = None + + def initialize(self): + """ + Initialize the analyzer by setting up the context and other necessary components. + This is called automatically when needed but can be called explicitly for eager initialization. + """ + if self._initialized: + return + + # Initialize context if not already done + if self._context is None: + self._context = self._create_context() + + self._initialized = True + + def _create_context(self) -> CodebaseContext: + """ + Create a CodebaseContext instance for the current codebase. + + Returns: + A new CodebaseContext instance + """ + # If the codebase already has a context, use it + if hasattr(self.codebase, "ctx") and self.codebase.ctx is not None: + return self.codebase.ctx + + # Otherwise, create a new context from the codebase's configuration + from codegen.sdk.codebase.config import ProjectConfig + from codegen.configs.models.codebase import CodebaseConfig + + # Create a project config from the codebase + project_config = ProjectConfig( + repo_operator=self.codebase.repo_operator, + programming_language=self.codebase.programming_language, + base_path=self.codebase.base_path ) - normalized_mi = max(0, min(100, raw_mi * 100 / 171)) - return int(normalized_mi) - except (ValueError, TypeError): - return 0 - - -def get_maintainability_rank(mi_score: float) -> str: - """Convert maintainability index score to a letter grade.""" - if mi_score >= 85: - return "A" - elif mi_score >= 65: - return "B" - elif mi_score >= 45: - return "C" - elif mi_score >= 25: - return "D" - else: - return "F" - - -def get_github_repo_description(repo_url): - api_url = f"https://api.github.com/repos/{repo_url}" - - response = requests.get(api_url) - - if response.status_code == 200: - repo_data = response.json() - return repo_data.get("description", "No description available") - else: - return "" - - -class RepoRequest(BaseModel): + + # Create and return a new context + return CodebaseContext([project_config], config=CodebaseConfig()) + + @property + def context(self) -> CodebaseContext: + """ + Get the CodebaseContext for the current codebase. + + Returns: + A CodebaseContext object for the codebase + """ + if not self._initialized: + self.initialize() + + return self._context + + @property + def error_analyzer(self) -> CodeAnalysisError: + """ + Get the CodeAnalysisError instance for error detection. + + Returns: + A CodeAnalysisError instance + """ + if self._error_analyzer is None: + self._error_analyzer = CodeAnalysisError(self.codebase) + return self._error_analyzer + + @property + def function_call_analyzer(self) -> FunctionCallAnalysis: + """ + Get the FunctionCallAnalysis instance for function call analysis. + + Returns: + A FunctionCallAnalysis instance + """ + if self._function_call_analyzer is None: + self._function_call_analyzer = FunctionCallAnalysis(self.codebase) + return self._function_call_analyzer + + @property + def type_validator(self) -> TypeValidation: + """ + Get the TypeValidation instance for type validation. + + Returns: + A TypeValidation instance + """ + if self._type_validator is None: + self._type_validator = TypeValidation(self.codebase) + return self._type_validator + + def get_codebase_summary(self) -> str: + """ + Get a comprehensive summary of the codebase. + + Returns: + A string containing summary information about the codebase + """ + return get_codebase_summary(self.codebase) + + def get_file_summary(self, file_path: str) -> str: + """ + Get a summary of a specific file. + + Args: + file_path: Path to the file to analyze + + Returns: + A string containing summary information about the file + """ + file = self.codebase.get_file(file_path) + if file is None: + return f"File not found: {file_path}" + return get_file_summary(file) + + def get_class_summary(self, class_name: str) -> str: + """ + Get a summary of a specific class. + + Args: + class_name: Name of the class to analyze + + Returns: + A string containing summary information about the class + """ + for cls in self.codebase.classes: + if cls.name == class_name: + return get_class_summary(cls) + return f"Class not found: {class_name}" + + def get_function_summary(self, function_name: str) -> str: + """ + Get a summary of a specific function. + + Args: + function_name: Name of the function to analyze + + Returns: + A string containing summary information about the function + """ + for func in self.codebase.functions: + if func.name == function_name: + return get_function_summary(func) + return f"Function not found: {function_name}" + + def get_symbol_summary(self, symbol_name: str) -> str: + """ + Get a summary of a specific symbol. + + Args: + symbol_name: Name of the symbol to analyze + + Returns: + A string containing summary information about the symbol + """ + for symbol in self.codebase.symbols: + if symbol.name == symbol_name: + return get_symbol_summary(symbol) + return f"Symbol not found: {symbol_name}" + + def find_symbol_by_name(self, symbol_name: str) -> Optional[Symbol]: + """ + Find a symbol by its name. + + Args: + symbol_name: Name of the symbol to find + + Returns: + The Symbol object if found, None otherwise + """ + for symbol in self.codebase.symbols: + if symbol.name == symbol_name: + return symbol + return None + + def find_file_by_path(self, file_path: str) -> Optional[SourceFile]: + """ + Find a file by its path. + + Args: + file_path: Path to the file to find + + Returns: + The SourceFile object if found, None otherwise + """ + return self.codebase.get_file(file_path) + + def find_class_by_name(self, class_name: str) -> Optional[Class]: + """ + Find a class by its name. + + Args: + class_name: Name of the class to find + + Returns: + The Class object if found, None otherwise + """ + for cls in self.codebase.classes: + if cls.name == class_name: + return cls + return None + + def find_function_by_name(self, function_name: str) -> Optional[Function]: + """ + Find a function by its name. + + Args: + function_name: Name of the function to find + + Returns: + The Function object if found, None otherwise + """ + for func in self.codebase.functions: + if func.name == function_name: + return func + return None + + def analyze_imports(self) -> Dict[str, Any]: + """ + Analyze import relationships in the codebase. + + Returns: + A dictionary containing import analysis results + """ + graph = create_graph_from_codebase(self.codebase.repo_name) + cycles = find_import_cycles(graph) + problematic_loops = find_problematic_import_loops(graph, cycles) + + return { + "import_cycles": cycles, + "problematic_loops": problematic_loops + } + + def analyze_errors(self, category: Optional[str] = None, severity: Optional[str] = None) -> Dict[str, Any]: + """ + Analyze the codebase for errors. + + Args: + category: Optional error category to filter by + severity: Optional error severity to filter by + + Returns: + A dictionary containing error analysis results + """ + # Get all errors + all_errors = self.error_analyzer.analyze() + + # Filter by category if specified + if category: + try: + category_enum = ErrorCategory[category] + all_errors = [error for error in all_errors if error.category == category_enum] + except KeyError: + pass + + # Filter by severity if specified + if severity: + try: + severity_enum = ErrorSeverity[severity] + all_errors = [error for error in all_errors if error.severity == severity_enum] + except KeyError: + pass + + # Convert errors to dictionaries + error_dicts = [error.to_dict() for error in all_errors] + + # Get error summary + error_summary = self.error_analyzer.get_error_summary() + severity_summary = self.error_analyzer.get_severity_summary() + + return { + "errors": error_dicts, + "error_summary": error_summary, + "severity_summary": severity_summary, + "total_errors": len(all_errors) + } + + def analyze_function_calls(self, function_name: Optional[str] = None) -> Dict[str, Any]: + """ + Analyze function calls in the codebase. + + Args: + function_name: Optional name of a specific function to analyze + + Returns: + A dictionary containing function call analysis results + """ + if function_name: + # Analyze a specific function + return self.function_call_analyzer.analyze_function_dependencies(function_name) + else: + # Analyze all functions + return self.function_call_analyzer.analyze_all() + + def analyze_types(self, function_name: Optional[str] = None) -> Dict[str, Any]: + """ + Analyze type annotations and compatibility in the codebase. + + Args: + function_name: Optional name of a specific function to analyze + + Returns: + A dictionary containing type analysis results + """ + if function_name: + # Find the function + func = self.find_function_by_name(function_name) + if not func: + return {"error": f"Function {function_name} not found"} + + # Analyze the function + annotation_errors = self.type_validator.annotation_validator.validate_function_annotations(func) + compatibility_errors = self.type_validator.compatibility_checker.check_assignment_compatibility(func) + compatibility_errors.extend(self.type_validator.compatibility_checker.check_return_compatibility(func)) + compatibility_errors.extend(self.type_validator.compatibility_checker.check_parameter_compatibility(func)) + inferred_types = self.type_validator.type_inference.infer_variable_types(func) + + return { + "function_name": function_name, + "annotation_errors": [error.to_dict() for error in annotation_errors], + "compatibility_errors": [error.to_dict() for error in compatibility_errors], + "inferred_types": inferred_types + } + else: + # Analyze all types + return self.type_validator.validate_all() + + def analyze_complexity(self) -> Dict[str, Any]: + """ + Analyze code complexity metrics for the codebase. + + Returns: + A dictionary containing complexity metrics + """ + # Get complex functions from error analysis + complex_function_errors = self.error_analyzer.analyze_by_category(ErrorCategory.COMPLEX_FUNCTION) + complex_functions = [ + { + "name": error.function_name, + "file_path": error.file_path, + "message": error.message + } + for error in complex_function_errors + ] + + # Get call graph complexity from function call analysis + call_graph = self.function_call_analyzer.call_graph + most_complex = call_graph.get_most_complex_functions() + most_called = call_graph.get_most_called_functions() + + return { + "complex_functions": complex_functions, + "most_complex_by_calls": most_complex, + "most_called_functions": most_called, + "circular_dependencies": call_graph.get_circular_dependencies() + } + + def get_function_call_graph(self) -> FunctionCallGraph: + """ + Get the function call graph for the codebase. + + Returns: + A FunctionCallGraph instance + """ + return self.function_call_analyzer.call_graph + + def analyze_file(self, file_path: str) -> Dict[str, Any]: + """ + Analyze a specific file. + + Args: + file_path: Path to the file to analyze + + Returns: + A dictionary containing analysis results for the file + """ + file = self.find_file_by_path(file_path) + if not file: + return {"error": f"File {file_path} not found"} + + # Get file summary + summary = get_file_summary(file) + + # Get errors in the file + errors = self.error_analyzer.analyze_file(file_path) + error_dicts = [error.to_dict() for error in errors] + + # Get functions in the file + functions = [] + for func in self.codebase.functions: + if func.filepath == file_path: + functions.append({ + "name": func.name, + "parameters": [p.name for p in func.parameters] if hasattr(func, "parameters") else [], + "return_type": func.return_type if hasattr(func, "return_type") else None + }) + + # Get classes in the file + classes = [] + for cls in self.codebase.classes: + if cls.filepath == file_path: + classes.append({ + "name": cls.name, + "methods": [m.name for m in cls.methods] if hasattr(cls, "methods") else [], + "attributes": [a.name for a in cls.attributes] if hasattr(cls, "attributes") else [] + }) + + # Get imports in the file + imports = [] + if hasattr(file, "imports"): + for imp in file.imports: + if hasattr(imp, "source"): + imports.append(imp.source) + + return { + "file_path": file_path, + "summary": summary, + "errors": error_dicts, + "functions": functions, + "classes": classes, + "imports": imports + } + + def analyze_function(self, function_name: str) -> Dict[str, Any]: + """ + Analyze a specific function. + + Args: + function_name: Name of the function to analyze + + Returns: + A dictionary containing analysis results for the function + """ + func = self.find_function_by_name(function_name) + if not func: + return {"error": f"Function {function_name} not found"} + + # Get function summary + summary = get_function_summary(func) + + # Get errors in the function + errors = self.error_analyzer.analyze_function(function_name) + error_dicts = [error.to_dict() for error in errors] + + # Get function call analysis + call_analysis = self.function_call_analyzer.analyze_function_dependencies(function_name) + + # Get parameter usage analysis + param_analysis = self.function_call_analyzer.parameter_usage.analyze_parameter_usage(function_name) + + # Get type analysis + type_analysis = self.analyze_types(function_name) + + return { + "function_name": function_name, + "file_path": func.filepath, + "summary": summary, + "errors": error_dicts, + "call_analysis": call_analysis, + "parameter_analysis": param_analysis, + "type_analysis": type_analysis + } + + def analyze_all(self) -> Dict[str, Any]: + """ + Perform comprehensive analysis of the codebase. + + Returns: + A dictionary containing all analysis results + """ + return { + "codebase_summary": self.get_codebase_summary(), + "error_analysis": self.analyze_errors(), + "function_call_analysis": self.analyze_function_calls(), + "type_analysis": self.analyze_types(), + "complexity_analysis": self.analyze_complexity(), + "import_analysis": self.analyze_imports() + } + + +# API Models +class AnalyzeRepoRequest(BaseModel): repo_url: str + branch: Optional[str] = None -@fastapi_app.post("/analyze_repo") -async def analyze_repo(request: RepoRequest) -> Dict[str, Any]: - """Analyze a repository and return comprehensive metrics.""" - repo_url = request.repo_url - codebase = Codebase.from_repo(repo_url) +class AnalyzeFileRequest(BaseModel): + repo_url: str + file_path: str + branch: Optional[str] = None - num_files = len(codebase.files(extensions="*")) - num_functions = len(codebase.functions) - num_classes = len(codebase.classes) - total_loc = total_lloc = total_sloc = total_comments = 0 - total_complexity = 0 - total_volume = 0 - total_mi = 0 - total_doi = 0 +class AnalyzeFunctionRequest(BaseModel): + repo_url: str + function_name: str + branch: Optional[str] = None - monthly_commits = get_monthly_commits(repo_url) - print(monthly_commits) - for file in codebase.files: - loc, lloc, sloc, comments = count_lines(file.source) - total_loc += loc - total_lloc += lloc - total_sloc += sloc - total_comments += comments +class AnalyzeErrorsRequest(BaseModel): + repo_url: str + category: Optional[str] = None + severity: Optional[str] = None + branch: Optional[str] = None - callables = codebase.functions + [m for c in codebase.classes for m in c.methods] - num_callables = 0 - for func in callables: - if not hasattr(func, "code_block"): - continue +# Helper function to get codebase from repo URL +def get_codebase_from_url(repo_url: str, branch: Optional[str] = None) -> Codebase: + """ + Get a Codebase object from a repository URL. + + Args: + repo_url: URL of the repository to analyze + branch: Optional branch to analyze + + Returns: + A Codebase object + """ + try: + if branch: + return Codebase.from_repo(repo_url, branch=branch) + else: + return Codebase.from_repo(repo_url) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Failed to load repository: {str(e)}") - complexity = calculate_cyclomatic_complexity(func) - operators, operands = get_operators_and_operands(func) - volume, _, _, _, _ = calculate_halstead_volume(operators, operands) - loc = len(func.code_block.source.splitlines()) - mi_score = calculate_maintainability_index(volume, complexity, loc) - total_complexity += complexity - total_volume += volume - total_mi += mi_score - num_callables += 1 +# API Routes +@app.post("/analyze_repo") +async def analyze_repo(request: AnalyzeRepoRequest): + """ + Analyze an entire repository. + + Args: + request: AnalyzeRepoRequest object + + Returns: + Analysis results for the repository + """ + codebase = get_codebase_from_url(request.repo_url, request.branch) + analyzer = CodeAnalyzer(codebase) + return analyzer.analyze_all() - for cls in codebase.classes: - doi = calculate_doi(cls) - total_doi += doi - desc = get_github_repo_description(repo_url) +@app.post("/analyze_file") +async def analyze_file(request: AnalyzeFileRequest): + """ + Analyze a specific file in a repository. + + Args: + request: AnalyzeFileRequest object + + Returns: + Analysis results for the file + """ + codebase = get_codebase_from_url(request.repo_url, request.branch) + analyzer = CodeAnalyzer(codebase) + return analyzer.analyze_file(request.file_path) - results = { - "repo_url": repo_url, - "line_metrics": { - "total": { - "loc": total_loc, - "lloc": total_lloc, - "sloc": total_sloc, - "comments": total_comments, - "comment_density": (total_comments / total_loc * 100) - if total_loc > 0 - else 0, - }, - }, - "cyclomatic_complexity": { - "average": total_complexity if num_callables > 0 else 0, - }, - "depth_of_inheritance": { - "average": total_doi / len(codebase.classes) if codebase.classes else 0, - }, - "halstead_metrics": { - "total_volume": int(total_volume), - "average_volume": int(total_volume / num_callables) - if num_callables > 0 - else 0, - }, - "maintainability_index": { - "average": int(total_mi / num_callables) if num_callables > 0 else 0, - }, - "description": desc, - "num_files": num_files, - "num_functions": num_functions, - "num_classes": num_classes, - "monthly_commits": monthly_commits, - } - return results +@app.post("/analyze_function") +async def analyze_function(request: AnalyzeFunctionRequest): + """ + Analyze a specific function in a repository. + + Args: + request: AnalyzeFunctionRequest object + + Returns: + Analysis results for the function + """ + codebase = get_codebase_from_url(request.repo_url, request.branch) + analyzer = CodeAnalyzer(codebase) + return analyzer.analyze_function(request.function_name) -@app.function(image=image) -@modal.asgi_app() -def fastapi_modal_app(): - return fastapi_app +@app.post("/analyze_errors") +async def analyze_errors(request: AnalyzeErrorsRequest): + """ + Analyze errors in a repository. + + Args: + request: AnalyzeErrorsRequest object + + Returns: + Error analysis results for the repository + """ + codebase = get_codebase_from_url(request.repo_url, request.branch) + analyzer = CodeAnalyzer(codebase) + return analyzer.analyze_errors(request.category, request.severity) if __name__ == "__main__": - app.deploy("analytics-app") + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) + diff --git a/codegen-on-oss/codegen_on_oss/analysis/codebase_context.py b/codegen-on-oss/codegen_on_oss/analysis/codebase_context.py index 5c0fd47dd..c092356b7 100644 --- a/codegen-on-oss/codegen_on_oss/analysis/codebase_context.py +++ b/codegen-on-oss/codegen_on_oss/analysis/codebase_context.py @@ -121,7 +121,6 @@ class CodebaseContext: dependency_manager: DependencyManager | None language_engine: LanguageEngine | None _computing = False - _graph: PyDiGraph[Importable, Edge] filepath_idx: dict[str, NodeId] _ext_module_idx: dict[str, NodeId] flags: Flags @@ -143,8 +142,6 @@ def __init__( from codegen.sdk.core.parser import Parser self.progress = progress or StubProgress() - self.__graph = PyDiGraph() - self.__graph_ready = False self.filepath_idx = {} self._ext_module_idx = {} self.generation = 0 diff --git a/codegen-on-oss/codegen_on_oss/analysis/error_detection.py b/codegen-on-oss/codegen_on_oss/analysis/error_detection.py new file mode 100644 index 000000000..ab3a4c53f --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analysis/error_detection.py @@ -0,0 +1,783 @@ +""" +Error detection module for code analysis. + +This module provides classes and functions for detecting errors in code, +including parameter errors, type errors, and call-in/call-out point errors. +""" + +from enum import Enum, auto +from typing import Dict, List, Optional, Set, Tuple, Union, Any + +from codegen import Codebase +from codegen.sdk.core.function import Function +from codegen.sdk.core.class_definition import Class +from codegen.sdk.core.symbol import Symbol + + +class ErrorSeverity(Enum): + """Severity levels for code errors.""" + INFO = auto() + WARNING = auto() + ERROR = auto() + CRITICAL = auto() + + +class ErrorCategory(Enum): + """Categories of code errors.""" + PARAMETER_TYPE_MISMATCH = auto() + PARAMETER_COUNT_MISMATCH = auto() + UNUSED_PARAMETER = auto() + UNDEFINED_PARAMETER = auto() + MISSING_REQUIRED_PARAMETER = auto() + RETURN_TYPE_MISMATCH = auto() + UNDEFINED_VARIABLE = auto() + UNUSED_IMPORT = auto() + UNUSED_VARIABLE = auto() + POTENTIAL_EXCEPTION = auto() + CALL_POINT_ERROR = auto() + CIRCULAR_DEPENDENCY = auto() + INCONSISTENT_RETURN = auto() + UNREACHABLE_CODE = auto() + COMPLEX_FUNCTION = auto() + + +class CodeError: + """Represents an error detected in the code.""" + + def __init__( + self, + category: ErrorCategory, + severity: ErrorSeverity, + message: str, + file_path: str, + line_number: Optional[int] = None, + column_number: Optional[int] = None, + function_name: Optional[str] = None, + class_name: Optional[str] = None, + code_snippet: Optional[str] = None, + related_symbols: Optional[List[str]] = None, + fix_suggestion: Optional[str] = None + ): + """ + Initialize a CodeError. + + Args: + category: The category of the error + severity: The severity level of the error + message: A descriptive message about the error + file_path: Path to the file containing the error + line_number: Line number where the error occurs (optional) + column_number: Column number where the error occurs (optional) + function_name: Name of the function containing the error (optional) + class_name: Name of the class containing the error (optional) + code_snippet: A snippet of the code containing the error (optional) + related_symbols: List of symbol names related to the error (optional) + fix_suggestion: A suggestion for fixing the error (optional) + """ + self.category = category + self.severity = severity + self.message = message + self.file_path = file_path + self.line_number = line_number + self.column_number = column_number + self.function_name = function_name + self.class_name = class_name + self.code_snippet = code_snippet + self.related_symbols = related_symbols or [] + self.fix_suggestion = fix_suggestion + + def to_dict(self) -> Dict[str, Any]: + """Convert the error to a dictionary.""" + return { + "category": self.category.name, + "severity": self.severity.name, + "message": self.message, + "file_path": self.file_path, + "line_number": self.line_number, + "column_number": self.column_number, + "function_name": self.function_name, + "class_name": self.class_name, + "code_snippet": self.code_snippet, + "related_symbols": self.related_symbols, + "fix_suggestion": self.fix_suggestion + } + + def __str__(self) -> str: + """String representation of the error.""" + location = f"{self.file_path}" + if self.line_number: + location += f":{self.line_number}" + if self.column_number: + location += f":{self.column_number}" + + context = "" + if self.function_name: + context += f" in function '{self.function_name}'" + if self.class_name: + context += f" in class '{self.class_name}'" + + return f"[{self.severity.name}] {self.category.name}: {self.message} at {location}{context}" + + +class ErrorDetector: + """Base class for error detectors.""" + + def __init__(self, codebase: Codebase): + """ + Initialize the error detector. + + Args: + codebase: The codebase to analyze + """ + self.codebase = codebase + self.errors: List[CodeError] = [] + + def detect_errors(self) -> List[CodeError]: + """ + Detect errors in the codebase. + + Returns: + A list of detected errors + """ + raise NotImplementedError("Subclasses must implement detect_errors") + + def clear_errors(self) -> None: + """Clear the list of detected errors.""" + self.errors = [] + + +class ParameterErrorDetector(ErrorDetector): + """Detector for parameter-related errors.""" + + def detect_errors(self) -> List[CodeError]: + """ + Detect parameter-related errors in the codebase. + + Returns: + A list of detected parameter errors + """ + self.clear_errors() + + for func in self.codebase.functions: + # Check for unused parameters + self._check_unused_parameters(func) + + # Check for parameter count mismatches in function calls + self._check_parameter_count_mismatches(func) + + # Check for missing required parameters + self._check_missing_required_parameters(func) + + # Check for parameter type mismatches + self._check_parameter_type_mismatches(func) + + return self.errors + + def _check_unused_parameters(self, func: Function) -> None: + """Check for unused parameters in a function.""" + if not hasattr(func, "parameters") or not hasattr(func, "code_block"): + return + + for param in func.parameters: + # Skip self parameter in methods + if param.name == "self" and hasattr(func, "parent") and isinstance(func.parent, Class): + continue + + # Check if parameter is used in the function body + if hasattr(func, "code_block") and func.code_block and hasattr(func.code_block, "source"): + source = func.code_block.source + # Simple check - this could be improved with AST analysis + if param.name not in source or param.name + "=" in source: + self.errors.append(CodeError( + category=ErrorCategory.UNUSED_PARAMETER, + severity=ErrorSeverity.WARNING, + message=f"Parameter '{param.name}' is declared but never used", + file_path=func.filepath, + function_name=func.name, + fix_suggestion=f"Remove the unused parameter '{param.name}' or use it in the function body" + )) + + def _check_parameter_count_mismatches(self, func: Function) -> None: + """Check for parameter count mismatches in function calls.""" + if not hasattr(func, "function_calls"): + return + + for call in func.function_calls: + if hasattr(call, "target") and hasattr(call.target, "parameters"): + expected_count = len(call.target.parameters) + actual_count = len(call.arguments) + + # Account for self parameter in method calls + if hasattr(call.target, "parent") and isinstance(call.target.parent, Class): + expected_count -= 1 + + # Account for *args and **kwargs + has_args = any(p.name == "args" and p.is_variadic for p in call.target.parameters) + has_kwargs = any(p.name == "kwargs" and p.is_keyword_variadic for p in call.target.parameters) + + if not has_args and not has_kwargs and actual_count != expected_count: + self.errors.append(CodeError( + category=ErrorCategory.PARAMETER_COUNT_MISMATCH, + severity=ErrorSeverity.ERROR, + message=f"Function call has {actual_count} arguments but {expected_count} were expected", + file_path=func.filepath, + function_name=func.name, + related_symbols=[call.target.name], + fix_suggestion=f"Adjust the number of arguments to match the function signature" + )) + + def _check_missing_required_parameters(self, func: Function) -> None: + """Check for missing required parameters in function calls.""" + if not hasattr(func, "function_calls"): + return + + for call in func.function_calls: + if hasattr(call, "target") and hasattr(call.target, "parameters"): + # Get required parameters (those without default values) + required_params = [p.name for p in call.target.parameters if not p.has_default_value] + + # Skip self parameter in method calls + if hasattr(call.target, "parent") and isinstance(call.target.parent, Class): + if "self" in required_params: + required_params.remove("self") + + # Check if all required parameters are provided + provided_params = [arg.name for arg in call.arguments if hasattr(arg, "name")] + + for param in required_params: + if param not in provided_params: + self.errors.append(CodeError( + category=ErrorCategory.MISSING_REQUIRED_PARAMETER, + severity=ErrorSeverity.ERROR, + message=f"Required parameter '{param}' is missing in function call", + file_path=func.filepath, + function_name=func.name, + related_symbols=[call.target.name], + fix_suggestion=f"Add the required parameter '{param}' to the function call" + )) + + def _check_parameter_type_mismatches(self, func: Function) -> None: + """Check for parameter type mismatches in function calls.""" + if not hasattr(func, "function_calls"): + return + + for call in func.function_calls: + if hasattr(call, "target") and hasattr(call.target, "parameters"): + for i, arg in enumerate(call.arguments): + if i < len(call.target.parameters) and hasattr(arg, "type_annotation") and hasattr(call.target.parameters[i], "type_annotation"): + arg_type = arg.type_annotation + param_type = call.target.parameters[i].type_annotation + + if arg_type and param_type and arg_type != param_type: + self.errors.append(CodeError( + category=ErrorCategory.PARAMETER_TYPE_MISMATCH, + severity=ErrorSeverity.WARNING, + message=f"Argument type '{arg_type}' does not match parameter type '{param_type}'", + file_path=func.filepath, + function_name=func.name, + related_symbols=[call.target.name], + fix_suggestion=f"Convert the argument to the expected type '{param_type}'" + )) + + +class ReturnErrorDetector(ErrorDetector): + """Detector for return-related errors.""" + + def detect_errors(self) -> List[CodeError]: + """ + Detect return-related errors in the codebase. + + Returns: + A list of detected return errors + """ + self.clear_errors() + + for func in self.codebase.functions: + # Check for return type mismatches + self._check_return_type_mismatches(func) + + # Check for inconsistent return statements + self._check_inconsistent_returns(func) + + return self.errors + + def _check_return_type_mismatches(self, func: Function) -> None: + """Check for return type mismatches in a function.""" + if not hasattr(func, "return_statements") or not hasattr(func, "return_type"): + return + + for ret in func.return_statements: + if hasattr(ret, "value") and hasattr(ret.value, "type_annotation") and func.return_type: + ret_type = ret.value.type_annotation + + if ret_type and ret_type != func.return_type: + self.errors.append(CodeError( + category=ErrorCategory.RETURN_TYPE_MISMATCH, + severity=ErrorSeverity.WARNING, + message=f"Return value type '{ret_type}' does not match declared return type '{func.return_type}'", + file_path=func.filepath, + function_name=func.name, + fix_suggestion=f"Convert the return value to the declared type '{func.return_type}'" + )) + + def _check_inconsistent_returns(self, func: Function) -> None: + """Check for inconsistent return statements in a function.""" + if not hasattr(func, "return_statements"): + return + + # Check if some paths return values and others don't + has_value_returns = any(hasattr(ret, "value") and ret.value for ret in func.return_statements) + has_void_returns = any(not hasattr(ret, "value") or not ret.value for ret in func.return_statements) + + if has_value_returns and has_void_returns: + self.errors.append(CodeError( + category=ErrorCategory.INCONSISTENT_RETURN, + severity=ErrorSeverity.ERROR, + message=f"Function has inconsistent return statements (some with values, some without)", + file_path=func.filepath, + function_name=func.name, + fix_suggestion=f"Ensure all return statements consistently return values or None" + )) + + +class CallGraphErrorDetector(ErrorDetector): + """Detector for call graph related errors.""" + + def detect_errors(self) -> List[CodeError]: + """ + Detect call graph related errors in the codebase. + + Returns: + A list of detected call graph errors + """ + self.clear_errors() + + # Build call graph + call_graph = self._build_call_graph() + + # Check for circular dependencies + self._check_circular_dependencies(call_graph) + + # Check for call-in/call-out point errors + self._check_call_point_errors() + + return self.errors + + def _build_call_graph(self) -> Dict[str, Set[str]]: + """Build a call graph for the codebase.""" + call_graph = {} + + for func in self.codebase.functions: + if not hasattr(func, "function_calls"): + continue + + caller = func.name + if caller not in call_graph: + call_graph[caller] = set() + + for call in func.function_calls: + if hasattr(call, "target") and hasattr(call.target, "name"): + callee = call.target.name + call_graph[caller].add(callee) + + # Ensure callee is in the graph + if callee not in call_graph: + call_graph[callee] = set() + + return call_graph + + def _check_circular_dependencies(self, call_graph: Dict[str, Set[str]]) -> None: + """Check for circular dependencies in the call graph.""" + visited = set() + path = [] + + def dfs(node): + if node in path: + # Found a cycle + cycle = path[path.index(node):] + [node] + self._report_circular_dependency(cycle) + return + + if node in visited: + return + + visited.add(node) + path.append(node) + + for neighbor in call_graph.get(node, set()): + dfs(neighbor) + + path.pop() + + for node in call_graph: + dfs(node) + + def _report_circular_dependency(self, cycle: List[str]) -> None: + """Report a circular dependency.""" + cycle_str = " -> ".join(cycle) + + # Find the functions involved in the cycle + functions = [] + for name in cycle: + for func in self.codebase.functions: + if func.name == name: + functions.append(func) + break + + if not functions: + return + + # Report the error for the first function in the cycle + func = functions[0] + self.errors.append(CodeError( + category=ErrorCategory.CIRCULAR_DEPENDENCY, + severity=ErrorSeverity.WARNING, + message=f"Circular dependency detected: {cycle_str}", + file_path=func.filepath, + function_name=func.name, + related_symbols=cycle, + fix_suggestion="Break the circular dependency by refactoring one of the functions" + )) + + def _check_call_point_errors(self) -> None: + """Check for call-in/call-out point errors.""" + for func in self.codebase.functions: + if not hasattr(func, "function_calls") or not hasattr(func, "call_sites"): + continue + + # Check if function is called with consistent arguments + call_sites = func.call_sites + if len(call_sites) > 1: + arg_counts = set(len(call.arguments) for call in call_sites if hasattr(call, "arguments")) + + if len(arg_counts) > 1: + self.errors.append(CodeError( + category=ErrorCategory.CALL_POINT_ERROR, + severity=ErrorSeverity.WARNING, + message=f"Function is called with inconsistent number of arguments ({', '.join(map(str, arg_counts))})", + file_path=func.filepath, + function_name=func.name, + fix_suggestion="Ensure the function is called consistently with the same number of arguments" + )) + + +class CodeQualityErrorDetector(ErrorDetector): + """Detector for code quality related errors.""" + + def detect_errors(self) -> List[CodeError]: + """ + Detect code quality related errors in the codebase. + + Returns: + A list of detected code quality errors + """ + self.clear_errors() + + for func in self.codebase.functions: + # Check for complex functions + self._check_complex_function(func) + + # Check for unreachable code + self._check_unreachable_code(func) + + # Check for potential exceptions + self._check_potential_exceptions(func) + + # Check for unused imports + self._check_unused_imports() + + # Check for unused variables + self._check_unused_variables() + + return self.errors + + def _check_complex_function(self, func: Function) -> None: + """Check if a function is too complex.""" + if not hasattr(func, "code_block"): + return + + # Calculate cyclomatic complexity + complexity = self._calculate_cyclomatic_complexity(func) + + if complexity > 10: + self.errors.append(CodeError( + category=ErrorCategory.COMPLEX_FUNCTION, + severity=ErrorSeverity.WARNING, + message=f"Function has high cyclomatic complexity ({complexity})", + file_path=func.filepath, + function_name=func.name, + fix_suggestion="Refactor the function into smaller, more manageable pieces" + )) + + def _calculate_cyclomatic_complexity(self, func: Function) -> int: + """Calculate the cyclomatic complexity of a function.""" + if not hasattr(func, "code_block") or not func.code_block: + return 1 + + # Base complexity is 1 + complexity = 1 + + # Count if statements + if hasattr(func, "if_statements"): + complexity += len(func.if_statements) + + # Count for loops + if hasattr(func, "for_loops"): + complexity += len(func.for_loops) + + # Count while loops + if hasattr(func, "while_loops"): + complexity += len(func.while_loops) + + # Count except blocks + if hasattr(func, "except_blocks"): + complexity += len(func.except_blocks) + + # Count boolean operators + if hasattr(func, "code_block") and hasattr(func.code_block, "source"): + source = func.code_block.source + complexity += source.count(" and ") + source.count(" or ") + + return complexity + + def _check_unreachable_code(self, func: Function) -> None: + """Check for unreachable code in a function.""" + if not hasattr(func, "code_block") or not hasattr(func, "return_statements"): + return + + # Simple check for code after return statements + # This is a simplified approach - a proper implementation would use AST analysis + if hasattr(func.code_block, "source"): + source_lines = func.code_block.source.splitlines() + + for i, line in enumerate(source_lines): + if line.strip().startswith("return "): + # Check if there's non-empty code after this return + for j in range(i + 1, len(source_lines)): + if source_lines[j].strip() and not source_lines[j].strip().startswith(("#", "\"\"\"", "'''", "else:", "except ", "finally:")): + self.errors.append(CodeError( + category=ErrorCategory.UNREACHABLE_CODE, + severity=ErrorSeverity.WARNING, + message=f"Code after return statement will never be executed", + file_path=func.filepath, + line_number=j + 1, # +1 because line numbers are 1-based + function_name=func.name, + code_snippet=source_lines[j], + fix_suggestion="Remove or move the unreachable code" + )) + break + + def _check_potential_exceptions(self, func: Function) -> None: + """Check for potential exceptions in a function.""" + if not hasattr(func, "code_block"): + return + + # Check for common error-prone patterns + if hasattr(func.code_block, "source"): + source = func.code_block.source + + # Check for dictionary access without get() + if "[" in source and not "try:" in source: + self.errors.append(CodeError( + category=ErrorCategory.POTENTIAL_EXCEPTION, + severity=ErrorSeverity.INFO, + message=f"Function may raise KeyError when accessing dictionary", + file_path=func.filepath, + function_name=func.name, + fix_suggestion="Use dict.get() or try-except to handle potential KeyError" + )) + + # Check for division without checking for zero + if "/" in source and not "try:" in source and not "if " in source: + self.errors.append(CodeError( + category=ErrorCategory.POTENTIAL_EXCEPTION, + severity=ErrorSeverity.INFO, + message=f"Function may raise ZeroDivisionError", + file_path=func.filepath, + function_name=func.name, + fix_suggestion="Check for zero before division or use try-except" + )) + + def _check_unused_imports(self) -> None: + """Check for unused imports in the codebase.""" + for file in self.codebase.files: + if not hasattr(file, "imports") or not hasattr(file, "source"): + continue + + for imp in file.imports: + if hasattr(imp, "imported_symbol") and hasattr(imp.imported_symbol, "name"): + symbol_name = imp.imported_symbol.name + + # Check if the import is used in the file + if symbol_name not in file.source or symbol_name + " " not in file.source: + self.errors.append(CodeError( + category=ErrorCategory.UNUSED_IMPORT, + severity=ErrorSeverity.INFO, + message=f"Import '{symbol_name}' is never used", + file_path=file.filepath, + fix_suggestion=f"Remove the unused import" + )) + + def _check_unused_variables(self) -> None: + """Check for unused variables in the codebase.""" + for func in self.codebase.functions: + if not hasattr(func, "code_block") or not hasattr(func, "variables"): + continue + + for var in func.variables: + if hasattr(var, "name") and hasattr(func.code_block, "source"): + var_name = var.name + source = func.code_block.source + + # Count occurrences of the variable name + # This is a simplified approach - a proper implementation would use AST analysis + occurrences = source.count(var_name) + + # If the variable only appears once (its declaration), it's unused + if occurrences == 1: + self.errors.append(CodeError( + category=ErrorCategory.UNUSED_VARIABLE, + severity=ErrorSeverity.INFO, + message=f"Variable '{var_name}' is defined but never used", + file_path=func.filepath, + function_name=func.name, + fix_suggestion=f"Remove the unused variable" + )) + + +class CodeAnalysisError: + """Main class for code error analysis.""" + + def __init__(self, codebase: Codebase): + """ + Initialize the code error analyzer. + + Args: + codebase: The codebase to analyze + """ + self.codebase = codebase + self.detectors = [ + ParameterErrorDetector(codebase), + ReturnErrorDetector(codebase), + CallGraphErrorDetector(codebase), + CodeQualityErrorDetector(codebase) + ] + + def analyze(self) -> List[CodeError]: + """ + Analyze the codebase for errors. + + Returns: + A list of all detected errors + """ + all_errors = [] + + for detector in self.detectors: + errors = detector.detect_errors() + all_errors.extend(errors) + + return all_errors + + def analyze_by_category(self, category: ErrorCategory) -> List[CodeError]: + """ + Analyze the codebase for errors of a specific category. + + Args: + category: The error category to filter by + + Returns: + A list of errors of the specified category + """ + all_errors = self.analyze() + return [error for error in all_errors if error.category == category] + + def analyze_by_severity(self, severity: ErrorSeverity) -> List[CodeError]: + """ + Analyze the codebase for errors of a specific severity. + + Args: + severity: The error severity to filter by + + Returns: + A list of errors of the specified severity + """ + all_errors = self.analyze() + return [error for error in all_errors if error.severity == severity] + + def analyze_file(self, file_path: str) -> List[CodeError]: + """ + Analyze a specific file for errors. + + Args: + file_path: Path to the file to analyze + + Returns: + A list of errors in the specified file + """ + all_errors = self.analyze() + return [error for error in all_errors if error.file_path == file_path] + + def analyze_function(self, function_name: str) -> List[CodeError]: + """ + Analyze a specific function for errors. + + Args: + function_name: Name of the function to analyze + + Returns: + A list of errors in the specified function + """ + all_errors = self.analyze() + return [error for error in all_errors if error.function_name == function_name] + + def analyze_class(self, class_name: str) -> List[CodeError]: + """ + Analyze a specific class for errors. + + Args: + class_name: Name of the class to analyze + + Returns: + A list of errors in the specified class + """ + all_errors = self.analyze() + return [error for error in all_errors if error.class_name == class_name] + + def get_error_summary(self) -> Dict[str, int]: + """ + Get a summary of errors by category. + + Returns: + A dictionary mapping error categories to counts + """ + all_errors = self.analyze() + summary = {} + + for error in all_errors: + category = error.category.name + if category in summary: + summary[category] += 1 + else: + summary[category] = 1 + + return summary + + def get_severity_summary(self) -> Dict[str, int]: + """ + Get a summary of errors by severity. + + Returns: + A dictionary mapping error severities to counts + """ + all_errors = self.analyze() + summary = {} + + for error in all_errors: + severity = error.severity.name + if severity in summary: + summary[severity] += 1 + else: + summary[severity] = 1 + + return summary + diff --git a/codegen-on-oss/codegen_on_oss/analysis/example.py b/codegen-on-oss/codegen_on_oss/analysis/example.py new file mode 100644 index 000000000..f10359103 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analysis/example.py @@ -0,0 +1,180 @@ +""" +Example script demonstrating the usage of the enhanced code analysis module. +""" + +import json +import sys +from pathlib import Path + +from codegen import Codebase +from codegen_on_oss.analysis.analysis import CodeAnalyzer + + +def print_section(title): + """Print a section title.""" + print("\n" + "=" * 80) + print(f" {title} ".center(80, "=")) + print("=" * 80) + + +def print_json(data): + """Print data as formatted JSON.""" + print(json.dumps(data, indent=2)) + + +def analyze_repo(repo_url, branch=None): + """Analyze a repository and print the results.""" + print_section(f"Analyzing repository: {repo_url}") + + # Load the codebase + print(f"Loading codebase from {repo_url}...") + if branch: + codebase = Codebase.from_repo(repo_url, branch=branch) + else: + codebase = Codebase.from_repo(repo_url) + + # Create analyzer + analyzer = CodeAnalyzer(codebase) + + # Get codebase summary + print_section("Codebase Summary") + print(analyzer.get_codebase_summary()) + + # Analyze errors + print_section("Error Analysis") + error_analysis = analyzer.analyze_errors() + print(f"Total errors: {error_analysis['total_errors']}") + print("\nError summary:") + print_json(error_analysis['error_summary']) + print("\nSeverity summary:") + print_json(error_analysis['severity_summary']) + + # Show some errors if there are any + if error_analysis['errors']: + print("\nSample errors:") + for i, error in enumerate(error_analysis['errors'][:5]): # Show first 5 errors + print(f"\n{i+1}. {error['category']} ({error['severity']}): {error['message']}") + if error['function_name']: + print(f" Function: {error['function_name']}") + if error['file_path']: + print(f" File: {error['file_path']}") + if error['fix_suggestion']: + print(f" Suggestion: {error['fix_suggestion']}") + + # Analyze function calls + print_section("Function Call Analysis") + call_analysis = analyzer.analyze_function_calls() + + print("Most called functions:") + for func, count in call_analysis['call_graph']['most_called']: + print(f"- {func}: {count} calls") + + print("\nMost complex functions (by number of calls made):") + for func, count in call_analysis['call_graph']['most_complex']: + print(f"- {func}: calls {count} other functions") + + print("\nEntry point functions:") + for func in call_analysis['call_graph']['entry_points'][:10]: # Show first 10 + print(f"- {func}") + + print("\nLeaf functions:") + for func in call_analysis['call_graph']['leaf_functions'][:10]: # Show first 10 + print(f"- {func}") + + # Analyze circular dependencies + if call_analysis['call_graph']['circular_dependencies']: + print("\nCircular dependencies:") + for i, cycle in enumerate(call_analysis['call_graph']['circular_dependencies'][:5]): # Show first 5 + print(f"- Cycle {i+1}: {' -> '.join(cycle)}") + + # Analyze type annotations + print_section("Type Analysis") + type_analysis = analyzer.analyze_types() + + print("Type annotation coverage:") + print_json(type_analysis['annotation_coverage']) + + if type_analysis['annotation_errors']: + print("\nSample type annotation errors:") + for i, error in enumerate(type_analysis['annotation_errors'][:5]): # Show first 5 + print(f"\n{i+1}. {error['message']}") + if error['function_name']: + print(f" Function: {error['function_name']}") + if error['file_path']: + print(f" File: {error['file_path']}") + if error['fix_suggestion']: + print(f" Suggestion: {error['fix_suggestion']}") + + # Analyze a specific function if there are any + if codebase.functions: + func = next(iter(codebase.functions)) + if hasattr(func, 'name'): + print_section(f"Detailed Analysis of Function: {func.name}") + func_analysis = analyzer.analyze_function(func.name) + + print("Function summary:") + print(func_analysis['summary']) + + print("\nFunction call analysis:") + print(f"- Calls: {', '.join(func_analysis['call_analysis']['calls'])}" if func_analysis['call_analysis']['calls'] else "- Calls: None") + print(f"- Called by: {', '.join(func_analysis['call_analysis']['called_by'])}" if func_analysis['call_analysis']['called_by'] else "- Called by: None") + print(f"- Call depth: {func_analysis['call_analysis']['call_depth']}") + + print("\nParameter analysis:") + if 'parameters' in func_analysis['parameter_analysis']: + for param in func_analysis['parameter_analysis']['parameters']: + print(f"- {param['name']}: {'Used' if param['is_used'] else 'Unused'}, Type: {param['type'] or 'Unknown'}") + + print("\nType analysis:") + if 'inferred_types' in func_analysis['type_analysis']: + print("Inferred types:") + for var, type_name in func_analysis['type_analysis']['inferred_types'].items(): + print(f"- {var}: {type_name}") + + # Analyze a specific file if there are any + if codebase.files: + file = next(iter(codebase.files)) + if hasattr(file, 'filepath'): + print_section(f"Detailed Analysis of File: {file.filepath}") + file_analysis = analyzer.analyze_file(file.filepath) + + print("File summary:") + print(file_analysis['summary']) + + print("\nFunctions in file:") + for func in file_analysis['functions']: + print(f"- {func['name']}: Parameters: {', '.join(func['parameters'])}, Return type: {func['return_type'] or 'Unknown'}") + + print("\nClasses in file:") + for cls in file_analysis['classes']: + print(f"- {cls['name']}: Methods: {', '.join(cls['methods'])}") + + print("\nImports in file:") + for imp in file_analysis['imports']: + print(f"- {imp}") + + if file_analysis['errors']: + print("\nErrors in file:") + for i, error in enumerate(file_analysis['errors']): + print(f"- {error['category']}: {error['message']}") + + +def main(): + """Main function.""" + if len(sys.argv) < 2: + print("Usage: python example.py [branch]") + print("Example: python example.py https://github.com/user/repo main") + return + + repo_url = sys.argv[1] + branch = sys.argv[2] if len(sys.argv) > 2 else None + + try: + analyze_repo(repo_url, branch) + except Exception as e: + print(f"Error analyzing repository: {e}") + + +if __name__ == "__main__": + main() + diff --git a/codegen-on-oss/codegen_on_oss/analysis/function_call_analysis.py b/codegen-on-oss/codegen_on_oss/analysis/function_call_analysis.py new file mode 100644 index 000000000..8ebc8d51d --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analysis/function_call_analysis.py @@ -0,0 +1,485 @@ +""" +Function call analysis module for code analysis. + +This module provides classes and functions for analyzing function calls, +including call graphs, parameter usage, and call chains. +""" + +from typing import Dict, List, Optional, Set, Tuple, Union, Any +import networkx as nx + +from codegen import Codebase +from codegen.sdk.core.function import Function + + +class FunctionCallGraph: + """Represents a graph of function calls in a codebase.""" + + def __init__(self, codebase: Codebase): + """ + Initialize the function call graph. + + Args: + codebase: The codebase to analyze + """ + self.codebase = codebase + self.graph = self._build_graph() + self.nx_graph = self._build_networkx_graph() + + def _build_graph(self) -> Dict[str, Set[str]]: + """ + Build a dictionary-based graph of function calls. + + Returns: + A dictionary mapping function names to sets of called function names + """ + graph = {} + + for func in self.codebase.functions: + caller = func.name + if caller not in graph: + graph[caller] = set() + + if hasattr(func, "function_calls"): + for call in func.function_calls: + if hasattr(call, "target") and hasattr(call.target, "name"): + callee = call.target.name + graph[caller].add(callee) + + # Ensure callee is in the graph + if callee not in graph: + graph[callee] = set() + + return graph + + def _build_networkx_graph(self) -> nx.DiGraph: + """ + Build a NetworkX directed graph of function calls. + + Returns: + A NetworkX DiGraph representing the call graph + """ + G = nx.DiGraph() + + # Add nodes + for func_name in self.graph: + G.add_node(func_name) + + # Add edges + for caller, callees in self.graph.items(): + for callee in callees: + G.add_edge(caller, callee) + + return G + + def get_callers(self, function_name: str) -> List[str]: + """ + Get all functions that call the specified function. + + Args: + function_name: Name of the function to find callers for + + Returns: + A list of function names that call the specified function + """ + callers = [] + + for caller, callees in self.graph.items(): + if function_name in callees: + callers.append(caller) + + return callers + + def get_callees(self, function_name: str) -> List[str]: + """ + Get all functions called by the specified function. + + Args: + function_name: Name of the function to find callees for + + Returns: + A list of function names called by the specified function + """ + return list(self.graph.get(function_name, set())) + + def get_call_chain(self, start: str, end: str) -> List[List[str]]: + """ + Get all call chains from start function to end function. + + Args: + start: Name of the starting function + end: Name of the ending function + + Returns: + A list of call chains (each chain is a list of function names) + """ + if start not in self.graph or end not in self.graph: + return [] + + try: + # Find all simple paths from start to end + paths = list(nx.all_simple_paths(self.nx_graph, start, end)) + return paths + except nx.NetworkXNoPath: + return [] + + def get_entry_points(self) -> List[str]: + """ + Get all entry point functions (functions that are not called by any other function). + + Returns: + A list of entry point function names + """ + entry_points = [] + + for func_name in self.graph: + if not self.get_callers(func_name): + entry_points.append(func_name) + + return entry_points + + def get_leaf_functions(self) -> List[str]: + """ + Get all leaf functions (functions that don't call any other function). + + Returns: + A list of leaf function names + """ + leaf_functions = [] + + for func_name, callees in self.graph.items(): + if not callees: + leaf_functions.append(func_name) + + return leaf_functions + + def get_call_depth(self, function_name: str) -> int: + """ + Get the maximum call depth of a function. + + Args: + function_name: Name of the function to find call depth for + + Returns: + The maximum call depth (0 for leaf functions) + """ + if function_name not in self.graph: + return 0 + + callees = self.graph[function_name] + if not callees: + return 0 + + return 1 + max(self.get_call_depth(callee) for callee in callees) + + def get_most_called_functions(self, limit: int = 10) -> List[Tuple[str, int]]: + """ + Get the most frequently called functions. + + Args: + limit: Maximum number of functions to return + + Returns: + A list of (function_name, call_count) tuples, sorted by call count + """ + call_counts = {} + + for func_name in self.graph: + call_counts[func_name] = len(self.get_callers(func_name)) + + # Sort by call count (descending) + sorted_counts = sorted(call_counts.items(), key=lambda x: x[1], reverse=True) + + return sorted_counts[:limit] + + def get_most_complex_functions(self, limit: int = 10) -> List[Tuple[str, int]]: + """ + Get the most complex functions based on the number of function calls they make. + + Args: + limit: Maximum number of functions to return + + Returns: + A list of (function_name, complexity) tuples, sorted by complexity + """ + complexity = {} + + for func_name, callees in self.graph.items(): + complexity[func_name] = len(callees) + + # Sort by complexity (descending) + sorted_complexity = sorted(complexity.items(), key=lambda x: x[1], reverse=True) + + return sorted_complexity[:limit] + + def get_circular_dependencies(self) -> List[List[str]]: + """ + Get all circular dependencies in the call graph. + + Returns: + A list of circular dependency chains + """ + try: + # Find all simple cycles in the graph + cycles = list(nx.simple_cycles(self.nx_graph)) + return cycles + except: + # Fallback to manual cycle detection if NetworkX fails + return self._find_cycles_manually() + + def _find_cycles_manually(self) -> List[List[str]]: + """ + Find cycles in the call graph manually. + + Returns: + A list of circular dependency chains + """ + cycles = [] + visited = set() + path = [] + + def dfs(node): + if node in path: + # Found a cycle + cycle = path[path.index(node):] + [node] + cycles.append(cycle) + return + + if node in visited: + return + + visited.add(node) + path.append(node) + + for neighbor in self.graph.get(node, set()): + dfs(neighbor) + + path.pop() + + for node in self.graph: + visited = set() + path = [] + dfs(node) + + return cycles + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the call graph to a dictionary. + + Returns: + A dictionary representation of the call graph + """ + return { + "nodes": list(self.graph.keys()), + "edges": [(caller, callee) for caller, callees in self.graph.items() for callee in callees], + "entry_points": self.get_entry_points(), + "leaf_functions": self.get_leaf_functions(), + "most_called": self.get_most_called_functions(), + "most_complex": self.get_most_complex_functions(), + "circular_dependencies": self.get_circular_dependencies() + } + + +class ParameterUsageAnalysis: + """Analyzes how parameters are used within functions.""" + + def __init__(self, codebase: Codebase): + """ + Initialize the parameter usage analyzer. + + Args: + codebase: The codebase to analyze + """ + self.codebase = codebase + + def analyze_parameter_usage(self, function_name: str) -> Dict[str, Any]: + """ + Analyze how parameters are used in a specific function. + + Args: + function_name: Name of the function to analyze + + Returns: + A dictionary with parameter usage information + """ + # Find the function + func = None + for f in self.codebase.functions: + if f.name == function_name: + func = f + break + + if not func or not hasattr(func, "parameters") or not hasattr(func, "code_block"): + return {"error": f"Function {function_name} not found or has no parameters"} + + result = { + "function_name": function_name, + "parameters": [] + } + + for param in func.parameters: + param_info = { + "name": param.name, + "type": param.type_annotation if hasattr(param, "type_annotation") else None, + "has_default": param.has_default_value if hasattr(param, "has_default_value") else False, + "is_used": False, + "usage_count": 0, + "usage_contexts": [] + } + + # Check if parameter is used in the function body + if hasattr(func, "code_block") and func.code_block and hasattr(func.code_block, "source"): + source = func.code_block.source + source_lines = source.splitlines() + + # Count occurrences of the parameter name + param_info["usage_count"] = source.count(param.name) + + # If the parameter appears more than once (beyond its declaration), it's used + if param_info["usage_count"] > 1 or (param.name + "=" not in source and param_info["usage_count"] > 0): + param_info["is_used"] = True + + # Find usage contexts + for i, line in enumerate(source_lines): + if param.name in line and not line.strip().startswith("def "): + param_info["usage_contexts"].append({ + "line_number": i + 1, # +1 because line numbers are 1-based + "line": line.strip() + }) + + result["parameters"].append(param_info) + + return result + + def analyze_all_parameters(self) -> Dict[str, Dict[str, Any]]: + """ + Analyze parameter usage for all functions in the codebase. + + Returns: + A dictionary mapping function names to parameter usage information + """ + result = {} + + for func in self.codebase.functions: + if hasattr(func, "name"): + result[func.name] = self.analyze_parameter_usage(func.name) + + return result + + def get_unused_parameters(self) -> Dict[str, List[str]]: + """ + Get all unused parameters in the codebase. + + Returns: + A dictionary mapping function names to lists of unused parameter names + """ + result = {} + + for func_name, analysis in self.analyze_all_parameters().items(): + if "parameters" in analysis: + unused = [p["name"] for p in analysis["parameters"] if not p["is_used"] and p["name"] != "self"] + if unused: + result[func_name] = unused + + return result + + def get_parameter_type_coverage(self) -> Dict[str, float]: + """ + Get the percentage of parameters with type annotations for each function. + + Returns: + A dictionary mapping function names to type coverage percentages + """ + result = {} + + for func_name, analysis in self.analyze_all_parameters().items(): + if "parameters" in analysis and analysis["parameters"]: + typed_params = [p for p in analysis["parameters"] if p["type"] is not None] + coverage = len(typed_params) / len(analysis["parameters"]) * 100 + result[func_name] = coverage + + return result + + +class FunctionCallAnalysis: + """Main class for function call analysis.""" + + def __init__(self, codebase: Codebase): + """ + Initialize the function call analyzer. + + Args: + codebase: The codebase to analyze + """ + self.codebase = codebase + self.call_graph = FunctionCallGraph(codebase) + self.parameter_usage = ParameterUsageAnalysis(codebase) + + def analyze_call_graph(self) -> Dict[str, Any]: + """ + Analyze the function call graph. + + Returns: + A dictionary with call graph analysis results + """ + return self.call_graph.to_dict() + + def analyze_parameter_usage(self, function_name: Optional[str] = None) -> Dict[str, Any]: + """ + Analyze parameter usage. + + Args: + function_name: Name of the function to analyze (optional) + + Returns: + A dictionary with parameter usage analysis results + """ + if function_name: + return self.parameter_usage.analyze_parameter_usage(function_name) + else: + return { + "all_parameters": self.parameter_usage.analyze_all_parameters(), + "unused_parameters": self.parameter_usage.get_unused_parameters(), + "type_coverage": self.parameter_usage.get_parameter_type_coverage() + } + + def analyze_function_dependencies(self, function_name: str) -> Dict[str, Any]: + """ + Analyze dependencies for a specific function. + + Args: + function_name: Name of the function to analyze + + Returns: + A dictionary with function dependency analysis results + """ + if function_name not in self.call_graph.graph: + return {"error": f"Function {function_name} not found"} + + return { + "function_name": function_name, + "calls": list(self.call_graph.get_callees(function_name)), + "called_by": self.call_graph.get_callers(function_name), + "call_depth": self.call_graph.get_call_depth(function_name), + "circular_dependencies": [cycle for cycle in self.call_graph.get_circular_dependencies() if function_name in cycle] + } + + def analyze_all(self) -> Dict[str, Any]: + """ + Perform comprehensive function call analysis. + + Returns: + A dictionary with all analysis results + """ + return { + "call_graph": self.analyze_call_graph(), + "parameter_usage": self.analyze_parameter_usage(), + "entry_points": self.call_graph.get_entry_points(), + "leaf_functions": self.call_graph.get_leaf_functions(), + "most_called_functions": self.call_graph.get_most_called_functions(), + "most_complex_functions": self.call_graph.get_most_complex_functions(), + "circular_dependencies": self.call_graph.get_circular_dependencies(), + "type_coverage": self.parameter_usage.get_parameter_type_coverage() + } + diff --git a/codegen-on-oss/codegen_on_oss/analysis/mdx_docs_generation.py b/codegen-on-oss/codegen_on_oss/analysis/mdx_docs_generation.py index 648a3b68e..9e4543bea 100644 --- a/codegen-on-oss/codegen_on_oss/analysis/mdx_docs_generation.py +++ b/codegen-on-oss/codegen_on_oss/analysis/mdx_docs_generation.py @@ -110,10 +110,10 @@ def format_parameters_for_mdx(parameters: list[ParameterDoc]) -> str: def format_return_for_mdx(return_type: list[str], return_description: str) -> str: description = sanitize_html_for_mdx(return_description) if return_description else "" - return_type = resolve_type_string(return_type[0]) + return_type_str = resolve_type_string(return_type[0]) return f""" - + """ @@ -154,8 +154,8 @@ def get_mdx_route_for_class(cls_doc: ClassDoc) -> str: def format_type_string(type_string: str) -> str: - type_string = type_string.split("|") - return " | ".join([type_str.strip() for type_str in type_string]) + type_strings = type_string.split("|") + return " | ".join([type_str.strip() for type_str in type_strings]) def resolve_type_string(type_string: str) -> str: diff --git a/codegen-on-oss/codegen_on_oss/analysis/server.py b/codegen-on-oss/codegen_on_oss/analysis/server.py new file mode 100644 index 000000000..fd5d8248d --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analysis/server.py @@ -0,0 +1,27 @@ +""" +Server script for running the code analysis API. +""" + +import argparse +import uvicorn +from codegen_on_oss.analysis.analysis import app + + +def main(): + """Run the code analysis API server.""" + parser = argparse.ArgumentParser(description="Run the code analysis API server") + parser.add_argument("--host", default="0.0.0.0", help="Host to bind the server to") + parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to") + parser.add_argument("--reload", action="store_true", help="Enable auto-reload on code changes") + + args = parser.parse_args() + + print(f"Starting code analysis API server on {args.host}:{args.port}") + print("API documentation available at http://localhost:8000/docs") + + uvicorn.run(app, host=args.host, port=args.port, reload=args.reload) + + +if __name__ == "__main__": + main() + diff --git a/codegen-on-oss/codegen_on_oss/analysis/type_validation.py b/codegen-on-oss/codegen_on_oss/analysis/type_validation.py new file mode 100644 index 000000000..6a36ec374 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analysis/type_validation.py @@ -0,0 +1,636 @@ +""" +Type validation module for code analysis. + +This module provides classes and functions for validating type annotations, +checking type compatibility, and inferring types for variables and expressions. +""" + +from typing import Dict, List, Optional, Set, Tuple, Union, Any + +from codegen import Codebase +from codegen.sdk.core.function import Function +from codegen.sdk.core.class_definition import Class +from codegen.sdk.core.symbol import Symbol + + +class TypeValidationError: + """Represents a type validation error.""" + + def __init__( + self, + message: str, + file_path: str, + line_number: Optional[int] = None, + column_number: Optional[int] = None, + function_name: Optional[str] = None, + class_name: Optional[str] = None, + variable_name: Optional[str] = None, + expected_type: Optional[str] = None, + actual_type: Optional[str] = None, + fix_suggestion: Optional[str] = None + ): + """ + Initialize a TypeValidationError. + + Args: + message: A descriptive message about the error + file_path: Path to the file containing the error + line_number: Line number where the error occurs (optional) + column_number: Column number where the error occurs (optional) + function_name: Name of the function containing the error (optional) + class_name: Name of the class containing the error (optional) + variable_name: Name of the variable with the type error (optional) + expected_type: The expected type (optional) + actual_type: The actual type (optional) + fix_suggestion: A suggestion for fixing the error (optional) + """ + self.message = message + self.file_path = file_path + self.line_number = line_number + self.column_number = column_number + self.function_name = function_name + self.class_name = class_name + self.variable_name = variable_name + self.expected_type = expected_type + self.actual_type = actual_type + self.fix_suggestion = fix_suggestion + + def to_dict(self) -> Dict[str, Any]: + """Convert the error to a dictionary.""" + return { + "message": self.message, + "file_path": self.file_path, + "line_number": self.line_number, + "column_number": self.column_number, + "function_name": self.function_name, + "class_name": self.class_name, + "variable_name": self.variable_name, + "expected_type": self.expected_type, + "actual_type": self.actual_type, + "fix_suggestion": self.fix_suggestion + } + + def __str__(self) -> str: + """String representation of the error.""" + location = f"{self.file_path}" + if self.line_number: + location += f":{self.line_number}" + if self.column_number: + location += f":{self.column_number}" + + context = "" + if self.function_name: + context += f" in function '{self.function_name}'" + if self.class_name: + context += f" in class '{self.class_name}'" + if self.variable_name: + context += f" for variable '{self.variable_name}'" + + type_info = "" + if self.expected_type and self.actual_type: + type_info = f" (expected: {self.expected_type}, actual: {self.actual_type})" + + return f"Type Error: {self.message}{type_info} at {location}{context}" + + +class TypeAnnotationValidator: + """Validates type annotations in the codebase.""" + + def __init__(self, codebase: Codebase): + """ + Initialize the type annotation validator. + + Args: + codebase: The codebase to analyze + """ + self.codebase = codebase + self.errors: List[TypeValidationError] = [] + + def validate_function_annotations(self, func: Function) -> List[TypeValidationError]: + """ + Validate type annotations in a function. + + Args: + func: The function to validate + + Returns: + A list of type validation errors + """ + errors = [] + + # Check return type annotation + if not hasattr(func, "return_type") or not func.return_type: + errors.append(TypeValidationError( + message="Missing return type annotation", + file_path=func.filepath, + function_name=func.name, + fix_suggestion=f"Add a return type annotation to function '{func.name}'" + )) + + # Check parameter type annotations + if hasattr(func, "parameters"): + for param in func.parameters: + if not hasattr(param, "type_annotation") or not param.type_annotation: + # Skip self parameter in methods + if param.name == "self" and hasattr(func, "parent") and isinstance(func.parent, Class): + continue + + errors.append(TypeValidationError( + message=f"Missing type annotation for parameter '{param.name}'", + file_path=func.filepath, + function_name=func.name, + variable_name=param.name, + fix_suggestion=f"Add a type annotation to parameter '{param.name}'" + )) + + return errors + + def validate_variable_annotations(self, func: Function) -> List[TypeValidationError]: + """ + Validate type annotations for variables in a function. + + Args: + func: The function to validate + + Returns: + A list of type validation errors + """ + errors = [] + + if hasattr(func, "variables"): + for var in func.variables: + if hasattr(var, "name") and not hasattr(var, "type_annotation"): + errors.append(TypeValidationError( + message=f"Missing type annotation for variable '{var.name}'", + file_path=func.filepath, + function_name=func.name, + variable_name=var.name, + fix_suggestion=f"Add a type annotation to variable '{var.name}'" + )) + + return errors + + def validate_class_annotations(self, cls: Class) -> List[TypeValidationError]: + """ + Validate type annotations in a class. + + Args: + cls: The class to validate + + Returns: + A list of type validation errors + """ + errors = [] + + # Check attribute type annotations + if hasattr(cls, "attributes"): + for attr in cls.attributes: + if hasattr(attr, "name") and not hasattr(attr, "type_annotation"): + errors.append(TypeValidationError( + message=f"Missing type annotation for attribute '{attr.name}'", + file_path=cls.filepath, + class_name=cls.name, + variable_name=attr.name, + fix_suggestion=f"Add a type annotation to attribute '{attr.name}'" + )) + + # Check method annotations + if hasattr(cls, "methods"): + for method in cls.methods: + method_errors = self.validate_function_annotations(method) + errors.extend(method_errors) + + return errors + + def validate_all(self) -> List[TypeValidationError]: + """ + Validate type annotations in the entire codebase. + + Returns: + A list of all type validation errors + """ + self.errors = [] + + # Validate functions + for func in self.codebase.functions: + self.errors.extend(self.validate_function_annotations(func)) + self.errors.extend(self.validate_variable_annotations(func)) + + # Validate classes + for cls in self.codebase.classes: + self.errors.extend(self.validate_class_annotations(cls)) + + return self.errors + + def get_annotation_coverage(self) -> Dict[str, float]: + """ + Calculate type annotation coverage for the codebase. + + Returns: + A dictionary with coverage percentages for different elements + """ + # Count functions with return type annotations + total_functions = len(list(self.codebase.functions)) + functions_with_return_type = 0 + + for func in self.codebase.functions: + if hasattr(func, "return_type") and func.return_type: + functions_with_return_type += 1 + + # Count parameters with type annotations + total_parameters = 0 + parameters_with_type = 0 + + for func in self.codebase.functions: + if hasattr(func, "parameters"): + for param in func.parameters: + # Skip self parameter in methods + if param.name == "self" and hasattr(func, "parent") and isinstance(func.parent, Class): + continue + + total_parameters += 1 + if hasattr(param, "type_annotation") and param.type_annotation: + parameters_with_type += 1 + + # Count variables with type annotations + total_variables = 0 + variables_with_type = 0 + + for func in self.codebase.functions: + if hasattr(func, "variables"): + for var in func.variables: + total_variables += 1 + if hasattr(var, "type_annotation") and var.type_annotation: + variables_with_type += 1 + + # Count class attributes with type annotations + total_attributes = 0 + attributes_with_type = 0 + + for cls in self.codebase.classes: + if hasattr(cls, "attributes"): + for attr in cls.attributes: + total_attributes += 1 + if hasattr(attr, "type_annotation") and attr.type_annotation: + attributes_with_type += 1 + + # Calculate coverage percentages + function_coverage = (functions_with_return_type / total_functions * 100) if total_functions > 0 else 0 + parameter_coverage = (parameters_with_type / total_parameters * 100) if total_parameters > 0 else 0 + variable_coverage = (variables_with_type / total_variables * 100) if total_variables > 0 else 0 + attribute_coverage = (attributes_with_type / total_attributes * 100) if total_attributes > 0 else 0 + + # Calculate overall coverage + total_elements = total_functions + total_parameters + total_variables + total_attributes + total_with_type = functions_with_return_type + parameters_with_type + variables_with_type + attributes_with_type + overall_coverage = (total_with_type / total_elements * 100) if total_elements > 0 else 0 + + return { + "overall": overall_coverage, + "functions": function_coverage, + "parameters": parameter_coverage, + "variables": variable_coverage, + "attributes": attribute_coverage + } + + +class TypeCompatibilityChecker: + """Checks type compatibility in the codebase.""" + + def __init__(self, codebase: Codebase): + """ + Initialize the type compatibility checker. + + Args: + codebase: The codebase to analyze + """ + self.codebase = codebase + self.errors: List[TypeValidationError] = [] + + def check_assignment_compatibility(self, func: Function) -> List[TypeValidationError]: + """ + Check type compatibility in assignments within a function. + + Args: + func: The function to check + + Returns: + A list of type validation errors + """ + errors = [] + + # This is a simplified implementation + # A proper implementation would use AST analysis to check all assignments + if hasattr(func, "code_block") and hasattr(func.code_block, "source"): + source_lines = func.code_block.source.splitlines() + + for i, line in enumerate(source_lines): + line = line.strip() + + # Check for assignments with type annotations + if ":" in line and "=" in line and not line.startswith(("#", "\"\"\"", "'''", "def ", "class ")): + parts = line.split(":", 1) + var_name = parts[0].strip() + + # Extract type annotation + type_parts = parts[1].split("=", 1) + type_annotation = type_parts[0].strip() + + # Extract assigned value + if len(type_parts) > 1: + value = type_parts[1].strip() + + # Simple type checking for literals + if type_annotation == "int" and (value.startswith("\"") or value.startswith("'")): + errors.append(TypeValidationError( + message=f"Type mismatch in assignment", + file_path=func.filepath, + line_number=i + 1, # +1 because line numbers are 1-based + function_name=func.name, + variable_name=var_name, + expected_type=type_annotation, + actual_type="str", + fix_suggestion=f"Ensure the assigned value is of type '{type_annotation}'" + )) + elif type_annotation == "str" and value.isdigit(): + errors.append(TypeValidationError( + message=f"Type mismatch in assignment", + file_path=func.filepath, + line_number=i + 1, + function_name=func.name, + variable_name=var_name, + expected_type=type_annotation, + actual_type="int", + fix_suggestion=f"Ensure the assigned value is of type '{type_annotation}'" + )) + + return errors + + def check_return_compatibility(self, func: Function) -> List[TypeValidationError]: + """ + Check type compatibility in return statements within a function. + + Args: + func: The function to check + + Returns: + A list of type validation errors + """ + errors = [] + + if not hasattr(func, "return_type") or not func.return_type or not hasattr(func, "return_statements"): + return errors + + return_type = func.return_type + + for ret in func.return_statements: + if hasattr(ret, "value") and hasattr(ret.value, "type_annotation") and ret.value.type_annotation: + ret_type = ret.value.type_annotation + + # Check if return type matches declared return type + if ret_type != return_type: + errors.append(TypeValidationError( + message=f"Return type mismatch", + file_path=func.filepath, + function_name=func.name, + expected_type=return_type, + actual_type=ret_type, + fix_suggestion=f"Ensure the return value is of type '{return_type}'" + )) + + return errors + + def check_parameter_compatibility(self, func: Function) -> List[TypeValidationError]: + """ + Check type compatibility in function calls within a function. + + Args: + func: The function to check + + Returns: + A list of type validation errors + """ + errors = [] + + if not hasattr(func, "function_calls"): + return errors + + for call in func.function_calls: + if hasattr(call, "target") and hasattr(call.target, "parameters"): + for i, arg in enumerate(call.arguments): + if i < len(call.target.parameters) and hasattr(arg, "type_annotation") and hasattr(call.target.parameters[i], "type_annotation"): + arg_type = arg.type_annotation + param_type = call.target.parameters[i].type_annotation + + if arg_type and param_type and arg_type != param_type: + errors.append(TypeValidationError( + message=f"Argument type mismatch", + file_path=func.filepath, + function_name=func.name, + variable_name=call.target.parameters[i].name, + expected_type=param_type, + actual_type=arg_type, + fix_suggestion=f"Ensure the argument is of type '{param_type}'" + )) + + return errors + + def check_all(self) -> List[TypeValidationError]: + """ + Check type compatibility in the entire codebase. + + Returns: + A list of all type validation errors + """ + self.errors = [] + + for func in self.codebase.functions: + self.errors.extend(self.check_assignment_compatibility(func)) + self.errors.extend(self.check_return_compatibility(func)) + self.errors.extend(self.check_parameter_compatibility(func)) + + return self.errors + + +class TypeInference: + """Infers types for variables and expressions in the codebase.""" + + def __init__(self, codebase: Codebase): + """ + Initialize the type inference engine. + + Args: + codebase: The codebase to analyze + """ + self.codebase = codebase + self.inferred_types: Dict[str, Dict[str, str]] = {} + + def infer_variable_types(self, func: Function) -> Dict[str, str]: + """ + Infer types for variables in a function. + + Args: + func: The function to analyze + + Returns: + A dictionary mapping variable names to inferred types + """ + inferred = {} + + if not hasattr(func, "code_block") or not hasattr(func.code_block, "source"): + return inferred + + source_lines = func.code_block.source.splitlines() + + for line in source_lines: + line = line.strip() + + # Infer types from assignments + if "=" in line and not line.startswith(("#", "\"\"\"", "'''", "def ", "class ", "if ", "for ", "while ")): + parts = line.split("=", 1) + var_name = parts[0].strip() + value = parts[1].strip() + + # Infer type from literal values + if value.isdigit(): + inferred[var_name] = "int" + elif value.startswith("\"") or value.startswith("'"): + inferred[var_name] = "str" + elif value in ("True", "False"): + inferred[var_name] = "bool" + elif value.startswith("[") and value.endswith("]"): + inferred[var_name] = "list" + elif value.startswith("{") and value.endswith("}"): + if ":" in value: + inferred[var_name] = "dict" + else: + inferred[var_name] = "set" + elif value.startswith("(") and value.endswith(")"): + inferred[var_name] = "tuple" + elif value == "None": + inferred[var_name] = "None" + + return inferred + + def infer_all_types(self) -> Dict[str, Dict[str, str]]: + """ + Infer types for variables in all functions. + + Returns: + A dictionary mapping function names to dictionaries of inferred types + """ + self.inferred_types = {} + + for func in self.codebase.functions: + if hasattr(func, "name"): + self.inferred_types[func.name] = self.infer_variable_types(func) + + return self.inferred_types + + def suggest_type_annotations(self) -> Dict[str, Dict[str, str]]: + """ + Suggest type annotations for variables without annotations. + + Returns: + A dictionary mapping function names to dictionaries of suggested types + """ + suggestions = {} + + # Infer types for all variables + self.infer_all_types() + + for func in self.codebase.functions: + if not hasattr(func, "name") or not hasattr(func, "variables"): + continue + + func_suggestions = {} + + for var in func.variables: + if hasattr(var, "name") and not hasattr(var, "type_annotation"): + var_name = var.name + + # Check if we have an inferred type for this variable + if func.name in self.inferred_types and var_name in self.inferred_types[func.name]: + func_suggestions[var_name] = self.inferred_types[func.name][var_name] + + if func_suggestions: + suggestions[func.name] = func_suggestions + + return suggestions + + +class TypeValidation: + """Main class for type validation.""" + + def __init__(self, codebase: Codebase): + """ + Initialize the type validator. + + Args: + codebase: The codebase to analyze + """ + self.codebase = codebase + self.annotation_validator = TypeAnnotationValidator(codebase) + self.compatibility_checker = TypeCompatibilityChecker(codebase) + self.type_inference = TypeInference(codebase) + + def validate_annotations(self) -> List[TypeValidationError]: + """ + Validate type annotations in the codebase. + + Returns: + A list of type validation errors + """ + return self.annotation_validator.validate_all() + + def check_compatibility(self) -> List[TypeValidationError]: + """ + Check type compatibility in the codebase. + + Returns: + A list of type validation errors + """ + return self.compatibility_checker.check_all() + + def infer_types(self) -> Dict[str, Dict[str, str]]: + """ + Infer types for variables in the codebase. + + Returns: + A dictionary of inferred types + """ + return self.type_inference.infer_all_types() + + def suggest_annotations(self) -> Dict[str, Dict[str, str]]: + """ + Suggest type annotations for variables without annotations. + + Returns: + A dictionary of suggested type annotations + """ + return self.type_inference.suggest_type_annotations() + + def get_annotation_coverage(self) -> Dict[str, float]: + """ + Get type annotation coverage for the codebase. + + Returns: + A dictionary with coverage percentages + """ + return self.annotation_validator.get_annotation_coverage() + + def validate_all(self) -> Dict[str, Any]: + """ + Perform comprehensive type validation. + + Returns: + A dictionary with all validation results + """ + return { + "annotation_errors": [error.to_dict() for error in self.validate_annotations()], + "compatibility_errors": [error.to_dict() for error in self.check_compatibility()], + "inferred_types": self.infer_types(), + "suggested_annotations": self.suggest_annotations(), + "annotation_coverage": self.get_annotation_coverage() + } + diff --git a/codegen-on-oss/codegen_on_oss/metrics.py b/codegen-on-oss/codegen_on_oss/metrics.py index d77b4e686..d81d5b20b 100644 --- a/codegen-on-oss/codegen_on_oss/metrics.py +++ b/codegen-on-oss/codegen_on_oss/metrics.py @@ -1,15 +1,36 @@ +""" +Metrics module for Codegen-on-OSS + +This module provides tools for measuring and recording performance metrics +and code quality metrics for codebases. +""" + import json import os import time +import math from collections.abc import Generator from contextlib import contextmanager from importlib.metadata import version -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import psutil +import networkx as nx +from codegen import Codebase from codegen_on_oss.errors import ParseRunError from codegen_on_oss.outputs.base import BaseOutput +from codegen_on_oss.analysis.analysis import ( + CodeAnalyzer, + calculate_cyclomatic_complexity, + calculate_halstead_volume, + calculate_maintainability_index, + count_lines, + get_operators_and_operands, + cc_rank, + get_maintainability_rank, + calculate_doi +) if TYPE_CHECKING: # Logger only available in type checking context. @@ -19,6 +40,478 @@ codegen_version = str(version("codegen")) +class CodeMetrics: + """ + A class to calculate and provide code quality metrics for a codebase. + Integrates with the analysis module for comprehensive code analysis. + """ + + # Constants for threshold values + COMPLEXITY_THRESHOLD = 10 + MAINTAINABILITY_THRESHOLD = 65 + INHERITANCE_DEPTH_THRESHOLD = 3 + VOLUME_THRESHOLD = 1000 + EFFORT_THRESHOLD = 50000 + BUG_THRESHOLD = 0.5 + + def __init__(self, codebase: Codebase): + """ + Initialize the CodeMetrics class with a codebase. + + Args: + codebase: The Codebase object to analyze + """ + self.codebase = codebase + self.analyzer = CodeAnalyzer(codebase) + self._complexity_metrics = None + self._line_metrics = None + self._maintainability_metrics = None + self._inheritance_metrics = None + self._halstead_metrics = None + + def calculate_all_metrics(self) -> Dict[str, Any]: + """ + Calculate all available metrics for the codebase. + + Returns: + A dictionary containing all metrics categories + """ + return { + "complexity": self.complexity_metrics, + "lines": self.line_metrics, + "maintainability": self.maintainability_metrics, + "inheritance": self.inheritance_metrics, + "halstead": self.halstead_metrics, + } + + @property + def complexity_metrics(self) -> Dict[str, Any]: + """ + Calculate cyclomatic complexity metrics for the codebase. + + Returns: + A dictionary containing complexity metrics including average, + rank, and per-function complexity scores + """ + if self._complexity_metrics is not None: + return self._complexity_metrics + + callables = self.codebase.functions + [ + m for c in self.codebase.classes for m in c.methods + ] + + complexities = [] + for func in callables: + if not hasattr(func, "code_block"): + continue + + complexity = calculate_cyclomatic_complexity(func) + complexities.append({ + "name": func.name, + "complexity": complexity, + "rank": cc_rank(complexity) + }) + + avg_complexity = ( + sum(item["complexity"] for item in complexities) / len(complexities) + if complexities else 0 + ) + + self._complexity_metrics = { + "average": avg_complexity, + "rank": cc_rank(avg_complexity), + "functions": complexities + } + + return self._complexity_metrics + + @property + def line_metrics(self) -> Dict[str, Any]: + """ + Calculate line-based metrics for the codebase. + + Returns: + A dictionary containing line metrics including total counts + and per-file metrics for LOC, LLOC, SLOC, and comments + """ + if self._line_metrics is not None: + return self._line_metrics + + total_loc = total_lloc = total_sloc = total_comments = 0 + file_metrics = [] + + for file in self.codebase.files: + loc, lloc, sloc, comments = count_lines(file.source) + comment_density = (comments / loc * 100) if loc > 0 else 0 + + file_metrics.append({ + "file": file.path, + "loc": loc, + "lloc": lloc, + "sloc": sloc, + "comments": comments, + "comment_density": comment_density + }) + + total_loc += loc + total_lloc += lloc + total_sloc += sloc + total_comments += comments + + total_comment_density = ( + total_comments / total_loc * 100 if total_loc > 0 else 0 + ) + + self._line_metrics = { + "total": { + "loc": total_loc, + "lloc": total_lloc, + "sloc": total_sloc, + "comments": total_comments, + "comment_density": total_comment_density + }, + "files": file_metrics + } + + return self._line_metrics + + @property + def maintainability_metrics(self) -> Dict[str, Any]: + """ + Calculate maintainability index metrics for the codebase. + + Returns: + A dictionary containing maintainability metrics including average, + rank, and per-function maintainability scores + """ + if self._maintainability_metrics is not None: + return self._maintainability_metrics + + callables = self.codebase.functions + [ + m for c in self.codebase.classes for m in c.methods + ] + + mi_scores = [] + for func in callables: + if not hasattr(func, "code_block"): + continue + + complexity = calculate_cyclomatic_complexity(func) + operators, operands = get_operators_and_operands(func) + volume, _, _, _, _ = calculate_halstead_volume(operators, operands) + loc = len(func.code_block.source.splitlines()) + mi_score = calculate_maintainability_index(volume, complexity, loc) + + mi_scores.append({ + "name": func.name, + "mi_score": mi_score, + "rank": get_maintainability_rank(mi_score) + }) + + avg_mi = ( + sum(item["mi_score"] for item in mi_scores) / len(mi_scores) + if mi_scores else 0 + ) + + self._maintainability_metrics = { + "average": avg_mi, + "rank": get_maintainability_rank(avg_mi), + "functions": mi_scores + } + + return self._maintainability_metrics + + @property + def inheritance_metrics(self) -> Dict[str, Any]: + """ + Calculate inheritance metrics for the codebase. + + Returns: + A dictionary containing inheritance metrics including average + depth of inheritance and per-class inheritance depth + """ + if self._inheritance_metrics is not None: + return self._inheritance_metrics + + class_metrics = [] + for cls in self.codebase.classes: + doi = calculate_doi(cls) + class_metrics.append({ + "name": cls.name, + "doi": doi + }) + + avg_doi = ( + sum(item["doi"] for item in class_metrics) / len(class_metrics) + if class_metrics else 0 + ) + + self._inheritance_metrics = { + "average": avg_doi, + "classes": class_metrics + } + + return self._inheritance_metrics + + @property + def halstead_metrics(self) -> Dict[str, Any]: + """ + Calculate Halstead complexity metrics for the codebase. + + Returns: + A dictionary containing Halstead metrics including volume, + difficulty, effort, and other Halstead measures + """ + if self._halstead_metrics is not None: + return self._halstead_metrics + + callables = self.codebase.functions + [ + m for c in self.codebase.classes for m in c.methods + ] + + halstead_metrics = [] + for func in callables: + if not hasattr(func, "code_block"): + continue + + operators, operands = get_operators_and_operands(func) + volume, n1, n2, n_operators, n_operands = calculate_halstead_volume( + operators, operands + ) + + # Calculate additional Halstead metrics + n = n_operators + n_operands + N = n1 + n2 + + difficulty = ( + (n_operators / 2) * (n2 / n_operands) if n_operands > 0 else 0 + ) + effort = difficulty * volume if volume > 0 else 0 + time_required = effort / 18 if effort > 0 else 0 # Seconds + bugs_delivered = volume / 3000 if volume > 0 else 0 + + halstead_metrics.append({ + "name": func.name, + "volume": volume, + "difficulty": difficulty, + "effort": effort, + "time_required": time_required, # in seconds + "bugs_delivered": bugs_delivered + }) + + avg_volume = ( + sum(item["volume"] for item in halstead_metrics) / len(halstead_metrics) + if halstead_metrics else 0 + ) + avg_difficulty = ( + sum(item["difficulty"] for item in halstead_metrics) / len(halstead_metrics) + if halstead_metrics else 0 + ) + avg_effort = ( + sum(item["effort"] for item in halstead_metrics) / len(halstead_metrics) + if halstead_metrics else 0 + ) + + self._halstead_metrics = { + "average": { + "volume": avg_volume, + "difficulty": avg_difficulty, + "effort": avg_effort + }, + "functions": halstead_metrics + } + + return self._halstead_metrics + + def find_complex_functions(self, threshold: int = COMPLEXITY_THRESHOLD) -> List[Dict[str, Any]]: + """ + Find functions with cyclomatic complexity above the threshold. + + Args: + threshold: The complexity threshold (default: 10) + + Returns: + A list of functions with complexity above the threshold + """ + metrics = self.complexity_metrics + return [ + func for func in metrics["functions"] + if func["complexity"] > threshold + ] + + def find_low_maintainability_functions( + self, threshold: int = MAINTAINABILITY_THRESHOLD + ) -> List[Dict[str, Any]]: + """ + Find functions with maintainability index below the threshold. + + Args: + threshold: The maintainability threshold (default: 65) + + Returns: + A list of functions with maintainability below the threshold + """ + metrics = self.maintainability_metrics + return [ + func for func in metrics["functions"] + if func["mi_score"] < threshold + ] + + def find_deep_inheritance_classes( + self, threshold: int = INHERITANCE_DEPTH_THRESHOLD + ) -> List[Dict[str, Any]]: + """ + Find classes with depth of inheritance above the threshold. + + Args: + threshold: The inheritance depth threshold (default: 3) + + Returns: + A list of classes with inheritance depth above the threshold + """ + metrics = self.inheritance_metrics + return [cls for cls in metrics["classes"] if cls["doi"] > threshold] + + def find_high_volume_functions(self, threshold: int = VOLUME_THRESHOLD) -> List[Dict[str, Any]]: + """ + Find functions with Halstead volume above the threshold. + + Args: + threshold: The volume threshold (default: 1000) + + Returns: + A list of functions with volume above the threshold + """ + metrics = self.halstead_metrics + return [ + func for func in metrics["functions"] + if func["volume"] > threshold + ] + + def find_high_effort_functions(self, threshold: int = EFFORT_THRESHOLD) -> List[Dict[str, Any]]: + """ + Find functions with high Halstead effort (difficult to maintain). + + Args: + threshold: The effort threshold (default: 50000) + + Returns: + A list of functions with effort above the threshold + """ + metrics = self.halstead_metrics + return [ + func for func in metrics["functions"] + if func["effort"] > threshold + ] + + def find_bug_prone_functions(self, threshold: float = BUG_THRESHOLD) -> List[Dict[str, Any]]: + """ + Find functions with high estimated bug delivery. + + Args: + threshold: The bugs delivered threshold (default: 0.5) + + Returns: + A list of functions likely to contain bugs + """ + metrics = self.halstead_metrics + return [ + func for func in metrics["functions"] + if func["bugs_delivered"] > threshold + ] + + def get_code_quality_summary(self) -> Dict[str, Any]: + """ + Generate a comprehensive code quality summary. + + Returns: + A dictionary with overall code quality metrics and problem areas + """ + return { + "overall_metrics": { + "complexity": self.complexity_metrics["average"], + "complexity_rank": self.complexity_metrics["rank"], + "maintainability": self.maintainability_metrics["average"], + "maintainability_rank": self.maintainability_metrics["rank"], + "lines_of_code": self.line_metrics["total"]["loc"], + "comment_density": self.line_metrics["total"]["comment_density"], + "inheritance_depth": self.inheritance_metrics["average"], + "halstead_volume": self.halstead_metrics["average"]["volume"], + "halstead_difficulty": self.halstead_metrics["average"]["difficulty"], + }, + "problem_areas": { + "complex_functions": len(self.find_complex_functions()), + "low_maintainability": len(self.find_low_maintainability_functions()), + "deep_inheritance": len(self.find_deep_inheritance_classes()), + "high_volume": len(self.find_high_volume_functions()), + "high_effort": len(self.find_high_effort_functions()), + "bug_prone": len(self.find_bug_prone_functions()), + }, + "import_analysis": self.analyzer.analyze_imports() + } + + def analyze_codebase_structure(self) -> Dict[str, Any]: + """ + Analyze the structure of the codebase. + + Returns: + A dictionary with codebase structure information + """ + return { + "summary": self.analyzer.get_codebase_summary(), + "files": len(self.codebase.files), + "functions": len(self.codebase.functions), + "classes": len(self.codebase.classes), + "imports": len(self.codebase.imports), + "symbols": len(self.codebase.symbols) + } + + def generate_documentation(self) -> None: + """ + Generate documentation for the codebase. + """ + self.analyzer.document_functions() + + def analyze_dependencies(self) -> Dict[str, Any]: + """ + Analyze dependencies in the codebase. + + Returns: + A dictionary with dependency analysis results + """ + # Create a dependency graph + G = nx.DiGraph() + + # Add nodes for all files + for file in self.codebase.files: + G.add_node(file.path) + + # Add edges for imports + for imp in self.codebase.imports: + if imp.from_file and imp.to_file: + G.add_edge(imp.from_file.filepath, imp.to_file.filepath) + + # Find cycles + cycles = list(nx.simple_cycles(G)) + + # Calculate centrality metrics + centrality = nx.degree_centrality(G) + + return { + "dependency_graph": { + "nodes": len(G.nodes), + "edges": len(G.edges), + "density": nx.density(G) + }, + "cycles": len(cycles), + "most_central_files": sorted( + [(file, score) for file, score in centrality.items()], + key=lambda x: x[1], + reverse=True + )[:10] + } + + class MetricsProfiler: """ A helper to record performance metrics across multiple profiles and write them to a CSV. @@ -42,7 +535,7 @@ def __init__(self, output: BaseOutput): @contextmanager def start_profiler( self, name: str, revision: str, language: str | None, logger: "Logger" - ) -> Generator["MetricsProfile", None, None]: + ) -> Generator[Any, None, None]: """ Starts a new profiling session for a given profile name. Returns a MetricsProfile instance that you can use to mark measurements. @@ -81,9 +574,9 @@ def fields(cls) -> list[str]: class MetricsProfile: """ Context-managed profile that records measurements at each call to `measure()`. - It tracks the wall-clock duration, CPU time, and memory usage (with delta) at the time of the call. - Upon exiting the context, it also writes all collected metrics, including the total time, - to a CSV file. + It tracks the wall-clock duration, CPU time, and memory usage (with delta) + at the time of the call. Upon exiting the context, it also writes all collected + metrics, including the total time, to a CSV file. """ if TYPE_CHECKING: @@ -131,7 +624,9 @@ def measure(self, action_name: str): """ current_time = time.perf_counter() current_cpu = float(time.process_time()) - current_mem = int(psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)) + current_mem = int( + psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024) + ) # Calculate time deltas. delta_time = current_time - self.last_measure_time @@ -168,7 +663,9 @@ def finish(self, error: str | None = None): """ finish_time = time.perf_counter() finish_cpu = float(time.process_time()) - finish_mem = int(psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)) + finish_mem = int( + psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024) + ) total_duration = finish_time - self.start_time @@ -196,3 +693,4 @@ def write_output(self, measurement: dict[str, Any]): """ self.logger.info(json.dumps(measurement, indent=4)) self.output.write_output(measurement) +