diff --git a/README.md b/README.md index 985a7f196..3c4d76a17 100644 --- a/README.md +++ b/README.md @@ -7,40 +7,47 @@ A powerful static code analysis system that provides extensive information about This analyzer provides comprehensive analysis of your codebase, including: ### 1. Codebase Structure Analysis + - File Statistics (count, language, size) - Symbol Tree Analysis - Import/Export Analysis - Module Organization ### 2. Symbol-Level Analysis + - Function Analysis (parameters, return types, complexity) - Class Analysis (methods, attributes, inheritance) - Variable Analysis - Type Analysis ### 3. Dependency and Flow Analysis + - Call Graph Generation - Data Flow Analysis - Control Flow Analysis - Symbol Usage Analysis ### 4. Code Quality Analysis + - Unused Code Detection - Code Duplication Analysis - Complexity Metrics - Style and Convention Analysis ### 5. Visualization Capabilities + - Dependency Graphs - Call Graphs - Symbol Trees - Heat Maps ### 6. Language-Specific Analysis + - Python-Specific Analysis - TypeScript-Specific Analysis ### 7. Code Metrics + - Monthly Commits - Cyclomatic Complexity - Halstead Volume @@ -49,12 +56,14 @@ This analyzer provides comprehensive analysis of your codebase, including: ## Installation 1. Clone the repository: + ```bash git clone https://github.com/yourusername/codebase-analyzer.git cd codebase-analyzer ``` 2. Install dependencies: + ```bash pip install -r requirements.txt ``` @@ -111,4 +120,3 @@ python codebase_analyzer.py --repo-url https://github.com/username/repo --output ## License MIT - diff --git a/codegen-examples/examples/snapshot_event_handler/README.md b/codegen-examples/examples/snapshot_event_handler/README.md index 8899580e1..447d154cd 100644 --- a/codegen-examples/examples/snapshot_event_handler/README.md +++ b/codegen-examples/examples/snapshot_event_handler/README.md @@ -1,6 +1,6 @@ # Event Handler with codebase snapshotting -This project is designed to using Modal snapshotting to provide parsed codebase instances with minimal latency, make it more manageable to write event based handlers. +This project is designed to using Modal snapshotting to provide parsed codebase instances with minimal latency, make it more manageable to write event based handlers. Follow the instructions below to set up and deploy the application. @@ -9,7 +9,7 @@ Follow the instructions below to set up and deploy the application. Before you begin, ensure you have the following installed and configured: 1. **uv**: A tool for managing virtual environments and syncing dependencies. -2. **Modal**: Ensure you have Modal configured on your system. +1. **Modal**: Ensure you have Modal configured on your system. ## Setup Instructions @@ -23,7 +23,7 @@ Before you begin, ensure you have the following installed and configured: source ./venv/bin/activate ``` -2. **Sync Dependencies** +1. **Sync Dependencies** Sync the project dependencies using `uv`: @@ -31,7 +31,7 @@ Before you begin, ensure you have the following installed and configured: uv sync ``` -3. **Deploy to Modal** +1. **Deploy to Modal** Deploy the application to Modal by running: @@ -48,7 +48,6 @@ Before you begin, ensure you have the following installed and configured: - `.env.template` and `.env`: Environment variable templates and configurations. - `pyproject.toml`: Project configuration and dependencies. - ## Integration -Once deployed, you can use the deployed web_url as the webhook endpoint for your slack, linear, or github webhooks. \ No newline at end of file +Once deployed, you can use the deployed web_url as the webhook endpoint for your slack, linear, or github webhooks. diff --git a/codegen-examples/examples/swebench_agent_run/local_run.ipynb b/codegen-examples/examples/swebench_agent_run/local_run.ipynb index f2f73c922..237732bbf 100644 --- a/codegen-examples/examples/swebench_agent_run/local_run.ipynb +++ b/codegen-examples/examples/swebench_agent_run/local_run.ipynb @@ -32,7 +32,14 @@ "metadata": {}, "outputs": [], "source": [ - "await run_eval(use_existing_preds=None, dataset=\"lite\", length=5, repo=\"django/django\", num_workers=10, model=\"claude-3-7-sonnet-latest\")" + "await run_eval(\n", + " use_existing_preds=None,\n", + " dataset=\"lite\",\n", + " length=5,\n", + " repo=\"django/django\",\n", + " num_workers=10,\n", + " model=\"claude-3-7-sonnet-latest\",\n", + ")" ] }, { @@ -76,7 +83,12 @@ "source": [ "from codegen.agents.code_agent import CodeAgent\n", "\n", - "agent = CodeAgent(codebase=codebase, tags=[\"local_test\"], model_name=\"claude-3-5-sonnet-latest\", model_provider=\"anthropic\")" + "agent = CodeAgent(\n", + " codebase=codebase,\n", + " tags=[\"local_test\"],\n", + " model_name=\"claude-3-5-sonnet-latest\",\n", + " model_provider=\"anthropic\",\n", + ")" ] }, { diff --git a/codegen-on-oss/codegen_on_oss/analyzers/README.md b/codegen-on-oss/codegen_on_oss/analyzers/README.md index e268fbd32..c62eafe97 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/README.md +++ b/codegen-on-oss/codegen_on_oss/analyzers/README.md @@ -25,6 +25,7 @@ analyzers/ ### 1. API Interface (`api.py`) The main entry point for frontend applications. Provides REST-like endpoints for: + - Codebase analysis - PR analysis - Dependency visualization @@ -34,6 +35,7 @@ The main entry point for frontend applications. Provides REST-like endpoints for ### 2. Analyzer System (`analyzer.py`) Plugin-based system that coordinates different types of analysis: + - Code quality analysis (complexity, maintainability) - Dependency analysis (imports, cycles, coupling) - PR impact analysis @@ -42,6 +44,7 @@ Plugin-based system that coordinates different types of analysis: ### 3. Issue Tracking (`issues.py`) Comprehensive issue model with: + - Severity levels (critical, error, warning, info) - Categories (dead code, complexity, dependency, etc.) - Location information and suggestions @@ -50,6 +53,7 @@ Comprehensive issue model with: ### 4. Dependency Analysis (`dependencies.py`) Analysis of codebase dependencies: + - Import dependencies between modules - Circular dependency detection - Module coupling analysis @@ -59,6 +63,7 @@ Analysis of codebase dependencies: ### 5. Code Quality Analysis (`code_quality.py`) Analysis of code quality aspects: + - Dead code detection (unused functions, variables) - Complexity metrics (cyclomatic, cognitive) - Parameter checking (types, usage) @@ -121,11 +126,7 @@ dependency_issues = api.get_issues(category="dependency_cycle") module_deps = api.get_module_dependencies(format="json") # Get function call graph -call_graph = api.get_function_call_graph( - function_name="main", - depth=3, - format="json" -) +call_graph = api.get_function_call_graph(function_name="main", depth=3, format="json") # Export visualization to file api.export_visualization(call_graph, format="html", filename="call_graph.html") @@ -197,52 +198,42 @@ For a web application exposing these endpoints with Flask: ```python from flask import Flask, request, jsonify -from codegen_on_oss.analyzers.api import ( - api_analyze_codebase, - api_analyze_pr, - api_get_visualization, - api_get_static_errors -) +from codegen_on_oss.analyzers.api import api_analyze_codebase, api_analyze_pr, api_get_visualization, api_get_static_errors app = Flask(__name__) + @app.route("/api/analyze/codebase", methods=["POST"]) def analyze_codebase(): data = request.json - result = api_analyze_codebase( - repo_path=data.get("repo_path"), - analysis_types=data.get("analysis_types") - ) + result = api_analyze_codebase(repo_path=data.get("repo_path"), analysis_types=data.get("analysis_types")) return jsonify(result) + @app.route("/api/analyze/pr", methods=["POST"]) def analyze_pr(): data = request.json - result = api_analyze_pr( - repo_path=data.get("repo_path"), - pr_number=data.get("pr_number") - ) + result = api_analyze_pr(repo_path=data.get("repo_path"), pr_number=data.get("pr_number")) return jsonify(result) + @app.route("/api/visualize", methods=["POST"]) def visualize(): data = request.json - result = api_get_visualization( - repo_path=data.get("repo_path"), - viz_type=data.get("viz_type"), - params=data.get("params", {}) - ) + result = api_get_visualization(repo_path=data.get("repo_path"), viz_type=data.get("viz_type"), params=data.get("params", {})) return jsonify(result) + @app.route("/api/issues", methods=["GET"]) def get_issues(): repo_path = request.args.get("repo_path") severity = request.args.get("severity") category = request.args.get("category") - + api = create_api(repo_path=repo_path) return jsonify(api.get_issues(severity=severity, category=category)) + if __name__ == "__main__": app.run(debug=True) -``` \ No newline at end of file +``` diff --git a/codegen-on-oss/codegen_on_oss/analyzers/__init__.py b/codegen-on-oss/codegen_on_oss/analyzers/__init__.py index f1ef5c5b4..5d4a9394f 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/__init__.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/__init__.py @@ -7,32 +7,43 @@ """ # Main API interface -from codegen_on_oss.analyzers.api import ( - CodegenAnalyzerAPI, - create_api, - api_analyze_codebase, - api_analyze_pr, - api_get_visualization, - api_get_static_errors -) - # Modern analyzer architecture from codegen_on_oss.analyzers.analyzer import ( AnalyzerManager, AnalyzerPlugin, AnalyzerRegistry, CodeQualityPlugin, - DependencyPlugin + DependencyPlugin, +) +from codegen_on_oss.analyzers.api import ( + CodegenAnalyzerAPI, + api_analyze_codebase, + api_analyze_pr, + api_get_static_errors, + api_get_visualization, + create_api, ) +# Legacy analyzer interfaces (for backward compatibility) +from codegen_on_oss.analyzers.base_analyzer import BaseCodeAnalyzer + +# Core analysis modules +from codegen_on_oss.analyzers.code_quality import CodeQualityAnalyzer +from codegen_on_oss.analyzers.codebase_analyzer import CodebaseAnalyzer +from codegen_on_oss.analyzers.dependencies import DependencyAnalyzer + +# Diff tracking +from codegen_on_oss.analyzers.diff_lite import ChangeType, DiffLite +from codegen_on_oss.analyzers.error_analyzer import CodebaseAnalyzer as ErrorAnalyzer + # Issue tracking system from codegen_on_oss.analyzers.issues import ( + AnalysisType, + CodeLocation, Issue, + IssueCategory, IssueCollection, IssueSeverity, - AnalysisType, - IssueCategory, - CodeLocation ) # Analysis result models @@ -40,54 +51,43 @@ AnalysisResult, CodeQualityResult, DependencyResult, - PrAnalysisResult + PrAnalysisResult, ) -# Core analysis modules -from codegen_on_oss.analyzers.code_quality import CodeQualityAnalyzer -from codegen_on_oss.analyzers.dependencies import DependencyAnalyzer - -# Legacy analyzer interfaces (for backward compatibility) -from codegen_on_oss.analyzers.base_analyzer import BaseCodeAnalyzer -from codegen_on_oss.analyzers.codebase_analyzer import CodebaseAnalyzer -from codegen_on_oss.analyzers.error_analyzer import CodebaseAnalyzer as ErrorAnalyzer - __all__ = [ - # Main API - 'CodegenAnalyzerAPI', - 'create_api', - 'api_analyze_codebase', - 'api_analyze_pr', - 'api_get_visualization', - 'api_get_static_errors', - - # Modern architecture - 'AnalyzerManager', - 'AnalyzerPlugin', - 'AnalyzerRegistry', - 'CodeQualityPlugin', - 'DependencyPlugin', - - # Issue tracking - 'Issue', - 'IssueCollection', - 'IssueSeverity', - 'AnalysisType', - 'IssueCategory', - 'CodeLocation', - # Analysis results - 'AnalysisResult', - 'CodeQualityResult', - 'DependencyResult', - 'PrAnalysisResult', - - # Core analyzers - 'CodeQualityAnalyzer', - 'DependencyAnalyzer', - + "AnalysisResult", + "AnalysisType", + # Modern architecture + "AnalyzerManager", + "AnalyzerPlugin", + "AnalyzerRegistry", # Legacy interfaces (for backward compatibility) - 'BaseCodeAnalyzer', - 'CodebaseAnalyzer', - 'ErrorAnalyzer', -] \ No newline at end of file + "BaseCodeAnalyzer", + # Diff tracking + "ChangeType", + "CodeLocation", + # Core analyzers + "CodeQualityAnalyzer", + "CodeQualityPlugin", + "CodeQualityResult", + "CodebaseAnalyzer", + # Main API + "CodegenAnalyzerAPI", + "DependencyAnalyzer", + "DependencyPlugin", + "DependencyResult", + "DiffLite", + "ErrorAnalyzer", + # Issue tracking + "Issue", + "IssueCategory", + "IssueCollection", + "IssueSeverity", + "PrAnalysisResult", + "api_analyze_codebase", + "api_analyze_pr", + "api_get_static_errors", + "api_get_visualization", + "create_api", +] diff --git a/codegen-on-oss/codegen_on_oss/analyzers/analyzer.py b/codegen-on-oss/codegen_on_oss/analyzers/analyzer.py index 4337bba5b..1b0e774cc 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/analyzer.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/analyzer.py @@ -7,35 +7,37 @@ It serves as the primary API entry point for the analyzer backend. """ -import os -import sys import json import logging +import sys from datetime import datetime -from pathlib import Path -from typing import Dict, List, Set, Tuple, Any, Optional, Union, Type, Callable -from enum import Enum +from typing import Any try: - from codegen.sdk.core.codebase import Codebase from codegen.configs.models.codebase import CodebaseConfig from codegen.configs.models.secrets import SecretsConfig - from codegen.sdk.codebase.config import ProjectConfig - from codegen.git.schemas.repo_config import RepoConfig from codegen.git.repo_operator.repo_operator import RepoOperator + from codegen.git.schemas.repo_config import RepoConfig + from codegen.sdk.codebase.config import ProjectConfig + from codegen.sdk.core.codebase import Codebase from codegen.shared.enums.programming_language import ProgrammingLanguage except ImportError: print("Codegen SDK not found. Please install it first.") sys.exit(1) # Import internal modules - these will be replaced with actual imports once implemented -from codegen_on_oss.analyzers.issues import Issue, IssueSeverity, AnalysisType, IssueCategory +from codegen_on_oss.analyzers.issues import ( + AnalysisType, + Issue, + IssueCategory, + IssueSeverity, +) # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) @@ -56,108 +58,115 @@ ".vscode", ] + class AnalyzerRegistry: """Registry of analyzer plugins.""" - + _instance = None - + def __new__(cls): if cls._instance is None: - cls._instance = super(AnalyzerRegistry, cls).__new__(cls) - cls._instance._analyzers = {} + cls._instance = super().__new__(cls) + cls._analyzers = {} return cls._instance - - def register(self, analysis_type: AnalysisType, analyzer_class: Type['AnalyzerPlugin']): + + def register( + self, analysis_type: AnalysisType, analyzer_class: type["AnalyzerPlugin"] + ): """Register an analyzer plugin.""" self._analyzers[analysis_type] = analyzer_class - - def get_analyzer(self, analysis_type: AnalysisType) -> Optional[Type['AnalyzerPlugin']]: + + def get_analyzer( + self, analysis_type: AnalysisType + ) -> type["AnalyzerPlugin"] | None: """Get the analyzer plugin for a specific analysis type.""" return self._analyzers.get(analysis_type) - - def list_analyzers(self) -> Dict[AnalysisType, Type['AnalyzerPlugin']]: + + def list_analyzers(self) -> dict[AnalysisType, type["AnalyzerPlugin"]]: """Get all registered analyzers.""" return self._analyzers.copy() + class AnalyzerPlugin: """Base class for analyzer plugins.""" - - def __init__(self, manager: 'AnalyzerManager'): + + def __init__(self, manager: "AnalyzerManager"): """Initialize the analyzer plugin.""" self.manager = manager self.issues = [] - - def analyze(self) -> Dict[str, Any]: + + def analyze(self) -> dict[str, Any]: """Perform analysis using this plugin.""" raise NotImplementedError("Analyzer plugins must implement analyze()") - + def add_issue(self, issue: Issue): """Add an issue to the list.""" self.manager.add_issue(issue) self.issues.append(issue) + class CodeQualityPlugin(AnalyzerPlugin): """Plugin for code quality analysis.""" - - def analyze(self) -> Dict[str, Any]: + + def analyze(self) -> dict[str, Any]: """Perform code quality analysis.""" # This is a simplified placeholder - would import and use code_quality.py result = { "dead_code": self._find_dead_code(), "complexity": self._analyze_complexity(), "maintainability": self._analyze_maintainability(), - "style_issues": self._analyze_style_issues() + "style_issues": self._analyze_style_issues(), } return result - - def _find_dead_code(self) -> Dict[str, Any]: + + def _find_dead_code(self) -> dict[str, Any]: """Find unused code in the codebase.""" # This is a placeholder return {"unused_functions": [], "unused_classes": [], "unused_variables": []} - - def _analyze_complexity(self) -> Dict[str, Any]: + + def _analyze_complexity(self) -> dict[str, Any]: """Analyze code complexity.""" # This is a placeholder return {"complex_functions": [], "average_complexity": 0} - - def _analyze_maintainability(self) -> Dict[str, Any]: + + def _analyze_maintainability(self) -> dict[str, Any]: """Analyze code maintainability.""" # This is a placeholder return {"maintainability_index": {}} - - def _analyze_style_issues(self) -> Dict[str, Any]: + + def _analyze_style_issues(self) -> dict[str, Any]: """Analyze code style issues.""" # This is a placeholder return {"style_violations": []} + class DependencyPlugin(AnalyzerPlugin): """Plugin for dependency analysis.""" - def analyze(self) -> Dict[str, Any]: + def analyze(self) -> dict[str, Any]: """Perform dependency analysis using the DependencyAnalyzer.""" - from codegen_on_oss.analyzers.dependencies import DependencyAnalyzer from codegen_on_oss.analyzers.codebase_context import CodebaseContext + from codegen_on_oss.analyzers.dependencies import DependencyAnalyzer # Create context if needed - context = getattr(self.manager, 'base_context', None) - if not context and hasattr(self.manager, 'base_codebase'): + context = getattr(self.manager, "base_context", None) + if not context and hasattr(self.manager, "base_codebase"): try: context = CodebaseContext( codebase=self.manager.base_codebase, base_path=self.manager.repo_path, pr_branch=None, - base_branch=self.manager.base_branch + base_branch=self.manager.base_branch, ) # Save context for future use self.manager.base_context = context - except Exception as e: - logger.error(f"Error initializing context: {e}") + except Exception: + logger.exception("Error initializing context") # Initialize and run the dependency analyzer if context: dependency_analyzer = DependencyAnalyzer( - codebase=self.manager.base_codebase, - context=context + codebase=self.manager.base_codebase, context=context ) # Run analysis @@ -173,43 +182,44 @@ def analyze(self) -> Dict[str, Any]: result = { "import_dependencies": self._analyze_imports(), "circular_dependencies": self._find_circular_dependencies(), - "module_coupling": self._analyze_module_coupling() + "module_coupling": self._analyze_module_coupling(), } return result - def _analyze_imports(self) -> Dict[str, Any]: + def _analyze_imports(self) -> dict[str, Any]: """Fallback import analysis if context initialization failed.""" return {"module_dependencies": [], "external_dependencies": []} - def _find_circular_dependencies(self) -> Dict[str, Any]: + def _find_circular_dependencies(self) -> dict[str, Any]: """Fallback circular dependencies analysis if context initialization failed.""" return {"circular_imports": []} - def _analyze_module_coupling(self) -> Dict[str, Any]: + def _analyze_module_coupling(self) -> dict[str, Any]: """Fallback module coupling analysis if context initialization failed.""" return {"high_coupling_modules": []} + class AnalyzerManager: """ Unified manager for codebase analysis. - + This class serves as the main entry point for all analysis operations, coordinating different analyzer plugins and managing results. """ - + def __init__( self, - repo_url: Optional[str] = None, - repo_path: Optional[str] = None, + repo_url: str | None = None, + repo_path: str | None = None, base_branch: str = "main", - pr_number: Optional[int] = None, - language: Optional[str] = None, - file_ignore_list: Optional[List[str]] = None, - config: Optional[Dict[str, Any]] = None + pr_number: int | None = None, + language: str | None = None, + file_ignore_list: list[str] | None = None, + config: dict[str, Any] | None = None, ): """ Initialize the analyzer manager. - + Args: repo_url: URL of the repository to analyze repo_path: Local path to the repository to analyze @@ -224,88 +234,89 @@ def __init__( self.base_branch = base_branch self.pr_number = pr_number self.language = language - + # Use custom ignore list or default global list self.file_ignore_list = file_ignore_list or GLOBAL_FILE_IGNORE_LIST - + # Configuration options self.config = config or {} - + # Codebase and context objects self.base_codebase = None self.pr_codebase = None - + # Analysis results self.issues = [] self.results = {} - + # PR comparison data self.pr_diff = None self.commit_shas = None self.modified_symbols = None self.pr_branch = None - + # Initialize codebase(s) based on provided parameters if repo_url: self._init_from_url(repo_url, language) elif repo_path: self._init_from_path(repo_path, language) - + # If PR number is provided, initialize PR-specific data if self.pr_number is not None and self.base_codebase is not None: self._init_pr_data(self.pr_number) - + # Register default analyzers self._register_default_analyzers() - - def _init_from_url(self, repo_url: str, language: Optional[str] = None): + + def _init_from_url(self, repo_url: str, language: str | None = None): """Initialize codebase from a repository URL.""" try: # Extract repository information - if repo_url.endswith('.git'): + if repo_url.endswith(".git"): repo_url = repo_url[:-4] - - parts = repo_url.rstrip('/').split('/') + + parts = repo_url.rstrip("/").split("/") repo_name = parts[-1] owner = parts[-2] repo_full_name = f"{owner}/{repo_name}" - + # Create temporary directory for cloning import tempfile + tmp_dir = tempfile.mkdtemp(prefix="analyzer_") - + # Set up configuration config = CodebaseConfig( debug=False, allow_external=True, py_resolve_syspath=True, ) - + secrets = SecretsConfig() - + # Determine programming language prog_lang = None if language: prog_lang = ProgrammingLanguage(language.upper()) - + # Initialize the codebase logger.info(f"Initializing codebase from {repo_url}") - + self.base_codebase = Codebase.from_github( repo_full_name=repo_full_name, tmp_dir=tmp_dir, language=prog_lang, config=config, - secrets=secrets + secrets=secrets, ) - + logger.info(f"Successfully initialized codebase from {repo_url}") - - except Exception as e: - logger.error(f"Error initializing codebase from URL: {e}") + + except Exception: + logger.exception("Error initializing codebase from URL") raise - - def _init_from_path(self, repo_path: str, language: Optional[str] = None): + + def _init_from_path(self, repo_path: str, language: str | None = None): """Initialize codebase from a local repository path.""" try: # Set up configuration @@ -314,163 +325,167 @@ def _init_from_path(self, repo_path: str, language: Optional[str] = None): allow_external=True, py_resolve_syspath=True, ) - + secrets = SecretsConfig() - + # Initialize the codebase logger.info(f"Initializing codebase from {repo_path}") - + # Determine programming language prog_lang = None if language: prog_lang = ProgrammingLanguage(language.upper()) - + # Set up repository configuration repo_config = RepoConfig.from_repo_path(repo_path) repo_config.respect_gitignore = False repo_operator = RepoOperator(repo_config=repo_config, bot_commit=False) - + # Create project configuration project_config = ProjectConfig( repo_operator=repo_operator, - programming_language=prog_lang if prog_lang else None + programming_language=prog_lang if prog_lang else None, ) - + # Initialize codebase self.base_codebase = Codebase( - projects=[project_config], - config=config, - secrets=secrets + projects=[project_config], config=config, secrets=secrets ) - + logger.info(f"Successfully initialized codebase from {repo_path}") - - except Exception as e: - logger.error(f"Error initializing codebase from path: {e}") + + except Exception: + logger.exception("Error initializing codebase from path") raise - + def _init_pr_data(self, pr_number: int): """Initialize PR-specific data.""" try: logger.info(f"Fetching PR #{pr_number} data") result = self.base_codebase.get_modified_symbols_in_pr(pr_number) - + # Unpack the result tuple if len(result) >= 3: self.pr_diff, self.commit_shas, self.modified_symbols = result[:3] if len(result) >= 4: self.pr_branch = result[3] - + logger.info(f"Found {len(self.modified_symbols)} modified symbols in PR") - + # Initialize PR codebase self._init_pr_codebase() - - except Exception as e: - logger.error(f"Error initializing PR data: {e}") + + except Exception: + logger.exception("Error initializing PR data") raise - + def _init_pr_codebase(self): """Initialize PR codebase by checking out the PR branch.""" if not self.base_codebase or not self.pr_number: logger.error("Base codebase or PR number not initialized") return - + try: # Get PR data if not already fetched if not self.pr_branch: self._init_pr_data(self.pr_number) - + if not self.pr_branch: logger.error("Failed to get PR branch") return - + # Clone the base codebase self.pr_codebase = self.base_codebase - + # Checkout PR branch logger.info(f"Checking out PR branch: {self.pr_branch}") self.pr_codebase.checkout(self.pr_branch) - + logger.info("Successfully initialized PR codebase") - - except Exception as e: - logger.error(f"Error initializing PR codebase: {e}") + + except Exception: + logger.exception("Error initializing PR codebase") raise - + def _register_default_analyzers(self): """Register default analyzers.""" registry = AnalyzerRegistry() registry.register(AnalysisType.CODE_QUALITY, CodeQualityPlugin) registry.register(AnalysisType.DEPENDENCY, DependencyPlugin) - + def add_issue(self, issue: Issue): """Add an issue to the list.""" # Check if issue should be skipped if self._should_skip_issue(issue): return - + self.issues.append(issue) - + def _should_skip_issue(self, issue: Issue) -> bool: """Check if an issue should be skipped.""" # Skip issues in ignored files file_path = issue.file - + # Check against ignore list for pattern in self.file_ignore_list: if pattern in file_path: return True - + # Check if the file is a test file - if "test" in file_path.lower() or "tests" in file_path.lower(): + if ( + "test" in file_path.lower() or "tests" in file_path.lower() + ) and issue.severity in [IssueSeverity.INFO, IssueSeverity.WARNING]: # Skip low-severity issues in test files - if issue.severity in [IssueSeverity.INFO, IssueSeverity.WARNING]: - return True - + return False + return False - - def get_issues(self, severity: Optional[IssueSeverity] = None, category: Optional[IssueCategory] = None) -> List[Issue]: + + def get_issues( + self, + severity: IssueSeverity | None = None, + category: IssueCategory | None = None, + ) -> list[Issue]: """ Get all issues matching the specified criteria. - + Args: severity: Optional severity level to filter by category: Optional category to filter by - + Returns: List of matching issues """ filtered_issues = self.issues - + if severity: filtered_issues = [i for i in filtered_issues if i.severity == severity] - + if category: filtered_issues = [i for i in filtered_issues if i.category == category] - + return filtered_issues - + def analyze( - self, - analysis_types: Optional[List[Union[AnalysisType, str]]] = None, - output_file: Optional[str] = None, - output_format: str = "json" - ) -> Dict[str, Any]: + self, + analysis_types: list[AnalysisType | str] | None = None, + output_file: str | None = None, + output_format: str = "json", + ) -> dict[str, Any]: """ Perform analysis on the codebase. - + Args: analysis_types: List of analysis types to perform output_file: Path to save results to output_format: Format of the output file - + Returns: Dictionary containing analysis results """ if not self.base_codebase: - raise ValueError("Codebase not initialized") - + msg = "Codebase not initialized" + raise ValueError(msg) + # Convert string analysis types to enums if analysis_types: analysis_types = [ @@ -480,78 +495,94 @@ def analyze( else: # Default to code quality and dependency analysis analysis_types = [AnalysisType.CODE_QUALITY, AnalysisType.DEPENDENCY] - + # Initialize results self.results = { "metadata": { "analysis_time": datetime.now().isoformat(), "analysis_types": [t.value for t in analysis_types], - "repo_name": getattr(self.base_codebase.ctx, 'repo_name', None), - "language": str(getattr(self.base_codebase.ctx, 'programming_language', None)), + "repo_name": getattr(self.base_codebase.ctx, "repo_name", None), + "language": str( + getattr(self.base_codebase.ctx, "programming_language", None) + ), }, "summary": {}, - "results": {} + "results": {}, } - + # Reset issues self.issues = [] - + # Run each analyzer registry = AnalyzerRegistry() - + for analysis_type in analysis_types: analyzer_class = registry.get_analyzer(analysis_type) - + if analyzer_class: logger.info(f"Running {analysis_type.value} analysis") analyzer = analyzer_class(self) analysis_result = analyzer.analyze() - + # Add results to unified results self.results["results"][analysis_type.value] = analysis_result else: logger.warning(f"No analyzer found for {analysis_type.value}") - + # Add issues to results self.results["issues"] = [issue.to_dict() for issue in self.issues] - + # Add issue statistics self.results["issue_stats"] = { "total": len(self.issues), "by_severity": { - "critical": sum(1 for issue in self.issues if issue.severity == IssueSeverity.CRITICAL), - "error": sum(1 for issue in self.issues if issue.severity == IssueSeverity.ERROR), - "warning": sum(1 for issue in self.issues if issue.severity == IssueSeverity.WARNING), - "info": sum(1 for issue in self.issues if issue.severity == IssueSeverity.INFO), - } + "critical": sum( + 1 + for issue in self.issues + if issue.severity == IssueSeverity.CRITICAL + ), + "error": sum( + 1 for issue in self.issues if issue.severity == IssueSeverity.ERROR + ), + "warning": sum( + 1 + for issue in self.issues + if issue.severity == IssueSeverity.WARNING + ), + "info": sum( + 1 for issue in self.issues if issue.severity == IssueSeverity.INFO + ), + }, } - + # Save results if output file is specified if output_file: self.save_results(output_file, output_format) - + return self.results - - def save_results(self, output_file: str, format: str = "json"): - """ - Save analysis results to a file. - + + def save_results(self, output_file: str, output_format: str = "json"): + """Save analysis results to a file. + Args: output_file: Path to the output file - format: Output format (json, html) + output_format: Format to save results in (json, yaml, markdown) """ - if format == "json": - with open(output_file, 'w') as f: - json.dump(self.results, f, indent=2) - elif format == "html": - self._generate_html_report(output_file) - else: - # Default to JSON - with open(output_file, 'w') as f: - json.dump(self.results, f, indent=2) - - logger.info(f"Results saved to {output_file}") - + with open(output_file, "w") as f: + if output_format == "json": + import json + + f.write(json.dumps(self.results.to_dict(), indent=2)) + elif output_format == "yaml": + import yaml + + f.write(yaml.dump(self.results.to_dict())) + elif output_format == "markdown": + f.write(self.generate_report("detailed")) + else: + msg = f"Unsupported format: {output_format}" + raise ValueError(msg) + def _generate_html_report(self, output_file: str): """Generate an HTML report of the analysis results.""" html_content = f""" @@ -578,307 +609,373 @@ def _generate_html_report(self, output_file: str):

Codebase Analysis Report

Summary

-

Repository: {self.results['metadata'].get('repo_name', 'Unknown')}

-

Language: {self.results['metadata'].get('language', 'Unknown')}

-

Analysis Time: {self.results['metadata'].get('analysis_time', 'Unknown')}

-

Analysis Types: {', '.join(self.results['metadata'].get('analysis_types', []))}

+

Repository: {self.results["metadata"].get("repo_name", "Unknown")}

+

Language: {self.results["metadata"].get("language", "Unknown")}

+

Analysis Time: {self.results["metadata"].get("analysis_time", "Unknown")}

+

Analysis Types: {", ".join(self.results["metadata"].get("analysis_types", []))}

Total Issues: {len(self.issues)}

""" - + # Add issues section html_content += """

Issues

""" - + # Add issues by severity for severity in ["critical", "error", "warning", "info"]: - severity_issues = [issue for issue in self.issues if issue.severity.value == severity] - + severity_issues = [ + issue for issue in self.issues if issue.severity.value == severity + ] + if severity_issues: html_content += f"""

{severity.upper()} Issues ({len(severity_issues)})

""" - + for issue in severity_issues: - location = f"{issue.file}:{issue.line}" if issue.line else issue.file - category = f"[{issue.category.value}]" if hasattr(issue, 'category') and issue.category else "" - + location = ( + f"{issue.file}:{issue.line}" if issue.line else issue.file + ) + category = ( + f"[{issue.category.value}]" + if hasattr(issue, "category") and issue.category + else "" + ) + html_content += f"""

{location} {category} {issue.message}

-

{issue.suggestion if hasattr(issue, 'suggestion') else ""}

+

{issue.suggestion if hasattr(issue, "suggestion") else ""}

""" - + html_content += """
""" - + # Add detailed analysis sections html_content += """

Detailed Analysis

""" - - for analysis_type, results in self.results.get('results', {}).items(): + + for analysis_type, results in self.results.get("results", {}).items(): html_content += f"""

{analysis_type}

{json.dumps(results, indent=2)}
""" - + html_content += """
""" - - with open(output_file, 'w') as f: + + with open(output_file, "w") as f: f.write(html_content) - + def generate_report(self, report_type: str = "summary") -> str: - """ - Generate a report from the analysis results. - + """Generate a report of the analysis results. + Args: report_type: Type of report to generate (summary, detailed, issues) - + Returns: Report as a string """ - if not self.results: - raise ValueError("No analysis results available") - if report_type == "summary": return self._generate_summary_report() elif report_type == "detailed": return self._generate_detailed_report() elif report_type == "issues": return self._generate_issues_report() - else: - raise ValueError(f"Unknown report type: {report_type}") - + + msg = f"Unknown report type: {report_type}" + raise ValueError(msg) + def _generate_summary_report(self) -> str: """Generate a summary report.""" report = "===== Codebase Analysis Summary Report =====\n\n" - + # Add metadata - report += f"Repository: {self.results['metadata'].get('repo_name', 'Unknown')}\n" + report += ( + f"Repository: {self.results['metadata'].get('repo_name', 'Unknown')}\n" + ) report += f"Language: {self.results['metadata'].get('language', 'Unknown')}\n" report += f"Analysis Time: {self.results['metadata'].get('analysis_time', 'Unknown')}\n" report += f"Analysis Types: {', '.join(self.results['metadata'].get('analysis_types', []))}\n\n" - + # Add issue statistics report += f"Total Issues: {len(self.issues)}\n" report += f"Critical: {self.results['issue_stats']['by_severity'].get('critical', 0)}\n" - report += f"Errors: {self.results['issue_stats']['by_severity'].get('error', 0)}\n" + report += ( + f"Errors: {self.results['issue_stats']['by_severity'].get('error', 0)}\n" + ) report += f"Warnings: {self.results['issue_stats']['by_severity'].get('warning', 0)}\n" - report += f"Info: {self.results['issue_stats']['by_severity'].get('info', 0)}\n\n" - + report += ( + f"Info: {self.results['issue_stats']['by_severity'].get('info', 0)}\n\n" + ) + # Add analysis summaries - for analysis_type, results in self.results.get('results', {}).items(): + for analysis_type, results in self.results.get("results", {}).items(): report += f"===== {analysis_type.upper()} Analysis =====\n" - + if analysis_type == "code_quality": if "dead_code" in results: dead_code = results["dead_code"] report += f"Dead Code: {len(dead_code.get('unused_functions', []))} unused functions, " - report += f"{len(dead_code.get('unused_classes', []))} unused classes\n" - + report += ( + f"{len(dead_code.get('unused_classes', []))} unused classes\n" + ) + if "complexity" in results: complexity = results["complexity"] report += f"Complexity: {len(complexity.get('complex_functions', []))} complex functions\n" - + elif analysis_type == "dependency": if "circular_dependencies" in results: circular = results["circular_dependencies"] report += f"Circular Dependencies: {len(circular.get('circular_imports', []))}\n" - + if "module_coupling" in results: coupling = results["module_coupling"] report += f"High Coupling Modules: {len(coupling.get('high_coupling_modules', []))}\n" - + report += "\n" - + return report - - def _generate_detailed_report(self) -> str: + + def _generate_detailed_report(self) -> str: # noqa: C901 """Generate a detailed report.""" report = "===== Codebase Analysis Detailed Report =====\n\n" - + # Add metadata - report += f"Repository: {self.results['metadata'].get('repo_name', 'Unknown')}\n" + report += ( + f"Repository: {self.results['metadata'].get('repo_name', 'Unknown')}\n" + ) report += f"Language: {self.results['metadata'].get('language', 'Unknown')}\n" report += f"Analysis Time: {self.results['metadata'].get('analysis_time', 'Unknown')}\n" report += f"Analysis Types: {', '.join(self.results['metadata'].get('analysis_types', []))}\n\n" - + # Add detailed issue report report += "===== Issues =====\n\n" - + for severity in ["critical", "error", "warning", "info"]: - severity_issues = [issue for issue in self.issues if issue.severity.value == severity] - + severity_issues = [ + issue for issue in self.issues if issue.severity.value == severity + ] + if severity_issues: report += f"{severity.upper()} Issues ({len(severity_issues)}):\n" - + for issue in severity_issues: - location = f"{issue.file}:{issue.line}" if issue.line else issue.file - category = f"[{issue.category.value}]" if hasattr(issue, 'category') and issue.category else "" - + location = ( + f"{issue.file}:{issue.line}" if issue.line else issue.file + ) + category = ( + f"[{issue.category.value}]" + if hasattr(issue, "category") and issue.category + else "" + ) + report += f"- {location} {category} {issue.message}\n" - if hasattr(issue, 'suggestion') and issue.suggestion: + if hasattr(issue, "suggestion") and issue.suggestion: report += f" Suggestion: {issue.suggestion}\n" - + report += "\n" - + # Add detailed analysis - for analysis_type, results in self.results.get('results', {}).items(): + for analysis_type, results in self.results.get("results", {}).items(): report += f"===== {analysis_type.upper()} Analysis =====\n\n" - + # Format based on analysis type if analysis_type == "code_quality": # Dead code details if "dead_code" in results: dead_code = results["dead_code"] report += "Dead Code:\n" - - if dead_code.get('unused_functions'): + + if dead_code.get("unused_functions"): report += " Unused Functions:\n" - for func in dead_code.get('unused_functions', [])[:10]: # Limit to 10 + for func in dead_code.get("unused_functions", [])[ + :10 + ]: # Limit to 10 report += f" - {func.get('name')} ({func.get('file')})\n" - - if len(dead_code.get('unused_functions', [])) > 10: + + if len(dead_code.get("unused_functions", [])) > 10: report += f" ... and {len(dead_code.get('unused_functions', [])) - 10} more\n" - - if dead_code.get('unused_classes'): + + if dead_code.get("unused_classes"): report += " Unused Classes:\n" - for cls in dead_code.get('unused_classes', [])[:10]: # Limit to 10 + for cls in dead_code.get("unused_classes", [])[ + :10 + ]: # Limit to 10 report += f" - {cls.get('name')} ({cls.get('file')})\n" - - if len(dead_code.get('unused_classes', [])) > 10: + + if len(dead_code.get("unused_classes", [])) > 10: report += f" ... and {len(dead_code.get('unused_classes', [])) - 10} more\n" - + report += "\n" - + # Complexity details if "complexity" in results: complexity = results["complexity"] report += "Code Complexity:\n" - - if complexity.get('complex_functions'): + + if complexity.get("complex_functions"): report += " Complex Functions:\n" - for func in complexity.get('complex_functions', [])[:10]: # Limit to 10 + for func in complexity.get("complex_functions", [])[ + :10 + ]: # Limit to 10 report += f" - {func.get('name')} (Complexity: {func.get('complexity')}, {func.get('file')})\n" - - if len(complexity.get('complex_functions', [])) > 10: + + if len(complexity.get("complex_functions", [])) > 10: report += f" ... and {len(complexity.get('complex_functions', [])) - 10} more\n" - + report += "\n" - + elif analysis_type == "dependency": # Circular dependencies if "circular_dependencies" in results: circular = results["circular_dependencies"] report += "Circular Dependencies:\n" - - if circular.get('circular_imports'): - for i, cycle in enumerate(circular.get('circular_imports', [])[:5]): # Limit to 5 - report += f" Cycle {i+1} (Length: {cycle.get('length')}):\n" - for j, file_path in enumerate(cycle.get('files', [])): - report += f" {j+1}. {file_path}\n" - - if len(circular.get('circular_imports', [])) > 5: + + if circular.get("circular_imports"): + for i, cycle in enumerate( + circular.get("circular_imports", [])[:5] + ): # Limit to 5 + report += ( + f" Cycle {i + 1} (Length: {cycle.get('length')}):\n" + ) + for j, file_path in enumerate(cycle.get("files", [])): + report += f" {j + 1}. {file_path}\n" + + if len(circular.get("circular_imports", [])) > 5: report += f" ... and {len(circular.get('circular_imports', [])) - 5} more cycles\n" - + report += "\n" - + # Module coupling if "module_coupling" in results: coupling = results["module_coupling"] report += "Module Coupling:\n" - - if coupling.get('high_coupling_modules'): + + if coupling.get("high_coupling_modules"): report += " High Coupling Modules:\n" - for module in coupling.get('high_coupling_modules', [])[:10]: # Limit to 10 + for module in coupling.get("high_coupling_modules", [])[ + :10 + ]: # Limit to 10 report += f" - {module.get('module')} (Ratio: {module.get('coupling_ratio'):.2f})\n" - - if len(coupling.get('high_coupling_modules', [])) > 10: + + if len(coupling.get("high_coupling_modules", [])) > 10: report += f" ... and {len(coupling.get('high_coupling_modules', [])) - 10} more\n" - + report += "\n" - + return report - + def _generate_issues_report(self) -> str: """Generate an issues-focused report.""" report = "===== Codebase Analysis Issues Report =====\n\n" - + # Add issue statistics report += f"Total Issues: {len(self.issues)}\n" report += f"Critical: {self.results['issue_stats']['by_severity'].get('critical', 0)}\n" - report += f"Errors: {self.results['issue_stats']['by_severity'].get('error', 0)}\n" + report += ( + f"Errors: {self.results['issue_stats']['by_severity'].get('error', 0)}\n" + ) report += f"Warnings: {self.results['issue_stats']['by_severity'].get('warning', 0)}\n" - report += f"Info: {self.results['issue_stats']['by_severity'].get('info', 0)}\n\n" - + report += ( + f"Info: {self.results['issue_stats']['by_severity'].get('info', 0)}\n\n" + ) + # Add issues by severity for severity in ["critical", "error", "warning", "info"]: - severity_issues = [issue for issue in self.issues if issue.severity.value == severity] - + severity_issues = [ + issue for issue in self.issues if issue.severity.value == severity + ] + if severity_issues: report += f"{severity.upper()} Issues ({len(severity_issues)}):\n" - + for issue in severity_issues: - location = f"{issue.file}:{issue.line}" if issue.line else issue.file - category = f"[{issue.category.value}]" if hasattr(issue, 'category') and issue.category else "" - + location = ( + f"{issue.file}:{issue.line}" if issue.line else issue.file + ) + category = ( + f"[{issue.category.value}]" + if hasattr(issue, "category") and issue.category + else "" + ) + report += f"- {location} {category} {issue.message}\n" - if hasattr(issue, 'suggestion') and issue.suggestion: + if hasattr(issue, "suggestion") and issue.suggestion: report += f" Suggestion: {issue.suggestion}\n" - + report += "\n" - + return report + def main(): """Command-line entry point.""" import argparse - + parser = argparse.ArgumentParser(description="Unified Codebase Analyzer") - + # Repository source options source_group = parser.add_mutually_exclusive_group(required=True) source_group.add_argument("--repo-url", help="URL of the repository to analyze") - source_group.add_argument("--repo-path", help="Local path to the repository to analyze") - + source_group.add_argument( + "--repo-path", help="Local path to the repository to analyze" + ) + # Analysis options - parser.add_argument("--analysis-types", nargs="+", choices=[at.value for at in AnalysisType], - default=["code_quality", "dependency"], - help="Types of analysis to perform") - parser.add_argument("--language", choices=["python", "typescript"], - help="Programming language (auto-detected if not provided)") - parser.add_argument("--base-branch", default="main", - help="Base branch for PR comparison (default: main)") - parser.add_argument("--pr-number", type=int, - help="PR number to analyze") - + parser.add_argument( + "--analysis-types", + nargs="+", + choices=[at.value for at in AnalysisType], + default=["code_quality", "dependency"], + help="Types of analysis to perform", + ) + parser.add_argument( + "--language", + choices=["python", "typescript"], + help="Programming language (auto-detected if not provided)", + ) + parser.add_argument( + "--base-branch", + default="main", + help="Base branch for PR comparison (default: main)", + ) + parser.add_argument("--pr-number", type=int, help="PR number to analyze") + # Output options - parser.add_argument("--output-file", - help="Path to the output file") - parser.add_argument("--output-format", choices=["json", "html", "console"], default="json", - help="Output format") - parser.add_argument("--report-type", choices=["summary", "detailed", "issues"], default="summary", - help="Type of report to generate (default: summary)") - + parser.add_argument("--output-file", help="Path to the output file") + parser.add_argument( + "--output-format", + choices=["json", "yaml", "markdown", "console"], + default="json", + help="Output format", + ) + parser.add_argument( + "--report-type", + choices=["summary", "detailed", "issues"], + default="summary", + help="Type of report to generate (default: summary)", + ) + args = parser.parse_args() - + try: # Initialize the analyzer manager manager = AnalyzerManager( @@ -886,26 +983,28 @@ def main(): repo_path=args.repo_path, language=args.language, base_branch=args.base_branch, - pr_number=args.pr_number + pr_number=args.pr_number, ) - + # Run the analysis manager.analyze( analysis_types=args.analysis_types, output_file=args.output_file, - output_format=args.output_format + output_format=args.output_format, ) - + # Generate and print report if format is console if args.output_format == "console": report = manager.generate_report(args.report_type) print(report) - - except Exception as e: - logger.error(f"Error: {e}") + + except Exception: + logger.exception("Error") import traceback + traceback.print_exc() sys.exit(1) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/codegen-on-oss/codegen_on_oss/analyzers/analyzer_manager.py b/codegen-on-oss/codegen_on_oss/analyzers/analyzer_manager.py index 4458ee541..b4855dfcb 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/analyzer_manager.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/analyzer_manager.py @@ -6,22 +6,18 @@ It coordinates the execution of different analyzer types and aggregates their results. """ -import os -import sys -import json import logging -from typing import Dict, List, Set, Tuple, Any, Optional, Union, Type -from datetime import datetime -from pathlib import Path +import sys +from typing import Any try: - from codegen_on_oss.analyzers.unified_analyzer import ( - UnifiedCodeAnalyzer, - AnalyzerRegistry, - CodeQualityAnalyzerPlugin, - DependencyAnalyzerPlugin + from codegen_on_oss.analyzers.issue_types import ( + AnalysisType, + Issue, + IssueCategory, + IssueSeverity, ) - from codegen_on_oss.analyzers.issue_types import Issue, IssueSeverity, AnalysisType, IssueCategory + from codegen_on_oss.analyzers.unified_analyzer import UnifiedCodeAnalyzer except ImportError: print("Required analyzer modules not found.") sys.exit(1) @@ -29,29 +25,32 @@ # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) + class AnalyzerManager: """ Central manager for running different types of code analysis. - + This class provides a unified interface for running various analyzers and aggregating their results. """ - - def __init__(self, - repo_url: Optional[str] = None, - repo_path: Optional[str] = None, - language: Optional[str] = None, - base_branch: str = "main", - pr_number: Optional[int] = None, - config: Optional[Dict[str, Any]] = None): + + def __init__( + self, + repo_url: str | None = None, + repo_path: str | None = None, + language: str | None = None, + base_branch: str = "main", + pr_number: int | None = None, + config: dict[str, Any] | None = None, + ): """ Initialize the analyzer manager. - + Args: repo_url: URL of the repository to analyze repo_path: Local path to the repository to analyze @@ -66,7 +65,7 @@ def __init__(self, self.base_branch = base_branch self.pr_number = pr_number self.config = config or {} - + # Initialize the unified analyzer self.analyzer = UnifiedCodeAnalyzer( repo_url=repo_url, @@ -74,88 +73,98 @@ def __init__(self, base_branch=base_branch, pr_number=pr_number, language=language, - config=config + config=config, ) - + # Register additional analyzers (if any) self._register_custom_analyzers() - + def _register_custom_analyzers(self): """Register custom analyzers with the registry.""" # The default analyzers (CODE_QUALITY and DEPENDENCY) are registered automatically # This method can be overridden by subclasses to register additional analyzers pass - - def run_analysis(self, - analysis_types: Optional[List[AnalysisType]] = None, - output_file: Optional[str] = None, - output_format: str = "json") -> Dict[str, Any]: + + def run_analysis( + self, + analysis_types: list[AnalysisType] | None = None, + output_file: str | None = None, + output_format: str = "json", + ) -> dict[str, Any]: """ Run analysis on the codebase. - + Args: analysis_types: Types of analysis to run (defaults to CODE_QUALITY and DEPENDENCY) output_file: Path to save results to (None for no save) output_format: Format for output file (json, html, console) - + Returns: Dictionary containing analysis results """ # Default to code quality and dependency analysis if analysis_types is None: analysis_types = [AnalysisType.CODE_QUALITY, AnalysisType.DEPENDENCY] - + try: # Run the analysis - logger.info(f"Running analysis: {', '.join([at.value for at in analysis_types])}") + logger.info( + f"Running analysis: {', '.join([at.value for at in analysis_types])}" + ) results = self.analyzer.analyze(analysis_types) - + # Save results if output file is specified if output_file: logger.info(f"Saving results to {output_file}") self.analyzer.save_results(output_file, output_format) - - return results - - except Exception as e: - logger.error(f"Error running analysis: {e}") + + return results + else: + return results + + except Exception: + logger.exception("Error running analysis") import traceback + traceback.print_exc() raise - - def get_issues(self, - severity: Optional[IssueSeverity] = None, - category: Optional[IssueCategory] = None) -> List[Issue]: + + def get_issues( + self, + severity: IssueSeverity | None = None, + category: IssueCategory | None = None, + ) -> list[Issue]: """ Get issues from the analyzer. - + Args: severity: Filter issues by severity category: Filter issues by category - + Returns: List of issues matching the filters """ return self.analyzer.get_issues(severity, category) - - def generate_report(self, - report_type: str = "summary", - output_file: Optional[str] = None) -> str: + + def generate_report( + self, report_type: str = "summary", output_file: str | None = None + ) -> str: """ Generate a report from the analysis results. - + Args: report_type: Type of report to generate (summary, detailed, issues) output_file: Path to save report to (None for returning as string) - + Returns: Report as a string (if output_file is None) """ - if not hasattr(self.analyzer, 'results') or not self.analyzer.results: - raise ValueError("No analysis results available. Run analysis first.") - + if not hasattr(self.analyzer, "results") or not self.analyzer.results: + msg = "No analysis results available. Run analysis first." + raise ValueError(msg) + report = "" - + if report_type == "summary": report = self._generate_summary_report() elif report_type == "detailed": @@ -163,266 +172,357 @@ def generate_report(self, elif report_type == "issues": report = self._generate_issues_report() else: - raise ValueError(f"Unknown report type: {report_type}") - + msg = f"Unknown report type: {report_type}" + raise ValueError(msg) + if output_file: - with open(output_file, 'w') as f: + with open(output_file, "w") as f: f.write(report) logger.info(f"Report saved to {output_file}") return "" else: return report - + def _generate_summary_report(self) -> str: """Generate a summary report of the analysis results.""" results = self.analyzer.results - + report = "===== Codebase Analysis Summary Report =====\n\n" - + # Add metadata report += "Metadata:\n" report += f" Repository: {results['metadata'].get('repo_name', 'Unknown')}\n" report += f" Language: {results['metadata'].get('language', 'Unknown')}\n" - report += f" Analysis Time: {results['metadata'].get('analysis_time', 'Unknown')}\n" + report += ( + f" Analysis Time: {results['metadata'].get('analysis_time', 'Unknown')}\n" + ) report += f" Analysis Types: {', '.join(results['metadata'].get('analysis_types', []))}\n" - + # Add issue statistics report += "\nIssue Statistics:\n" report += f" Total Issues: {results['issue_stats']['total']}\n" - report += f" Critical: {results['issue_stats']['by_severity'].get('critical', 0)}\n" + report += ( + f" Critical: {results['issue_stats']['by_severity'].get('critical', 0)}\n" + ) report += f" Errors: {results['issue_stats']['by_severity'].get('error', 0)}\n" - report += f" Warnings: {results['issue_stats']['by_severity'].get('warning', 0)}\n" + report += ( + f" Warnings: {results['issue_stats']['by_severity'].get('warning', 0)}\n" + ) report += f" Info: {results['issue_stats']['by_severity'].get('info', 0)}\n" - + # Add codebase summary - if 'summary' in results: + if "summary" in results: report += "\nCodebase Summary:\n" - summary = results['summary'] + summary = results["summary"] report += f" Files: {summary.get('file_count', 0)}\n" report += f" Lines of Code: {summary.get('total_loc', 0)}\n" report += f" Functions: {summary.get('function_count', 0)}\n" report += f" Classes: {summary.get('class_count', 0)}\n" - + # Add analysis summaries - for analysis_type, analysis_results in results.get('results', {}).items(): + for analysis_type, analysis_results in results.get("results", {}).items(): report += f"\n{analysis_type.title()} Analysis Summary:\n" - - if analysis_type == 'code_quality': - if 'dead_code' in analysis_results: - dead_code = analysis_results['dead_code'] + + if analysis_type == "code_quality": + if "dead_code" in analysis_results: + dead_code = analysis_results["dead_code"] report += f" Dead Code Items: {dead_code['summary']['total_dead_code_count']}\n" report += f" Unused Functions: {dead_code['summary']['unused_functions_count']}\n" report += f" Unused Classes: {dead_code['summary']['unused_classes_count']}\n" report += f" Unused Variables: {dead_code['summary']['unused_variables_count']}\n" report += f" Unused Imports: {dead_code['summary']['unused_imports_count']}\n" - - if 'complexity' in analysis_results: - complexity = analysis_results['complexity'] + + if "complexity" in analysis_results: + complexity = analysis_results["complexity"] report += f" Average Complexity: {complexity.get('average_complexity', 0):.2f}\n" report += f" High Complexity Functions: {len(complexity.get('high_complexity_functions', []))}\n" - + # Distribution - dist = complexity.get('complexity_distribution', {}) - report += f" Complexity Distribution:\n" + dist = complexity.get("complexity_distribution", {}) + report += " Complexity Distribution:\n" report += f" Low: {dist.get('low', 0)}\n" report += f" Medium: {dist.get('medium', 0)}\n" report += f" High: {dist.get('high', 0)}\n" report += f" Very High: {dist.get('very_high', 0)}\n" - - elif analysis_type == 'dependency': - if 'circular_dependencies' in analysis_results: - circular = analysis_results['circular_dependencies'] + + elif analysis_type == "dependency": + if "circular_dependencies" in analysis_results: + circular = analysis_results["circular_dependencies"] report += f" Circular Dependencies: {circular.get('circular_dependencies_count', 0)}\n" report += f" Affected Modules: {len(circular.get('affected_modules', []))}\n" - - if 'module_coupling' in analysis_results: - coupling = analysis_results['module_coupling'] + + if "module_coupling" in analysis_results: + coupling = analysis_results["module_coupling"] report += f" Average Coupling: {coupling.get('average_coupling', 0):.2f}\n" report += f" High Coupling Modules: {len(coupling.get('high_coupling_modules', []))}\n" report += f" Low Coupling Modules: {len(coupling.get('low_coupling_modules', []))}\n" - + return report - - def _generate_detailed_report(self) -> str: + + def _generate_detailed_report(self) -> str: # noqa: C901 """Generate a detailed report of the analysis results.""" results = self.analyzer.results - + report = "===== Codebase Analysis Detailed Report =====\n\n" - + # Add metadata report += "Metadata:\n" report += f" Repository: {results['metadata'].get('repo_name', 'Unknown')}\n" report += f" Language: {results['metadata'].get('language', 'Unknown')}\n" - report += f" Analysis Time: {results['metadata'].get('analysis_time', 'Unknown')}\n" + report += ( + f" Analysis Time: {results['metadata'].get('analysis_time', 'Unknown')}\n" + ) report += f" Analysis Types: {', '.join(results['metadata'].get('analysis_types', []))}\n" - - # Add detailed analysis sections - for analysis_type, analysis_results in results.get('results', {}).items(): - report += f"\n{analysis_type.title()} Analysis:\n" - - # Add relevant sections from each analysis type - if analysis_type == 'code_quality': + + # Add issue statistics + report += "\nIssue Statistics:\n" + report += f" Total Issues: {results['issue_stats']['total']}\n" + report += ( + f" Critical: {results['issue_stats']['by_severity'].get('critical', 0)}\n" + ) + report += f" Errors: {results['issue_stats']['by_severity'].get('error', 0)}\n" + report += ( + f" Warnings: {results['issue_stats']['by_severity'].get('warning', 0)}\n" + ) + report += f" Info: {results['issue_stats']['by_severity'].get('info', 0)}\n\n" + + # Add codebase summary + if "summary" in results: + report += "Codebase Summary:\n" + summary = results["summary"] + report += f" Files: {summary.get('file_count', 0)}\n" + report += f" Lines of Code: {summary.get('total_loc', 0)}\n" + report += f" Functions: {summary.get('function_count', 0)}\n" + report += f" Classes: {summary.get('class_count', 0)}\n" + report += f" Imports: {summary.get('import_count', 0)}\n" + report += f" Dependencies: {summary.get('dependency_count', 0)}\n\n" + + # Language breakdown + if "language_breakdown" in summary: + report += " Language Breakdown:\n" + for lang, stats in summary["language_breakdown"].items(): + report += f" {lang}: {stats['file_count']} files, {stats['loc']} lines\n" + + # Add detailed analysis results + report += "\nDetailed Analysis Results:\n" + + # Add analysis summaries + for analysis_type, analysis_results in results.get("results", {}).items(): + report += f"\n=== {analysis_type.upper()} ANALYSIS ===\n" + + if analysis_type == "code_quality": # Dead code - if 'dead_code' in analysis_results: - dead_code = analysis_results['dead_code'] - report += f"\n Dead Code Analysis:\n" + if "dead_code" in analysis_results: + dead_code = analysis_results["dead_code"] + report += "\n Dead Code Analysis:\n" report += f" Total Dead Code Items: {dead_code['summary']['total_dead_code_count']}\n" - + report += f" Unused Functions: {dead_code['summary']['unused_functions_count']}\n" + report += f" Unused Classes: {dead_code['summary']['unused_classes_count']}\n" + report += f" Unused Variables: {dead_code['summary']['unused_variables_count']}\n" + report += f" Unused Imports: {dead_code['summary']['unused_imports_count']}\n" + # Unused functions - if dead_code['unused_functions']: + if dead_code["unused_functions"]: report += f"\n Unused Functions ({len(dead_code['unused_functions'])}):\n" - for func in dead_code['unused_functions'][:10]: # Limit to top 10 + for func in dead_code["unused_functions"][ + :10 + ]: # Limit to top 10 report += f" {func['name']} ({func['file']}:{func['line']})\n" - if len(dead_code['unused_functions']) > 10: + if len(dead_code["unused_functions"]) > 10: report += f" ... and {len(dead_code['unused_functions']) - 10} more\n" - + # Unused classes - if dead_code['unused_classes']: + if dead_code["unused_classes"]: report += f"\n Unused Classes ({len(dead_code['unused_classes'])}):\n" - for cls in dead_code['unused_classes'][:10]: # Limit to top 10 - report += f" {cls['name']} ({cls['file']}:{cls['line']})\n" - if len(dead_code['unused_classes']) > 10: + for cls in dead_code["unused_classes"][:10]: # Limit to top 10 + report += ( + f" {cls['name']} ({cls['file']}:{cls['line']})\n" + ) + if len(dead_code["unused_classes"]) > 10: report += f" ... and {len(dead_code['unused_classes']) - 10} more\n" - + # Complexity - if 'complexity' in analysis_results: - complexity = analysis_results['complexity'] - report += f"\n Code Complexity Analysis:\n" + if "complexity" in analysis_results: + complexity = analysis_results["complexity"] + report += "\n Code Complexity Analysis:\n" report += f" Average Complexity: {complexity.get('average_complexity', 0):.2f}\n" - + # High complexity functions - high_complexity = complexity.get('high_complexity_functions', []) + high_complexity = complexity.get("high_complexity_functions", []) if high_complexity: report += f"\n High Complexity Functions ({len(high_complexity)}):\n" for func in high_complexity[:10]: # Limit to top 10 report += f" {func['name']} (Complexity: {func['complexity']}, {func['file']}:{func['line']})\n" if len(high_complexity) > 10: - report += f" ... and {len(high_complexity) - 10} more\n" - + report += ( + f" ... and {len(high_complexity) - 10} more\n" + ) + # Maintainability - if 'maintainability' in analysis_results: - maintain = analysis_results['maintainability'] - report += f"\n Maintainability Analysis:\n" + if "maintainability" in analysis_results: + maintain = analysis_results["maintainability"] + report += "\n Maintainability Analysis:\n" report += f" Average Maintainability: {maintain.get('average_maintainability', 0):.2f}\n" - + # Low maintainability functions - low_maintain = maintain.get('low_maintainability_functions', []) + low_maintain = maintain.get("low_maintainability_functions", []) if low_maintain: report += f"\n Low Maintainability Functions ({len(low_maintain)}):\n" for func in low_maintain[:10]: # Limit to top 10 report += f" {func['name']} (Index: {func['maintainability']:.1f}, {func['file']}:{func['line']})\n" if len(low_maintain) > 10: report += f" ... and {len(low_maintain) - 10} more\n" - - elif analysis_type == 'dependency': + + elif analysis_type == "dependency": # Circular dependencies - if 'circular_dependencies' in analysis_results: - circular = analysis_results['circular_dependencies'] - report += f"\n Circular Dependencies Analysis:\n" + if "circular_dependencies" in analysis_results: + circular = analysis_results["circular_dependencies"] + report += "\n Circular Dependencies Analysis:\n" report += f" Total Circular Dependencies: {circular.get('circular_dependencies_count', 0)}\n" - + # List circular import chains - if circular.get('circular_imports', []): + if circular.get("circular_imports", []): report += f"\n Circular Import Chains ({len(circular['circular_imports'])}):\n" - for i, cycle in enumerate(circular['circular_imports'][:5]): # Limit to top 5 - report += f" Chain {i+1} (Length: {cycle['length']}):\n" - for j, file_path in enumerate(cycle['files']): - report += f" {j+1}. {file_path}\n" - if len(circular['circular_imports']) > 5: + for i, cycle in enumerate( + circular["circular_imports"][:5] + ): # Limit to top 5 + report += ( + f" Chain {i + 1} (Length: {cycle['length']}):\n" + ) + for j, file_path in enumerate(cycle["files"]): + report += f" {j + 1}. {file_path}\n" + if len(circular["circular_imports"]) > 5: report += f" ... and {len(circular['circular_imports']) - 5} more chains\n" - + # Module coupling - if 'module_coupling' in analysis_results: - coupling = analysis_results['module_coupling'] - report += f"\n Module Coupling Analysis:\n" + if "module_coupling" in analysis_results: + coupling = analysis_results["module_coupling"] + report += "\n Module Coupling Analysis:\n" report += f" Average Coupling: {coupling.get('average_coupling', 0):.2f}\n" - + # High coupling modules - high_coupling = coupling.get('high_coupling_modules', []) + high_coupling = coupling.get("high_coupling_modules", []) if high_coupling: - report += f"\n High Coupling Modules ({len(high_coupling)}):\n" + report += ( + f"\n High Coupling Modules ({len(high_coupling)}):\n" + ) for module in high_coupling[:10]: # Limit to top 10 report += f" {module['module']} (Ratio: {module['coupling_ratio']:.2f}, Files: {module['file_count']}, Imports: {module['import_count']})\n" if len(high_coupling) > 10: report += f" ... and {len(high_coupling) - 10} more\n" - + # External dependencies - if 'external_dependencies' in analysis_results: - ext_deps = analysis_results['external_dependencies'] - most_used = ext_deps.get('most_used_external_modules', []) - + if "external_dependencies" in analysis_results: + ext_deps = analysis_results["external_dependencies"] + most_used = ext_deps.get("most_used_external_modules", []) + if most_used: - report += f"\n Most Used External Modules:\n" + report += "\n Most Used External Modules:\n" for module in most_used[:10]: report += f" {module['module']} (Used {module['usage_count']} times)\n" - + return report - + def _generate_issues_report(self) -> str: """Generate a report focused on issues found during analysis.""" issues = self.analyzer.issues - + report = "===== Codebase Analysis Issues Report =====\n\n" - + # Issue statistics report += f"Total Issues: {len(issues)}\n" report += f"Critical: {sum(1 for issue in issues if issue.severity == IssueSeverity.CRITICAL)}\n" report += f"Errors: {sum(1 for issue in issues if issue.severity == IssueSeverity.ERROR)}\n" report += f"Warnings: {sum(1 for issue in issues if issue.severity == IssueSeverity.WARNING)}\n" report += f"Info: {sum(1 for issue in issues if issue.severity == IssueSeverity.INFO)}\n" - + # Group issues by severity issues_by_severity = {} - for severity in [IssueSeverity.CRITICAL, IssueSeverity.ERROR, IssueSeverity.WARNING, IssueSeverity.INFO]: - issues_by_severity[severity] = [issue for issue in issues if issue.severity == severity] - + for severity in [ + IssueSeverity.CRITICAL, + IssueSeverity.ERROR, + IssueSeverity.WARNING, + IssueSeverity.INFO, + ]: + issues_by_severity[severity] = [ + issue for issue in issues if issue.severity == severity + ] + # Format issues by severity - for severity in [IssueSeverity.CRITICAL, IssueSeverity.ERROR, IssueSeverity.WARNING, IssueSeverity.INFO]: + for severity in [ + IssueSeverity.CRITICAL, + IssueSeverity.ERROR, + IssueSeverity.WARNING, + IssueSeverity.INFO, + ]: severity_issues = issues_by_severity[severity] - + if severity_issues: - report += f"\n{severity.value.upper()} Issues ({len(severity_issues)}):\n" - + report += ( + f"\n{severity.value.upper()} Issues ({len(severity_issues)}):\n" + ) + for issue in severity_issues: - location = f"{issue.file}:{issue.line}" if issue.line else issue.file + location = ( + f"{issue.file}:{issue.line}" if issue.line else issue.file + ) category = f"[{issue.category.value}]" if issue.category else "" report += f"- {location} {category} {issue.message}\n" report += f" Suggestion: {issue.suggestion}\n" - + return report + def main(): """Command-line entry point for running analyzers.""" import argparse - + parser = argparse.ArgumentParser(description="Codebase Analyzer Manager") - + # Repository source options source_group = parser.add_mutually_exclusive_group(required=True) source_group.add_argument("--repo-url", help="URL of the repository to analyze") - source_group.add_argument("--repo-path", help="Local path to the repository to analyze") - + source_group.add_argument( + "--repo-path", help="Local path to the repository to analyze" + ) + # Analysis options - parser.add_argument("--analysis-types", nargs="+", choices=[at.value for at in AnalysisType], - default=["code_quality", "dependency"], - help="Types of analysis to perform") - parser.add_argument("--language", choices=["python", "typescript"], - help="Programming language (auto-detected if not provided)") - parser.add_argument("--base-branch", default="main", - help="Base branch for PR comparison (default: main)") - parser.add_argument("--pr-number", type=int, - help="PR number to analyze") - + parser.add_argument( + "--analysis-types", + nargs="+", + choices=[at.value for at in AnalysisType], + default=["code_quality", "dependency"], + help="Types of analysis to perform", + ) + parser.add_argument( + "--language", + choices=["python", "typescript"], + help="Programming language (auto-detected if not provided)", + ) + parser.add_argument( + "--base-branch", + default="main", + help="Base branch for PR comparison (default: main)", + ) + parser.add_argument("--pr-number", type=int, help="PR number to analyze") + # Output options - parser.add_argument("--output-file", - help="Path to the output file") - parser.add_argument("--output-format", choices=["json", "html", "console"], default="json", - help="Output format") - parser.add_argument("--report-type", choices=["summary", "detailed", "issues"], default="summary", - help="Type of report to generate (default: summary)") - + parser.add_argument("--output-file", help="Path to the output file") + parser.add_argument( + "--output-format", + choices=["json", "html", "console"], + default="json", + help="Output format", + ) + parser.add_argument( + "--report-type", + choices=["summary", "detailed", "issues"], + default="summary", + help="Type of report to generate (default: summary)", + ) + args = parser.parse_args() - + try: # Initialize the analyzer manager manager = AnalyzerManager( @@ -430,23 +530,25 @@ def main(): repo_path=args.repo_path, language=args.language, base_branch=args.base_branch, - pr_number=args.pr_number + pr_number=args.pr_number, ) - + # Run the analysis analysis_types = [AnalysisType(at) for at in args.analysis_types] manager.run_analysis(analysis_types, args.output_file, args.output_format) - + # Generate and print report if args.output_format == "console": report = manager.generate_report(args.report_type) print(report) - + except Exception as e: print(f"Error: {e}") import traceback + traceback.print_exc() sys.exit(1) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/codegen-on-oss/codegen_on_oss/analyzers/api.py b/codegen-on-oss/codegen_on_oss/analyzers/api.py index b774f37c0..c2fa25779 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/api.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/api.py @@ -7,265 +7,344 @@ and data export. """ -import os -import sys -import json import logging -from typing import Dict, List, Set, Tuple, Any, Optional, Union +from typing import Any # Import analyzer components from codegen_on_oss.analyzers.analyzer import AnalyzerManager -from codegen_on_oss.analyzers.issues import Issue, IssueSeverity, AnalysisType, IssueCategory -from codegen_on_oss.analyzers.visualization import Visualizer, VisualizationType, OutputFormat +from codegen_on_oss.analyzers.issues import ( + AnalysisType, + IssueCategory, + IssueSeverity, +) +from codegen_on_oss.analyzers.visualization import ( + Visualizer, +) # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) + class CodegenAnalyzerAPI: """ Backend API for codegit-on-git. - + This class provides a unified interface for the frontend to interact with the codebase analysis backend, including analysis, visualization, and data export. """ - - def __init__(self, repo_path: Optional[str] = None, repo_url: Optional[str] = None): + + def __init__(self, repo_path: str | None = None, repo_url: str | None = None): """ Initialize the API with a repository. - + Args: repo_path: Local path to the repository repo_url: URL of the repository """ # Initialize analyzer self.analyzer = AnalyzerManager(repo_path=repo_path, repo_url=repo_url) - + # Initialize visualizer when needed self._visualizer = None - + # Cache for analysis results self._analysis_cache = {} - + @property def visualizer(self) -> Visualizer: """Get or initialize visualizer.""" if self._visualizer is None: self._visualizer = Visualizer() return self._visualizer - + def analyze_codebase( - self, - analysis_types: Optional[List[Union[str, AnalysisType]]] = None, - force_refresh: bool = False - ) -> Dict[str, Any]: + self, + analysis_types: list[str | AnalysisType] | None = None, + force_refresh: bool = False, + ) -> dict[str, Any]: """ Analyze the entire codebase. - + Args: analysis_types: Types of analysis to perform force_refresh: Whether to force a refresh of the analysis - + Returns: Analysis results """ cache_key = str(analysis_types) if analysis_types else "default" - + # Check cache first if not force_refresh and cache_key in self._analysis_cache: return self._analysis_cache[cache_key] - + # Run analysis results = self.analyzer.analyze(analysis_types=analysis_types) - + # Cache results self._analysis_cache[cache_key] = results - + return results - + def analyze_pr( - self, + self, pr_number: int, - analysis_types: Optional[List[Union[str, AnalysisType]]] = None, - force_refresh: bool = False - ) -> Dict[str, Any]: + analysis_types: list[str | AnalysisType] | None = None, + force_refresh: bool = False, + ) -> dict[str, Any]: """ Analyze a specific PR. - + Args: pr_number: PR number to analyze analysis_types: Types of analysis to perform force_refresh: Whether to force a refresh of the analysis - + Returns: Analysis results """ - cache_key = f"pr_{pr_number}_{str(analysis_types)}" - + cache_key = f"pr_{pr_number}_{analysis_types!s}" + # Check cache first if not force_refresh and cache_key in self._analysis_cache: return self._analysis_cache[cache_key] - + # Set PR number self.analyzer.pr_number = pr_number - + # Use default analysis types if none provided if analysis_types is None: analysis_types = ["pr", "code_quality"] - + # Run analysis results = self.analyzer.analyze(analysis_types=analysis_types) - + # Cache results self._analysis_cache[cache_key] = results - + return results - + def get_issues( - self, - severity: Optional[Union[str, IssueSeverity]] = None, - category: Optional[Union[str, IssueCategory]] = None - ) -> List[Dict[str, Any]]: + self, + severity: str | IssueSeverity | None = None, + category: str | IssueCategory | None = None, + ) -> list[dict[str, Any]]: """ Get issues matching criteria. - + Args: severity: Issue severity to filter by category: Issue category to filter by - + Returns: List of matching issues """ # Run analysis if not already done if not self._analysis_cache: self.analyze_codebase() - + # Convert string severity to enum if needed if isinstance(severity, str): severity = IssueSeverity(severity) - + # Convert string category to enum if needed if isinstance(category, str): category = IssueCategory(category) - + # Get issues issues = self.analyzer.get_issues(severity=severity, category=category) - + # Convert to dictionaries return [issue.to_dict() for issue in issues] - - def find_symbol(self, symbol_name: str) -> Optional[Dict[str, Any]]: + + def find_symbol(self, symbol_name: str) -> dict[str, Any] | None: """ Find a specific symbol in the codebase. - + Args: symbol_name: Name of the symbol to find - + Returns: Symbol information if found, None otherwise """ # Run analysis if not already done if not self._analysis_cache: self.analyze_codebase() - + # Get symbol symbol = self.analyzer.base_codebase.get_symbol(symbol_name) - + if symbol: # Convert to dictionary return self._symbol_to_dict(symbol) - + return None - + def get_module_dependencies( - self, - module_path: Optional[str] = None, + self, + module_path: str | None = None, layout: str = "hierarchical", - format: str = "json" - ) -> Dict[str, Any]: + output_format: str = "json", + ) -> dict[str, Any]: """ Get module dependencies. - + Args: module_path: Path to the module to analyze layout: Layout algorithm to use - format: Output format - + output_format: Output format + Returns: Module dependency visualization """ # Run analysis if not already done if not self._analysis_cache: self.analyze_codebase(analysis_types=["dependency"]) - + # Generate visualization viz = self.visualizer.generate_module_dependency_graph( codebase_context=self.analyzer.base_context, module_path=module_path, - layout=layout + layout=layout, ) - + # Export if needed - if format != "json": - return self.visualizer.export(viz, format=format) - + if output_format != "json": + return self.visualizer.export(viz, format=output_format) + return viz - + + def generate_dependency_graph( + self, + repo_path: str | None = None, + module_path: str | None = None, + layout: str = "hierarchical", + output_format: str = "json", + ) -> dict[str, Any]: + """ + Generate a dependency graph for the codebase. + + Args: + repo_path: Path to the repository (optional, uses self.repo_path if not provided) + module_path: Path to the specific module to analyze (optional) + layout: Graph layout algorithm (hierarchical, force, circular) + output_format: Output format (json, dot, graphml) + + Returns: + Dictionary containing the dependency graph data + """ + # Run analysis if not already done + if not self._analysis_cache: + self.analyze_codebase(analysis_types=["dependency"]) + + # Generate visualization + viz = self.visualizer.generate_module_dependency_graph( + codebase_context=self.analyzer.base_context, + module_path=module_path, + layout=layout, + ) + + # Export if needed + if output_format != "json": + return self.visualizer.export(viz, format=output_format) + + return viz + def get_function_call_graph( - self, - function_name: Union[str, List[str]], + self, + function_name: str | list[str], depth: int = 2, layout: str = "hierarchical", - format: str = "json" - ) -> Dict[str, Any]: + output_format: str = "json", + ) -> dict[str, Any]: """ Get function call graph. - + Args: function_name: Name of the function(s) to analyze depth: Maximum depth of the call graph layout: Layout algorithm to use - format: Output format - + output_format: Output format + Returns: Function call graph visualization """ # Run analysis if not already done if not self._analysis_cache: self.analyze_codebase(analysis_types=["code_quality"]) - + # Generate visualization viz = self.visualizer.generate_function_call_graph( functions=function_name, codebase_context=self.analyzer.base_context, depth=depth, - layout=layout + layout=layout, ) - + # Export if needed - if format != "json": - return self.visualizer.export(viz, format=format) - + if output_format != "json": + return self.visualizer.export(viz, format=output_format) + + return viz + + def generate_call_graph( + self, + function_name: str | None = None, + file_path: str | None = None, + depth: int = 2, + layout: str = "hierarchical", + output_format: str = "json", + ) -> dict[str, Any]: + """ + Generate a call graph for a specific function or file. + + Args: + function_name: Name of the function to analyze + file_path: Path to the file containing the function + depth: Maximum depth of the call graph + layout: Graph layout algorithm (hierarchical, force, circular) + output_format: Output format (json, dot, graphml) + + Returns: + Dictionary containing the call graph data + """ + # Run analysis if not already done + if not self._analysis_cache: + self.analyze_codebase(analysis_types=["code_quality"]) + + # Generate visualization + viz = self.visualizer.generate_function_call_graph( + functions=function_name, + codebase_context=self.analyzer.base_context, + depth=depth, + layout=layout, + ) + + # Export if needed + if output_format != "json": + return self.visualizer.export(viz, format=output_format) + return viz - + def get_pr_impact( - self, - pr_number: Optional[int] = None, + self, + pr_number: int | None = None, layout: str = "force", - format: str = "json" - ) -> Dict[str, Any]: + output_format: str = "json", + ) -> dict[str, Any]: """ Get PR impact visualization. - + Args: pr_number: PR number to analyze layout: Layout algorithm to use - format: Output format - + output_format: Output format + Returns: PR impact visualization """ @@ -273,192 +352,289 @@ def get_pr_impact( if pr_number is not None: self.analyze_pr(pr_number, analysis_types=["pr"]) elif self.analyzer.pr_number is None: - raise ValueError("No PR number specified") - + msg = "No PR number specified" + raise ValueError(msg) + # Generate visualization viz = self.visualizer.generate_pr_diff_visualization( - pr_analysis=self.analyzer.results["results"]["pr"], - layout=layout + pr_analysis=self.analyzer.results["results"]["pr"], layout=layout ) - + # Export if needed - if format != "json": - return self.visualizer.export(viz, format=format) - + if output_format != "json": + return self.visualizer.export(viz, format=output_format) + return viz - + def export_visualization( - self, - visualization: Dict[str, Any], - format: str = "json", - filename: Optional[str] = None - ) -> Union[str, Dict[str, Any]]: + self, + visualization: dict[str, Any], + output_format: str = "json", + filename: str | None = None, + ) -> str | dict[str, Any]: """ Export visualization in specified format. - + Args: visualization: Visualization to export - format: Output format + output_format: Output format filename: Output filename - + Returns: Exported visualization or path to saved file """ return self.visualizer.export( - visualization, - format=format, - filename=filename + visualization, format=output_format, filename=filename ) - - def get_static_errors(self) -> List[Dict[str, Any]]: + + def get_static_errors(self) -> list[dict[str, Any]]: """ Get static errors in the codebase. - + Returns: List of static errors """ # Run analysis if not already done if not self._analysis_cache: self.analyze_codebase(analysis_types=["code_quality"]) - + # Get errors errors = self.analyzer.get_issues(severity=IssueSeverity.ERROR) - + # Convert to dictionaries return [error.to_dict() for error in errors] - - def get_parameter_issues(self) -> List[Dict[str, Any]]: + + def get_parameter_issues(self) -> list[dict[str, Any]]: """ Get parameter-related issues. - + Returns: List of parameter issues """ # Run analysis if not already done if not self._analysis_cache: self.analyze_codebase(analysis_types=["code_quality"]) - + # Get parameter issues issues = self.analyzer.get_issues(category=IssueCategory.PARAMETER_MISMATCH) - + # Convert to dictionaries return [issue.to_dict() for issue in issues] - - def get_unimplemented_functions(self) -> List[Dict[str, Any]]: + + def get_unimplemented_functions(self) -> list[dict[str, Any]]: """ Get unimplemented functions. - + Returns: List of unimplemented functions """ # Run analysis if not already done if not self._analysis_cache: self.analyze_codebase(analysis_types=["code_quality"]) - + # Get implementation issues issues = self.analyzer.get_issues(category=IssueCategory.IMPLEMENTATION_ERROR) - + # Convert to dictionaries return [issue.to_dict() for issue in issues] - - def get_circular_dependencies(self) -> List[Dict[str, Any]]: + + def get_circular_dependencies(self) -> list[dict[str, Any]]: """ Get circular dependencies. - + Returns: List of circular dependencies """ # Run analysis if not already done if not self._analysis_cache: self.analyze_codebase(analysis_types=["dependency"]) - + # Get circular dependencies if "dependency" in self.analyzer.results.get("results", {}): - return self.analyzer.results["results"]["dependency"].get("circular_dependencies", {}).get("circular_imports", []) - + return ( + self.analyzer.results["results"]["dependency"] + .get("circular_dependencies", {}) + .get("circular_imports", []) + ) + return [] - - def get_module_coupling(self) -> List[Dict[str, Any]]: + + def get_module_coupling(self) -> list[dict[str, Any]]: """ Get module coupling metrics. - + Returns: Module coupling metrics """ # Run analysis if not already done if not self._analysis_cache: self.analyze_codebase(analysis_types=["dependency"]) - + # Get module coupling if "dependency" in self.analyzer.results.get("results", {}): - return self.analyzer.results["results"]["dependency"].get("module_coupling", {}).get("high_coupling_modules", []) - + return ( + self.analyzer.results["results"]["dependency"] + .get("module_coupling", {}) + .get("high_coupling_modules", []) + ) + return [] - - def get_diff_analysis(self, pr_number: int) -> Dict[str, Any]: + + def get_diff_analysis(self, pr_number: int) -> dict[str, Any]: """ Get diff analysis for a PR. - + Args: pr_number: PR number to analyze - + Returns: Diff analysis results """ # Analyze PR self.analyze_pr(pr_number, analysis_types=["pr"]) - + # Get diff analysis if "pr" in self.analyzer.results.get("results", {}): return self.analyzer.results["results"]["pr"] - + return {} - + def clear_cache(self): """Clear the analysis cache.""" self._analysis_cache = {} - - def _symbol_to_dict(self, symbol) -> Dict[str, Any]: + + def _symbol_to_dict(self, symbol) -> dict[str, Any]: """Convert symbol to dictionary.""" symbol_dict = { - "name": symbol.name if hasattr(symbol, 'name') else str(symbol), - "type": str(symbol.symbol_type) if hasattr(symbol, 'symbol_type') else "unknown", - "file": symbol.file.file_path if hasattr(symbol, 'file') and hasattr(symbol.file, 'file_path') else "unknown", - "line": symbol.line if hasattr(symbol, 'line') else None, + "name": symbol.name if hasattr(symbol, "name") else str(symbol), + "type": str(symbol.symbol_type) + if hasattr(symbol, "symbol_type") + else "unknown", + "file": symbol.file.file_path + if hasattr(symbol, "file") and hasattr(symbol.file, "file_path") + else "unknown", + "line": symbol.line if hasattr(symbol, "line") else None, } - + # Add function-specific info - if hasattr(symbol, 'parameters'): + if hasattr(symbol, "parameters"): symbol_dict["parameters"] = [ { - "name": p.name if hasattr(p, 'name') else str(p), - "type": str(p.type) if hasattr(p, 'type') and p.type else None, - "has_default": p.has_default if hasattr(p, 'has_default') else False + "name": p.name if hasattr(p, "name") else str(p), + "type": str(p.type) if hasattr(p, "type") and p.type else None, + "has_default": p.has_default + if hasattr(p, "has_default") + else False, } for p in symbol.parameters ] - - symbol_dict["return_type"] = str(symbol.return_type) if hasattr(symbol, 'return_type') and symbol.return_type else None - symbol_dict["is_async"] = symbol.is_async if hasattr(symbol, 'is_async') else False - + + symbol_dict["return_type"] = ( + str(symbol.return_type) + if hasattr(symbol, "return_type") and symbol.return_type + else None + ) + symbol_dict["is_async"] = ( + symbol.is_async if hasattr(symbol, "is_async") else False + ) + # Add class-specific info - if hasattr(symbol, 'superclasses'): + if hasattr(symbol, "superclasses"): symbol_dict["superclasses"] = [ - sc.name if hasattr(sc, 'name') else str(sc) + sc.name if hasattr(sc, "name") else str(sc) for sc in symbol.superclasses ] - + return symbol_dict + def generate_class_diagram( + self, + class_name: str | None = None, + module_name: str | None = None, + include_methods: bool = True, + include_attributes: bool = True, + output_format: str = "json", + ) -> dict[str, Any]: + """ + Generate a class diagram for the codebase. + + Args: + class_name: Name of the class to analyze (optional) + module_name: Name of the module containing the class (optional) + include_methods: Whether to include methods in the diagram + include_attributes: Whether to include attributes in the diagram + output_format: Output format (json, dot, graphml, plantuml) -def create_api(repo_path: Optional[str] = None, repo_url: Optional[str] = None) -> CodegenAnalyzerAPI: + Returns: + Dictionary containing the class diagram data + """ + # Run analysis if not already done + if not self._analysis_cache: + self.analyze_codebase(analysis_types=["dependency"]) + + # Generate visualization + viz = self.visualizer.generate_class_diagram( + codebase_context=self.analyzer.base_context, + class_name=class_name, + module_name=module_name, + include_methods=include_methods, + include_attributes=include_attributes, + ) + + # Export if needed + if output_format != "json": + return self.visualizer.export(viz, format=output_format) + + return viz + + def generate_sequence_diagram( + self, + function_name: str, + file_path: str | None = None, + max_depth: int = 3, + output_format: str = "json", + ) -> dict[str, Any]: + """ + Generate a sequence diagram for a specific function. + + Args: + function_name: Name of the function to analyze + file_path: Path to the file containing the function (optional) + max_depth: Maximum depth of the sequence diagram + output_format: Output format (json, plantuml) + + Returns: + Dictionary containing the sequence diagram data + """ + # Run analysis if not already done + if not self._analysis_cache: + self.analyze_codebase(analysis_types=["code_quality"]) + + # Generate visualization + viz = self.visualizer.generate_sequence_diagram( + codebase_context=self.analyzer.base_context, + function_name=function_name, + file_path=file_path, + max_depth=max_depth, + ) + + # Export if needed + if output_format != "json": + return self.visualizer.export(viz, format=output_format) + + return viz + + +def create_api( + repo_path: str | None = None, repo_url: str | None = None +) -> CodegenAnalyzerAPI: """ Create an API instance. - + Args: repo_path: Local path to the repository repo_url: URL of the repository - + Returns: API instance """ @@ -466,14 +642,16 @@ def create_api(repo_path: Optional[str] = None, repo_url: Optional[str] = None) # API endpoints for Flask or FastAPI integration -def api_analyze_codebase(repo_path: str, analysis_types: Optional[List[str]] = None) -> Dict[str, Any]: +def api_analyze_codebase( + repo_path: str, analysis_types: list[str] | None = None +) -> dict[str, Any]: """ API endpoint for codebase analysis. - + Args: repo_path: Path to the repository analysis_types: Types of analysis to perform - + Returns: Analysis results """ @@ -481,14 +659,14 @@ def api_analyze_codebase(repo_path: str, analysis_types: Optional[List[str]] = N return api.analyze_codebase(analysis_types=analysis_types) -def api_analyze_pr(repo_path: str, pr_number: int) -> Dict[str, Any]: +def api_analyze_pr(repo_path: str, pr_number: int) -> dict[str, Any]: """ API endpoint for PR analysis. - + Args: repo_path: Path to the repository pr_number: PR number to analyze - + Returns: Analysis results """ @@ -497,23 +675,21 @@ def api_analyze_pr(repo_path: str, pr_number: int) -> Dict[str, Any]: def api_get_visualization( - repo_path: str, - viz_type: str, - params: Dict[str, Any] -) -> Dict[str, Any]: + repo_path: str, viz_type: str, params: dict[str, Any] +) -> dict[str, Any]: """ API endpoint for visualizations. - + Args: repo_path: Path to the repository viz_type: Type of visualization params: Visualization parameters - + Returns: Visualization data """ api = create_api(repo_path=repo_path) - + # Run appropriate analysis based on visualization type if viz_type == "module_dependencies": api.analyze_codebase(analysis_types=["dependency"]) @@ -521,38 +697,39 @@ def api_get_visualization( api.analyze_codebase(analysis_types=["code_quality"]) elif viz_type == "pr_impact": api.analyze_pr(pr_number=params["pr_number"]) - + # Generate visualization if viz_type == "module_dependencies": return api.get_module_dependencies( module_path=params.get("module_path"), layout=params.get("layout", "hierarchical"), - format=params.get("format", "json") + format=params.get("format", "json"), ) elif viz_type == "function_calls": return api.get_function_call_graph( function_name=params["function_name"], depth=params.get("depth", 2), layout=params.get("layout", "hierarchical"), - format=params.get("format", "json") + format=params.get("format", "json"), ) elif viz_type == "pr_impact": return api.get_pr_impact( pr_number=params.get("pr_number"), layout=params.get("layout", "force"), - format=params.get("format", "json") + format=params.get("format", "json"), ) else: - raise ValueError(f"Unknown visualization type: {viz_type}") + msg = f"Unknown visualization type: {viz_type}" + raise ValueError(msg) -def api_get_static_errors(repo_path: str) -> List[Dict[str, Any]]: +def api_get_static_errors(repo_path: str) -> list[dict[str, Any]]: """ API endpoint for static errors. - + Args: repo_path: Path to the repository - + Returns: List of static errors """ @@ -560,38 +737,47 @@ def api_get_static_errors(repo_path: str) -> List[Dict[str, Any]]: return api.get_static_errors() -def api_get_function_issues(repo_path: str, function_name: str) -> List[Dict[str, Any]]: +def api_get_function_issues(repo_path: str, function_name: str) -> list[dict[str, Any]]: """ API endpoint for function issues. - + Args: repo_path: Path to the repository function_name: Name of the function - + Returns: List of function issues """ api = create_api(repo_path=repo_path) api.analyze_codebase(analysis_types=["code_quality"]) - + # Get symbol symbol = api.analyzer.base_codebase.get_symbol(function_name) - + if not symbol: return [] - + # Get file path - file_path = symbol.file.file_path if hasattr(symbol, 'file') and hasattr(symbol.file, 'file_path') else None - + file_path = ( + symbol.file.file_path + if hasattr(symbol, "file") and hasattr(symbol.file, "file_path") + else None + ) + if not file_path: return [] - + # Get issues for this file and symbol issues = api.analyzer.get_issues() return [ - issue.to_dict() for issue in issues - if issue.file == file_path and ( - issue.symbol == function_name or - (hasattr(issue, 'related_symbols') and function_name in issue.related_symbols) + issue.to_dict() + for issue in issues + if issue.file == file_path + and ( + issue.symbol == function_name + or ( + hasattr(issue, "related_symbols") + and function_name in issue.related_symbols + ) ) - ] \ No newline at end of file + ] diff --git a/codegen-on-oss/codegen_on_oss/analyzers/base_analyzer.py b/codegen-on-oss/codegen_on_oss/analyzers/base_analyzer.py index aec1c571f..20d75e9f1 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/base_analyzer.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/base_analyzer.py @@ -6,29 +6,36 @@ It defines a common interface and shared functionality for codebase analysis. """ -import os -import sys import json import logging +import sys import tempfile -from datetime import datetime -from pathlib import Path -from typing import Dict, List, Set, Tuple, Any, Optional, Union, TypeVar, cast from abc import ABC, abstractmethod +from typing import Any try: - from codegen.sdk.core.codebase import Codebase from codegen.configs.models.codebase import CodebaseConfig from codegen.configs.models.secrets import SecretsConfig - from codegen.sdk.codebase.config import ProjectConfig - from codegen.git.schemas.repo_config import RepoConfig from codegen.git.repo_operator.repo_operator import RepoOperator + from codegen.git.schemas.repo_config import RepoConfig + from codegen.sdk.codebase.config import ProjectConfig + from codegen.sdk.core.codebase import Codebase from codegen.shared.enums.programming_language import ProgrammingLanguage - + + from codegen_on_oss.analyzers.issue_types import ( + AnalysisType, + Issue, + IssueCategory, + IssueSeverity, + ) + # Import from our own modules - from codegen_on_oss.context_codebase import CodebaseContext, get_node_classes, GLOBAL_FILE_IGNORE_LIST + from codegen_on_oss.context_codebase import ( + GLOBAL_FILE_IGNORE_LIST, + CodebaseContext, + get_node_classes, + ) from codegen_on_oss.current_code_codebase import get_selected_codebase - from codegen_on_oss.analyzers.issue_types import Issue, IssueSeverity, AnalysisType, IssueCategory except ImportError: print("Codegen SDK or required modules not found.") sys.exit(1) @@ -36,33 +43,34 @@ # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) + class BaseCodeAnalyzer(ABC): """ Base class for all code analyzers. - + This abstract class defines the common interface and shared functionality for all code analyzers in the system. Specific analyzers should inherit from this class and implement the abstract methods. """ - + def __init__( self, - repo_url: Optional[str] = None, - repo_path: Optional[str] = None, + repo_url: str | None = None, + repo_path: str | None = None, base_branch: str = "main", - pr_number: Optional[int] = None, - language: Optional[str] = None, - file_ignore_list: Optional[List[str]] = None, - config: Optional[Dict[str, Any]] = None + pr_number: int | None = None, + language: str | None = None, + file_ignore_list: list[str] | None = None, + config: dict[str, Any] | None = None, ): """ Initialize the base analyzer. - + Args: repo_url: URL of the repository to analyze repo_path: Local path to the repository to analyze @@ -77,98 +85,98 @@ def __init__( self.base_branch = base_branch self.pr_number = pr_number self.language = language - + # Use custom ignore list or default global list self.file_ignore_list = file_ignore_list or GLOBAL_FILE_IGNORE_LIST - + # Configuration options self.config = config or {} - + # Codebase and context objects self.base_codebase = None self.pr_codebase = None self.base_context = None self.pr_context = None - + # Analysis results - self.issues: List[Issue] = [] - self.results: Dict[str, Any] = {} - + self.issues: list[Issue] = [] + self.results: dict[str, Any] = {} + # PR comparison data self.pr_diff = None self.commit_shas = None self.modified_symbols = None self.pr_branch = None - + # Initialize codebase(s) based on provided parameters if repo_url: self._init_from_url(repo_url, language) elif repo_path: self._init_from_path(repo_path, language) - + # If PR number is provided, initialize PR-specific data if self.pr_number is not None and self.base_codebase is not None: self._init_pr_data(self.pr_number) - + # Initialize contexts self._init_contexts() - - def _init_from_url(self, repo_url: str, language: Optional[str] = None): + + def _init_from_url(self, repo_url: str, language: str | None = None): """ Initialize codebase from a repository URL. - + Args: repo_url: URL of the repository language: Programming language of the codebase """ try: # Extract repository information - if repo_url.endswith('.git'): + if repo_url.endswith(".git"): repo_url = repo_url[:-4] - - parts = repo_url.rstrip('/').split('/') + + parts = repo_url.rstrip("/").split("/") repo_name = parts[-1] owner = parts[-2] repo_full_name = f"{owner}/{repo_name}" - + # Create temporary directory for cloning tmp_dir = tempfile.mkdtemp(prefix="analyzer_") - + # Set up configuration config = CodebaseConfig( debug=False, allow_external=True, py_resolve_syspath=True, ) - + secrets = SecretsConfig() - + # Determine programming language prog_lang = None if language: prog_lang = ProgrammingLanguage(language.upper()) - + # Initialize the codebase logger.info(f"Initializing codebase from {repo_url}") - + self.base_codebase = Codebase.from_github( repo_full_name=repo_full_name, tmp_dir=tmp_dir, language=prog_lang, config=config, - secrets=secrets + secrets=secrets, ) - + logger.info(f"Successfully initialized codebase from {repo_url}") - + except Exception as e: - logger.error(f"Error initializing codebase from URL: {e}") + logger.exception(f"Error initializing codebase from URL: {e}") raise - - def _init_from_path(self, repo_path: str, language: Optional[str] = None): + + def _init_from_path(self, repo_path: str, language: str | None = None): """ Initialize codebase from a local repository path. - + Args: repo_path: Path to the repository language: Programming language of the codebase @@ -180,95 +188,93 @@ def _init_from_path(self, repo_path: str, language: Optional[str] = None): allow_external=True, py_resolve_syspath=True, ) - + secrets = SecretsConfig() - + # Initialize the codebase logger.info(f"Initializing codebase from {repo_path}") - + # Determine programming language prog_lang = None if language: prog_lang = ProgrammingLanguage(language.upper()) - + # Set up repository configuration repo_config = RepoConfig.from_repo_path(repo_path) repo_config.respect_gitignore = False repo_operator = RepoOperator(repo_config=repo_config, bot_commit=False) - + # Create project configuration project_config = ProjectConfig( repo_operator=repo_operator, - programming_language=prog_lang if prog_lang else None + programming_language=prog_lang if prog_lang else None, ) - + # Initialize codebase self.base_codebase = Codebase( - projects=[project_config], - config=config, - secrets=secrets + projects=[project_config], config=config, secrets=secrets ) - + logger.info(f"Successfully initialized codebase from {repo_path}") - + except Exception as e: - logger.error(f"Error initializing codebase from path: {e}") + logger.exception(f"Error initializing codebase from path: {e}") raise - + def _init_pr_data(self, pr_number: int): """ Initialize PR-specific data. - + Args: pr_number: PR number to analyze """ try: logger.info(f"Fetching PR #{pr_number} data") result = self.base_codebase.get_modified_symbols_in_pr(pr_number) - + # Unpack the result tuple if len(result) >= 3: self.pr_diff, self.commit_shas, self.modified_symbols = result[:3] if len(result) >= 4: self.pr_branch = result[3] - + logger.info(f"Found {len(self.modified_symbols)} modified symbols in PR") - + # Initialize PR codebase self._init_pr_codebase() - + except Exception as e: - logger.error(f"Error initializing PR data: {e}") + logger.exception(f"Error initializing PR data: {e}") raise - + def _init_pr_codebase(self): """Initialize PR codebase by checking out the PR branch.""" if not self.base_codebase or not self.pr_number: logger.error("Base codebase or PR number not initialized") return - + try: # Get PR data if not already fetched if not self.pr_branch: self._init_pr_data(self.pr_number) - + if not self.pr_branch: logger.error("Failed to get PR branch") return - + # Clone the base codebase self.pr_codebase = self.base_codebase - + # Checkout PR branch logger.info(f"Checking out PR branch: {self.pr_branch}") self.pr_codebase.checkout(self.pr_branch) - + logger.info("Successfully initialized PR codebase") - + except Exception as e: - logger.error(f"Error initializing PR codebase: {e}") + logger.exception(f"Error initializing PR codebase: {e}") raise - + def _init_contexts(self): """Initialize CodebaseContext objects for both base and PR codebases.""" if self.base_codebase: @@ -277,75 +283,79 @@ def _init_contexts(self): codebase=self.base_codebase, base_path=self.repo_path, pr_branch=None, - base_branch=self.base_branch + base_branch=self.base_branch, ) logger.info("Successfully initialized base context") except Exception as e: - logger.error(f"Error initializing base context: {e}") - + logger.exception(f"Error initializing base context: {e}") + if self.pr_codebase: try: self.pr_context = CodebaseContext( codebase=self.pr_codebase, base_path=self.repo_path, pr_branch=self.pr_branch, - base_branch=self.base_branch + base_branch=self.base_branch, ) logger.info("Successfully initialized PR context") except Exception as e: - logger.error(f"Error initializing PR context: {e}") - + logger.exception(f"Error initializing PR context: {e}") + def add_issue(self, issue: Issue): """ Add an issue to the list of detected issues. - + Args: issue: Issue to add """ self.issues.append(issue) - - def get_issues(self, severity: Optional[IssueSeverity] = None, category: Optional[IssueCategory] = None) -> List[Issue]: + + def get_issues( + self, + severity: IssueSeverity | None = None, + category: IssueCategory | None = None, + ) -> list[Issue]: """ Get all issues matching the specified criteria. - + Args: severity: Optional severity level to filter by category: Optional category to filter by - + Returns: List of matching issues """ filtered_issues = self.issues - + if severity: filtered_issues = [i for i in filtered_issues if i.severity == severity] - + if category: filtered_issues = [i for i in filtered_issues if i.category == category] - + return filtered_issues - + def save_results(self, output_file: str): """ Save analysis results to a file. - + Args: output_file: Path to the output file """ - with open(output_file, 'w') as f: + with open(output_file, "w") as f: json.dump(self.results, f, indent=2) - + logger.info(f"Results saved to {output_file}") - + @abstractmethod - def analyze(self, analysis_type: AnalysisType) -> Dict[str, Any]: + def analyze(self, analysis_type: AnalysisType) -> dict[str, Any]: """ Perform analysis on the codebase. - + Args: analysis_type: Type of analysis to perform - + Returns: Dictionary containing analysis results """ - pass \ No newline at end of file + pass diff --git a/codegen-on-oss/codegen_on_oss/analyzers/code_quality.py b/codegen-on-oss/codegen_on_oss/analyzers/code_quality.py index f40c79eaf..1c93c6e8e 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/code_quality.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/code_quality.py @@ -7,106 +7,108 @@ functions with excessive complexity, parameter errors, and implementation problems. """ -import os -import re -import sys -import math import logging -from typing import Dict, List, Set, Tuple, Any, Optional, Union, cast +import math +import re +from typing import Any + +from codegen_on_oss.analyzers.codebase_context import CodebaseContext # Import from our own modules from codegen_on_oss.analyzers.issues import ( - Issue, IssueSeverity, IssueCategory, IssueCollection, - CodeLocation, create_issue, AnalysisType + IssueCategory, + IssueCollection, + IssueSeverity, + create_issue, ) -from codegen_on_oss.analyzers.codebase_context import CodebaseContext # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) + class CodeQualityAnalyzer: """ Analyzer for code quality issues. - + This class analyzes code quality issues in a codebase, including dead code, complexity, style, and maintainability issues. """ - + def __init__( - self, + self, codebase_context: CodebaseContext, - issue_collection: Optional[IssueCollection] = None + issue_collection: IssueCollection | None = None, ): """ Initialize the analyzer. - + Args: codebase_context: Context for the codebase to analyze issue_collection: Collection for storing issues """ self.context = codebase_context self.issues = issue_collection or IssueCollection() - + # Register default issue filters self._register_default_filters() - + def _register_default_filters(self): """Register default issue filters.""" # Filter out issues in test files self.issues.add_filter( lambda issue: "test" not in issue.location.file.lower(), - "Skip issues in test files" + "Skip issues in test files", ) - + # Filter out issues in generated files self.issues.add_filter( lambda issue: "generated" not in issue.location.file.lower(), - "Skip issues in generated files" + "Skip issues in generated files", ) - - def analyze(self) -> Dict[str, Any]: + + def analyze(self) -> dict[str, Any]: """ Perform code quality analysis. - + Returns: Dictionary containing analysis results """ logger.info("Starting code quality analysis") - + # Clear existing issues self.issues = IssueCollection() self._register_default_filters() - + # Analyze dead code dead_code = self._find_dead_code() - + # Analyze complexity complexity = self._analyze_complexity() - + # Analyze parameters parameter_issues = self._check_function_parameters() - + # Analyze style issues style_issues = self._check_style_issues() - + # Analyze implementations implementation_issues = self._check_implementations() - + # Analyze maintainability maintainability = self._calculate_maintainability() - + # Combine results results = { "summary": { "issue_count": len(self.issues.issues), "analyzed_functions": len(self.context.get_functions()), "analyzed_classes": len(self.context.get_classes()), - "analyzed_files": len(self.context.get_files()) + "analyzed_files": len(self.context.get_files()), }, "dead_code": dead_code, "complexity": complexity, @@ -114,176 +116,219 @@ def analyze(self) -> Dict[str, Any]: "style_issues": style_issues, "implementation_issues": implementation_issues, "maintainability": maintainability, - "issues": self.issues.to_dict() + "issues": self.issues.to_dict(), } - - logger.info(f"Code quality analysis complete. Found {len(self.issues.issues)} issues.") - + + logger.info( + f"Code quality analysis complete. Found {len(self.issues.issues)} issues." + ) + return results - - def _find_dead_code(self) -> Dict[str, Any]: + + def _find_dead_code(self) -> dict[str, Any]: """ Find unused code (dead code) in the codebase. - + Returns: Dictionary containing dead code analysis results """ logger.info("Analyzing dead code") - + dead_code = { "unused_functions": [], "unused_classes": [], "unused_variables": [], - "unused_imports": [] + "unused_imports": [], } - + # Find unused functions for function in self.context.get_functions(): # Skip if function should be excluded if self._should_skip_symbol(function): continue - + # Skip decorated functions (as they might be used indirectly) - if hasattr(function, 'decorators') and function.decorators: + if hasattr(function, "decorators") and function.decorators: continue - + # Check if function has no call sites or usages - has_call_sites = hasattr(function, 'call_sites') and len(function.call_sites) > 0 - has_usages = hasattr(function, 'usages') and len(function.usages) > 0 - + has_call_sites = ( + hasattr(function, "call_sites") and len(function.call_sites) > 0 + ) + has_usages = hasattr(function, "usages") and len(function.usages) > 0 + if not has_call_sites and not has_usages: # Skip magic methods and main functions - if (hasattr(function, 'is_magic') and function.is_magic) or ( - hasattr(function, 'name') and function.name in ['main', '__main__']): + if (hasattr(function, "is_magic") and function.is_magic) or ( + hasattr(function, "name") and function.name in ["main", "__main__"] + ): continue - + # Get file path and name safely - file_path = function.file.file_path if hasattr(function, 'file') and hasattr(function.file, 'file_path') else "unknown" - func_name = function.name if hasattr(function, 'name') else str(function) - + file_path = ( + function.file.file_path + if hasattr(function, "file") and hasattr(function.file, "file_path") + else "unknown" + ) + func_name = ( + function.name if hasattr(function, "name") else str(function) + ) + # Add to dead code list dead_code["unused_functions"].append({ "name": func_name, "file": file_path, - "line": function.line if hasattr(function, 'line') else None + "line": function.line if hasattr(function, "line") else None, }) - + # Add issue - self.issues.add_issue(create_issue( - message=f"Unused function: {func_name}", - severity=IssueSeverity.WARNING, - file=file_path, - line=function.line if hasattr(function, 'line') else None, - category=IssueCategory.DEAD_CODE, - symbol=func_name, - suggestion="Consider removing this unused function or documenting why it's needed" - )) - + self.issues.add_issue( + create_issue( + message=f"Unused function: {func_name}", + severity=IssueSeverity.WARNING, + file=file_path, + line=function.line if hasattr(function, "line") else None, + category=IssueCategory.DEAD_CODE, + symbol=func_name, + suggestion="Consider removing this unused function or documenting why it's needed", + ) + ) + # Find unused classes for cls in self.context.get_classes(): # Skip if class should be excluded if self._should_skip_symbol(cls): continue - + # Check if class has no usages - has_usages = hasattr(cls, 'usages') and len(cls.usages) > 0 - + has_usages = hasattr(cls, "usages") and len(cls.usages) > 0 + if not has_usages: # Get file path and name safely - file_path = cls.file.file_path if hasattr(cls, 'file') and hasattr(cls.file, 'file_path') else "unknown" - cls_name = cls.name if hasattr(cls, 'name') else str(cls) - + file_path = ( + cls.file.file_path + if hasattr(cls, "file") and hasattr(cls.file, "file_path") + else "unknown" + ) + cls_name = cls.name if hasattr(cls, "name") else str(cls) + # Add to dead code list dead_code["unused_classes"].append({ "name": cls_name, "file": file_path, - "line": cls.line if hasattr(cls, 'line') else None + "line": cls.line if hasattr(cls, "line") else None, }) - + # Add issue - self.issues.add_issue(create_issue( - message=f"Unused class: {cls_name}", - severity=IssueSeverity.WARNING, - file=file_path, - line=cls.line if hasattr(cls, 'line') else None, - category=IssueCategory.DEAD_CODE, - symbol=cls_name, - suggestion="Consider removing this unused class or documenting why it's needed" - )) - + self.issues.add_issue( + create_issue( + message=f"Unused class: {cls_name}", + severity=IssueSeverity.WARNING, + file=file_path, + line=cls.line if hasattr(cls, "line") else None, + category=IssueCategory.DEAD_CODE, + symbol=cls_name, + suggestion="Consider removing this unused class or documenting why it's needed", + ) + ) + # Find unused variables for function in self.context.get_functions(): - if not hasattr(function, 'code_block') or not hasattr(function.code_block, 'local_var_assignments'): + if not hasattr(function, "code_block") or not hasattr( + function.code_block, "local_var_assignments" + ): continue - + for var_assignment in function.code_block.local_var_assignments: # Check if variable has no usages - has_usages = hasattr(var_assignment, 'local_usages') and len(var_assignment.local_usages) > 0 - + has_usages = ( + hasattr(var_assignment, "local_usages") + and len(var_assignment.local_usages) > 0 + ) + if not has_usages: # Skip if variable name indicates it's intentionally unused (e.g., _) - var_name = var_assignment.name if hasattr(var_assignment, 'name') else str(var_assignment) + var_name = ( + var_assignment.name + if hasattr(var_assignment, "name") + else str(var_assignment) + ) if var_name == "_" or var_name.startswith("_unused"): continue - + # Get file path - file_path = function.file.file_path if hasattr(function, 'file') and hasattr(function.file, 'file_path') else "unknown" - + file_path = ( + function.file.file_path + if hasattr(function, "file") + and hasattr(function.file, "file_path") + else "unknown" + ) + # Add to dead code list dead_code["unused_variables"].append({ "name": var_name, "file": file_path, - "line": var_assignment.line if hasattr(var_assignment, 'line') else None, - "function": function.name if hasattr(function, 'name') else str(function) + "line": var_assignment.line + if hasattr(var_assignment, "line") + else None, + "function": function.name + if hasattr(function, "name") + else str(function), }) - + # Add issue - self.issues.add_issue(create_issue( - message=f"Unused variable '{var_name}' in function '{function.name if hasattr(function, 'name') else 'unknown'}'", - severity=IssueSeverity.INFO, - file=file_path, - line=var_assignment.line if hasattr(var_assignment, 'line') else None, - category=IssueCategory.DEAD_CODE, - symbol=var_name, - suggestion="Consider removing this unused variable" - )) - + self.issues.add_issue( + create_issue( + message=f"Unused variable '{var_name}' in function '{function.name if hasattr(function, 'name') else 'unknown'}'", + severity=IssueSeverity.INFO, + file=file_path, + line=var_assignment.line + if hasattr(var_assignment, "line") + else None, + category=IssueCategory.DEAD_CODE, + symbol=var_name, + suggestion="Consider removing this unused variable", + ) + ) + # Find unused imports for file in self.context.get_files(): - if hasattr(file, 'is_binary') and file.is_binary: + if hasattr(file, "is_binary") and file.is_binary: continue - - if not hasattr(file, 'imports'): + + if not hasattr(file, "imports"): continue - - file_path = file.file_path if hasattr(file, 'file_path') else str(file) - + + file_path = file.file_path if hasattr(file, "file_path") else str(file) + for imp in file.imports: - if not hasattr(imp, 'usages'): + if not hasattr(imp, "usages"): continue - + if len(imp.usages) == 0: # Get import source safely - import_source = imp.source if hasattr(imp, 'source') else str(imp) - + import_source = imp.source if hasattr(imp, "source") else str(imp) + # Add to dead code list dead_code["unused_imports"].append({ "import": import_source, "file": file_path, - "line": imp.line if hasattr(imp, 'line') else None + "line": imp.line if hasattr(imp, "line") else None, }) - + # Add issue - self.issues.add_issue(create_issue( - message=f"Unused import: {import_source}", - severity=IssueSeverity.INFO, - file=file_path, - line=imp.line if hasattr(imp, 'line') else None, - category=IssueCategory.DEAD_CODE, - code=import_source, - suggestion="Remove this unused import" - )) - + self.issues.add_issue( + create_issue( + message=f"Unused import: {import_source}", + severity=IssueSeverity.INFO, + file=file_path, + line=imp.line if hasattr(imp, "line") else None, + category=IssueCategory.DEAD_CODE, + code=import_source, + suggestion="Remove this unused import", + ) + ) + # Add summary statistics dead_code["summary"] = { "unused_functions_count": len(dead_code["unused_functions"]), @@ -291,24 +336,24 @@ def _find_dead_code(self) -> Dict[str, Any]: "unused_variables_count": len(dead_code["unused_variables"]), "unused_imports_count": len(dead_code["unused_imports"]), "total_dead_code_count": ( - len(dead_code["unused_functions"]) + - len(dead_code["unused_classes"]) + - len(dead_code["unused_variables"]) + - len(dead_code["unused_imports"]) - ) + len(dead_code["unused_functions"]) + + len(dead_code["unused_classes"]) + + len(dead_code["unused_variables"]) + + len(dead_code["unused_imports"]) + ), } - + return dead_code - - def _analyze_complexity(self) -> Dict[str, Any]: + + def _analyze_complexity(self) -> dict[str, Any]: """ Analyze code complexity. - + Returns: Dictionary containing complexity analysis results """ logger.info("Analyzing code complexity") - + complexity_result = { "function_complexity": [], "high_complexity_functions": [], @@ -317,42 +362,46 @@ def _analyze_complexity(self) -> Dict[str, Any]: "low": 0, "medium": 0, "high": 0, - "very_high": 0 - } + "very_high": 0, + }, } - + # Process all functions to calculate complexity total_complexity = 0 function_count = 0 - + for function in self.context.get_functions(): # Skip if function should be excluded if self._should_skip_symbol(function): continue - + # Skip if no code block - if not hasattr(function, 'code_block'): + if not hasattr(function, "code_block"): continue - + # Calculate cyclomatic complexity complexity = self._calculate_cyclomatic_complexity(function) - + # Get file path and name safely - file_path = function.file.file_path if hasattr(function, 'file') and hasattr(function.file, 'file_path') else "unknown" - func_name = function.name if hasattr(function, 'name') else str(function) - + file_path = ( + function.file.file_path + if hasattr(function, "file") and hasattr(function.file, "file_path") + else "unknown" + ) + func_name = function.name if hasattr(function, "name") else str(function) + # Add to complexity list complexity_result["function_complexity"].append({ "name": func_name, "file": file_path, - "line": function.line if hasattr(function, 'line') else None, - "complexity": complexity + "line": function.line if hasattr(function, "line") else None, + "complexity": complexity, }) - + # Track total complexity total_complexity += complexity function_count += 1 - + # Categorize complexity if complexity <= 5: complexity_result["complexity_distribution"]["low"] += 1 @@ -362,212 +411,246 @@ def _analyze_complexity(self) -> Dict[str, Any]: complexity_result["complexity_distribution"]["high"] += 1 else: complexity_result["complexity_distribution"]["very_high"] += 1 - + # Flag high complexity functions if complexity > 10: complexity_result["high_complexity_functions"].append({ "name": func_name, "file": file_path, - "line": function.line if hasattr(function, 'line') else None, - "complexity": complexity + "line": function.line if hasattr(function, "line") else None, + "complexity": complexity, }) - + # Add issue - severity = IssueSeverity.WARNING if complexity <= 15 else IssueSeverity.ERROR - self.issues.add_issue(create_issue( - message=f"Function '{func_name}' has high cyclomatic complexity ({complexity})", - severity=severity, - file=file_path, - line=function.line if hasattr(function, 'line') else None, - category=IssueCategory.COMPLEXITY, - symbol=func_name, - suggestion="Consider refactoring this function to reduce complexity" - )) - + severity = ( + IssueSeverity.WARNING if complexity <= 15 else IssueSeverity.ERROR + ) + self.issues.add_issue( + create_issue( + message=f"Function '{func_name}' has high cyclomatic complexity ({complexity})", + severity=severity, + file=file_path, + line=function.line if hasattr(function, "line") else None, + category=IssueCategory.COMPLEXITY, + symbol=func_name, + suggestion="Consider refactoring this function to reduce complexity", + ) + ) + # Calculate average complexity - complexity_result["average_complexity"] = total_complexity / function_count if function_count > 0 else 0.0 - + complexity_result["average_complexity"] = ( + total_complexity / function_count if function_count > 0 else 0.0 + ) + # Sort high complexity functions by complexity - complexity_result["high_complexity_functions"].sort(key=lambda x: x["complexity"], reverse=True) - + complexity_result["high_complexity_functions"].sort( + key=lambda x: x["complexity"], reverse=True + ) + return complexity_result - + def _calculate_cyclomatic_complexity(self, function) -> int: """ Calculate cyclomatic complexity for a function. - + Args: function: Function to analyze - + Returns: Cyclomatic complexity score """ complexity = 1 # Base complexity - + def analyze_statement(statement): nonlocal complexity - + # Check for if statements (including elif branches) - if hasattr(statement, 'if_clause'): + if hasattr(statement, "if_clause"): complexity += 1 - + # Count elif branches - if hasattr(statement, 'elif_statements'): + if hasattr(statement, "elif_statements"): complexity += len(statement.elif_statements) - + # Count else branches - if hasattr(statement, 'else_clause') and statement.else_clause: + if hasattr(statement, "else_clause") and statement.else_clause: complexity += 1 - + # Count for loops - if hasattr(statement, 'is_for_loop') and statement.is_for_loop: + if hasattr(statement, "is_for_loop") and statement.is_for_loop: complexity += 1 - + # Count while loops - if hasattr(statement, 'is_while_loop') and statement.is_while_loop: + if hasattr(statement, "is_while_loop") and statement.is_while_loop: complexity += 1 - + # Count try/except blocks (each except adds a path) - if hasattr(statement, 'is_try_block') and statement.is_try_block: - if hasattr(statement, 'except_clauses'): + if hasattr(statement, "is_try_block") and statement.is_try_block: + if hasattr(statement, "except_clauses"): complexity += len(statement.except_clauses) - + # Recursively process nested statements - if hasattr(statement, 'statements'): + if hasattr(statement, "statements"): for nested_stmt in statement.statements: analyze_statement(nested_stmt) - + # Process all statements in the function's code block - if hasattr(function, 'code_block') and hasattr(function.code_block, 'statements'): + if hasattr(function, "code_block") and hasattr( + function.code_block, "statements" + ): for statement in function.code_block.statements: analyze_statement(statement) - + # If we can't analyze the AST, fall back to simple pattern matching - elif hasattr(function, 'source'): + elif hasattr(function, "source"): source = function.source # Count branch points - complexity += source.count('if ') - complexity += source.count('elif ') - complexity += source.count('for ') - complexity += source.count('while ') - complexity += source.count('except:') - complexity += source.count('except ') - complexity += source.count('case ') - + complexity += source.count("if ") + complexity += source.count("elif ") + complexity += source.count("for ") + complexity += source.count("while ") + complexity += source.count("except:") + complexity += source.count("except ") + complexity += source.count("case ") + return complexity - - def _check_function_parameters(self) -> Dict[str, Any]: + + def _check_function_parameters(self) -> dict[str, Any]: """ Check for function parameter issues. - + Returns: Dictionary containing parameter analysis results """ logger.info("Analyzing function parameters") - + parameter_issues = { "missing_types": [], "inconsistent_types": [], "unused_parameters": [], - "incorrect_usage": [] + "incorrect_usage": [], } - + for function in self.context.get_functions(): # Skip if function should be excluded if self._should_skip_symbol(function): continue - + # Skip if no parameters - if not hasattr(function, 'parameters'): + if not hasattr(function, "parameters"): continue - - file_path = function.file.file_path if hasattr(function, 'file') and hasattr(function.file, 'file_path') else "unknown" - func_name = function.name if hasattr(function, 'name') else str(function) - + + file_path = ( + function.file.file_path + if hasattr(function, "file") and hasattr(function.file, "file_path") + else "unknown" + ) + func_name = function.name if hasattr(function, "name") else str(function) + # Check for missing type annotations missing_types = [] for param in function.parameters: - if not hasattr(param, 'name'): + if not hasattr(param, "name"): continue - - if not hasattr(param, 'type') or not param.type: + + if not hasattr(param, "type") or not param.type: missing_types.append(param.name) - + if missing_types: parameter_issues["missing_types"].append({ "function": func_name, "file": file_path, - "line": function.line if hasattr(function, 'line') else None, - "parameters": missing_types + "line": function.line if hasattr(function, "line") else None, + "parameters": missing_types, }) - - self.issues.add_issue(create_issue( - message=f"Function '{func_name}' has parameters without type annotations: {', '.join(missing_types)}", - severity=IssueSeverity.WARNING, - file=file_path, - line=function.line if hasattr(function, 'line') else None, - category=IssueCategory.TYPE_ERROR, - symbol=func_name, - suggestion="Add type annotations to all parameters" - )) - + + self.issues.add_issue( + create_issue( + message=f"Function '{func_name}' has parameters without type annotations: {', '.join(missing_types)}", + severity=IssueSeverity.WARNING, + file=file_path, + line=function.line if hasattr(function, "line") else None, + category=IssueCategory.TYPE_ERROR, + symbol=func_name, + suggestion="Add type annotations to all parameters", + ) + ) + # Check for unused parameters - if hasattr(function, 'source'): + if hasattr(function, "source"): # This is a simple check that looks for parameter names in the function body # A more sophisticated check would analyze the AST unused_params = [] for param in function.parameters: - if not hasattr(param, 'name'): + if not hasattr(param, "name"): continue - + # Skip self/cls parameter in methods - if param.name in ['self', 'cls'] and hasattr(function, 'parent') and function.parent: + if ( + param.name in ["self", "cls"] + and hasattr(function, "parent") + and function.parent + ): continue - + # Check if parameter name appears in function body # This is a simple heuristic and may produce false positives - param_regex = r'\b' + re.escape(param.name) + r'\b' - body_lines = function.source.split('\n')[1:] if function.source.count('\n') > 0 else [] - body_text = '\n'.join(body_lines) - + param_regex = r"\b" + re.escape(param.name) + r"\b" + body_lines = ( + function.source.split("\n")[1:] + if function.source.count("\n") > 0 + else [] + ) + body_text = "\n".join(body_lines) + if not re.search(param_regex, body_text): unused_params.append(param.name) - + if unused_params: parameter_issues["unused_parameters"].append({ "function": func_name, "file": file_path, - "line": function.line if hasattr(function, 'line') else None, - "parameters": unused_params + "line": function.line if hasattr(function, "line") else None, + "parameters": unused_params, }) - - self.issues.add_issue(create_issue( - message=f"Function '{func_name}' has unused parameters: {', '.join(unused_params)}", - severity=IssueSeverity.INFO, - file=file_path, - line=function.line if hasattr(function, 'line') else None, - category=IssueCategory.DEAD_CODE, - symbol=func_name, - suggestion="Remove unused parameters or use them in the function body" - )) - + + self.issues.add_issue( + create_issue( + message=f"Function '{func_name}' has unused parameters: {', '.join(unused_params)}", + severity=IssueSeverity.INFO, + file=file_path, + line=function.line if hasattr(function, "line") else None, + category=IssueCategory.DEAD_CODE, + symbol=func_name, + suggestion="Remove unused parameters or use them in the function body", + ) + ) + # Check for incorrect parameter usage at call sites - if hasattr(function, 'call_sites'): + if hasattr(function, "call_sites"): for call_site in function.call_sites: # Skip if call site has no arguments - if not hasattr(call_site, 'args'): + if not hasattr(call_site, "args"): continue - + # Get required parameter count (excluding those with defaults) required_count = 0 - if hasattr(function, 'parameters'): - required_count = sum(1 for p in function.parameters - if not hasattr(p, 'has_default') or not p.has_default) - + if hasattr(function, "parameters"): + required_count = sum( + 1 + for p in function.parameters + if not hasattr(p, "has_default") or not p.has_default + ) + # Get call site file info - call_file = call_site.file.file_path if hasattr(call_site, 'file') and hasattr(call_site.file, 'file_path') else "unknown" - call_line = call_site.line if hasattr(call_site, 'line') else None - + call_file = ( + call_site.file.file_path + if hasattr(call_site, "file") + and hasattr(call_site.file, "file_path") + else "unknown" + ) + call_line = call_site.line if hasattr(call_site, "line") else None + # Check parameter count arg_count = len(call_site.args) if arg_count < required_count: @@ -576,76 +659,127 @@ def _check_function_parameters(self) -> Dict[str, Any]: "caller_file": call_file, "caller_line": call_line, "required_count": required_count, - "provided_count": arg_count + "provided_count": arg_count, }) - - self.issues.add_issue(create_issue( - message=f"Call to '{func_name}' has too few arguments ({arg_count} provided, {required_count} required)", - severity=IssueSeverity.ERROR, - file=call_file, - line=call_line, - category=IssueCategory.PARAMETER_MISMATCH, - symbol=func_name, - suggestion=f"Provide all required arguments to '{func_name}'" - )) - + + self.issues.add_issue( + create_issue( + message=f"Call to '{func_name}' has too few arguments ({arg_count} provided, {required_count} required)", + severity=IssueSeverity.ERROR, + file=call_file, + line=call_line, + category=IssueCategory.PARAMETER_MISMATCH, + symbol=func_name, + suggestion=f"Provide all required arguments to '{func_name}'", + ) + ) + # Check for inconsistent parameter types across overloaded functions functions_by_name = {} for function in self.context.get_functions(): - if hasattr(function, 'name'): + if hasattr(function, "name"): if function.name not in functions_by_name: functions_by_name[function.name] = [] functions_by_name[function.name].append(function) - + for func_name, overloads in functions_by_name.items(): if len(overloads) > 1: # Check for inconsistent parameter types for i, func1 in enumerate(overloads): - for func2 in overloads[i+1:]: + for func2 in overloads[i + 1 :]: inconsistent_types = [] - + # Skip if either function has no parameters - if not hasattr(func1, 'parameters') or not hasattr(func2, 'parameters'): + if not hasattr(func1, "parameters") or not hasattr( + func2, "parameters" + ): continue - + # Get common parameter names - func1_param_names = {p.name for p in func1.parameters if hasattr(p, 'name')} - func2_param_names = {p.name for p in func2.parameters if hasattr(p, 'name')} - common_params = func1_param_names.intersection(func2_param_names) - + func1_param_names = { + p.name for p in func1.parameters if hasattr(p, "name") + } + func2_param_names = { + p.name for p in func2.parameters if hasattr(p, "name") + } + common_params = func1_param_names.intersection( + func2_param_names + ) + # Check parameter types for param_name in common_params: # Get parameter objects - param1 = next((p for p in func1.parameters if hasattr(p, 'name') and p.name == param_name), None) - param2 = next((p for p in func2.parameters if hasattr(p, 'name') and p.name == param_name), None) - - if param1 and param2 and hasattr(param1, 'type') and hasattr(param2, 'type'): - if param1.type and param2.type and str(param1.type) != str(param2.type): - inconsistent_types.append({ - "parameter": param_name, - "type1": str(param1.type), - "type2": str(param2.type), - "function1": f"{func1.file.file_path}:{func1.line}" if hasattr(func1, 'file') and hasattr(func1.file, 'file_path') and hasattr(func1, 'line') else str(func1), - "function2": f"{func2.file.file_path}:{func2.line}" if hasattr(func2, 'file') and hasattr(func2.file, 'file_path') and hasattr(func2, 'line') else str(func2) - }) - + param1 = next( + ( + p + for p in func1.parameters + if hasattr(p, "name") and p.name == param_name + ), + None, + ) + param2 = next( + ( + p + for p in func2.parameters + if hasattr(p, "name") and p.name == param_name + ), + None, + ) + + if ( + param1 + and param2 + and hasattr(param1, "type") + and hasattr(param2, "type") + ) and ( + param1.type + and param2.type + and str(param1.type) != str(param2.type) + ): + inconsistent_types.append({ + "parameter": param_name, + "type1": str(param1.type), + "type2": str(param2.type), + "function1": f"{func1.file.file_path}:{func1.line}" + if hasattr(func1, "file") + and hasattr(func1.file, "file_path") + and hasattr(func1, "line") + else str(func1), + "function2": f"{func2.file.file_path}:{func2.line}" + if hasattr(func2, "file") + and hasattr(func2.file, "file_path") + and hasattr(func2, "line") + else str(func2), + }) + if inconsistent_types: - parameter_issues["inconsistent_types"].extend(inconsistent_types) - + parameter_issues["inconsistent_types"].extend( + inconsistent_types + ) + for issue in inconsistent_types: - func1_file = func1.file.file_path if hasattr(func1, 'file') and hasattr(func1.file, 'file_path') else "unknown" - func1_line = func1.line if hasattr(func1, 'line') else None - - self.issues.add_issue(create_issue( - message=f"Inconsistent types for parameter '{issue['parameter']}': {issue['type1']} vs {issue['type2']}", - severity=IssueSeverity.ERROR, - file=func1_file, - line=func1_line, - category=IssueCategory.TYPE_ERROR, - symbol=func_name, - suggestion="Use consistent parameter types across function overloads" - )) - + func1_file = ( + func1.file.file_path + if hasattr(func1, "file") + and hasattr(func1.file, "file_path") + else "unknown" + ) + func1_line = ( + func1.line if hasattr(func1, "line") else None + ) + + self.issues.add_issue( + create_issue( + message=f"Inconsistent types for parameter '{issue['parameter']}': {issue['type1']} vs {issue['type2']}", + severity=IssueSeverity.ERROR, + file=func1_file, + line=func1_line, + category=IssueCategory.TYPE_ERROR, + symbol=func_name, + suggestion="Use consistent parameter types across function overloads", + ) + ) + # Add summary statistics parameter_issues["summary"] = { "missing_types_count": len(parameter_issues["missing_types"]), @@ -653,24 +787,24 @@ def _check_function_parameters(self) -> Dict[str, Any]: "unused_parameters_count": len(parameter_issues["unused_parameters"]), "incorrect_usage_count": len(parameter_issues["incorrect_usage"]), "total_issues": ( - len(parameter_issues["missing_types"]) + - len(parameter_issues["inconsistent_types"]) + - len(parameter_issues["unused_parameters"]) + - len(parameter_issues["incorrect_usage"]) - ) + len(parameter_issues["missing_types"]) + + len(parameter_issues["inconsistent_types"]) + + len(parameter_issues["unused_parameters"]) + + len(parameter_issues["incorrect_usage"]) + ), } - + return parameter_issues - - def _check_style_issues(self) -> Dict[str, Any]: + + def _check_style_issues(self) -> dict[str, Any]: """ Check for code style issues. - + Returns: Dictionary containing style analysis results """ logger.info("Analyzing code style") - + style_result = { "long_functions": [], "long_lines": [], @@ -678,105 +812,124 @@ def _check_style_issues(self) -> Dict[str, Any]: "summary": { "long_functions_count": 0, "long_lines_count": 0, - "inconsistent_naming_count": 0 - } + "inconsistent_naming_count": 0, + }, } - + # Check for long functions (too many lines) for function in self.context.get_functions(): # Skip if function should be excluded if self._should_skip_symbol(function): continue - + # Get function code - if hasattr(function, 'source'): + if hasattr(function, "source"): code = function.source - lines = code.split('\n') - + lines = code.split("\n") + # Check function length if len(lines) > 50: # Threshold for "too long" # Get file path and name safely - file_path = function.file.file_path if hasattr(function, 'file') and hasattr(function.file, 'file_path') else "unknown" - func_name = function.name if hasattr(function, 'name') else str(function) - + file_path = ( + function.file.file_path + if hasattr(function, "file") + and hasattr(function.file, "file_path") + else "unknown" + ) + func_name = ( + function.name if hasattr(function, "name") else str(function) + ) + # Add to long functions list style_result["long_functions"].append({ "name": func_name, "file": file_path, - "line": function.line if hasattr(function, 'line') else None, - "line_count": len(lines) + "line": function.line if hasattr(function, "line") else None, + "line_count": len(lines), }) - + # Add issue - self.issues.add_issue(create_issue( - message=f"Function '{func_name}' is too long ({len(lines)} lines)", - severity=IssueSeverity.INFO, - file=file_path, - line=function.line if hasattr(function, 'line') else None, - category=IssueCategory.STYLE_ISSUE, - symbol=func_name, - suggestion="Consider breaking this function into smaller, more focused functions" - )) - + self.issues.add_issue( + create_issue( + message=f"Function '{func_name}' is too long ({len(lines)} lines)", + severity=IssueSeverity.INFO, + file=file_path, + line=function.line if hasattr(function, "line") else None, + category=IssueCategory.STYLE_ISSUE, + symbol=func_name, + suggestion="Consider breaking this function into smaller, more focused functions", + ) + ) + # Check for long lines for file in self.context.get_files(): # Skip binary files - if hasattr(file, 'is_binary') and file.is_binary: + if hasattr(file, "is_binary") and file.is_binary: continue - + # Get file content - if hasattr(file, 'content'): - lines = file.content.split('\n') - file_path = file.file_path if hasattr(file, 'file_path') else str(file) - + if hasattr(file, "content"): + lines = file.content.split("\n") + file_path = file.file_path if hasattr(file, "file_path") else str(file) + # Find long lines for i, line in enumerate(lines): if len(line) > 100: # Threshold for "too long" # Skip comment lines - if line.lstrip().startswith('#') or line.lstrip().startswith('//'): + if line.lstrip().startswith("#") or line.lstrip().startswith( + "//" + ): continue - + # Skip lines with strings that can't be easily broken if '"' in line or "'" in line: # If the line is mostly a string, skip it if line.count('"') >= 2 or line.count("'") >= 2: continue - + # Add to long lines list style_result["long_lines"].append({ "file": file_path, "line_number": i + 1, "line_length": len(line), - "line_content": line[:50] + "..." if len(line) > 50 else line + "line_content": line[:50] + "..." + if len(line) > 50 + else line, }) - + # Add issue (only for very long lines) if len(line) > 120: - self.issues.add_issue(create_issue( - message=f"Line is too long ({len(line)} characters)", - severity=IssueSeverity.INFO, - file=file_path, - line=i + 1, - category=IssueCategory.STYLE_ISSUE, - suggestion="Consider breaking this line into multiple lines" - )) - + self.issues.add_issue( + create_issue( + message=f"Line is too long ({len(line)} characters)", + severity=IssueSeverity.INFO, + file=file_path, + line=i + 1, + category=IssueCategory.STYLE_ISSUE, + suggestion="Consider breaking this line into multiple lines", + ) + ) + # Update summary - style_result["summary"]["long_functions_count"] = len(style_result["long_functions"]) + style_result["summary"]["long_functions_count"] = len( + style_result["long_functions"] + ) style_result["summary"]["long_lines_count"] = len(style_result["long_lines"]) - style_result["summary"]["inconsistent_naming_count"] = len(style_result["inconsistent_naming"]) - + style_result["summary"]["inconsistent_naming_count"] = len( + style_result["inconsistent_naming"] + ) + return style_result - - def _check_implementations(self) -> Dict[str, Any]: + + def _check_implementations(self) -> dict[str, Any]: """ Check for implementation issues. - + Returns: Dictionary containing implementation analysis results """ logger.info("Analyzing implementations") - + implementation_issues = { "unimplemented_functions": [], "empty_functions": [], @@ -786,231 +939,286 @@ def _check_implementations(self) -> Dict[str, Any]: "unimplemented_functions_count": 0, "empty_functions_count": 0, "abstract_methods_without_implementation_count": 0, - "interface_methods_not_implemented_count": 0 - } + "interface_methods_not_implemented_count": 0, + }, } - + # Check for empty functions for function in self.context.get_functions(): # Skip if function should be excluded if self._should_skip_symbol(function): continue - + # Get function source - if hasattr(function, 'source'): + if hasattr(function, "source"): source = function.source - + # Check if function is empty or just has 'pass' is_empty = False - + if not source or source.strip() == "": is_empty = True else: # Extract function body (skip the first line with the def) - body_lines = source.split('\n')[1:] if '\n' in source else [] - + body_lines = source.split("\n")[1:] if "\n" in source else [] + # Check if body is empty or just has whitespace, docstring, or pass non_empty_lines = [ - line for line in body_lines - if line.strip() and - not line.strip().startswith('#') and - not (line.strip().startswith('"""') or line.strip().startswith("'''")) and - not line.strip() == 'pass' + line + for line in body_lines + if line.strip() + and not line.strip().startswith("#") + and not ( + line.strip().startswith('"""') + or line.strip().startswith("'''") + ) + and line.strip() != "pass" ] - + if not non_empty_lines: is_empty = True - + if is_empty: # Get file path and name safely - file_path = function.file.file_path if hasattr(function, 'file') and hasattr(function.file, 'file_path') else "unknown" - func_name = function.name if hasattr(function, 'name') else str(function) - + file_path = ( + function.file.file_path + if hasattr(function, "file") + and hasattr(function.file, "file_path") + else "unknown" + ) + func_name = ( + function.name if hasattr(function, "name") else str(function) + ) + # Skip interface/abstract methods that are supposed to be empty is_abstract = ( - hasattr(function, 'is_abstract') and function.is_abstract or - hasattr(function, 'parent') and hasattr(function.parent, 'is_interface') and function.parent.is_interface + hasattr(function, "is_abstract") and function.is_abstract + ) or ( + hasattr(function, "parent") + and hasattr(function.parent, "is_interface") + and function.parent.is_interface ) - + if not is_abstract: # Add to empty functions list implementation_issues["empty_functions"].append({ "name": func_name, "file": file_path, - "line": function.line if hasattr(function, 'line') else None + "line": function.line + if hasattr(function, "line") + else None, }) - + # Add issue - self.issues.add_issue(create_issue( - message=f"Function '{func_name}' is empty", - severity=IssueSeverity.WARNING, - file=file_path, - line=function.line if hasattr(function, 'line') else None, - category=IssueCategory.MISSING_IMPLEMENTATION, - symbol=func_name, - suggestion="Implement this function or remove it if not needed" - )) - + self.issues.add_issue( + create_issue( + message=f"Function '{func_name}' is empty", + severity=IssueSeverity.WARNING, + file=file_path, + line=function.line + if hasattr(function, "line") + else None, + category=IssueCategory.MISSING_IMPLEMENTATION, + symbol=func_name, + suggestion="Implement this function or remove it if not needed", + ) + ) + # Check for abstract methods without implementations abstract_methods = [] for function in self.context.get_functions(): # Skip if function should be excluded if self._should_skip_symbol(function): continue - + # Check if function is abstract is_abstract = ( - hasattr(function, 'is_abstract') and function.is_abstract or - hasattr(function, 'decorators') and any( - hasattr(d, 'name') and d.name in ['abstractmethod', 'abc.abstractmethod'] + hasattr(function, "is_abstract") and function.is_abstract + ) or ( + hasattr(function, "decorators") + and any( + hasattr(d, "name") + and d.name in ["abstractmethod", "abc.abstractmethod"] for d in function.decorators ) ) - - if is_abstract and hasattr(function, 'parent') and hasattr(function, 'name'): + + if ( + is_abstract + and hasattr(function, "parent") + and hasattr(function, "name") + ): abstract_methods.append((function.parent, function.name)) - + # For each abstract method, check if it has implementations in subclasses for parent, method_name in abstract_methods: - if not hasattr(parent, 'name'): + if not hasattr(parent, "name"): continue - + parent_name = parent.name - + # Find all subclasses subclasses = [] for cls in self.context.get_classes(): - if hasattr(cls, 'superclasses'): + if hasattr(cls, "superclasses"): for superclass in cls.superclasses: - if hasattr(superclass, 'name') and superclass.name == parent_name: + if ( + hasattr(superclass, "name") + and superclass.name == parent_name + ): subclasses.append(cls) - + # Check if method is implemented in all subclasses for subclass in subclasses: - if not hasattr(subclass, 'methods'): + if not hasattr(subclass, "methods"): continue - + # Check if method is implemented implemented = any( - hasattr(m, 'name') and m.name == method_name + hasattr(m, "name") and m.name == method_name for m in subclass.methods ) - + if not implemented: # Get file path and name safely - file_path = subclass.file.file_path if hasattr(subclass, 'file') and hasattr(subclass.file, 'file_path') else "unknown" - cls_name = subclass.name if hasattr(subclass, 'name') else str(subclass) - + file_path = ( + subclass.file.file_path + if hasattr(subclass, "file") + and hasattr(subclass.file, "file_path") + else "unknown" + ) + cls_name = ( + subclass.name if hasattr(subclass, "name") else str(subclass) + ) + # Add to unimplemented list - implementation_issues["abstract_methods_without_implementation"].append({ + implementation_issues[ + "abstract_methods_without_implementation" + ].append({ "method": method_name, "parent_class": parent_name, "subclass": cls_name, "file": file_path, - "line": subclass.line if hasattr(subclass, 'line') else None + "line": subclass.line if hasattr(subclass, "line") else None, }) - + # Add issue - self.issues.add_issue(create_issue( - message=f"Class '{cls_name}' does not implement abstract method '{method_name}' from '{parent_name}'", - severity=IssueSeverity.ERROR, - file=file_path, - line=subclass.line if hasattr(subclass, 'line') else None, - category=IssueCategory.MISSING_IMPLEMENTATION, - symbol=cls_name, - suggestion=f"Implement the '{method_name}' method in '{cls_name}'" - )) - + self.issues.add_issue( + create_issue( + message=f"Class '{cls_name}' does not implement abstract method '{method_name}' from '{parent_name}'", + severity=IssueSeverity.ERROR, + file=file_path, + line=subclass.line if hasattr(subclass, "line") else None, + category=IssueCategory.MISSING_IMPLEMENTATION, + symbol=cls_name, + suggestion=f"Implement the '{method_name}' method in '{cls_name}'", + ) + ) + # Update summary - implementation_issues["summary"]["unimplemented_functions_count"] = len(implementation_issues["unimplemented_functions"]) - implementation_issues["summary"]["empty_functions_count"] = len(implementation_issues["empty_functions"]) - implementation_issues["summary"]["abstract_methods_without_implementation_count"] = len(implementation_issues["abstract_methods_without_implementation"]) - implementation_issues["summary"]["interface_methods_not_implemented_count"] = len(implementation_issues["interface_methods_not_implemented"]) - + implementation_issues["summary"]["unimplemented_functions_count"] = len( + implementation_issues["unimplemented_functions"] + ) + implementation_issues["summary"]["empty_functions_count"] = len( + implementation_issues["empty_functions"] + ) + implementation_issues["summary"][ + "abstract_methods_without_implementation_count" + ] = len(implementation_issues["abstract_methods_without_implementation"]) + implementation_issues["summary"]["interface_methods_not_implemented_count"] = ( + len(implementation_issues["interface_methods_not_implemented"]) + ) + return implementation_issues - - def _calculate_maintainability(self) -> Dict[str, Any]: + + def _calculate_maintainability(self) -> dict[str, Any]: """ Calculate maintainability metrics. - + Returns: Dictionary containing maintainability analysis results """ logger.info("Analyzing maintainability") - + maintainability_result = { "function_maintainability": [], "low_maintainability_functions": [], "average_maintainability": 0.0, - "maintainability_distribution": { - "high": 0, - "medium": 0, - "low": 0 - } + "maintainability_distribution": {"high": 0, "medium": 0, "low": 0}, } - + # Process all functions to calculate maintainability total_maintainability = 0 function_count = 0 - + for function in self.context.get_functions(): # Skip if function should be excluded if self._should_skip_symbol(function): continue - + # Skip if no code block - if not hasattr(function, 'code_block'): + if not hasattr(function, "code_block"): continue - + # Calculate metrics complexity = self._calculate_cyclomatic_complexity(function) - + # Calculate Halstead volume (approximation) operators = 0 operands = 0 - - if hasattr(function, 'source'): + + if hasattr(function, "source"): code = function.source # Simple approximation of operators and operands - operators = len([c for c in code if c in '+-*/=<>!&|^~%']) + operators = len([c for c in code if c in "+-*/=<>!&|^~%"]) # Counting words as potential operands - operands = len(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', code)) - - halstead_volume = operators * operands * math.log2(operators + operands) if operators + operands > 0 else 0 - + operands = len(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", code)) + + halstead_volume = ( + operators * operands * math.log2(operators + operands) + if operators + operands > 0 + else 0 + ) + # Count lines of code - loc = len(function.source.split('\n')) if hasattr(function, 'source') else 0 - + loc = len(function.source.split("\n")) if hasattr(function, "source") else 0 + # Calculate maintainability index # Formula: 171 - 5.2 * ln(Halstead Volume) - 0.23 * (Cyclomatic Complexity) - 16.2 * ln(LOC) - halstead_term = 5.2 * math.log(max(1, halstead_volume)) if halstead_volume > 0 else 0 + halstead_term = ( + 5.2 * math.log(max(1, halstead_volume)) if halstead_volume > 0 else 0 + ) complexity_term = 0.23 * complexity loc_term = 16.2 * math.log(max(1, loc)) if loc > 0 else 0 - + maintainability = 171 - halstead_term - complexity_term - loc_term - + # Normalize to 0-100 scale maintainability = max(0, min(100, maintainability * 100 / 171)) - + # Get file path and name safely - file_path = function.file.file_path if hasattr(function, 'file') and hasattr(function.file, 'file_path') else "unknown" - func_name = function.name if hasattr(function, 'name') else str(function) - + file_path = ( + function.file.file_path + if hasattr(function, "file") and hasattr(function.file, "file_path") + else "unknown" + ) + func_name = function.name if hasattr(function, "name") else str(function) + # Add to maintainability list maintainability_result["function_maintainability"].append({ "name": func_name, "file": file_path, - "line": function.line if hasattr(function, 'line') else None, + "line": function.line if hasattr(function, "line") else None, "maintainability": maintainability, "complexity": complexity, "halstead_volume": halstead_volume, - "loc": loc + "loc": loc, }) - + # Track total maintainability total_maintainability += maintainability function_count += 1 - + # Categorize maintainability if maintainability >= 70: maintainability_result["maintainability_distribution"]["high"] += 1 @@ -1018,85 +1226,84 @@ def _calculate_maintainability(self) -> Dict[str, Any]: maintainability_result["maintainability_distribution"]["medium"] += 1 else: maintainability_result["maintainability_distribution"]["low"] += 1 - + # Flag low maintainability functions maintainability_result["low_maintainability_functions"].append({ "name": func_name, "file": file_path, - "line": function.line if hasattr(function, 'line') else None, + "line": function.line if hasattr(function, "line") else None, "maintainability": maintainability, "complexity": complexity, "halstead_volume": halstead_volume, - "loc": loc + "loc": loc, }) - + # Add issue - self.issues.add_issue(create_issue( - message=f"Function '{func_name}' has low maintainability index ({maintainability:.1f})", - severity=IssueSeverity.WARNING, - file=file_path, - line=function.line if hasattr(function, 'line') else None, - category=IssueCategory.COMPLEXITY, - symbol=func_name, - suggestion="Consider refactoring this function to improve maintainability" - )) - + self.issues.add_issue( + create_issue( + message=f"Function '{func_name}' has low maintainability index ({maintainability:.1f})", + severity=IssueSeverity.WARNING, + file=file_path, + line=function.line if hasattr(function, "line") else None, + category=IssueCategory.COMPLEXITY, + symbol=func_name, + suggestion="Consider refactoring this function to improve maintainability", + ) + ) + # Calculate average maintainability - maintainability_result["average_maintainability"] = total_maintainability / function_count if function_count > 0 else 0.0 - + maintainability_result["average_maintainability"] = ( + total_maintainability / function_count if function_count > 0 else 0.0 + ) + # Sort low maintainability functions - maintainability_result["low_maintainability_functions"].sort(key=lambda x: x["maintainability"]) - + maintainability_result["low_maintainability_functions"].sort( + key=lambda x: x["maintainability"] + ) + return maintainability_result - + def _should_skip_symbol(self, symbol) -> bool: """ Check if a symbol should be skipped during analysis. - + Args: symbol: Symbol to check - + Returns: True if the symbol should be skipped, False otherwise """ # Skip if no file - if not hasattr(symbol, 'file'): + if not hasattr(symbol, "file"): return True - + # Skip if file should be skipped - if self._should_skip_file(symbol.file): - return True - - return False - + return bool(self._should_skip_file(symbol.file)) + def _should_skip_file(self, file) -> bool: """ Check if a file should be skipped during analysis. - + Args: file: File to check - + Returns: True if the file should be skipped, False otherwise """ # Skip binary files - if hasattr(file, 'is_binary') and file.is_binary: + if hasattr(file, "is_binary") and file.is_binary: return True - + # Get file path - file_path = file.file_path if hasattr(file, 'file_path') else str(file) - + file_path = file.file_path if hasattr(file, "file_path") else str(file) + # Skip test files if "test" in file_path.lower(): return True - + # Skip generated files if "generated" in file_path.lower(): return True - + # Skip files in ignore list - for pattern in self.context.file_ignore_list: - if pattern in file_path: - return True - - return False \ No newline at end of file + return any(pattern in file_path for pattern in self.context.file_ignore_list) diff --git a/codegen-on-oss/codegen_on_oss/analyzers/code_quality_analyzer.py b/codegen-on-oss/codegen_on_oss/analyzers/code_quality_analyzer.py index 8e8983e4d..90da4122d 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/code_quality_analyzer.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/code_quality_analyzer.py @@ -6,52 +6,60 @@ dead code, complexity, style, and maintainability. """ -import os -import sys -import math import logging -from typing import Dict, List, Set, Tuple, Any, Optional, Union +import math +from typing import Any from codegen_on_oss.analyzers.base_analyzer import BaseCodeAnalyzer -from codegen_on_oss.analyzers.issue_types import Issue, IssueSeverity, AnalysisType, IssueCategory +from codegen_on_oss.analyzers.issue_types import ( + AnalysisType, + Issue, + IssueCategory, + IssueSeverity, +) # Configure logging logger = logging.getLogger(__name__) + class CodeQualityAnalyzer(BaseCodeAnalyzer): """ Analyzer for code quality issues. - + This analyzer detects issues related to code quality, including dead code, complexity, style, and maintainability. """ - - def analyze(self, analysis_type: AnalysisType = AnalysisType.CODE_QUALITY) -> Dict[str, Any]: + + def analyze( + self, analysis_type: AnalysisType = AnalysisType.CODE_QUALITY + ) -> dict[str, Any]: """ Perform code quality analysis on the codebase. - + Args: analysis_type: Type of analysis to perform - + Returns: Dictionary containing analysis results """ if not self.base_codebase: raise ValueError("Codebase not initialized") - + result = { "metadata": { "analysis_time": str(datetime.now()), "analysis_type": analysis_type, - "repo_name": getattr(self.base_codebase.ctx, 'repo_name', None), - "language": str(getattr(self.base_codebase.ctx, 'programming_language', None)), + "repo_name": getattr(self.base_codebase.ctx, "repo_name", None), + "language": str( + getattr(self.base_codebase.ctx, "programming_language", None) + ), }, "summary": {}, } - + # Reset issues list self.issues = [] - + # Perform appropriate analysis based on type if analysis_type == AnalysisType.CODE_QUALITY: # Run all code quality checks @@ -59,33 +67,47 @@ def analyze(self, analysis_type: AnalysisType = AnalysisType.CODE_QUALITY) -> Di result["complexity"] = self._analyze_code_complexity() result["style_issues"] = self._check_style_issues() result["maintainability"] = self._calculate_maintainability() - + # Add issues to the result result["issues"] = [issue.to_dict() for issue in self.issues] result["issue_counts"] = { "total": len(self.issues), "by_severity": { - "critical": sum(1 for issue in self.issues if issue.severity == IssueSeverity.CRITICAL), - "error": sum(1 for issue in self.issues if issue.severity == IssueSeverity.ERROR), - "warning": sum(1 for issue in self.issues if issue.severity == IssueSeverity.WARNING), - "info": sum(1 for issue in self.issues if issue.severity == IssueSeverity.INFO), + "critical": sum( + 1 + for issue in self.issues + if issue.severity == IssueSeverity.CRITICAL + ), + "error": sum( + 1 for issue in self.issues if issue.severity == IssueSeverity.ERROR + ), + "warning": sum( + 1 + for issue in self.issues + if issue.severity == IssueSeverity.WARNING + ), + "info": sum( + 1 for issue in self.issues if issue.severity == IssueSeverity.INFO + ), }, "by_category": { - category.value: sum(1 for issue in self.issues if issue.category == category) + category.value: sum( + 1 for issue in self.issues if issue.category == category + ) for category in IssueCategory if any(issue.category == category for issue in self.issues) - } + }, } - + # Store results self.results = result - + return result - - def _find_dead_code(self) -> Dict[str, Any]: + + def _find_dead_code(self) -> dict[str, Any]: """ Find unused code (dead code) in the codebase. - + Returns: Dictionary containing dead code analysis results """ @@ -93,113 +115,154 @@ def _find_dead_code(self) -> Dict[str, Any]: "unused_functions": [], "unused_classes": [], "unused_variables": [], - "unused_imports": [] + "unused_imports": [], } - + # Find unused functions - if hasattr(self.base_codebase, 'functions'): + if hasattr(self.base_codebase, "functions"): for func in self.base_codebase.functions: # Skip test files - if hasattr(func, 'file') and hasattr(func.file, 'filepath') and "test" in func.file.filepath: + if ( + hasattr(func, "file") + and hasattr(func.file, "filepath") + and "test" in func.file.filepath + ): continue - + # Skip decorated functions (as they might be used indirectly) - if hasattr(func, 'decorators') and func.decorators: + if hasattr(func, "decorators") and func.decorators: continue - + # Check if function has no call sites or usages - has_call_sites = hasattr(func, 'call_sites') and len(func.call_sites) > 0 - has_usages = hasattr(func, 'usages') and len(func.usages) > 0 - + has_call_sites = ( + hasattr(func, "call_sites") and len(func.call_sites) > 0 + ) + has_usages = hasattr(func, "usages") and len(func.usages) > 0 + if not has_call_sites and not has_usages: # Get file path and name safely - file_path = func.file.filepath if hasattr(func, 'file') and hasattr(func.file, 'filepath') else "unknown" - func_name = func.name if hasattr(func, 'name') else str(func) - + file_path = ( + func.file.filepath + if hasattr(func, "file") and hasattr(func.file, "filepath") + else "unknown" + ) + func_name = func.name if hasattr(func, "name") else str(func) + # Add to dead code list dead_code["unused_functions"].append({ "name": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None + "line": func.line if hasattr(func, "line") else None, }) - + # Add issue - self.add_issue(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Unused function: {func_name}", - severity=IssueSeverity.WARNING, - category=IssueCategory.DEAD_CODE, - symbol=func_name, - suggestion="Consider removing this unused function or documenting why it's needed" - )) - + self.add_issue( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"Unused function: {func_name}", + severity=IssueSeverity.WARNING, + category=IssueCategory.DEAD_CODE, + symbol=func_name, + suggestion="Consider removing this unused function or documenting why it's needed", + ) + ) + # Find unused classes - if hasattr(self.base_codebase, 'classes'): + if hasattr(self.base_codebase, "classes"): for cls in self.base_codebase.classes: # Skip test files - if hasattr(cls, 'file') and hasattr(cls.file, 'filepath') and "test" in cls.file.filepath: + if ( + hasattr(cls, "file") + and hasattr(cls.file, "filepath") + and "test" in cls.file.filepath + ): continue - + # Check if class has no usages - has_usages = hasattr(cls, 'usages') and len(cls.usages) > 0 - + has_usages = hasattr(cls, "usages") and len(cls.usages) > 0 + if not has_usages: # Get file path and name safely - file_path = cls.file.filepath if hasattr(cls, 'file') and hasattr(cls.file, 'filepath') else "unknown" - cls_name = cls.name if hasattr(cls, 'name') else str(cls) - + file_path = ( + cls.file.filepath + if hasattr(cls, "file") and hasattr(cls.file, "filepath") + else "unknown" + ) + cls_name = cls.name if hasattr(cls, "name") else str(cls) + # Add to dead code list dead_code["unused_classes"].append({ "name": cls_name, "file": file_path, - "line": cls.line if hasattr(cls, 'line') else None + "line": cls.line if hasattr(cls, "line") else None, }) - + # Add issue - self.add_issue(Issue( - file=file_path, - line=cls.line if hasattr(cls, 'line') else None, - message=f"Unused class: {cls_name}", - severity=IssueSeverity.WARNING, - category=IssueCategory.DEAD_CODE, - symbol=cls_name, - suggestion="Consider removing this unused class or documenting why it's needed" - )) - + self.add_issue( + Issue( + file=file_path, + line=cls.line if hasattr(cls, "line") else None, + message=f"Unused class: {cls_name}", + severity=IssueSeverity.WARNING, + category=IssueCategory.DEAD_CODE, + symbol=cls_name, + suggestion="Consider removing this unused class or documenting why it's needed", + ) + ) + # Find unused variables - if hasattr(self.base_codebase, 'functions'): + if hasattr(self.base_codebase, "functions"): for func in self.base_codebase.functions: - if not hasattr(func, 'code_block') or not hasattr(func.code_block, 'local_var_assignments'): + if not hasattr(func, "code_block") or not hasattr( + func.code_block, "local_var_assignments" + ): continue - + for var_assignment in func.code_block.local_var_assignments: # Check if variable has no usages - has_usages = hasattr(var_assignment, 'local_usages') and len(var_assignment.local_usages) > 0 - + has_usages = ( + hasattr(var_assignment, "local_usages") + and len(var_assignment.local_usages) > 0 + ) + if not has_usages: # Get file path and name safely - file_path = func.file.filepath if hasattr(func, 'file') and hasattr(func.file, 'filepath') else "unknown" - var_name = var_assignment.name if hasattr(var_assignment, 'name') else str(var_assignment) - + file_path = ( + func.file.filepath + if hasattr(func, "file") and hasattr(func.file, "filepath") + else "unknown" + ) + var_name = ( + var_assignment.name + if hasattr(var_assignment, "name") + else str(var_assignment) + ) + # Add to dead code list dead_code["unused_variables"].append({ "name": var_name, "file": file_path, - "line": var_assignment.line if hasattr(var_assignment, 'line') else None + "line": var_assignment.line + if hasattr(var_assignment, "line") + else None, }) - + # Add issue - self.add_issue(Issue( - file=file_path, - line=var_assignment.line if hasattr(var_assignment, 'line') else None, - message=f"Unused variable: {var_name}", - severity=IssueSeverity.INFO, - category=IssueCategory.DEAD_CODE, - symbol=var_name, - suggestion="Consider removing this unused variable" - )) - + self.add_issue( + Issue( + file=file_path, + line=var_assignment.line + if hasattr(var_assignment, "line") + else None, + message=f"Unused variable: {var_name}", + severity=IssueSeverity.INFO, + category=IssueCategory.DEAD_CODE, + symbol=var_name, + suggestion="Consider removing this unused variable", + ) + ) + # Summarize findings dead_code["summary"] = { "unused_functions_count": len(dead_code["unused_functions"]), @@ -207,19 +270,19 @@ def _find_dead_code(self) -> Dict[str, Any]: "unused_variables_count": len(dead_code["unused_variables"]), "unused_imports_count": len(dead_code["unused_imports"]), "total_dead_code_count": ( - len(dead_code["unused_functions"]) + - len(dead_code["unused_classes"]) + - len(dead_code["unused_variables"]) + - len(dead_code["unused_imports"]) - ) + len(dead_code["unused_functions"]) + + len(dead_code["unused_classes"]) + + len(dead_code["unused_variables"]) + + len(dead_code["unused_imports"]) + ), } - + return dead_code - - def _analyze_code_complexity(self) -> Dict[str, Any]: + + def _analyze_code_complexity(self) -> dict[str, Any]: """ Analyze code complexity. - + Returns: Dictionary containing complexity analysis results """ @@ -231,39 +294,43 @@ def _analyze_code_complexity(self) -> Dict[str, Any]: "low": 0, "medium": 0, "high": 0, - "very_high": 0 - } + "very_high": 0, + }, } - + # Process all functions to calculate complexity total_complexity = 0 function_count = 0 - - if hasattr(self.base_codebase, 'functions'): + + if hasattr(self.base_codebase, "functions"): for func in self.base_codebase.functions: # Skip if no code block - if not hasattr(func, 'code_block'): + if not hasattr(func, "code_block"): continue - + # Calculate cyclomatic complexity complexity = self._calculate_cyclomatic_complexity(func) - + # Get file path and name safely - file_path = func.file.filepath if hasattr(func, 'file') and hasattr(func.file, 'filepath') else "unknown" - func_name = func.name if hasattr(func, 'name') else str(func) - + file_path = ( + func.file.filepath + if hasattr(func, "file") and hasattr(func.file, "filepath") + else "unknown" + ) + func_name = func.name if hasattr(func, "name") else str(func) + # Add to complexity list complexity_result["function_complexity"].append({ "name": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, - "complexity": complexity + "line": func.line if hasattr(func, "line") else None, + "complexity": complexity, }) - + # Track total complexity total_complexity += complexity function_count += 1 - + # Categorize complexity if complexity <= 5: complexity_result["complexity_distribution"]["low"] += 1 @@ -273,92 +340,104 @@ def _analyze_code_complexity(self) -> Dict[str, Any]: complexity_result["complexity_distribution"]["high"] += 1 else: complexity_result["complexity_distribution"]["very_high"] += 1 - + # Flag high complexity functions if complexity > 10: complexity_result["high_complexity_functions"].append({ "name": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, - "complexity": complexity + "line": func.line if hasattr(func, "line") else None, + "complexity": complexity, }) - + # Add issue - severity = IssueSeverity.WARNING if complexity <= 15 else IssueSeverity.ERROR - self.add_issue(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"High cyclomatic complexity: {complexity}", - severity=severity, - category=IssueCategory.COMPLEXITY, - symbol=func_name, - suggestion="Consider refactoring this function to reduce complexity" - )) - + severity = ( + IssueSeverity.WARNING + if complexity <= 15 + else IssueSeverity.ERROR + ) + self.add_issue( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"High cyclomatic complexity: {complexity}", + severity=severity, + category=IssueCategory.COMPLEXITY, + symbol=func_name, + suggestion="Consider refactoring this function to reduce complexity", + ) + ) + # Calculate average complexity - complexity_result["average_complexity"] = total_complexity / function_count if function_count > 0 else 0.0 - + complexity_result["average_complexity"] = ( + total_complexity / function_count if function_count > 0 else 0.0 + ) + # Sort high complexity functions by complexity - complexity_result["high_complexity_functions"].sort(key=lambda x: x["complexity"], reverse=True) - + complexity_result["high_complexity_functions"].sort( + key=lambda x: x["complexity"], reverse=True + ) + return complexity_result - + def _calculate_cyclomatic_complexity(self, function) -> int: """ Calculate cyclomatic complexity for a function. - + Args: function: Function to analyze - + Returns: Cyclomatic complexity score """ complexity = 1 # Base complexity - + def analyze_statement(statement): nonlocal complexity - + # Check for if statements (including elif branches) - if hasattr(statement, 'if_clause'): + if hasattr(statement, "if_clause"): complexity += 1 - + # Count elif branches - if hasattr(statement, 'elif_statements'): + if hasattr(statement, "elif_statements"): complexity += len(statement.elif_statements) - + # Count else branches - if hasattr(statement, 'else_clause') and statement.else_clause: + if hasattr(statement, "else_clause") and statement.else_clause: complexity += 1 - + # Count for loops - if hasattr(statement, 'is_for_loop') and statement.is_for_loop: + if hasattr(statement, "is_for_loop") and statement.is_for_loop: complexity += 1 - + # Count while loops - if hasattr(statement, 'is_while_loop') and statement.is_while_loop: + if hasattr(statement, "is_while_loop") and statement.is_while_loop: complexity += 1 - + # Count try/except blocks (each except adds a path) - if hasattr(statement, 'is_try_block') and statement.is_try_block: - if hasattr(statement, 'except_clauses'): + if hasattr(statement, "is_try_block") and statement.is_try_block: + if hasattr(statement, "except_clauses"): complexity += len(statement.except_clauses) - + # Recursively process nested statements - if hasattr(statement, 'statements'): + if hasattr(statement, "statements"): for nested_stmt in statement.statements: analyze_statement(nested_stmt) - + # Process all statements in the function's code block - if hasattr(function, 'code_block') and hasattr(function.code_block, 'statements'): + if hasattr(function, "code_block") and hasattr( + function.code_block, "statements" + ): for statement in function.code_block.statements: analyze_statement(statement) - + return complexity - - def _check_style_issues(self) -> Dict[str, Any]: + + def _check_style_issues(self) -> dict[str, Any]: """ Check for code style issues. - + Returns: Dictionary containing style issues analysis results """ @@ -369,54 +448,64 @@ def _check_style_issues(self) -> Dict[str, Any]: "summary": { "long_functions_count": 0, "long_lines_count": 0, - "inconsistent_naming_count": 0 - } + "inconsistent_naming_count": 0, + }, } - + # Check for long functions (too many lines) - if hasattr(self.base_codebase, 'functions'): + if hasattr(self.base_codebase, "functions"): for func in self.base_codebase.functions: # Get function code - if hasattr(func, 'code_block') and hasattr(func.code_block, 'source'): + if hasattr(func, "code_block") and hasattr(func.code_block, "source"): code = func.code_block.source - lines = code.split('\n') - + lines = code.split("\n") + # Check function length if len(lines) > 50: # Threshold for "too long" # Get file path and name safely - file_path = func.file.filepath if hasattr(func, 'file') and hasattr(func.file, 'filepath') else "unknown" - func_name = func.name if hasattr(func, 'name') else str(func) - + file_path = ( + func.file.filepath + if hasattr(func, "file") and hasattr(func.file, "filepath") + else "unknown" + ) + func_name = func.name if hasattr(func, "name") else str(func) + # Add to long functions list style_result["long_functions"].append({ "name": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, - "line_count": len(lines) + "line": func.line if hasattr(func, "line") else None, + "line_count": len(lines), }) - + # Add issue - self.add_issue(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Long function: {len(lines)} lines", - severity=IssueSeverity.INFO, - category=IssueCategory.STYLE_ISSUE, - symbol=func_name, - suggestion="Consider breaking this function into smaller, more focused functions" - )) - + self.add_issue( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"Long function: {len(lines)} lines", + severity=IssueSeverity.INFO, + category=IssueCategory.STYLE_ISSUE, + symbol=func_name, + suggestion="Consider breaking this function into smaller, more focused functions", + ) + ) + # Update summary - style_result["summary"]["long_functions_count"] = len(style_result["long_functions"]) + style_result["summary"]["long_functions_count"] = len( + style_result["long_functions"] + ) style_result["summary"]["long_lines_count"] = len(style_result["long_lines"]) - style_result["summary"]["inconsistent_naming_count"] = len(style_result["inconsistent_naming"]) - + style_result["summary"]["inconsistent_naming_count"] = len( + style_result["inconsistent_naming"] + ) + return style_result - - def _calculate_maintainability(self) -> Dict[str, Any]: + + def _calculate_maintainability(self) -> dict[str, Any]: """ Calculate maintainability metrics. - + Returns: Dictionary containing maintainability analysis results """ @@ -424,107 +513,129 @@ def _calculate_maintainability(self) -> Dict[str, Any]: "function_maintainability": [], "low_maintainability_functions": [], "average_maintainability": 0.0, - "maintainability_distribution": { - "high": 0, - "medium": 0, - "low": 0 - } + "maintainability_distribution": {"high": 0, "medium": 0, "low": 0}, } - + # Process all functions to calculate maintainability total_maintainability = 0 function_count = 0 - - if hasattr(self.base_codebase, 'functions'): + + if hasattr(self.base_codebase, "functions"): for func in self.base_codebase.functions: # Skip if no code block - if not hasattr(func, 'code_block'): + if not hasattr(func, "code_block"): continue - + # Calculate metrics complexity = self._calculate_cyclomatic_complexity(func) - + # Calculate Halstead volume (approximation) operators = 0 operands = 0 - - if hasattr(func, 'code_block') and hasattr(func.code_block, 'source'): + + if hasattr(func, "code_block") and hasattr(func.code_block, "source"): code = func.code_block.source # Simple approximation of operators and operands - operators = len([c for c in code if c in '+-*/=<>!&|^~%']) + operators = len([c for c in code if c in "+-*/=<>!&|^~%"]) # Counting words as potential operands import re - operands = len(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', code)) - - halstead_volume = operators * operands * math.log2(operators + operands) if operators + operands > 0 else 0 - + + operands = len(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", code)) + + halstead_volume = ( + operators * operands * math.log2(operators + operands) + if operators + operands > 0 + else 0 + ) + # Count lines of code - loc = len(func.code_block.source.split('\n')) if hasattr(func, 'code_block') and hasattr(func.code_block, 'source') else 0 - + loc = ( + len(func.code_block.source.split("\n")) + if hasattr(func, "code_block") + and hasattr(func.code_block, "source") + else 0 + ) + # Calculate maintainability index # Formula: 171 - 5.2 * ln(Halstead Volume) - 0.23 * (Cyclomatic Complexity) - 16.2 * ln(LOC) - halstead_term = 5.2 * math.log(max(1, halstead_volume)) if halstead_volume > 0 else 0 + halstead_term = ( + 5.2 * math.log(max(1, halstead_volume)) + if halstead_volume > 0 + else 0 + ) complexity_term = 0.23 * complexity loc_term = 16.2 * math.log(max(1, loc)) if loc > 0 else 0 - + maintainability = 171 - halstead_term - complexity_term - loc_term - + # Normalize to 0-100 scale maintainability = max(0, min(100, maintainability * 100 / 171)) - + # Get file path and name safely - file_path = func.file.filepath if hasattr(func, 'file') and hasattr(func.file, 'filepath') else "unknown" - func_name = func.name if hasattr(func, 'name') else str(func) - + file_path = ( + func.file.filepath + if hasattr(func, "file") and hasattr(func.file, "filepath") + else "unknown" + ) + func_name = func.name if hasattr(func, "name") else str(func) + # Add to maintainability list maintainability_result["function_maintainability"].append({ "name": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, + "line": func.line if hasattr(func, "line") else None, "maintainability": maintainability, "complexity": complexity, "halstead_volume": halstead_volume, - "loc": loc + "loc": loc, }) - + # Track total maintainability total_maintainability += maintainability function_count += 1 - + # Categorize maintainability if maintainability >= 70: maintainability_result["maintainability_distribution"]["high"] += 1 elif maintainability >= 50: - maintainability_result["maintainability_distribution"]["medium"] += 1 + maintainability_result["maintainability_distribution"][ + "medium" + ] += 1 else: maintainability_result["maintainability_distribution"]["low"] += 1 - + # Flag low maintainability functions maintainability_result["low_maintainability_functions"].append({ "name": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, + "line": func.line if hasattr(func, "line") else None, "maintainability": maintainability, "complexity": complexity, "halstead_volume": halstead_volume, - "loc": loc + "loc": loc, }) - + # Add issue - self.add_issue(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Low maintainability index: {maintainability:.1f}", - severity=IssueSeverity.WARNING, - category=IssueCategory.COMPLEXITY, - symbol=func_name, - suggestion="Consider refactoring this function to improve maintainability" - )) - + self.add_issue( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"Low maintainability index: {maintainability:.1f}", + severity=IssueSeverity.WARNING, + category=IssueCategory.COMPLEXITY, + symbol=func_name, + suggestion="Consider refactoring this function to improve maintainability", + ) + ) + # Calculate average maintainability - maintainability_result["average_maintainability"] = total_maintainability / function_count if function_count > 0 else 0.0 - + maintainability_result["average_maintainability"] = ( + total_maintainability / function_count if function_count > 0 else 0.0 + ) + # Sort low maintainability functions - maintainability_result["low_maintainability_functions"].sort(key=lambda x: x["maintainability"]) - - return maintainability_result \ No newline at end of file + maintainability_result["low_maintainability_functions"].sort( + key=lambda x: x["maintainability"] + ) + + return maintainability_result diff --git a/codegen-on-oss/codegen_on_oss/analyzers/codebase_analyzer.py b/codegen-on-oss/codegen_on_oss/analyzers/codebase_analyzer.py index c555e44fd..ce78f235b 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/codebase_analyzer.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/codebase_analyzer.py @@ -7,38 +7,43 @@ issues, errors, and quality problems. """ -import os -import sys +import argparse import json -import time import logging -import argparse +import sys import tempfile -import networkx as nx -from datetime import datetime -from pathlib import Path -from typing import Dict, List, Set, Tuple, Any, Optional, Union, TypeVar, cast from dataclasses import dataclass +from datetime import datetime from enum import Enum +from typing import Any, cast + +import networkx as nx try: - from codegen.sdk.core.codebase import Codebase from codegen.configs.models.codebase import CodebaseConfig from codegen.configs.models.secrets import SecretsConfig - from codegen.sdk.codebase.config import ProjectConfig - from codegen.git.schemas.repo_config import RepoConfig from codegen.git.repo_operator.repo_operator import RepoOperator - from codegen.shared.enums.programming_language import ProgrammingLanguage - from codegen.sdk.codebase.codebase_analysis import get_codebase_summary, get_file_summary + from codegen.git.schemas.repo_config import RepoConfig + from codegen.git.utils.pr_review import CodegenPR + from codegen.sdk.codebase.codebase_analysis import ( + get_codebase_summary, + get_file_summary, + ) + from codegen.sdk.codebase.config import ProjectConfig + from codegen.sdk.core.class_definition import Class + from codegen.sdk.core.codebase import Codebase from codegen.sdk.core.file import SourceFile - from codegen.sdk.enums import EdgeType, SymbolType - from codegen.sdk.core.symbol import Symbol from codegen.sdk.core.function import Function - from codegen.sdk.core.class_definition import Class - from codegen.git.utils.pr_review import CodegenPR + from codegen.sdk.core.symbol import Symbol + from codegen.sdk.enums import EdgeType, SymbolType + from codegen.shared.enums.programming_language import ProgrammingLanguage # Import our custom CodebaseContext - from codegen_on_oss.context_codebase import CodebaseContext, get_node_classes, GLOBAL_FILE_IGNORE_LIST + from codegen_on_oss.context_codebase import ( + GLOBAL_FILE_IGNORE_LIST, + CodebaseContext, + get_node_classes, + ) except ImportError: print("Codegen SDK not found. Please install it first.") sys.exit(1) @@ -46,35 +51,41 @@ # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) + class AnalysisType(str, Enum): """Types of analysis that can be performed.""" + CODEBASE = "codebase" PR = "pr" COMPARISON = "comparison" + class IssueSeverity(str, Enum): """Severity levels for issues.""" + ERROR = "error" WARNING = "warning" INFO = "info" + @dataclass class Issue: """Represents an issue found during analysis.""" + file: str - line: Optional[int] + line: int | None message: str severity: IssueSeverity - symbol: Optional[str] = None - code: Optional[str] = None - suggestion: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: + symbol: str | None = None + code: str | None = None + suggestion: str | None = None + + def to_dict(self) -> dict[str, Any]: return { "file": self.file, "line": self.line, @@ -82,9 +93,10 @@ def to_dict(self) -> Dict[str, Any]: "severity": self.severity, "symbol": self.symbol, "code": self.code, - "suggestion": self.suggestion + "suggestion": self.suggestion, } + class CodebaseAnalyzer: """ Advanced analyzer for codebases and PRs using the Codegen SDK. @@ -100,12 +112,12 @@ class CodebaseAnalyzer: def __init__( self, - repo_url: Optional[str] = None, - repo_path: Optional[str] = None, + repo_url: str | None = None, + repo_path: str | None = None, base_branch: str = "main", - pr_number: Optional[int] = None, - language: Optional[str] = None, - file_ignore_list: Optional[List[str]] = None + pr_number: int | None = None, + language: str | None = None, + file_ignore_list: list[str] | None = None, ): """Initialize the CodebaseAnalyzer. @@ -155,7 +167,7 @@ def __init__( codebase=self.base_codebase, base_path=self.repo_path, pr_branch=None, - base_branch=self.base_branch + base_branch=self.base_branch, ) if self.pr_codebase: @@ -163,60 +175,60 @@ def __init__( codebase=self.pr_codebase, base_path=self.repo_path, pr_branch=self.pr_branch, - base_branch=self.base_branch + base_branch=self.base_branch, ) - - def _init_from_url(self, repo_url: str, language: Optional[str] = None): + + def _init_from_url(self, repo_url: str, language: str | None = None): """Initialize base codebase from a repository URL.""" try: # Extract owner and repo name from URL - if repo_url.endswith('.git'): + if repo_url.endswith(".git"): repo_url = repo_url[:-4] - - parts = repo_url.rstrip('/').split('/') + + parts = repo_url.rstrip("/").split("/") repo_name = parts[-1] owner = parts[-2] repo_full_name = f"{owner}/{repo_name}" - + # Create a temporary directory for cloning tmp_dir = tempfile.mkdtemp(prefix="codebase_analyzer_") - + # Configure the codebase config = CodebaseConfig( debug=False, allow_external=True, py_resolve_syspath=True, ) - + secrets = SecretsConfig() - + # Initialize the codebase logger.info(f"Initializing codebase from {repo_url}...") - + prog_lang = None if language: prog_lang = ProgrammingLanguage(language.upper()) - + # Initialize base codebase self.base_codebase = Codebase.from_github( repo_full_name=repo_full_name, tmp_dir=tmp_dir, language=prog_lang, config=config, - secrets=secrets + secrets=secrets, ) - + logger.info(f"Successfully initialized codebase from {repo_url}") - + # If PR number is specified, also initialize PR codebase if self.pr_number: self._init_pr_codebase() - + except Exception as e: - logger.error(f"Error initializing codebase from URL: {e}") + logger.exception(f"Error initializing codebase from URL: {e}") raise - - def _init_from_path(self, repo_path: str, language: Optional[str] = None): + + def _init_from_path(self, repo_path: str, language: str | None = None): """Initialize codebase from a local repository path.""" try: # Configure the codebase @@ -225,104 +237,104 @@ def _init_from_path(self, repo_path: str, language: Optional[str] = None): allow_external=True, py_resolve_syspath=True, ) - + secrets = SecretsConfig() - + # Initialize the codebase logger.info(f"Initializing codebase from {repo_path}...") - + # Set up programming language prog_lang = None if language: prog_lang = ProgrammingLanguage(language.upper()) - + # Create repo config and repo operator repo_config = RepoConfig.from_repo_path(repo_path) repo_config.respect_gitignore = False repo_operator = RepoOperator(repo_config=repo_config, bot_commit=False) - + # Configure project with repo operator and language project_config = ProjectConfig( repo_operator=repo_operator, - programming_language=prog_lang if prog_lang else None + programming_language=prog_lang if prog_lang else None, ) - + # Initialize codebase with proper project configuration self.base_codebase = Codebase( - projects=[project_config], - config=config, - secrets=secrets + projects=[project_config], config=config, secrets=secrets ) - + logger.info(f"Successfully initialized codebase from {repo_path}") - + # If PR number is specified, also initialize PR codebase if self.pr_number: self._init_pr_codebase() - + except Exception as e: - logger.error(f"Error initializing codebase from path: {e}") + logger.exception(f"Error initializing codebase from path: {e}") raise - + def _init_pr_data(self, pr_number: int): """Initialize PR-specific data.""" try: logger.info(f"Fetching PR #{pr_number} data...") result = self.base_codebase.get_modified_symbols_in_pr(pr_number) - + # Unpack the result tuple if len(result) >= 3: self.pr_diff, self.commit_shas, self.modified_symbols = result[:3] if len(result) >= 4: self.pr_branch = result[3] - + logger.info(f"Found {len(self.modified_symbols)} modified symbols in PR") - + except Exception as e: - logger.error(f"Error initializing PR data: {e}") + logger.exception(f"Error initializing PR data: {e}") raise - + def _init_pr_codebase(self): """Initialize PR codebase by checking out the PR branch.""" if not self.base_codebase or not self.pr_number: logger.error("Base codebase or PR number not initialized") return - + try: # Get PR data if not already fetched if not self.pr_branch: self._init_pr_data(self.pr_number) - + if not self.pr_branch: logger.error("Failed to get PR branch") return - + # Clone the base codebase self.pr_codebase = self.base_codebase - + # Checkout PR branch logger.info(f"Checking out PR branch: {self.pr_branch}") self.pr_codebase.checkout(self.pr_branch) - + logger.info("Successfully initialized PR codebase") - + except Exception as e: - logger.error(f"Error initializing PR codebase: {e}") + logger.exception(f"Error initializing PR codebase: {e}") raise - - def analyze(self, analysis_type: AnalysisType = AnalysisType.CODEBASE) -> Dict[str, Any]: + + def analyze( + self, analysis_type: AnalysisType = AnalysisType.CODEBASE + ) -> dict[str, Any]: """ Perform a comprehensive analysis of the codebase or PR. - + Args: analysis_type: Type of analysis to perform (codebase, pr, or comparison) - + Returns: Dict containing the analysis results """ if not self.base_codebase: raise ValueError("Codebase not initialized") - + result = { "metadata": { "analysis_time": datetime.now().isoformat(), @@ -332,45 +344,55 @@ def analyze(self, analysis_type: AnalysisType = AnalysisType.CODEBASE) -> Dict[s }, "summary": get_codebase_summary(self.base_codebase), } - + # Reset issues list self.issues = [] - + if analysis_type == AnalysisType.CODEBASE: # Perform static analysis on base codebase logger.info("Performing static analysis on codebase...") - result["static_analysis"] = self._perform_static_analysis(self.base_codebase) - + result["static_analysis"] = self._perform_static_analysis( + self.base_codebase + ) + elif analysis_type == AnalysisType.PR: # Analyze PR changes if not self.pr_number: raise ValueError("PR number not provided") - + logger.info(f"Analyzing PR #{self.pr_number}...") result["pr_analysis"] = self._analyze_pr() - + elif analysis_type == AnalysisType.COMPARISON: # Compare base codebase with PR if not self.pr_codebase: raise ValueError("PR codebase not initialized") - + logger.info("Comparing base codebase with PR...") result["comparison"] = self._compare_codebases() - + # Add issues to the result result["issues"] = [issue.to_dict() for issue in self.issues] result["issue_counts"] = { "total": len(self.issues), "by_severity": { - "error": sum(1 for issue in self.issues if issue.severity == IssueSeverity.ERROR), - "warning": sum(1 for issue in self.issues if issue.severity == IssueSeverity.WARNING), - "info": sum(1 for issue in self.issues if issue.severity == IssueSeverity.INFO), - } + "error": sum( + 1 for issue in self.issues if issue.severity == IssueSeverity.ERROR + ), + "warning": sum( + 1 + for issue in self.issues + if issue.severity == IssueSeverity.WARNING + ), + "info": sum( + 1 for issue in self.issues if issue.severity == IssueSeverity.INFO + ), + }, } - + return result - - def _perform_static_analysis(self, codebase: Codebase) -> Dict[str, Any]: + + def _perform_static_analysis(self, codebase: Codebase) -> dict[str, Any]: """ Perform static analysis on a codebase using the CodebaseContext for deep graph-based analysis. @@ -414,10 +436,14 @@ def _perform_static_analysis(self, codebase: Codebase) -> Dict[str, Any]: # Add graph-based analysis if context is available if context: # Analyze dependency chains - analysis_result["dependency_chains"] = self._analyze_dependency_chains(context) + analysis_result["dependency_chains"] = self._analyze_dependency_chains( + context + ) # Analyze circular dependencies - analysis_result["circular_dependencies"] = self._find_circular_dependencies(context) + analysis_result["circular_dependencies"] = self._find_circular_dependencies( + context + ) # Analyze module coupling analysis_result["module_coupling"] = self._analyze_module_coupling(context) @@ -427,16 +453,13 @@ def _perform_static_analysis(self, codebase: Codebase) -> Dict[str, Any]: return analysis_result - def _analyze_dependency_chains(self, context: CodebaseContext) -> Dict[str, Any]: + def _analyze_dependency_chains(self, context: CodebaseContext) -> dict[str, Any]: """Analyze dependency chains in the codebase.""" - result = { - "long_chains": [], - "critical_paths": [] - } + result = {"long_chains": [], "critical_paths": []} # Find long dependency chains for node in context.nodes: - if not hasattr(node, 'name'): + if not hasattr(node, "name"): continue # Skip non-symbol nodes @@ -452,7 +475,7 @@ def _analyze_dependency_chains(self, context: CodebaseContext) -> Dict[str, Any] # Find paths paths = [] for target in symbol_nodes: - if node != target and hasattr(target, 'name'): + if node != target and hasattr(target, "name"): try: path = nx.shortest_path(subgraph, node, target) if len(path) > 3: # Only track paths with at least 3 edges @@ -465,12 +488,18 @@ def _analyze_dependency_chains(self, context: CodebaseContext) -> Dict[str, Any] if paths and len(paths[0]) > 3: path_info = { "source": node.name, - "targets": [paths[0][-1].name if hasattr(paths[0][-1], 'name') else str(paths[0][-1])], + "targets": [ + paths[0][-1].name + if hasattr(paths[0][-1], "name") + else str(paths[0][-1]) + ], "length": len(paths[0]), - "path": [n.name if hasattr(n, 'name') else str(n) for n in paths[0]] + "path": [ + n.name if hasattr(n, "name") else str(n) for n in paths[0] + ], } result["long_chains"].append(path_info) - except Exception as e: + except Exception: # Skip errors in graph analysis pass @@ -480,12 +509,9 @@ def _analyze_dependency_chains(self, context: CodebaseContext) -> Dict[str, Any] return result - def _find_circular_dependencies(self, context: CodebaseContext) -> Dict[str, Any]: + def _find_circular_dependencies(self, context: CodebaseContext) -> dict[str, Any]: """Find circular dependencies in the codebase.""" - result = { - "circular_imports": [], - "circular_function_calls": [] - } + result = {"circular_imports": [], "circular_function_calls": []} # Find circular dependencies in the context graph try: @@ -494,54 +520,68 @@ def _find_circular_dependencies(self, context: CodebaseContext) -> Dict[str, Any # Filter and categorize cycles for cycle in cycles: # Check if it's an import cycle - if all(hasattr(node, 'symbol_type') and hasattr(node, 'name') for node in cycle): - cycle_type = "unknown" - + if all( + hasattr(node, "symbol_type") and hasattr(node, "name") + for node in cycle + ): # Check if all nodes in the cycle are files if all(isinstance(node, SourceFile) for node in cycle): - cycle_type = "import" result["circular_imports"].append({ - "files": [node.path if hasattr(node, 'path') else str(node) for node in cycle], - "length": len(cycle) + "files": [ + node.path if hasattr(node, "path") else str(node) + for node in cycle + ], + "length": len(cycle), }) # Check if all nodes in the cycle are functions elif all(isinstance(node, Function) for node in cycle): - cycle_type = "function_call" result["circular_function_calls"].append({ - "functions": [node.name if hasattr(node, 'name') else str(node) for node in cycle], - "length": len(cycle) + "functions": [ + node.name if hasattr(node, "name") else str(node) + for node in cycle + ], + "length": len(cycle), }) # Add as an issue - if len(cycle) > 0 and hasattr(cycle[0], 'file') and hasattr(cycle[0].file, 'file_path'): - self.issues.append(Issue( - file=cycle[0].file.file_path, - line=cycle[0].line if hasattr(cycle[0], 'line') else None, - message=f"Circular function call dependency detected", - severity=IssueSeverity.ERROR, - symbol=cycle[0].name if hasattr(cycle[0], 'name') else str(cycle[0]), - suggestion="Refactor the code to eliminate circular dependencies" - )) - except Exception as e: + if ( + len(cycle) > 0 + and hasattr(cycle[0], "file") + and hasattr(cycle[0].file, "file_path") + ): + self.issues.append( + Issue( + file=cycle[0].file.file_path, + line=cycle[0].line + if hasattr(cycle[0], "line") + else None, + message="Circular function call dependency detected", + severity=IssueSeverity.ERROR, + symbol=cycle[0].name + if hasattr(cycle[0], "name") + else str(cycle[0]), + suggestion="Refactor the code to eliminate circular dependencies", + ) + ) + except Exception: # Skip errors in cycle detection pass return result - def _analyze_module_coupling(self, context: CodebaseContext) -> Dict[str, Any]: + def _analyze_module_coupling(self, context: CodebaseContext) -> dict[str, Any]: """Analyze module coupling in the codebase.""" - result = { - "high_coupling": [], - "low_cohesion": [] - } + result = {"high_coupling": [], "low_cohesion": []} # Create a mapping of files to their dependencies file_dependencies = {} # Iterate over all files - for file_node in [node for node in context.nodes if isinstance(node, SourceFile)]: - if not hasattr(file_node, 'path'): + for file_node in [ + node for node in context.nodes if isinstance(node, SourceFile) + ]: + if not hasattr(file_node, "path"): continue file_path = str(file_node.path) @@ -549,57 +589,61 @@ def _analyze_module_coupling(self, context: CodebaseContext) -> Dict[str, Any]: # Get all outgoing dependencies dependencies = [] for succ in context.successors(file_node): - if isinstance(succ, SourceFile) and hasattr(succ, 'path'): + if isinstance(succ, SourceFile) and hasattr(succ, "path"): dependencies.append(str(succ.path)) # Get all symbols in the file - file_symbols = [node for node in context.nodes if isinstance(node, Symbol) and - hasattr(node, 'file') and hasattr(node.file, 'path') and - str(node.file.path) == file_path] + file_symbols = [ + node + for node in context.nodes + if isinstance(node, Symbol) + and hasattr(node, "file") + and hasattr(node.file, "path") + and str(node.file.path) == file_path + ] # Calculate coupling metrics file_dependencies[file_path] = { "dependencies": dependencies, "dependency_count": len(dependencies), "symbol_count": len(file_symbols), - "coupling_ratio": len(dependencies) / max(1, len(file_symbols)) + "coupling_ratio": len(dependencies) / max(1, len(file_symbols)), } # Identify files with high coupling (many dependencies) high_coupling_files = sorted( file_dependencies.items(), key=lambda x: x[1]["dependency_count"], - reverse=True + reverse=True, )[:10] result["high_coupling"] = [ { "file": file_path, "dependency_count": data["dependency_count"], - "dependencies": data["dependencies"][:5] # Limit to first 5 for brevity + "dependencies": data["dependencies"][ + :5 + ], # Limit to first 5 for brevity } for file_path, data in high_coupling_files - if data["dependency_count"] > 5 # Only include if it has more than 5 dependencies + if data["dependency_count"] + > 5 # Only include if it has more than 5 dependencies ] return result - def _analyze_call_hierarchy(self, context: CodebaseContext) -> Dict[str, Any]: + def _analyze_call_hierarchy(self, context: CodebaseContext) -> dict[str, Any]: """Analyze function call hierarchy in the codebase.""" - result = { - "entry_points": [], - "leaf_functions": [], - "deep_call_chains": [] - } + result = {"entry_points": [], "leaf_functions": [], "deep_call_chains": []} # Find potential entry points (functions not called by others) entry_points = [] for node in context.nodes: - if isinstance(node, Function) and hasattr(node, 'name'): + if isinstance(node, Function) and hasattr(node, "name"): # Check if this function has no incoming CALLS edges has_callers = False - for pred, _, data in context.in_edges(node, data=True): - if 'type' in data and data['type'] == EdgeType.CALLS: + for _pred, _, data in context.in_edges(node, data=True): + if "type" in data and data["type"] == EdgeType.CALLS: has_callers = True break @@ -609,11 +653,11 @@ def _analyze_call_hierarchy(self, context: CodebaseContext) -> Dict[str, Any]: # Find leaf functions (those that don't call other functions) leaf_functions = [] for node in context.nodes: - if isinstance(node, Function) and hasattr(node, 'name'): + if isinstance(node, Function) and hasattr(node, "name"): # Check if this function has no outgoing CALLS edges has_callees = False - for _, succ, data in context.out_edges(node, data=True): - if 'type' in data and data['type'] == EdgeType.CALLS: + for _, _succ, data in context.out_edges(node, data=True): + if "type" in data and data["type"] == EdgeType.CALLS: has_callees = True break @@ -624,7 +668,9 @@ def _analyze_call_hierarchy(self, context: CodebaseContext) -> Dict[str, Any]: result["entry_points"] = [ { "name": func.name, - "file": func.file.file_path if hasattr(func, 'file') and hasattr(func.file, 'file_path') else "unknown" + "file": func.file.file_path + if hasattr(func, "file") and hasattr(func.file, "file_path") + else "unknown", } for func in entry_points[:20] # Limit to 20 for brevity ] @@ -633,7 +679,9 @@ def _analyze_call_hierarchy(self, context: CodebaseContext) -> Dict[str, Any]: result["leaf_functions"] = [ { "name": func.name, - "file": func.file.file_path if hasattr(func, 'file') and hasattr(func.file, 'file_path') else "unknown" + "file": func.file.file_path + if hasattr(func, "file") and hasattr(func.file, "file_path") + else "unknown", } for func in leaf_functions[:20] # Limit to 20 for brevity ] @@ -659,10 +707,12 @@ def _analyze_call_hierarchy(self, context: CodebaseContext) -> Dict[str, Any]: call_chain = { "entry_point": entry_point.name, "length": len(longest_path), - "calls": [func.name for func in longest_path if hasattr(func, 'name')] + "calls": [ + func.name for func in longest_path if hasattr(func, "name") + ], } result["deep_call_chains"].append(call_chain) - except Exception as e: + except Exception: # Skip errors in path finding pass @@ -671,19 +721,19 @@ def _analyze_call_hierarchy(self, context: CodebaseContext) -> Dict[str, Any]: result["deep_call_chains"] = result["deep_call_chains"][:10] return result - - def _analyze_pr(self) -> Dict[str, Any]: + + def _analyze_pr(self) -> dict[str, Any]: """Analyze a PR and find issues.""" if not self.pr_codebase or not self.pr_diff or not self.commit_shas: raise ValueError("PR data not initialized") - + pr_analysis = {} - + # Get modified symbols and files modified_files = set(self.commit_shas.keys()) pr_analysis["modified_files_count"] = len(modified_files) pr_analysis["modified_symbols_count"] = len(self.modified_symbols) - + # Analyze modified files file_issues = [] for file_path in modified_files: @@ -691,24 +741,28 @@ def _analyze_pr(self) -> Dict[str, Any]: if file: # Check file issues self._check_file_issues(file) - + # Add file summary file_issues.append({ "file": file_path, - "issues": [issue.to_dict() for issue in self.issues if issue.file == file_path] + "issues": [ + issue.to_dict() + for issue in self.issues + if issue.file == file_path + ], }) - + pr_analysis["file_issues"] = file_issues - + # Perform targeted static analysis on modified symbols new_func_count = 0 modified_func_count = 0 - + for symbol_name in self.modified_symbols: symbol = self.pr_codebase.get_symbol(symbol_name) if not symbol: continue - + # Check if function is new or modified if symbol.symbol_type == SymbolType.Function: # Try to find in base codebase @@ -720,17 +774,17 @@ def _analyze_pr(self) -> Dict[str, Any]: modified_func_count += 1 except: new_func_count += 1 - + # Check function for issues func = cast(Function, symbol) self._check_function_for_issues(func) - + pr_analysis["new_functions"] = new_func_count pr_analysis["modified_functions"] = modified_func_count - + return pr_analysis - - def _compare_codebases(self) -> Dict[str, Any]: + + def _compare_codebases(self) -> dict[str, Any]: """ Compare base codebase with PR codebase using advanced CodebaseContext. @@ -741,13 +795,15 @@ def _compare_codebases(self) -> Dict[str, Any]: raise ValueError("Both base and PR codebases must be initialized") if not self.base_context or not self.pr_context: - raise ValueError("Both base and PR CodebaseContext objects must be initialized") + raise ValueError( + "Both base and PR CodebaseContext objects must be initialized" + ) comparison = { "graph_analysis": {}, "structure_changes": {}, "dependency_changes": {}, - "api_changes": {} + "api_changes": {}, } # Compare graph structures using CodebaseContext @@ -760,10 +816,13 @@ def _compare_codebases(self) -> Dict[str, Any]: pr_only_nodes = [] for base_node in base_nodes: - if hasattr(base_node, 'name'): + if hasattr(base_node, "name"): node_name = base_node.name # Look for matching node in PR - pr_node = next((n for n in pr_nodes if hasattr(n, 'name') and n.name == node_name), None) + pr_node = next( + (n for n in pr_nodes if hasattr(n, "name") and n.name == node_name), + None, + ) if pr_node: common_nodes.append((base_node, pr_node)) @@ -772,22 +831,24 @@ def _compare_codebases(self) -> Dict[str, Any]: # Find PR-only nodes for pr_node in pr_nodes: - if hasattr(pr_node, 'name'): + if hasattr(pr_node, "name"): node_name = pr_node.name # Check if it already exists in base - if not any(hasattr(n, 'name') and n.name == node_name for n in base_nodes): + if not any( + hasattr(n, "name") and n.name == node_name for n in base_nodes + ): pr_only_nodes.append(pr_node) # Add graph analysis results comparison["graph_analysis"] = { "common_node_count": len(common_nodes), "base_only_node_count": len(base_only_nodes), - "pr_only_node_count": len(pr_only_nodes) + "pr_only_node_count": len(pr_only_nodes), } # Compare dependencies using graph edges - base_edges = list(self.base_context.edges(data=True)) - pr_edges = list(self.pr_context.edges(data=True)) + list(self.base_context.edges(data=True)) + list(self.pr_context.edges(data=True)) # Analyze dependency changes removed_dependencies = [] @@ -813,9 +874,14 @@ def _compare_codebases(self) -> Dict[str, Any]: } # For functions, compare parameters - if (base_symbol and hasattr(base_symbol, 'symbol_type') and base_symbol.symbol_type == SymbolType.Function and - pr_symbol and hasattr(pr_symbol, 'symbol_type') and pr_symbol.symbol_type == SymbolType.Function): - + if ( + base_symbol + and hasattr(base_symbol, "symbol_type") + and base_symbol.symbol_type == SymbolType.Function + and pr_symbol + and hasattr(pr_symbol, "symbol_type") + and pr_symbol.symbol_type == SymbolType.Function + ): base_func = cast(Function, base_symbol) pr_func = cast(Function, pr_symbol) @@ -825,115 +891,185 @@ def _compare_codebases(self) -> Dict[str, Any]: # Analyze dependency changes for this function for dep in base_dependencies: - if hasattr(dep, 'name') and not any(hasattr(d, 'name') and d.name == dep.name for d in pr_dependencies): + if hasattr(dep, "name") and not any( + hasattr(d, "name") and d.name == dep.name + for d in pr_dependencies + ): removed_dependencies.append((base_func.name, dep.name)) for dep in pr_dependencies: - if hasattr(dep, 'name') and not any(hasattr(d, 'name') and d.name == dep.name for d in base_dependencies): + if hasattr(dep, "name") and not any( + hasattr(d, "name") and d.name == dep.name + for d in base_dependencies + ): added_dependencies.append((pr_func.name, dep.name)) - + # Compare parameter counts base_params = list(base_func.parameters) pr_params = list(pr_func.parameters) - + param_changes = [] removed_params = [] added_params = [] - + # Find removed parameters for base_param in base_params: - if not any(pr_param.name == base_param.name for pr_param in pr_params if hasattr(pr_param, 'name')): - removed_params.append(base_param.name if hasattr(base_param, 'name') else str(base_param)) - + if not any( + pr_param.name == base_param.name + for pr_param in pr_params + if hasattr(pr_param, "name") + ): + removed_params.append( + base_param.name + if hasattr(base_param, "name") + else str(base_param) + ) + # Find added parameters for pr_param in pr_params: - if not any(base_param.name == pr_param.name for base_param in base_params if hasattr(base_param, 'name')): - added_params.append(pr_param.name if hasattr(pr_param, 'name') else str(pr_param)) - + if not any( + base_param.name == pr_param.name + for base_param in base_params + if hasattr(base_param, "name") + ): + added_params.append( + pr_param.name + if hasattr(pr_param, "name") + else str(pr_param) + ) + symbol_comparison["parameter_changes"] = { "removed": removed_params, - "added": added_params + "added": added_params, } - + # Check for parameter type changes for base_param in base_params: for pr_param in pr_params: - if (hasattr(base_param, 'name') and hasattr(pr_param, 'name') and - base_param.name == pr_param.name): - - base_type = str(base_param.type) if hasattr(base_param, 'type') and base_param.type else None - pr_type = str(pr_param.type) if hasattr(pr_param, 'type') and pr_param.type else None - + if ( + hasattr(base_param, "name") + and hasattr(pr_param, "name") + and base_param.name == pr_param.name + ): + base_type = ( + str(base_param.type) + if hasattr(base_param, "type") and base_param.type + else None + ) + pr_type = ( + str(pr_param.type) + if hasattr(pr_param, "type") and pr_param.type + else None + ) + if base_type != pr_type: param_changes.append({ "param": base_param.name, "old_type": base_type, - "new_type": pr_type + "new_type": pr_type, }) - + if param_changes: symbol_comparison["type_changes"] = param_changes - + # Check if return type changed - base_return_type = str(base_func.return_type) if hasattr(base_func, 'return_type') and base_func.return_type else None - pr_return_type = str(pr_func.return_type) if hasattr(pr_func, 'return_type') and pr_func.return_type else None - + base_return_type = ( + str(base_func.return_type) + if hasattr(base_func, "return_type") and base_func.return_type + else None + ) + pr_return_type = ( + str(pr_func.return_type) + if hasattr(pr_func, "return_type") and pr_func.return_type + else None + ) + if base_return_type != pr_return_type: symbol_comparison["return_type_change"] = { "old": base_return_type, - "new": pr_return_type + "new": pr_return_type, } - + # Check call site compatibility - if hasattr(base_func, 'call_sites') and hasattr(pr_func, 'call_sites'): + if hasattr(base_func, "call_sites") and hasattr( + pr_func, "call_sites" + ): base_call_sites = list(base_func.call_sites) call_site_issues = [] - + # For each call site in base, check if it's still compatible with PR function for call_site in base_call_sites: - if len(removed_params) > 0 and not all(param.has_default for param in base_params if hasattr(param, 'name') and param.name in removed_params): + if len(removed_params) > 0 and not all( + param.has_default + for param in base_params + if hasattr(param, "name") + and param.name in removed_params + ): # Required parameter was removed - file_path = call_site.file.file_path if hasattr(call_site, 'file') and hasattr(call_site.file, 'file_path') else "unknown" - line = call_site.line if hasattr(call_site, 'line') else None - + file_path = ( + call_site.file.file_path + if hasattr(call_site, "file") + and hasattr(call_site.file, "file_path") + else "unknown" + ) + line = ( + call_site.line + if hasattr(call_site, "line") + else None + ) + call_site_issues.append({ "file": file_path, "line": line, - "issue": "Required parameter was removed, call site may be broken" + "issue": "Required parameter was removed, call site may be broken", }) - + # Add issue - self.issues.append(Issue( - file=file_path, - line=line, - message=f"Call to {symbol_name} may be broken due to signature change", - severity=IssueSeverity.ERROR, - symbol=symbol_name, - suggestion="Update call site to match new function signature" - )) - + self.issues.append( + Issue( + file=file_path, + line=line, + message=f"Call to {symbol_name} may be broken due to signature change", + severity=IssueSeverity.ERROR, + symbol=symbol_name, + suggestion="Update call site to match new function signature", + ) + ) + if call_site_issues: symbol_comparison["call_site_issues"] = call_site_issues - + detailed_comparison.append(symbol_comparison) - + comparison["symbol_comparison"] = detailed_comparison - + # Compare overall codebase stats base_stats = { "files": len(list(self.base_codebase.files)), - "functions": len(list(self.base_codebase.functions)) if hasattr(self.base_codebase, 'functions') else 0, - "classes": len(list(self.base_codebase.classes)) if hasattr(self.base_codebase, 'classes') else 0, - "imports": len(list(self.base_codebase.imports)) if hasattr(self.base_codebase, 'imports') else 0, + "functions": len(list(self.base_codebase.functions)) + if hasattr(self.base_codebase, "functions") + else 0, + "classes": len(list(self.base_codebase.classes)) + if hasattr(self.base_codebase, "classes") + else 0, + "imports": len(list(self.base_codebase.imports)) + if hasattr(self.base_codebase, "imports") + else 0, } - + pr_stats = { "files": len(list(self.pr_codebase.files)), - "functions": len(list(self.pr_codebase.functions)) if hasattr(self.pr_codebase, 'functions') else 0, - "classes": len(list(self.pr_codebase.classes)) if hasattr(self.pr_codebase, 'classes') else 0, - "imports": len(list(self.pr_codebase.imports)) if hasattr(self.pr_codebase, 'imports') else 0, + "functions": len(list(self.pr_codebase.functions)) + if hasattr(self.pr_codebase, "functions") + else 0, + "classes": len(list(self.pr_codebase.classes)) + if hasattr(self.pr_codebase, "classes") + else 0, + "imports": len(list(self.pr_codebase.imports)) + if hasattr(self.pr_codebase, "imports") + else 0, } - + comparison["stats_comparison"] = { "base": base_stats, "pr": pr_stats, @@ -942,412 +1078,503 @@ def _compare_codebases(self) -> Dict[str, Any]: "functions": pr_stats["functions"] - base_stats["functions"], "classes": pr_stats["classes"] - base_stats["classes"], "imports": pr_stats["imports"] - base_stats["imports"], - } + }, } - + return comparison - - def _find_dead_code(self, codebase: Codebase) -> Dict[str, Any]: + + def _find_dead_code(self, codebase: Codebase) -> dict[str, Any]: """Find unused code (dead code) in the codebase.""" dead_code = { "unused_functions": [], "unused_classes": [], "unused_variables": [], - "unused_imports": [] + "unused_imports": [], } - + # Find unused functions (no call sites) - if hasattr(codebase, 'functions'): + if hasattr(codebase, "functions"): for func in codebase.functions: - if not hasattr(func, 'call_sites'): + if not hasattr(func, "call_sites"): continue - + if len(func.call_sites) == 0: # Skip magic methods and main functions - if (hasattr(func, 'is_magic') and func.is_magic) or (hasattr(func, 'name') and func.name in ['main', '__main__']): + if (hasattr(func, "is_magic") and func.is_magic) or ( + hasattr(func, "name") and func.name in ["main", "__main__"] + ): continue - + # Get file and name safely - file_path = func.file.file_path if hasattr(func, 'file') and hasattr(func.file, 'file_path') else "unknown" - func_name = func.name if hasattr(func, 'name') else str(func) - + file_path = ( + func.file.file_path + if hasattr(func, "file") and hasattr(func.file, "file_path") + else "unknown" + ) + func_name = func.name if hasattr(func, "name") else str(func) + # Add to dead code list and issues dead_code["unused_functions"].append({ "name": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None + "line": func.line if hasattr(func, "line") else None, }) - - self.issues.append(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Unused function: {func_name}", - severity=IssueSeverity.WARNING, - symbol=func_name, - suggestion="Consider removing or using this function" - )) - + + self.issues.append( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"Unused function: {func_name}", + severity=IssueSeverity.WARNING, + symbol=func_name, + suggestion="Consider removing or using this function", + ) + ) + # Find unused classes (no symbol usages) - if hasattr(codebase, 'classes'): + if hasattr(codebase, "classes"): for cls in codebase.classes: - if not hasattr(cls, 'symbol_usages'): + if not hasattr(cls, "symbol_usages"): continue - + if len(cls.symbol_usages) == 0: # Get file and name safely - file_path = cls.file.file_path if hasattr(cls, 'file') and hasattr(cls.file, 'file_path') else "unknown" - cls_name = cls.name if hasattr(cls, 'name') else str(cls) - + file_path = ( + cls.file.file_path + if hasattr(cls, "file") and hasattr(cls.file, "file_path") + else "unknown" + ) + cls_name = cls.name if hasattr(cls, "name") else str(cls) + # Add to dead code list and issues dead_code["unused_classes"].append({ "name": cls_name, "file": file_path, - "line": cls.line if hasattr(cls, 'line') else None + "line": cls.line if hasattr(cls, "line") else None, }) - - self.issues.append(Issue( - file=file_path, - line=cls.line if hasattr(cls, 'line') else None, - message=f"Unused class: {cls_name}", - severity=IssueSeverity.WARNING, - symbol=cls_name, - suggestion="Consider removing or using this class" - )) - + + self.issues.append( + Issue( + file=file_path, + line=cls.line if hasattr(cls, "line") else None, + message=f"Unused class: {cls_name}", + severity=IssueSeverity.WARNING, + symbol=cls_name, + suggestion="Consider removing or using this class", + ) + ) + # Find unused variables - if hasattr(codebase, 'global_vars'): + if hasattr(codebase, "global_vars"): for var in codebase.global_vars: - if not hasattr(var, 'symbol_usages'): + if not hasattr(var, "symbol_usages"): continue - + if len(var.symbol_usages) == 0: # Get file and name safely - file_path = var.file.file_path if hasattr(var, 'file') and hasattr(var.file, 'file_path') else "unknown" - var_name = var.name if hasattr(var, 'name') else str(var) - + file_path = ( + var.file.file_path + if hasattr(var, "file") and hasattr(var.file, "file_path") + else "unknown" + ) + var_name = var.name if hasattr(var, "name") else str(var) + # Add to dead code list and issues dead_code["unused_variables"].append({ "name": var_name, "file": file_path, - "line": var.line if hasattr(var, 'line') else None + "line": var.line if hasattr(var, "line") else None, }) - - self.issues.append(Issue( - file=file_path, - line=var.line if hasattr(var, 'line') else None, - message=f"Unused variable: {var_name}", - severity=IssueSeverity.INFO, - symbol=var_name, - suggestion="Consider removing this unused variable" - )) - + + self.issues.append( + Issue( + file=file_path, + line=var.line if hasattr(var, "line") else None, + message=f"Unused variable: {var_name}", + severity=IssueSeverity.INFO, + symbol=var_name, + suggestion="Consider removing this unused variable", + ) + ) + # Find unused imports for file in codebase.files: - if hasattr(file, 'is_binary') and file.is_binary: + if hasattr(file, "is_binary") and file.is_binary: continue - - if not hasattr(file, 'imports'): + + if not hasattr(file, "imports"): continue - - file_path = file.file_path if hasattr(file, 'file_path') else str(file) - + + file_path = file.file_path if hasattr(file, "file_path") else str(file) + for imp in file.imports: - if not hasattr(imp, 'usages'): + if not hasattr(imp, "usages"): continue - + if len(imp.usages) == 0: # Get import source safely - import_source = imp.source if hasattr(imp, 'source') else str(imp) - + import_source = imp.source if hasattr(imp, "source") else str(imp) + # Add to dead code list and issues dead_code["unused_imports"].append({ "import": import_source, "file": file_path, - "line": imp.line if hasattr(imp, 'line') else None + "line": imp.line if hasattr(imp, "line") else None, }) - - self.issues.append(Issue( - file=file_path, - line=imp.line if hasattr(imp, 'line') else None, - message=f"Unused import: {import_source}", - severity=IssueSeverity.INFO, - code=import_source, - suggestion="Remove this unused import" - )) - + + self.issues.append( + Issue( + file=file_path, + line=imp.line if hasattr(imp, "line") else None, + message=f"Unused import: {import_source}", + severity=IssueSeverity.INFO, + code=import_source, + suggestion="Remove this unused import", + ) + ) + # Add total counts dead_code["counts"] = { "unused_functions": len(dead_code["unused_functions"]), "unused_classes": len(dead_code["unused_classes"]), "unused_variables": len(dead_code["unused_variables"]), "unused_imports": len(dead_code["unused_imports"]), - "total": len(dead_code["unused_functions"]) + len(dead_code["unused_classes"]) + - len(dead_code["unused_variables"]) + len(dead_code["unused_imports"]), + "total": len(dead_code["unused_functions"]) + + len(dead_code["unused_classes"]) + + len(dead_code["unused_variables"]) + + len(dead_code["unused_imports"]), } - + return dead_code - - def _check_function_parameters(self, codebase: Codebase) -> Dict[str, Any]: + + def _check_function_parameters(self, codebase: Codebase) -> dict[str, Any]: """Check function parameters for issues.""" parameter_issues = { "missing_types": [], "inconsistent_types": [], - "unused_parameters": [] + "unused_parameters": [], } - - if not hasattr(codebase, 'functions'): + + if not hasattr(codebase, "functions"): return parameter_issues - + for func in codebase.functions: - if not hasattr(func, 'parameters'): + if not hasattr(func, "parameters"): continue - - file_path = func.file.file_path if hasattr(func, 'file') and hasattr(func.file, 'file_path') else "unknown" - func_name = func.name if hasattr(func, 'name') else str(func) - + + file_path = ( + func.file.file_path + if hasattr(func, "file") and hasattr(func.file, "file_path") + else "unknown" + ) + func_name = func.name if hasattr(func, "name") else str(func) + # Check for missing type annotations missing_types = [] for param in func.parameters: - if not hasattr(param, 'name'): + if not hasattr(param, "name"): continue - - if not hasattr(param, 'type') or not param.type: + + if not hasattr(param, "type") or not param.type: missing_types.append(param.name) - + if missing_types: parameter_issues["missing_types"].append({ "function": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, - "parameters": missing_types + "line": func.line if hasattr(func, "line") else None, + "parameters": missing_types, }) - - self.issues.append(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Function {func_name} has parameters without type annotations: {', '.join(missing_types)}", - severity=IssueSeverity.WARNING, - symbol=func_name, - suggestion="Add type annotations to all parameters" - )) - + + self.issues.append( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"Function {func_name} has parameters without type annotations: {', '.join(missing_types)}", + severity=IssueSeverity.WARNING, + symbol=func_name, + suggestion="Add type annotations to all parameters", + ) + ) + # Check for unused parameters - if hasattr(func, 'source'): + if hasattr(func, "source"): # This is a simple check that looks for parameter names in the function body # A more sophisticated check would analyze the AST unused_params = [] for param in func.parameters: - if not hasattr(param, 'name'): + if not hasattr(param, "name"): continue - + # Skip self/cls parameter in methods - if param.name in ['self', 'cls'] and hasattr(func, 'parent') and func.parent: + if ( + param.name in ["self", "cls"] + and hasattr(func, "parent") + and func.parent + ): continue - + # Check if parameter name appears in function body # This is a simple heuristic and may produce false positives - param_regex = r'\b' + re.escape(param.name) + r'\b' - body_lines = func.source.split('\n')[1:] if func.source.count('\n') > 0 else [] - body_text = '\n'.join(body_lines) - + param_regex = r"\b" + re.escape(param.name) + r"\b" + body_lines = ( + func.source.split("\n")[1:] + if func.source.count("\n") > 0 + else [] + ) + body_text = "\n".join(body_lines) + if not re.search(param_regex, body_text): unused_params.append(param.name) - + if unused_params: parameter_issues["unused_parameters"].append({ "function": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, - "parameters": unused_params + "line": func.line if hasattr(func, "line") else None, + "parameters": unused_params, }) - - self.issues.append(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Function {func_name} has potentially unused parameters: {', '.join(unused_params)}", - severity=IssueSeverity.INFO, - symbol=func_name, - suggestion="Check if these parameters are actually used" - )) - + + self.issues.append( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"Function {func_name} has potentially unused parameters: {', '.join(unused_params)}", + severity=IssueSeverity.INFO, + symbol=func_name, + suggestion="Check if these parameters are actually used", + ) + ) + # Check for consistent parameter types across overloaded functions - if hasattr(codebase, 'functions'): + if hasattr(codebase, "functions"): # Find functions with the same name - overloads = [f for f in codebase.functions if hasattr(f, 'name') and f.name == func_name and f != func] - + overloads = [ + f + for f in codebase.functions + if hasattr(f, "name") and f.name == func_name and f != func + ] + if overloads: for overload in overloads: # Check if the same parameter name has different types - if not hasattr(overload, 'parameters'): + if not hasattr(overload, "parameters"): continue - + inconsistent_types = [] for param in func.parameters: - if not hasattr(param, 'name') or not hasattr(param, 'type'): + if not hasattr(param, "name") or not hasattr(param, "type"): continue - + # Find matching parameter in overload - matching_params = [p for p in overload.parameters if hasattr(p, 'name') and p.name == param.name] - + matching_params = [ + p + for p in overload.parameters + if hasattr(p, "name") and p.name == param.name + ] + for matching_param in matching_params: - if (hasattr(matching_param, 'type') and matching_param.type and - str(matching_param.type) != str(param.type)): - + if ( + hasattr(matching_param, "type") + and matching_param.type + and str(matching_param.type) != str(param.type) + ): inconsistent_types.append({ "parameter": param.name, "type1": str(param.type), "type2": str(matching_param.type), "function1": f"{func_name} at {file_path}:{func.line if hasattr(func, 'line') else '?'}", - "function2": f"{overload.name} at {overload.file.file_path if hasattr(overload, 'file') and hasattr(overload.file, 'file_path') else 'unknown'}:{overload.line if hasattr(overload, 'line') else '?'}" + "function2": f"{overload.name} at {overload.file.file_path if hasattr(overload, 'file') and hasattr(overload.file, 'file_path') else 'unknown'}:{overload.line if hasattr(overload, 'line') else '?'}", }) - + if inconsistent_types: - parameter_issues["inconsistent_types"].extend(inconsistent_types) - + parameter_issues["inconsistent_types"].extend( + inconsistent_types + ) + for issue in inconsistent_types: - self.issues.append(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Inconsistent parameter types for {issue['parameter']}: {issue['type1']} vs {issue['type2']}", - severity=IssueSeverity.ERROR, - symbol=func_name, - suggestion="Use consistent parameter types across function overloads" - )) - + self.issues.append( + Issue( + file=file_path, + line=func.line + if hasattr(func, "line") + else None, + message=f"Inconsistent parameter types for {issue['parameter']}: {issue['type1']} vs {issue['type2']}", + severity=IssueSeverity.ERROR, + symbol=func_name, + suggestion="Use consistent parameter types across function overloads", + ) + ) + # Add total counts parameter_issues["counts"] = { "missing_types": len(parameter_issues["missing_types"]), "inconsistent_types": len(parameter_issues["inconsistent_types"]), "unused_parameters": len(parameter_issues["unused_parameters"]), - "total": len(parameter_issues["missing_types"]) + len(parameter_issues["inconsistent_types"]) + - len(parameter_issues["unused_parameters"]), + "total": len(parameter_issues["missing_types"]) + + len(parameter_issues["inconsistent_types"]) + + len(parameter_issues["unused_parameters"]), } - + return parameter_issues - - def _check_error_handling(self, codebase: Codebase) -> Dict[str, Any]: + + def _check_error_handling(self, codebase: Codebase) -> dict[str, Any]: """Check for error handling issues.""" error_handling = { "bare_excepts": [], "pass_in_except": [], - "errors_not_raised": [] + "errors_not_raised": [], } - - if not hasattr(codebase, 'functions'): + + if not hasattr(codebase, "functions"): return error_handling - + for func in codebase.functions: - if not hasattr(func, 'source'): + if not hasattr(func, "source"): continue - - file_path = func.file.file_path if hasattr(func, 'file') and hasattr(func.file, 'file_path') else "unknown" - func_name = func.name if hasattr(func, 'name') else str(func) - + + file_path = ( + func.file.file_path + if hasattr(func, "file") and hasattr(func.file, "file_path") + else "unknown" + ) + func_name = func.name if hasattr(func, "name") else str(func) + # Check for bare except clauses - if re.search(r'except\s*:', func.source): + if re.search(r"except\s*:", func.source): error_handling["bare_excepts"].append({ "function": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, + "line": func.line if hasattr(func, "line") else None, }) - - self.issues.append(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Function {func_name} uses bare 'except:' clause", - severity=IssueSeverity.WARNING, - symbol=func_name, - suggestion="Specify exception types to catch" - )) - + + self.issues.append( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"Function {func_name} uses bare 'except:' clause", + severity=IssueSeverity.WARNING, + symbol=func_name, + suggestion="Specify exception types to catch", + ) + ) + # Check for 'pass' in except blocks - if re.search(r'except[^:]*:.*\bpass\b', func.source, re.DOTALL): + if re.search(r"except[^:]*:.*\bpass\b", func.source, re.DOTALL): error_handling["pass_in_except"].append({ "function": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, + "line": func.line if hasattr(func, "line") else None, }) - - self.issues.append(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Function {func_name} silently ignores exceptions with 'pass'", - severity=IssueSeverity.WARNING, - symbol=func_name, - suggestion="Add proper error handling or logging" - )) - + + self.issues.append( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"Function {func_name} silently ignores exceptions with 'pass'", + severity=IssueSeverity.WARNING, + symbol=func_name, + suggestion="Add proper error handling or logging", + ) + ) + # Check for error classes that aren't raised - if hasattr(func, 'symbol_type') and func.symbol_type == SymbolType.Class: + if hasattr(func, "symbol_type") and func.symbol_type == SymbolType.Class: # Check if class name contains 'Error' or 'Exception' - if hasattr(func, 'name') and ('Error' in func.name or 'Exception' in func.name): + if hasattr(func, "name") and ( + "Error" in func.name or "Exception" in func.name + ): cls = cast(Class, func) - + # Check if class extends Exception is_exception = False - if hasattr(cls, 'superclasses'): - superclass_names = [sc.name for sc in cls.superclasses if hasattr(sc, 'name')] - if any(name in ['Exception', 'BaseException'] for name in superclass_names): + if hasattr(cls, "superclasses"): + superclass_names = [ + sc.name for sc in cls.superclasses if hasattr(sc, "name") + ] + if any( + name in ["Exception", "BaseException"] + for name in superclass_names + ): is_exception = True - - if is_exception and hasattr(cls, 'symbol_usages') and not any('raise' in str(usage) for usage in cls.symbol_usages): + + if ( + is_exception + and hasattr(cls, "symbol_usages") + and not any( + "raise" in str(usage) for usage in cls.symbol_usages + ) + ): error_handling["errors_not_raised"].append({ "class": cls.name, "file": file_path, - "line": cls.line if hasattr(cls, 'line') else None, + "line": cls.line if hasattr(cls, "line") else None, }) - - self.issues.append(Issue( - file=file_path, - line=cls.line if hasattr(cls, 'line') else None, - message=f"Exception class {cls.name} is defined but never raised", - severity=IssueSeverity.INFO, - symbol=cls.name, - suggestion="Either use this exception or remove it" - )) - + + self.issues.append( + Issue( + file=file_path, + line=cls.line if hasattr(cls, "line") else None, + message=f"Exception class {cls.name} is defined but never raised", + severity=IssueSeverity.INFO, + symbol=cls.name, + suggestion="Either use this exception or remove it", + ) + ) + # Add total counts error_handling["counts"] = { "bare_excepts": len(error_handling["bare_excepts"]), "pass_in_except": len(error_handling["pass_in_except"]), "errors_not_raised": len(error_handling["errors_not_raised"]), - "total": len(error_handling["bare_excepts"]) + len(error_handling["pass_in_except"]) + - len(error_handling["errors_not_raised"]), + "total": len(error_handling["bare_excepts"]) + + len(error_handling["pass_in_except"]) + + len(error_handling["errors_not_raised"]), } - + return error_handling - - def _check_call_sites(self, codebase: Codebase) -> Dict[str, Any]: + + def _check_call_sites(self, codebase: Codebase) -> dict[str, Any]: """Check for issues with function call sites.""" - call_site_issues = { - "wrong_parameter_count": [], - "wrong_return_type_usage": [] - } - - if not hasattr(codebase, 'functions'): + call_site_issues = {"wrong_parameter_count": [], "wrong_return_type_usage": []} + + if not hasattr(codebase, "functions"): return call_site_issues - + for func in codebase.functions: - if not hasattr(func, 'call_sites'): + if not hasattr(func, "call_sites"): continue - - file_path = func.file.file_path if hasattr(func, 'file') and hasattr(func.file, 'file_path') else "unknown" - func_name = func.name if hasattr(func, 'name') else str(func) - + + ( + func.file.file_path + if hasattr(func, "file") and hasattr(func.file, "file_path") + else "unknown" + ) + func_name = func.name if hasattr(func, "name") else str(func) + # Get required parameter count (excluding those with defaults) required_count = 0 - if hasattr(func, 'parameters'): - required_count = sum(1 for p in func.parameters if not hasattr(p, 'has_default') or not p.has_default) - + if hasattr(func, "parameters"): + required_count = sum( + 1 + for p in func.parameters + if not hasattr(p, "has_default") or not p.has_default + ) + # Check each call site for call_site in func.call_sites: - if not hasattr(call_site, 'args'): + if not hasattr(call_site, "args"): continue - + # Get call site file info - call_file = call_site.file.file_path if hasattr(call_site, 'file') and hasattr(call_site.file, 'file_path') else "unknown" - call_line = call_site.line if hasattr(call_site, 'line') else None - + call_file = ( + call_site.file.file_path + if hasattr(call_site, "file") + and hasattr(call_site.file, "file_path") + else "unknown" + ) + call_line = call_site.line if hasattr(call_site, "line") else None + # Check parameter count arg_count = len(call_site.args) if arg_count < required_count: @@ -1356,422 +1583,515 @@ def _check_call_sites(self, codebase: Codebase) -> Dict[str, Any]: "caller_file": call_file, "caller_line": call_line, "required_count": required_count, - "provided_count": arg_count + "provided_count": arg_count, }) - - self.issues.append(Issue( - file=call_file, - line=call_line, - message=f"Call to {func_name} has too few arguments ({arg_count} provided, {required_count} required)", - severity=IssueSeverity.ERROR, - symbol=func_name, - suggestion=f"Provide all required arguments to {func_name}" - )) - + + self.issues.append( + Issue( + file=call_file, + line=call_line, + message=f"Call to {func_name} has too few arguments ({arg_count} provided, {required_count} required)", + severity=IssueSeverity.ERROR, + symbol=func_name, + suggestion=f"Provide all required arguments to {func_name}", + ) + ) + # Add total counts call_site_issues["counts"] = { "wrong_parameter_count": len(call_site_issues["wrong_parameter_count"]), "wrong_return_type_usage": len(call_site_issues["wrong_return_type_usage"]), - "total": len(call_site_issues["wrong_parameter_count"]) + len(call_site_issues["wrong_return_type_usage"]), + "total": len(call_site_issues["wrong_parameter_count"]) + + len(call_site_issues["wrong_return_type_usage"]), } - + return call_site_issues - - def _check_imports(self, codebase: Codebase) -> Dict[str, Any]: + + def _check_imports(self, codebase: Codebase) -> dict[str, Any]: """Check for import issues.""" - import_issues = { - "circular_imports": [], - "wildcard_imports": [] - } - + import_issues = {"circular_imports": [], "wildcard_imports": []} + # Check for circular imports try: # Build dependency graph dependency_map = {} - + for file in codebase.files: - if hasattr(file, 'is_binary') and file.is_binary: + if hasattr(file, "is_binary") and file.is_binary: continue - - if not hasattr(file, 'imports'): + + if not hasattr(file, "imports"): continue - - file_path = file.file_path if hasattr(file, 'file_path') else str(file) + + file_path = file.file_path if hasattr(file, "file_path") else str(file) imports = [] - + for imp in file.imports: if hasattr(imp, "imported_symbol") and imp.imported_symbol: imported_symbol = imp.imported_symbol if hasattr(imported_symbol, "file") and imported_symbol.file: - imported_file_path = imported_symbol.file.file_path if hasattr(imported_symbol.file, 'file_path') else str(imported_symbol.file) + imported_file_path = ( + imported_symbol.file.file_path + if hasattr(imported_symbol.file, "file_path") + else str(imported_symbol.file) + ) imports.append(imported_file_path) - + dependency_map[file_path] = imports - + # Create a directed graph import networkx as nx + G = nx.DiGraph() - + # Add nodes and edges for file_path, imports in dependency_map.items(): G.add_node(file_path) for imp in imports: - if imp in dependency_map: # Only add edges for files that exist in our dependency map + if ( + imp in dependency_map + ): # Only add edges for files that exist in our dependency map G.add_edge(file_path, imp) - + # Find cycles try: cycles = list(nx.simple_cycles(G)) - + for cycle in cycles: import_issues["circular_imports"].append({ "cycle": cycle, - "length": len(cycle) + "length": len(cycle), }) - + # Create an issue for each file in the cycle for file_path in cycle: - self.issues.append(Issue( - file=file_path, - line=None, - message=f"Circular import detected: {' -> '.join(cycle)}", - severity=IssueSeverity.ERROR, - suggestion="Refactor imports to break circular dependency" - )) + self.issues.append( + Issue( + file=file_path, + line=None, + message=f"Circular import detected: {' -> '.join(cycle)}", + severity=IssueSeverity.ERROR, + suggestion="Refactor imports to break circular dependency", + ) + ) except nx.NetworkXNoCycle: pass # No cycles found - + except Exception as e: - logger.error(f"Error detecting circular imports: {e}") - + logger.exception(f"Error detecting circular imports: {e}") + # Check for wildcard imports for file in codebase.files: - if hasattr(file, 'is_binary') and file.is_binary: + if hasattr(file, "is_binary") and file.is_binary: continue - - if not hasattr(file, 'imports'): + + if not hasattr(file, "imports"): continue - - file_path = file.file_path if hasattr(file, 'file_path') else str(file) - + + file_path = file.file_path if hasattr(file, "file_path") else str(file) + for imp in file.imports: - if not hasattr(imp, 'source'): + if not hasattr(imp, "source"): continue - + # Check for wildcard imports (from module import *) - if re.search(r'from\s+[\w.]+\s+import\s+\*', imp.source): + if re.search(r"from\s+[\w.]+\s+import\s+\*", imp.source): import_issues["wildcard_imports"].append({ "file": file_path, - "line": imp.line if hasattr(imp, 'line') else None, - "import": imp.source + "line": imp.line if hasattr(imp, "line") else None, + "import": imp.source, }) - - self.issues.append(Issue( - file=file_path, - line=imp.line if hasattr(imp, 'line') else None, - message=f"Wildcard import: {imp.source}", - severity=IssueSeverity.WARNING, - code=imp.source, - suggestion="Import specific symbols instead of using wildcard imports" - )) - + + self.issues.append( + Issue( + file=file_path, + line=imp.line if hasattr(imp, "line") else None, + message=f"Wildcard import: {imp.source}", + severity=IssueSeverity.WARNING, + code=imp.source, + suggestion="Import specific symbols instead of using wildcard imports", + ) + ) + # Add total counts import_issues["counts"] = { "circular_imports": len(import_issues["circular_imports"]), "wildcard_imports": len(import_issues["wildcard_imports"]), - "total": len(import_issues["circular_imports"]) + len(import_issues["wildcard_imports"]), + "total": len(import_issues["circular_imports"]) + + len(import_issues["wildcard_imports"]), } - + return import_issues - - def _check_inheritance(self, codebase: Codebase) -> Dict[str, Any]: + + def _check_inheritance(self, codebase: Codebase) -> dict[str, Any]: """Check for inheritance issues.""" inheritance_issues = { "deep_inheritance": [], "multiple_inheritance": [], - "inconsistent_interfaces": [] + "inconsistent_interfaces": [], } - - if not hasattr(codebase, 'classes'): + + if not hasattr(codebase, "classes"): return inheritance_issues - + for cls in codebase.classes: - if not hasattr(cls, 'superclasses'): + if not hasattr(cls, "superclasses"): continue - - file_path = cls.file.file_path if hasattr(cls, 'file') and hasattr(cls.file, 'file_path') else "unknown" - cls_name = cls.name if hasattr(cls, 'name') else str(cls) - + + file_path = ( + cls.file.file_path + if hasattr(cls, "file") and hasattr(cls.file, "file_path") + else "unknown" + ) + cls_name = cls.name if hasattr(cls, "name") else str(cls) + # Check inheritance depth inheritance_depth = len(cls.superclasses) if inheritance_depth > 3: # Arbitrary threshold for deep inheritance inheritance_issues["deep_inheritance"].append({ "class": cls_name, "file": file_path, - "line": cls.line if hasattr(cls, 'line') else None, + "line": cls.line if hasattr(cls, "line") else None, "depth": inheritance_depth, - "hierarchy": [sc.name if hasattr(sc, 'name') else str(sc) for sc in cls.superclasses] + "hierarchy": [ + sc.name if hasattr(sc, "name") else str(sc) + for sc in cls.superclasses + ], }) - - self.issues.append(Issue( - file=file_path, - line=cls.line if hasattr(cls, 'line') else None, - message=f"Deep inheritance detected for class {cls_name} (depth: {inheritance_depth})", - severity=IssueSeverity.WARNING, - symbol=cls_name, - suggestion="Consider composition over inheritance or flattening the hierarchy" - )) - + + self.issues.append( + Issue( + file=file_path, + line=cls.line if hasattr(cls, "line") else None, + message=f"Deep inheritance detected for class {cls_name} (depth: {inheritance_depth})", + severity=IssueSeverity.WARNING, + symbol=cls_name, + suggestion="Consider composition over inheritance or flattening the hierarchy", + ) + ) + # Check multiple inheritance if inheritance_depth > 1: inheritance_issues["multiple_inheritance"].append({ "class": cls_name, "file": file_path, - "line": cls.line if hasattr(cls, 'line') else None, - "superclasses": [sc.name if hasattr(sc, 'name') else str(sc) for sc in cls.superclasses] + "line": cls.line if hasattr(cls, "line") else None, + "superclasses": [ + sc.name if hasattr(sc, "name") else str(sc) + for sc in cls.superclasses + ], }) - + # We don't create an issue for this by default, as multiple inheritance is not always bad - + # Add total counts inheritance_issues["counts"] = { "deep_inheritance": len(inheritance_issues["deep_inheritance"]), "multiple_inheritance": len(inheritance_issues["multiple_inheritance"]), - "inconsistent_interfaces": len(inheritance_issues["inconsistent_interfaces"]), - "total": len(inheritance_issues["deep_inheritance"]) + len(inheritance_issues["multiple_inheritance"]) + - len(inheritance_issues["inconsistent_interfaces"]), + "inconsistent_interfaces": len( + inheritance_issues["inconsistent_interfaces"] + ), + "total": len(inheritance_issues["deep_inheritance"]) + + len(inheritance_issues["multiple_inheritance"]) + + len(inheritance_issues["inconsistent_interfaces"]), } - + return inheritance_issues - - def _analyze_code_complexity(self, codebase: Codebase) -> Dict[str, Any]: + + def _analyze_code_complexity(self, codebase: Codebase) -> dict[str, Any]: """Analyze code complexity.""" complexity = { "complex_functions": [], "long_functions": [], - "deeply_nested_code": [] + "deeply_nested_code": [], } - - if not hasattr(codebase, 'functions'): + + if not hasattr(codebase, "functions"): return complexity - + for func in codebase.functions: - if not hasattr(func, 'source'): + if not hasattr(func, "source"): continue - - file_path = func.file.file_path if hasattr(func, 'file') and hasattr(func.file, 'file_path') else "unknown" - func_name = func.name if hasattr(func, 'name') else str(func) - + + file_path = ( + func.file.file_path + if hasattr(func, "file") and hasattr(func.file, "file_path") + else "unknown" + ) + func_name = func.name if hasattr(func, "name") else str(func) + # Check function length - func_lines = func.source.count('\n') + 1 + func_lines = func.source.count("\n") + 1 if func_lines > 50: # Arbitrary threshold for long functions complexity["long_functions"].append({ "function": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, - "length": func_lines + "line": func.line if hasattr(func, "line") else None, + "length": func_lines, }) - - self.issues.append(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Function {func_name} is too long ({func_lines} lines)", - severity=IssueSeverity.WARNING, - symbol=func_name, - suggestion="Consider breaking this function into smaller functions" - )) - + + self.issues.append( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"Function {func_name} is too long ({func_lines} lines)", + severity=IssueSeverity.WARNING, + symbol=func_name, + suggestion="Consider breaking this function into smaller functions", + ) + ) + # Check cyclomatic complexity (approximate) # Count branch points (if, for, while, case, etc.) branch_points = ( - func.source.count('if ') + - func.source.count('elif ') + - func.source.count('for ') + - func.source.count('while ') + - func.source.count('case ') + - func.source.count('except ') + - func.source.count(' and ') + - func.source.count(' or ') + func.source.count("if ") + + func.source.count("elif ") + + func.source.count("for ") + + func.source.count("while ") + + func.source.count("case ") + + func.source.count("except ") + + func.source.count(" and ") + + func.source.count(" or ") ) - + if branch_points > 10: # Arbitrary threshold for complex functions complexity["complex_functions"].append({ "function": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, - "branch_points": branch_points + "line": func.line if hasattr(func, "line") else None, + "branch_points": branch_points, }) - - self.issues.append(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Function {func_name} is complex (branch points: {branch_points})", - severity=IssueSeverity.WARNING, - symbol=func_name, - suggestion="Refactor to reduce complexity" - )) - + + self.issues.append( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"Function {func_name} is complex (branch points: {branch_points})", + severity=IssueSeverity.WARNING, + symbol=func_name, + suggestion="Refactor to reduce complexity", + ) + ) + # Check nesting depth - lines = func.source.split('\n') + lines = func.source.split("\n") max_indent = 0 for line in lines: indent = len(line) - len(line.lstrip()) max_indent = max(max_indent, indent) - + # Estimate nesting depth (rough approximation) est_nesting_depth = max_indent // 4 # Assuming 4 spaces per indent level - + if est_nesting_depth > 4: # Arbitrary threshold for deeply nested code complexity["deeply_nested_code"].append({ "function": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, - "estimated_nesting_depth": est_nesting_depth + "line": func.line if hasattr(func, "line") else None, + "estimated_nesting_depth": est_nesting_depth, }) - - self.issues.append(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Function {func_name} has deeply nested code (est. depth: {est_nesting_depth})", - severity=IssueSeverity.WARNING, - symbol=func_name, - suggestion="Refactor to reduce nesting by extracting methods or using early returns" - )) - + + self.issues.append( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"Function {func_name} has deeply nested code (est. depth: {est_nesting_depth})", + severity=IssueSeverity.WARNING, + symbol=func_name, + suggestion="Refactor to reduce nesting by extracting methods or using early returns", + ) + ) + # Add total counts complexity["counts"] = { "complex_functions": len(complexity["complex_functions"]), "long_functions": len(complexity["long_functions"]), "deeply_nested_code": len(complexity["deeply_nested_code"]), - "total": len(complexity["complex_functions"]) + len(complexity["long_functions"]) + - len(complexity["deeply_nested_code"]), + "total": len(complexity["complex_functions"]) + + len(complexity["long_functions"]) + + len(complexity["deeply_nested_code"]), } - + return complexity - + def _check_file_issues(self, file: SourceFile) -> None: """Check a file for issues.""" # Skip binary files - if hasattr(file, 'is_binary') and file.is_binary: + if hasattr(file, "is_binary") and file.is_binary: return - - file_path = file.file_path if hasattr(file, 'file_path') else str(file) - + + file_path = file.file_path if hasattr(file, "file_path") else str(file) + # Check file size - if hasattr(file, 'content'): + if hasattr(file, "content"): file_size = len(file.content) if file_size > 500 * 1024: # 500 KB - self.issues.append(Issue( + self.issues.append( + Issue( + file=file_path, + line=None, + message=f"File is very large ({file_size / 1024:.1f} KB)", + severity=IssueSeverity.WARNING, + suggestion="Consider breaking this file into smaller modules", + ) + ) + + # Check for too many imports + if hasattr(file, "imports") and len(file.imports) > 30: # Arbitrary threshold + self.issues.append( + Issue( file=file_path, line=None, - message=f"File is very large ({file_size / 1024:.1f} KB)", + message=f"File has too many imports ({len(file.imports)})", severity=IssueSeverity.WARNING, - suggestion="Consider breaking this file into smaller modules" - )) - - # Check for too many imports - if hasattr(file, 'imports') and len(file.imports) > 30: # Arbitrary threshold - self.issues.append(Issue( - file=file_path, - line=None, - message=f"File has too many imports ({len(file.imports)})", - severity=IssueSeverity.WARNING, - suggestion="Consider refactoring to reduce the number of imports" - )) - + suggestion="Consider refactoring to reduce the number of imports", + ) + ) + # Check for file-level issues in symbol definitions - if hasattr(file, 'symbols'): + if hasattr(file, "symbols"): # Check for mixing class and function definitions at the top level - toplevel_classes = [s for s in file.symbols if hasattr(s, 'symbol_type') and s.symbol_type == SymbolType.Class] - toplevel_functions = [s for s in file.symbols if hasattr(s, 'symbol_type') and s.symbol_type == SymbolType.Function] - + toplevel_classes = [ + s + for s in file.symbols + if hasattr(s, "symbol_type") and s.symbol_type == SymbolType.Class + ] + toplevel_functions = [ + s + for s in file.symbols + if hasattr(s, "symbol_type") and s.symbol_type == SymbolType.Function + ] + if len(toplevel_classes) > 0 and len(toplevel_functions) > 5: - self.issues.append(Issue( - file=file_path, - line=None, - message=f"File mixes classes and many functions at the top level", - severity=IssueSeverity.INFO, - suggestion="Consider separating classes and functions into different modules" - )) - + self.issues.append( + Issue( + file=file_path, + line=None, + message="File mixes classes and many functions at the top level", + severity=IssueSeverity.INFO, + suggestion="Consider separating classes and functions into different modules", + ) + ) + def _check_function_for_issues(self, func: Function) -> None: """Check a function for issues.""" - file_path = func.file.file_path if hasattr(func, 'file') and hasattr(func.file, 'file_path') else "unknown" - func_name = func.name if hasattr(func, 'name') else str(func) - + file_path = ( + func.file.file_path + if hasattr(func, "file") and hasattr(func.file, "file_path") + else "unknown" + ) + func_name = func.name if hasattr(func, "name") else str(func) + # Check for return type - if not hasattr(func, 'return_type') or not func.return_type: - self.issues.append(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Function {func_name} lacks a return type annotation", - severity=IssueSeverity.WARNING, - symbol=func_name, - suggestion="Add a return type annotation" - )) - - # Check parameters for types - if hasattr(func, 'parameters'): - missing_types = [p.name for p in func.parameters if hasattr(p, 'name') and (not hasattr(p, 'type') or not p.type)] - if missing_types: - self.issues.append(Issue( + if not hasattr(func, "return_type") or not func.return_type: + self.issues.append( + Issue( file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Function {func_name} has parameters without type annotations: {', '.join(missing_types)}", + line=func.line if hasattr(func, "line") else None, + message=f"Function {func_name} lacks a return type annotation", severity=IssueSeverity.WARNING, symbol=func_name, - suggestion="Add type annotations to all parameters" - )) - + suggestion="Add a return type annotation", + ) + ) + + # Check parameters for types + if hasattr(func, "parameters"): + missing_types = [ + p.name + for p in func.parameters + if hasattr(p, "name") and (not hasattr(p, "type") or not p.type) + ] + if missing_types: + self.issues.append( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"Function {func_name} has parameters without type annotations: {', '.join(missing_types)}", + severity=IssueSeverity.WARNING, + symbol=func_name, + suggestion="Add type annotations to all parameters", + ) + ) + # Check for docstring - if hasattr(func, 'source'): - lines = func.source.split('\n') + if hasattr(func, "source"): + lines = func.source.split("\n") if len(lines) > 1: # Check if second line starts a docstring - if not any(line.strip().startswith('"""') or line.strip().startswith("'''") for line in lines[:3]): - self.issues.append(Issue( + if not any( + line.strip().startswith('"""') or line.strip().startswith("'''") + for line in lines[:3] + ): + self.issues.append( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"Function {func_name} lacks a docstring", + severity=IssueSeverity.INFO, + symbol=func_name, + suggestion="Add a docstring describing the function's purpose, parameters, and return value", + ) + ) + + # Check for error handling in async functions + if hasattr(func, "is_async") and func.is_async and hasattr(func, "source"): + if "await" in func.source and "try" not in func.source: + self.issues.append( + Issue( file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Function {func_name} lacks a docstring", - severity=IssueSeverity.INFO, + line=func.line if hasattr(func, "line") else None, + message=f"Async function {func_name} has awaits without try/except", + severity=IssueSeverity.WARNING, symbol=func_name, - suggestion="Add a docstring describing the function's purpose, parameters, and return value" - )) - - # Check for error handling in async functions - if hasattr(func, 'is_async') and func.is_async and hasattr(func, 'source'): - if 'await' in func.source and 'try' not in func.source: - self.issues.append(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Async function {func_name} has awaits without try/except", - severity=IssueSeverity.WARNING, - symbol=func_name, - suggestion="Add error handling for await expressions" - )) + suggestion="Add error handling for await expressions", + ) + ) + def main(): """Main entry point for the codebase analyzer.""" - parser = argparse.ArgumentParser(description="Comprehensive Codebase and PR Analyzer") - + parser = argparse.ArgumentParser( + description="Comprehensive Codebase and PR Analyzer" + ) + # Repository source options source_group = parser.add_mutually_exclusive_group(required=True) source_group.add_argument("--repo-url", help="URL of the repository to analyze") - source_group.add_argument("--repo-path", help="Local path to the repository to analyze") - + source_group.add_argument( + "--repo-path", help="Local path to the repository to analyze" + ) + # Analysis options - parser.add_argument("--analysis-type", choices=["codebase", "pr", "comparison"], default="codebase", - help="Type of analysis to perform (default: codebase)") - parser.add_argument("--language", choices=["python", "typescript"], help="Programming language (auto-detected if not provided)") - parser.add_argument("--base-branch", default="main", help="Base branch for PR comparison (default: main)") + parser.add_argument( + "--analysis-type", + choices=["codebase", "pr", "comparison"], + default="codebase", + help="Type of analysis to perform (default: codebase)", + ) + parser.add_argument( + "--language", + choices=["python", "typescript"], + help="Programming language (auto-detected if not provided)", + ) + parser.add_argument( + "--base-branch", + default="main", + help="Base branch for PR comparison (default: main)", + ) parser.add_argument("--pr-number", type=int, help="PR number to analyze") - + # Output options - parser.add_argument("--output-format", choices=["json", "html", "console"], default="json", help="Output format") + parser.add_argument( + "--output-format", + choices=["json", "html", "console"], + default="json", + help="Output format", + ) parser.add_argument("--output-file", help="Path to the output file") - + args = parser.parse_args() - + try: # Initialize the analyzer analyzer = CodebaseAnalyzer( @@ -1779,17 +2099,17 @@ def main(): repo_path=args.repo_path, base_branch=args.base_branch, pr_number=args.pr_number, - language=args.language + language=args.language, ) - + # Perform the analysis analysis_type = AnalysisType(args.analysis_type) results = analyzer.analyze(analysis_type) - + # Output the results if args.output_format == "json": if args.output_file: - with open(args.output_file, 'w') as f: + with open(args.output_file, "w") as f: json.dump(results, f, indent=2) print(f"Analysis results saved to {args.output_file}") else: @@ -1798,8 +2118,8 @@ def main(): # Create a simple HTML report if not args.output_file: args.output_file = "codebase_analysis_report.html" - - with open(args.output_file, 'w') as f: + + with open(args.output_file, "w") as f: f.write(f""" @@ -1829,47 +2149,51 @@ def main():
  • Info: {results["issue_counts"]["by_severity"]["info"]}
  • - +

    Issues

    - +

    Detailed Analysis

    """)
    -                
    +
                     # Add detailed analysis as formatted JSON
                     f.write(json.dumps(results, indent=2))
    -                
    +
                     f.write("""
             
    """) - + print(f"HTML report saved to {args.output_file}") - + elif args.output_format == "console": - print(f"===== Codebase Analysis Report =====") + print("===== Codebase Analysis Report =====") print(f"Repository: {results['metadata']['repo_name']}") print(f"Language: {results['metadata']['language']}") print(f"Analysis Type: {results['metadata']['analysis_type']}") @@ -1878,24 +2202,29 @@ def main(): print(f" Errors: {results['issue_counts']['by_severity']['error']}") print(f" Warnings: {results['issue_counts']['by_severity']['warning']}") print(f" Info: {results['issue_counts']['by_severity']['info']}") - + print("\n===== Issues =====") for issue in results["issues"]: severity = issue["severity"].upper() - location = f"{issue['file']}:{issue['line']}" if issue['line'] else issue['file'] + location = ( + f"{issue['file']}:{issue['line']}" + if issue["line"] + else issue["file"] + ) print(f"[{severity}] {location}: {issue['message']}") - if issue['symbol']: + if issue["symbol"]: print(f" Symbol: {issue['symbol']}") - if issue['suggestion']: + if issue["suggestion"]: print(f" Suggestion: {issue['suggestion']}") print() - + except Exception as e: print(f"Error: {e}") import traceback + traceback.print_exc() sys.exit(1) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/codegen-on-oss/codegen_on_oss/analyzers/codebase_context.py b/codegen-on-oss/codegen_on_oss/analyzers/codebase_context.py index bb1cd1bb4..68c00661a 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/codebase_context.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/codebase_context.py @@ -7,22 +7,23 @@ visualization, and PR comparison. It serves as the central data model for analysis. """ -import os -import sys import logging -import networkx as nx -from typing import Dict, List, Set, Tuple, Any, Optional, Union, Callable, TypeVar, cast +import sys from enum import Enum -from pathlib import Path +from typing import Any + +import networkx as nx try: + from codegen.sdk.codebase.codebase_context import ( + CodebaseContext as SDKCodebaseContext, + ) + from codegen.sdk.core.class_definition import Class from codegen.sdk.core.codebase import Codebase - from codegen.sdk.codebase.codebase_context import CodebaseContext as SDKCodebaseContext - from codegen.sdk.core.file import SourceFile from codegen.sdk.core.directory import Directory - from codegen.sdk.core.symbol import Symbol + from codegen.sdk.core.file import SourceFile from codegen.sdk.core.function import Function - from codegen.sdk.core.class_definition import Class + from codegen.sdk.core.symbol import Symbol from codegen.sdk.enums import EdgeType, SymbolType except ImportError: print("Codegen SDK not found. Please install it first.") @@ -32,16 +33,15 @@ from codegen_on_oss.analyzers.context.file import FileContext from codegen_on_oss.analyzers.context.function import FunctionContext from codegen_on_oss.analyzers.context.graph import ( - build_dependency_graph, + calculate_centrality, find_circular_dependencies, - calculate_centrality ) # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) @@ -62,8 +62,10 @@ ".vscode", ] + class NodeType(str, Enum): """Types of nodes in the graph.""" + FILE = "file" DIRECTORY = "directory" FUNCTION = "function" @@ -72,6 +74,7 @@ class NodeType(str, Enum): VARIABLE = "variable" UNKNOWN = "unknown" + def get_node_type(node: Any) -> NodeType: """Determine the type of a node.""" if isinstance(node, SourceFile): @@ -85,26 +88,27 @@ def get_node_type(node: Any) -> NodeType: else: return NodeType.UNKNOWN + class CodebaseContext: """ Graph-based representation of a codebase for advanced analysis. - + This class provides a unified graph representation of a codebase, including files, directories, functions, classes, and their relationships. It serves as the central data model for all analysis operations. """ - + def __init__( self, codebase: Codebase, - base_path: Optional[str] = None, - pr_branch: Optional[str] = None, + base_path: str | None = None, + pr_branch: str | None = None, base_branch: str = "main", - file_ignore_list: Optional[List[str]] = None + file_ignore_list: list[str] | None = None, ): """ Initialize the CodebaseContext. - + Args: codebase: The codebase to analyze base_path: Base path of the codebase @@ -117,303 +121,314 @@ def __init__( self.pr_branch = pr_branch self.base_branch = base_branch self.file_ignore_list = file_ignore_list or GLOBAL_FILE_IGNORE_LIST - + # Initialize graph self._graph = nx.DiGraph() - + # File and symbol context caches self._file_contexts = {} self._function_contexts = {} - + # Build the graph self._build_graph() - + def _build_graph(self): """Build the codebase graph.""" logger.info("Building codebase graph...") - + # Add nodes for files for file in self.codebase.files: # Skip ignored files if self._should_ignore_file(file): continue - + # Add file node - file_path = file.file_path if hasattr(file, 'file_path') else str(file) - self._graph.add_node(file, - type=NodeType.FILE, - path=file_path) - + file_path = file.file_path if hasattr(file, "file_path") else str(file) + self._graph.add_node(file, type=NodeType.FILE, path=file_path) + # Add nodes for functions in the file - if hasattr(file, 'functions'): + if hasattr(file, "functions"): for func in file.functions: # Create function node - func_name = func.name if hasattr(func, 'name') else str(func) - self._graph.add_node(func, - type=NodeType.FUNCTION, - name=func_name, - file=file) - + func_name = func.name if hasattr(func, "name") else str(func) + self._graph.add_node( + func, type=NodeType.FUNCTION, name=func_name, file=file + ) + # Add edge from file to function self._graph.add_edge(file, func, type=EdgeType.CONTAINS) - + # Add nodes for classes in the file - if hasattr(file, 'classes'): + if hasattr(file, "classes"): for cls in file.classes: # Create class node - cls_name = cls.name if hasattr(cls, 'name') else str(cls) - self._graph.add_node(cls, - type=NodeType.CLASS, - name=cls_name, - file=file) - + cls_name = cls.name if hasattr(cls, "name") else str(cls) + self._graph.add_node( + cls, type=NodeType.CLASS, name=cls_name, file=file + ) + # Add edge from file to class self._graph.add_edge(file, cls, type=EdgeType.CONTAINS) - + # Add nodes for methods in the class - if hasattr(cls, 'methods'): + if hasattr(cls, "methods"): for method in cls.methods: # Create method node - method_name = method.name if hasattr(method, 'name') else str(method) - self._graph.add_node(method, - type=NodeType.FUNCTION, - name=method_name, - file=file, - class_name=cls_name) - + method_name = ( + method.name if hasattr(method, "name") else str(method) + ) + self._graph.add_node( + method, + type=NodeType.FUNCTION, + name=method_name, + file=file, + class_name=cls_name, + ) + # Add edge from class to method self._graph.add_edge(cls, method, type=EdgeType.CONTAINS) - + # Add edges for imports for file in self.codebase.files: # Skip ignored files if self._should_ignore_file(file): continue - + # Add import edges - if hasattr(file, 'imports'): + if hasattr(file, "imports"): for imp in file.imports: # Get imported file imported_file = None - - if hasattr(imp, 'resolved_file'): + + if hasattr(imp, "resolved_file"): imported_file = imp.resolved_file - elif hasattr(imp, 'resolved_symbol') and hasattr(imp.resolved_symbol, 'file'): + elif hasattr(imp, "resolved_symbol") and hasattr( + imp.resolved_symbol, "file" + ): imported_file = imp.resolved_symbol.file - + if imported_file and imported_file in self._graph: # Add edge from file to imported file self._graph.add_edge(file, imported_file, type=EdgeType.IMPORTS) - + # Add edges for function calls - for func in [n for n in self._graph.nodes if get_node_type(n) == NodeType.FUNCTION]: - if hasattr(func, 'call_sites'): + for func in [ + n for n in self._graph.nodes if get_node_type(n) == NodeType.FUNCTION + ]: + if hasattr(func, "call_sites"): for call_site in func.call_sites: - if hasattr(call_site, 'called_function') and call_site.called_function in self._graph: + if ( + hasattr(call_site, "called_function") + and call_site.called_function in self._graph + ): # Add edge from function to called function - self._graph.add_edge(func, call_site.called_function, type=EdgeType.CALLS) - + self._graph.add_edge( + func, call_site.called_function, type=EdgeType.CALLS + ) + # Add edges for class inheritance for cls in [n for n in self._graph.nodes if get_node_type(n) == NodeType.CLASS]: - if hasattr(cls, 'superclasses'): + if hasattr(cls, "superclasses"): for superclass in cls.superclasses: if superclass in self._graph: # Add edge from class to superclass - self._graph.add_edge(cls, superclass, type=EdgeType.INHERITS_FROM) - - logger.info(f"Graph built with {len(self._graph.nodes)} nodes and {len(self._graph.edges)} edges") - + self._graph.add_edge( + cls, superclass, type=EdgeType.INHERITS_FROM + ) + + logger.info( + f"Graph built with {len(self._graph.nodes)} nodes and {len(self._graph.edges)} edges" + ) + def _should_ignore_file(self, file) -> bool: """Check if a file should be ignored.""" - if hasattr(file, 'is_binary') and file.is_binary: + if hasattr(file, "is_binary") and file.is_binary: return True - - file_path = file.file_path if hasattr(file, 'file_path') else str(file) - + + file_path = file.file_path if hasattr(file, "file_path") else str(file) + # Check against ignore list - for pattern in self.file_ignore_list: - if pattern in file_path: - return True - - return False - - def get_file_context(self, file: Union[SourceFile, str]) -> FileContext: + return any(pattern in file_path for pattern in self.file_ignore_list) + + def get_file_context(self, file: SourceFile | str) -> FileContext: """ Get context for a specific file. - + Args: file: File object or file path - + Returns: FileContext for the specified file """ # If file is a string, find the corresponding file object if isinstance(file, str): for f in self.codebase.files: - file_path = f.file_path if hasattr(f, 'file_path') else str(f) + file_path = f.file_path if hasattr(f, "file_path") else str(f) if file_path == file: file = f break else: raise ValueError(f"File not found: {file}") - + # Get file path - file_path = file.file_path if hasattr(file, 'file_path') else str(file) - + file_path = file.file_path if hasattr(file, "file_path") else str(file) + # Return cached context if available if file_path in self._file_contexts: return self._file_contexts[file_path] - + # Create and cache new context context = FileContext(file) self._file_contexts[file_path] = context - + return context - - def get_function_context(self, function: Union[Function, str]) -> FunctionContext: + + def get_function_context(self, function: Function | str) -> FunctionContext: """ Get context for a specific function. - + Args: function: Function object or function name - + Returns: FunctionContext for the specified function """ # If function is a string, find the corresponding function object if isinstance(function, str): for f in self.get_functions(): - if hasattr(f, 'name') and f.name == function: + if hasattr(f, "name") and f.name == function: function = f break else: raise ValueError(f"Function not found: {function}") - + # Get function name - func_name = function.name if hasattr(function, 'name') else str(function) - + func_name = function.name if hasattr(function, "name") else str(function) + # Return cached context if available if func_name in self._function_contexts: return self._function_contexts[func_name] - + # Create and cache new context context = FunctionContext(function) self._function_contexts[func_name] = context - + return context - + @property def graph(self) -> nx.DiGraph: """Get the codebase graph.""" return self._graph - + @property - def nodes(self) -> List[Any]: + def nodes(self) -> list[Any]: """Get all nodes in the graph.""" return list(self._graph.nodes) - - def get_node(self, name: str) -> Optional[Any]: + + def get_node(self, name: str) -> Any | None: """ Get a node by name. - + Args: name: Name of the node to get - + Returns: The node, or None if not found """ for node in self._graph.nodes: - if (hasattr(node, 'name') and node.name == name) or str(node) == name: + if (hasattr(node, "name") and node.name == name) or str(node) == name: return node return None - - def predecessors(self, node: Any) -> List[Any]: + + def predecessors(self, node: Any) -> list[Any]: """ Get predecessors of a node. - + Args: node: Node to get predecessors for - + Returns: List of predecessor nodes """ return list(self._graph.predecessors(node)) - - def successors(self, node: Any) -> List[Any]: + + def successors(self, node: Any) -> list[Any]: """ Get successors of a node. - + Args: node: Node to get successors for - + Returns: List of successor nodes """ return list(self._graph.successors(node)) - - def get_nodes_by_type(self, node_type: NodeType) -> List[Any]: + + def get_nodes_by_type(self, node_type: NodeType) -> list[Any]: """ Get nodes by type. - + Args: node_type: Type of nodes to get - + Returns: List of nodes of the specified type """ return [n for n in self._graph.nodes if get_node_type(n) == node_type] - - def get_files(self) -> List[SourceFile]: + + def get_files(self) -> list[SourceFile]: """ Get all files in the codebase. - + Returns: List of files """ return self.get_nodes_by_type(NodeType.FILE) - - def get_functions(self) -> List[Function]: + + def get_functions(self) -> list[Function]: """ Get all functions in the codebase. - + Returns: List of functions """ return self.get_nodes_by_type(NodeType.FUNCTION) - - def get_classes(self) -> List[Class]: + + def get_classes(self) -> list[Class]: """ Get all classes in the codebase. - + Returns: List of classes """ return self.get_nodes_by_type(NodeType.CLASS) - - def find_paths(self, source: Any, target: Any, cutoff: Optional[int] = None) -> List[List[Any]]: + + def find_paths( + self, source: Any, target: Any, cutoff: int | None = None + ) -> list[list[Any]]: """ Find all paths between two nodes. - + Args: source: Source node target: Target node cutoff: Maximum path length - + Returns: List of paths from source to target """ if source not in self._graph or target not in self._graph: return [] - + try: return list(nx.all_simple_paths(self._graph, source, target, cutoff=cutoff)) except nx.NetworkXError: return [] - - def find_cycles(self) -> List[List[Any]]: + + def find_cycles(self) -> list[list[Any]]: """ Find cycles in the graph. - + Returns: List of cycles in the graph """ @@ -421,118 +436,121 @@ def find_cycles(self) -> List[List[Any]]: return list(nx.simple_cycles(self._graph)) except nx.NetworkXNoCycle: return [] - + def get_import_graph(self) -> nx.DiGraph: """ Get the import dependency graph. - + Returns: NetworkX DiGraph representing import dependencies """ # Create a subgraph with only file nodes files = self.get_files() subgraph = self._graph.subgraph(files) - + # Create a new graph with only import edges import_graph = nx.DiGraph() - + for source, target, data in subgraph.edges(data=True): - if 'type' in data and data['type'] == EdgeType.IMPORTS: + if "type" in data and data["type"] == EdgeType.IMPORTS: # Get file paths - source_path = source.file_path if hasattr(source, 'file_path') else str(source) - target_path = target.file_path if hasattr(target, 'file_path') else str(target) - + source_path = ( + source.file_path if hasattr(source, "file_path") else str(source) + ) + target_path = ( + target.file_path if hasattr(target, "file_path") else str(target) + ) + # Add edge to import graph import_graph.add_edge(source_path, target_path) - + return import_graph - + def get_call_graph(self) -> nx.DiGraph: """ Get the function call graph. - + Returns: NetworkX DiGraph representing function calls """ # Create a subgraph with only function nodes functions = self.get_functions() subgraph = self._graph.subgraph(functions) - + # Create a new graph with only call edges call_graph = nx.DiGraph() - + for source, target, data in subgraph.edges(data=True): - if 'type' in data and data['type'] == EdgeType.CALLS: + if "type" in data and data["type"] == EdgeType.CALLS: # Get function names - source_name = source.name if hasattr(source, 'name') else str(source) - target_name = target.name if hasattr(target, 'name') else str(target) - + source_name = source.name if hasattr(source, "name") else str(source) + target_name = target.name if hasattr(target, "name") else str(target) + # Add edge to call graph call_graph.add_edge(source_name, target_name) - + return call_graph - + def get_inheritance_graph(self) -> nx.DiGraph: """ Get the class inheritance graph. - + Returns: NetworkX DiGraph representing class inheritance """ # Create a subgraph with only class nodes classes = self.get_classes() subgraph = self._graph.subgraph(classes) - + # Create a new graph with only inheritance edges inheritance_graph = nx.DiGraph() - + for source, target, data in subgraph.edges(data=True): - if 'type' in data and data['type'] == EdgeType.INHERITS_FROM: + if "type" in data and data["type"] == EdgeType.INHERITS_FROM: # Get class names - source_name = source.name if hasattr(source, 'name') else str(source) - target_name = target.name if hasattr(target, 'name') else str(target) - + source_name = source.name if hasattr(source, "name") else str(source) + target_name = target.name if hasattr(target, "name") else str(target) + # Add edge to inheritance graph inheritance_graph.add_edge(source_name, target_name) - + return inheritance_graph - - def analyze_dependencies(self) -> Dict[str, Any]: + + def analyze_dependencies(self) -> dict[str, Any]: """ Analyze dependencies in the codebase. - + Returns: Dictionary containing dependency analysis results """ # Get import graph import_graph = self.get_import_graph() - + # Find circular dependencies circular_deps = find_circular_dependencies(import_graph) - + # Calculate centrality centrality = calculate_centrality(import_graph) - + # Find hub modules (most central) hub_modules = sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:10] - + return { "circular_dependencies": [ - {"cycle": cycle, "length": len(cycle)} - for cycle in circular_deps + {"cycle": cycle, "length": len(cycle)} for cycle in circular_deps ], "hub_modules": [ {"module": module, "centrality": centrality} for module, centrality in hub_modules ], "dependency_count": len(import_graph.edges), - "module_count": len(import_graph.nodes) + "module_count": len(import_graph.nodes), } - - def analyze_code_structure(self) -> Dict[str, Any]: + + def analyze_code_structure(self) -> dict[str, Any]: """ Analyze code structure. - + Returns: Dictionary containing code structure analysis results """ @@ -543,115 +561,117 @@ def analyze_code_structure(self) -> Dict[str, Any]: "average_file_size": self._calculate_average_file_size(), "average_function_size": self._calculate_average_function_size(), "most_complex_files": self._find_most_complex_files(10), - "most_complex_functions": self._find_most_complex_functions(10) + "most_complex_functions": self._find_most_complex_functions(10), } - + def _calculate_average_file_size(self) -> float: """ Calculate average file size in lines. - + Returns: Average file size in lines """ files = self.get_files() - + if not files: return 0 - + total_lines = 0 file_count = 0 - + for file in files: - if hasattr(file, 'content'): - lines = len(file.content.split('\n')) + if hasattr(file, "content"): + lines = len(file.content.split("\n")) total_lines += lines file_count += 1 - + return total_lines / file_count if file_count > 0 else 0 - + def _calculate_average_function_size(self) -> float: """ Calculate average function size in lines. - + Returns: Average function size in lines """ functions = self.get_functions() - + if not functions: return 0 - + total_lines = 0 function_count = 0 - + for func in functions: - if hasattr(func, 'source'): - lines = len(func.source.split('\n')) + if hasattr(func, "source"): + lines = len(func.source.split("\n")) total_lines += lines function_count += 1 - + return total_lines / function_count if function_count > 0 else 0 - - def _find_most_complex_files(self, limit: int = 10) -> List[Dict[str, Any]]: + + def _find_most_complex_files(self, limit: int = 10) -> list[dict[str, Any]]: """ Find the most complex files. - + Args: limit: Maximum number of files to return - + Returns: List of complex files with complexity metrics """ files = self.get_files() file_complexity = [] - + for file in files: file_context = self.get_file_context(file) complexity = file_context.analyze_complexity() - + file_complexity.append({ "file": file_context.path, - "complexity": complexity + "complexity": complexity, }) - + # Sort by complexity - file_complexity.sort(key=lambda x: x["complexity"].get("total_complexity", 0), reverse=True) - + file_complexity.sort( + key=lambda x: x["complexity"].get("total_complexity", 0), reverse=True + ) + return file_complexity[:limit] - - def _find_most_complex_functions(self, limit: int = 10) -> List[Dict[str, Any]]: + + def _find_most_complex_functions(self, limit: int = 10) -> list[dict[str, Any]]: """ Find the most complex functions. - + Args: limit: Maximum number of functions to return - + Returns: List of complex functions with complexity metrics """ functions = self.get_functions() function_complexity = [] - + for func in functions: function_context = self.get_function_context(func) complexity = function_context.analyze_complexity() - + function_complexity.append({ "function": function_context.name, "file": function_context.file_path, "line": function_context.line, - "complexity": complexity["cyclomatic_complexity"] + "complexity": complexity["cyclomatic_complexity"], }) - + # Sort by complexity function_complexity.sort(key=lambda x: x["complexity"], reverse=True) - + return function_complexity[:limit] - - def export_to_dict(self) -> Dict[str, Any]: + + def export_to_dict(self) -> dict[str, Any]: """ Export the codebase context to a dictionary. - + Returns: Dictionary representation of the codebase context """ @@ -661,27 +681,31 @@ def export_to_dict(self) -> Dict[str, Any]: "id": str(id(node)), "type": get_node_type(node).value, } - - if hasattr(node, 'name'): + + if hasattr(node, "name"): node_data["name"] = node.name - - if hasattr(node, 'file') and hasattr(node.file, 'file_path'): + + if hasattr(node, "file") and hasattr(node.file, "file_path"): node_data["file"] = node.file.file_path - + nodes.append(node_data) - + edges = [] for source, target, data in self._graph.edges(data=True): edge_data = { "source": str(id(source)), "target": str(id(target)), } - + if "type" in data: - edge_data["type"] = data["type"].value if isinstance(data["type"], Enum) else str(data["type"]) - + edge_data["type"] = ( + data["type"].value + if isinstance(data["type"], Enum) + else str(data["type"]) + ) + edges.append(edge_data) - + return { "nodes": nodes, "edges": edges, @@ -689,6 +713,6 @@ def export_to_dict(self) -> Dict[str, Any]: "file_count": len(self.get_files()), "function_count": len(self.get_functions()), "class_count": len(self.get_classes()), - "edge_count": len(self._graph.edges) - } - } \ No newline at end of file + "edge_count": len(self._graph.edges), + }, + } diff --git a/codegen-on-oss/codegen_on_oss/analyzers/codebase_visualizer.py b/codegen-on-oss/codegen_on_oss/analyzers/codebase_visualizer.py index 0e7a47b7a..2cea2331b 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/codebase_visualizer.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/codebase_visualizer.py @@ -8,54 +8,67 @@ developers understand codebase architecture and identify potential problems. """ -import os -import sys import json import logging -import tempfile -import math -from enum import Enum -from pathlib import Path -from typing import Dict, List, Set, Tuple, Any, Optional, Union, TypeVar, cast, Callable -from datetime import datetime +import os +import sys from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any try: - import networkx as nx import matplotlib.pyplot as plt + import networkx as nx from matplotlib.colors import LinearSegmentedColormap except ImportError: - print("Visualization dependencies not found. Please install them with: pip install networkx matplotlib") + print( + "Visualization dependencies not found. Please install them with: pip install networkx matplotlib" + ) sys.exit(1) try: - from codegen.sdk.core.codebase import Codebase - from codegen.sdk.core.symbol import Symbol - from codegen.sdk.core.function import Function from codegen.sdk.core.class_definition import Class + from codegen.sdk.core.codebase import Codebase + from codegen.sdk.core.detached_symbols.function_call import FunctionCall from codegen.sdk.core.file import SourceFile + from codegen.sdk.core.function import Function from codegen.sdk.core.import_resolution import Import + from codegen.sdk.core.symbol import Symbol from codegen.sdk.enums import EdgeType, SymbolType - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - + + from codegen_on_oss.codebase_analyzer import ( + AnalysisType, + CodebaseAnalyzer, + Issue, + IssueSeverity, + ) + # Import custom modules - from codegen_on_oss.context_codebase import CodebaseContext, get_node_classes, GLOBAL_FILE_IGNORE_LIST - from codegen_on_oss.codebase_analyzer import CodebaseAnalyzer, Issue, IssueSeverity, AnalysisType + from codegen_on_oss.context_codebase import ( + GLOBAL_FILE_IGNORE_LIST, + CodebaseContext, + get_node_classes, + ) from codegen_on_oss.current_code_codebase import get_selected_codebase except ImportError: - print("Codegen SDK or custom modules not found. Please ensure all dependencies are installed.") + print( + "Codegen SDK or custom modules not found. Please ensure all dependencies are installed." + ) sys.exit(1) # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) + class VisualizationType(str, Enum): """Types of visualizations supported by this module.""" + CALL_GRAPH = "call_graph" DEPENDENCY_GRAPH = "dependency_graph" BLAST_RADIUS = "blast_radius" @@ -66,61 +79,68 @@ class VisualizationType(str, Enum): ISSUES_HEATMAP = "issues_heatmap" PR_COMPARISON = "pr_comparison" + class OutputFormat(str, Enum): """Output formats for visualizations.""" + JSON = "json" PNG = "png" SVG = "svg" HTML = "html" DOT = "dot" + @dataclass class VisualizationConfig: """Configuration for visualization generation.""" + max_depth: int = 5 ignore_external: bool = True ignore_tests: bool = True node_size_base: int = 300 edge_width_base: float = 1.0 - filename_filter: Optional[List[str]] = None - symbol_filter: Optional[List[str]] = None + filename_filter: list[str] | None = None + symbol_filter: list[str] | None = None output_format: OutputFormat = OutputFormat.JSON - output_directory: Optional[str] = None + output_directory: str | None = None layout_algorithm: str = "spring" - highlight_nodes: List[str] = field(default_factory=list) + highlight_nodes: list[str] = field(default_factory=list) highlight_color: str = "#ff5555" - color_palette: Dict[str, str] = field(default_factory=lambda: { - "Function": "#a277ff", # Purple - "Class": "#ffca85", # Orange - "File": "#80CBC4", # Teal - "Module": "#81D4FA", # Light Blue - "Variable": "#B39DDB", # Light Purple - "Root": "#ef5350", # Red - "Warning": "#FFCA28", # Amber - "Error": "#EF5350", # Red - "Dead": "#78909C", # Gray - "External": "#B0BEC5", # Light Gray - }) + color_palette: dict[str, str] = field( + default_factory=lambda: { + "Function": "#a277ff", # Purple + "Class": "#ffca85", # Orange + "File": "#80CBC4", # Teal + "Module": "#81D4FA", # Light Blue + "Variable": "#B39DDB", # Light Purple + "Root": "#ef5350", # Red + "Warning": "#FFCA28", # Amber + "Error": "#EF5350", # Red + "Dead": "#78909C", # Gray + "External": "#B0BEC5", # Light Gray + } + ) + class CodebaseVisualizer: """ Visualizer for codebase structures and analytics. - + This class provides methods to generate various visualizations of a codebase, including call graphs, dependency graphs, complexity heatmaps, and more. It integrates with CodebaseAnalyzer to visualize analysis results. """ - + def __init__( self, - analyzer: Optional[CodebaseAnalyzer] = None, - codebase: Optional[Codebase] = None, - context: Optional[CodebaseContext] = None, - config: Optional[VisualizationConfig] = None + analyzer: CodebaseAnalyzer | None = None, + codebase: Codebase | None = None, + context: CodebaseContext | None = None, + config: VisualizationConfig | None = None, ): """ Initialize the CodebaseVisualizer. - + Args: analyzer: Optional CodebaseAnalyzer instance with analysis results codebase: Optional Codebase instance to visualize @@ -131,37 +151,40 @@ def __init__( self.codebase = codebase or (analyzer.base_codebase if analyzer else None) self.context = context or (analyzer.base_context if analyzer else None) self.config = config or VisualizationConfig() - + # Create visualization directory if specified if self.config.output_directory: os.makedirs(self.config.output_directory, exist_ok=True) - + # Initialize graph for visualization self.graph = nx.DiGraph() - + # Initialize codebase if needed if not self.codebase and not self.context: - logger.info("No codebase or context provided, initializing from current directory") + logger.info( + "No codebase or context provided, initializing from current directory" + ) self.codebase = get_selected_codebase() self.context = CodebaseContext( - codebase=self.codebase, - base_path=os.getcwd() + codebase=self.codebase, base_path=os.getcwd() ) elif self.codebase and not self.context: logger.info("Creating context from provided codebase") self.context = CodebaseContext( codebase=self.codebase, - base_path=os.getcwd() if not hasattr(self.codebase, 'base_path') else self.codebase.base_path + base_path=os.getcwd() + if not hasattr(self.codebase, "base_path") + else self.codebase.base_path, ) - + def _initialize_graph(self): """Initialize a fresh graph for visualization.""" self.graph = nx.DiGraph() - + def _add_node(self, node: Any, **attrs): """ Add a node to the visualization graph with attributes. - + Args: node: Node object to add **attrs: Node attributes @@ -169,10 +192,10 @@ def _add_node(self, node: Any, **attrs): # Skip if node already exists if self.graph.has_node(node): return - + # Generate node ID (memory address for unique identification) node_id = id(node) - + # Get node name if "name" in attrs: node_name = attrs["name"] @@ -182,11 +205,11 @@ def _add_node(self, node: Any, **attrs): node_name = str(node.path).split("/")[-1] else: node_name = str(node) - + # Determine node type and color node_type = node.__class__.__name__ color = attrs.get("color", self.config.color_palette.get(node_type, "#BBBBBB")) - + # Add node with attributes self.graph.add_node( node_id, @@ -194,15 +217,15 @@ def _add_node(self, node: Any, **attrs): name=node_name, type=node_type, color=color, - **attrs + **attrs, ) - + return node_id - + def _add_edge(self, source: Any, target: Any, **attrs): """ Add an edge to the visualization graph with attributes. - + Args: source: Source node target: Target node @@ -211,71 +234,78 @@ def _add_edge(self, source: Any, target: Any, **attrs): # Get node IDs source_id = id(source) target_id = id(target) - + # Add edge with attributes - self.graph.add_edge( - source_id, - target_id, - **attrs - ) - - def _generate_filename(self, visualization_type: VisualizationType, entity_name: str): + self.graph.add_edge(source_id, target_id, **attrs) + + def _generate_filename( + self, visualization_type: VisualizationType, entity_name: str + ): """ Generate a filename for the visualization. - + Args: visualization_type: Type of visualization entity_name: Name of the entity being visualized - + Returns: Generated filename """ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - sanitized_name = entity_name.replace("/", "_").replace("\\", "_").replace(".", "_") + sanitized_name = ( + entity_name.replace("/", "_").replace("\\", "_").replace(".", "_") + ) return f"{visualization_type.value}_{sanitized_name}_{timestamp}.{self.config.output_format.value}" - - def _save_visualization(self, visualization_type: VisualizationType, entity_name: str, data: Any): + + def _save_visualization( + self, visualization_type: VisualizationType, entity_name: str, data: Any + ): """ Save a visualization to file or return it. - + Args: visualization_type: Type of visualization entity_name: Name of the entity being visualized data: Visualization data to save - + Returns: Path to saved file or visualization data """ filename = self._generate_filename(visualization_type, entity_name) - + if self.config.output_directory: filepath = os.path.join(self.config.output_directory, filename) else: filepath = filename - + if self.config.output_format == OutputFormat.JSON: - with open(filepath, 'w') as f: + with open(filepath, "w") as f: json.dump(data, f, indent=2) elif self.config.output_format in [OutputFormat.PNG, OutputFormat.SVG]: # Save matplotlib figure - plt.savefig(filepath, format=self.config.output_format.value, bbox_inches='tight') + plt.savefig( + filepath, format=self.config.output_format.value, bbox_inches="tight" + ) plt.close() elif self.config.output_format == OutputFormat.DOT: # Save as DOT file for Graphviz try: from networkx.drawing.nx_agraph import write_dot + write_dot(self.graph, filepath) except ImportError: - logger.error("networkx.drawing.nx_agraph not available. Install pygraphviz for DOT format.") + logger.exception( + "networkx.drawing.nx_agraph not available. Install pygraphviz for DOT format." + ) return None - + logger.info(f"Visualization saved to {filepath}") return filepath - + def _convert_graph_to_json(self): """ Convert the networkx graph to a JSON-serializable dictionary. - + Returns: Dictionary representation of the graph """ @@ -288,19 +318,22 @@ def _convert_graph_to_json(self): "type": attrs.get("type", ""), "color": attrs.get("color", "#BBBBBB"), } - + # Add file path if available if "file_path" in attrs: node_data["file_path"] = attrs["file_path"] - + # Add other attributes for key, value in attrs.items(): if key not in ["name", "type", "color", "file_path", "original_node"]: - if isinstance(value, (str, int, float, bool, list, dict)) or value is None: + if ( + isinstance(value, str | int | float | bool | list | dict) + or value is None + ): node_data[key] = value - + nodes.append(node_data) - + edges = [] for source, target, attrs in self.graph.edges(data=True): # Create a serializable edge @@ -308,14 +341,17 @@ def _convert_graph_to_json(self): "source": source, "target": target, } - + # Add other attributes for key, value in attrs.items(): - if isinstance(value, (str, int, float, bool, list, dict)) or value is None: + if ( + isinstance(value, str | int | float | bool | list | dict) + or value is None + ): edge_data[key] = value - + edges.append(edge_data) - + return { "nodes": nodes, "edges": edges, @@ -325,18 +361,18 @@ def _convert_graph_to_json(self): "timestamp": datetime.now().isoformat(), "node_count": len(nodes), "edge_count": len(edges), - } + }, } - + def _plot_graph(self): """ Plot the graph using matplotlib. - + Returns: Matplotlib figure """ plt.figure(figsize=(12, 10)) - + # Extract node positions using specified layout algorithm if self.config.layout_algorithm == "spring": pos = nx.spring_layout(self.graph, seed=42) @@ -347,578 +383,653 @@ def _plot_graph(self): else: # Default to spring layout pos = nx.spring_layout(self.graph, seed=42) - + # Extract node colors - node_colors = [attrs.get("color", "#BBBBBB") for _, attrs in self.graph.nodes(data=True)] - + node_colors = [ + attrs.get("color", "#BBBBBB") for _, attrs in self.graph.nodes(data=True) + ] + # Extract node sizes (can be based on some metric) node_sizes = [self.config.node_size_base for _ in self.graph.nodes()] - + # Draw nodes nx.draw_networkx_nodes( - self.graph, pos, - node_color=node_colors, - node_size=node_sizes, - alpha=0.8 + self.graph, pos, node_color=node_colors, node_size=node_sizes, alpha=0.8 ) - + # Draw edges nx.draw_networkx_edges( - self.graph, pos, + self.graph, + pos, width=self.config.edge_width_base, alpha=0.6, arrows=True, - arrowsize=10 + arrowsize=10, ) - + # Draw labels nx.draw_networkx_labels( - self.graph, pos, - labels={node: attrs.get("name", "") for node, attrs in self.graph.nodes(data=True)}, + self.graph, + pos, + labels={ + node: attrs.get("name", "") + for node, attrs in self.graph.nodes(data=True) + }, font_size=8, - font_weight="bold" + font_weight="bold", ) - + plt.title(f"{self.current_visualization_type} - {self.current_entity_name}") plt.axis("off") - + return plt.gcf() - - def visualize_call_graph(self, function_name: str, max_depth: Optional[int] = None): + + def visualize_call_graph(self, function_name: str, max_depth: int | None = None): """ Generate a call graph visualization for a function. - + Args: function_name: Name of the function to visualize max_depth: Maximum depth of the call graph (overrides config) - + Returns: Visualization data or path to saved file """ self.current_visualization_type = VisualizationType.CALL_GRAPH self.current_entity_name = function_name - + # Set max depth - current_max_depth = max_depth if max_depth is not None else self.config.max_depth - + current_max_depth = ( + max_depth if max_depth is not None else self.config.max_depth + ) + # Initialize graph self._initialize_graph() - + # Find the function in the codebase function = None for func in self.codebase.functions: if func.name == function_name: function = func break - + if not function: logger.error(f"Function {function_name} not found in codebase") return None - + # Add root node - root_id = self._add_node( + self._add_node( function, name=function_name, color=self.config.color_palette.get("Root"), - is_root=True + is_root=True, ) - + # Recursively add call relationships - visited = set([function]) - + visited = {function} + def add_calls(func, depth=0): if depth >= current_max_depth: return - + # Skip if no function calls attribute if not hasattr(func, "function_calls"): return - + for call in func.function_calls: # Skip recursive calls if call.name == func.name: continue - + # Get the called function called_func = call.function_definition if not called_func: continue - + # Skip external modules if configured - if self.config.ignore_external and hasattr(called_func, "is_external") and called_func.is_external: + if ( + self.config.ignore_external + and hasattr(called_func, "is_external") + and called_func.is_external + ): continue - + # Generate name for display - if hasattr(called_func, "is_method") and called_func.is_method and hasattr(called_func, "parent_class"): + if ( + hasattr(called_func, "is_method") + and called_func.is_method + and hasattr(called_func, "parent_class") + ): called_name = f"{called_func.parent_class.name}.{called_func.name}" else: called_name = called_func.name - + # Add node for called function - called_id = self._add_node( + self._add_node( called_func, name=called_name, color=self.config.color_palette.get("Function"), - file_path=called_func.file.path if hasattr(called_func, "file") and hasattr(called_func.file, "path") else None + file_path=called_func.file.path + if hasattr(called_func, "file") + and hasattr(called_func.file, "path") + else None, ) - + # Add edge for call relationship self._add_edge( function, called_func, type="call", file_path=call.filepath if hasattr(call, "filepath") else None, - line=call.line if hasattr(call, "line") else None + line=call.line if hasattr(call, "line") else None, ) - + # Recursively process called function if isinstance(called_func, Function) and called_func not in visited: visited.add(called_func) add_calls(called_func, depth + 1) - + # Start from the root function add_calls(function) - + # Generate visualization data if self.config.output_format == OutputFormat.JSON: data = self._convert_graph_to_json() - return self._save_visualization(VisualizationType.CALL_GRAPH, function_name, data) + return self._save_visualization( + VisualizationType.CALL_GRAPH, function_name, data + ) else: fig = self._plot_graph() - return self._save_visualization(VisualizationType.CALL_GRAPH, function_name, fig) - - def visualize_dependency_graph(self, symbol_name: str, max_depth: Optional[int] = None): + return self._save_visualization( + VisualizationType.CALL_GRAPH, function_name, fig + ) + + def visualize_dependency_graph( + self, symbol_name: str, max_depth: int | None = None + ): """ Generate a dependency graph visualization for a symbol. - + Args: symbol_name: Name of the symbol to visualize max_depth: Maximum depth of the dependency graph (overrides config) - + Returns: Visualization data or path to saved file """ self.current_visualization_type = VisualizationType.DEPENDENCY_GRAPH self.current_entity_name = symbol_name - + # Set max depth - current_max_depth = max_depth if max_depth is not None else self.config.max_depth - + current_max_depth = ( + max_depth if max_depth is not None else self.config.max_depth + ) + # Initialize graph self._initialize_graph() - + # Find the symbol in the codebase symbol = None for sym in self.codebase.symbols: if hasattr(sym, "name") and sym.name == symbol_name: symbol = sym break - + if not symbol: logger.error(f"Symbol {symbol_name} not found in codebase") return None - + # Add root node - root_id = self._add_node( + self._add_node( symbol, name=symbol_name, color=self.config.color_palette.get("Root"), - is_root=True + is_root=True, ) - + # Recursively add dependencies - visited = set([symbol]) - + visited = {symbol} + def add_dependencies(sym, depth=0): if depth >= current_max_depth: return - + # Skip if no dependencies attribute if not hasattr(sym, "dependencies"): return - + for dep in sym.dependencies: dep_symbol = None - + if isinstance(dep, Symbol): dep_symbol = dep elif isinstance(dep, Import) and hasattr(dep, "resolved_symbol"): dep_symbol = dep.resolved_symbol - + if not dep_symbol: continue - + # Skip external modules if configured - if self.config.ignore_external and hasattr(dep_symbol, "is_external") and dep_symbol.is_external: + if ( + self.config.ignore_external + and hasattr(dep_symbol, "is_external") + and dep_symbol.is_external + ): continue - + # Add node for dependency - dep_id = self._add_node( + self._add_node( dep_symbol, - name=dep_symbol.name if hasattr(dep_symbol, "name") else str(dep_symbol), - color=self.config.color_palette.get(dep_symbol.__class__.__name__, "#BBBBBB"), - file_path=dep_symbol.file.path if hasattr(dep_symbol, "file") and hasattr(dep_symbol.file, "path") else None + name=dep_symbol.name + if hasattr(dep_symbol, "name") + else str(dep_symbol), + color=self.config.color_palette.get( + dep_symbol.__class__.__name__, "#BBBBBB" + ), + file_path=dep_symbol.file.path + if hasattr(dep_symbol, "file") and hasattr(dep_symbol.file, "path") + else None, ) - + # Add edge for dependency relationship - self._add_edge( - sym, - dep_symbol, - type="depends_on" - ) - + self._add_edge(sym, dep_symbol, type="depends_on") + # Recursively process dependency if dep_symbol not in visited: visited.add(dep_symbol) add_dependencies(dep_symbol, depth + 1) - + # Start from the root symbol add_dependencies(symbol) - + # Generate visualization data if self.config.output_format == OutputFormat.JSON: data = self._convert_graph_to_json() - return self._save_visualization(VisualizationType.DEPENDENCY_GRAPH, symbol_name, data) + return self._save_visualization( + VisualizationType.DEPENDENCY_GRAPH, symbol_name, data + ) else: fig = self._plot_graph() - return self._save_visualization(VisualizationType.DEPENDENCY_GRAPH, symbol_name, fig) - - def visualize_blast_radius(self, symbol_name: str, max_depth: Optional[int] = None): + return self._save_visualization( + VisualizationType.DEPENDENCY_GRAPH, symbol_name, fig + ) + + def visualize_blast_radius(self, symbol_name: str, max_depth: int | None = None): """ Generate a blast radius visualization for a symbol. - + Args: symbol_name: Name of the symbol to visualize max_depth: Maximum depth of the blast radius (overrides config) - + Returns: Visualization data or path to saved file """ self.current_visualization_type = VisualizationType.BLAST_RADIUS self.current_entity_name = symbol_name - + # Set max depth - current_max_depth = max_depth if max_depth is not None else self.config.max_depth - + current_max_depth = ( + max_depth if max_depth is not None else self.config.max_depth + ) + # Initialize graph self._initialize_graph() - + # Find the symbol in the codebase symbol = None for sym in self.codebase.symbols: if hasattr(sym, "name") and sym.name == symbol_name: symbol = sym break - + if not symbol: logger.error(f"Symbol {symbol_name} not found in codebase") return None - + # Add root node - root_id = self._add_node( + self._add_node( symbol, name=symbol_name, color=self.config.color_palette.get("Root"), - is_root=True + is_root=True, ) - + # Recursively add usages (reverse dependencies) - visited = set([symbol]) - + visited = {symbol} + def add_usages(sym, depth=0): if depth >= current_max_depth: return - + # Skip if no usages attribute if not hasattr(sym, "usages"): return - + for usage in sym.usages: # Skip if no usage symbol if not hasattr(usage, "usage_symbol"): continue - + usage_symbol = usage.usage_symbol - + # Skip external modules if configured - if self.config.ignore_external and hasattr(usage_symbol, "is_external") and usage_symbol.is_external: + if ( + self.config.ignore_external + and hasattr(usage_symbol, "is_external") + and usage_symbol.is_external + ): continue - + # Add node for usage - usage_id = self._add_node( + self._add_node( usage_symbol, - name=usage_symbol.name if hasattr(usage_symbol, "name") else str(usage_symbol), - color=self.config.color_palette.get(usage_symbol.__class__.__name__, "#BBBBBB"), - file_path=usage_symbol.file.path if hasattr(usage_symbol, "file") and hasattr(usage_symbol.file, "path") else None + name=usage_symbol.name + if hasattr(usage_symbol, "name") + else str(usage_symbol), + color=self.config.color_palette.get( + usage_symbol.__class__.__name__, "#BBBBBB" + ), + file_path=usage_symbol.file.path + if hasattr(usage_symbol, "file") + and hasattr(usage_symbol.file, "path") + else None, ) - + # Add edge for usage relationship - self._add_edge( - sym, - usage_symbol, - type="used_by" - ) - + self._add_edge(sym, usage_symbol, type="used_by") + # Recursively process usage if usage_symbol not in visited: visited.add(usage_symbol) add_usages(usage_symbol, depth + 1) - + # Start from the root symbol add_usages(symbol) - + # Generate visualization data if self.config.output_format == OutputFormat.JSON: data = self._convert_graph_to_json() - return self._save_visualization(VisualizationType.BLAST_RADIUS, symbol_name, data) + return self._save_visualization( + VisualizationType.BLAST_RADIUS, symbol_name, data + ) else: fig = self._plot_graph() - return self._save_visualization(VisualizationType.BLAST_RADIUS, symbol_name, fig) - + return self._save_visualization( + VisualizationType.BLAST_RADIUS, symbol_name, fig + ) + def visualize_class_methods(self, class_name: str): """ Generate a class methods visualization. - + Args: class_name: Name of the class to visualize - + Returns: Visualization data or path to saved file """ self.current_visualization_type = VisualizationType.CLASS_METHODS self.current_entity_name = class_name - + # Initialize graph self._initialize_graph() - + # Find the class in the codebase class_obj = None for cls in self.codebase.classes: if cls.name == class_name: class_obj = cls break - + if not class_obj: logger.error(f"Class {class_name} not found in codebase") return None - + # Add class node - class_id = self._add_node( + self._add_node( class_obj, name=class_name, color=self.config.color_palette.get("Class"), - is_root=True + is_root=True, ) - + # Skip if no methods attribute if not hasattr(class_obj, "methods"): logger.error(f"Class {class_name} has no methods attribute") return None - + # Add method nodes and connections method_ids = {} for method in class_obj.methods: method_name = f"{class_name}.{method.name}" - + # Add method node method_id = self._add_node( method, name=method_name, color=self.config.color_palette.get("Function"), - file_path=method.file.path if hasattr(method, "file") and hasattr(method.file, "path") else None + file_path=method.file.path + if hasattr(method, "file") and hasattr(method.file, "path") + else None, ) - + method_ids[method.name] = method_id - + # Add edge from class to method - self._add_edge( - class_obj, - method, - type="contains" - ) - + self._add_edge(class_obj, method, type="contains") + # Add call relationships between methods for method in class_obj.methods: # Skip if no function calls attribute if not hasattr(method, "function_calls"): continue - + for call in method.function_calls: # Get the called function called_func = call.function_definition if not called_func: continue - + # Only add edges between methods of this class - if hasattr(called_func, "is_method") and called_func.is_method and \ - hasattr(called_func, "parent_class") and called_func.parent_class == class_obj: + if ( + hasattr(called_func, "is_method") + and called_func.is_method + and hasattr(called_func, "parent_class") + and called_func.parent_class == class_obj + ): self._add_edge( method, called_func, type="calls", - line=call.line if hasattr(call, "line") else None + line=call.line if hasattr(call, "line") else None, ) - + # Generate visualization data if self.config.output_format == OutputFormat.JSON: data = self._convert_graph_to_json() - return self._save_visualization(VisualizationType.CLASS_METHODS, class_name, data) + return self._save_visualization( + VisualizationType.CLASS_METHODS, class_name, data + ) else: fig = self._plot_graph() - return self._save_visualization(VisualizationType.CLASS_METHODS, class_name, fig) - + return self._save_visualization( + VisualizationType.CLASS_METHODS, class_name, fig + ) + def visualize_module_dependencies(self, module_path: str): """ Generate a module dependencies visualization. - + Args: module_path: Path to the module to visualize - + Returns: Visualization data or path to saved file """ self.current_visualization_type = VisualizationType.MODULE_DEPENDENCIES self.current_entity_name = module_path - + # Initialize graph self._initialize_graph() - + # Get all files in the module module_files = [] for file in self.codebase.files: if hasattr(file, "path") and str(file.path).startswith(module_path): module_files.append(file) - + if not module_files: logger.error(f"No files found in module {module_path}") return None - + # Add file nodes module_node_ids = {} for file in module_files: file_name = str(file.path).split("/")[-1] file_module = "/".join(str(file.path).split("/")[:-1]) - + # Add file node file_id = self._add_node( file, name=file_name, module=file_module, color=self.config.color_palette.get("File"), - file_path=str(file.path) + file_path=str(file.path), ) - + module_node_ids[str(file.path)] = file_id - + # Add import relationships for file in module_files: # Skip if no imports attribute if not hasattr(file, "imports"): continue - + for imp in file.imports: imported_file = None - + # Try to get imported file if hasattr(imp, "resolved_file"): imported_file = imp.resolved_file - elif hasattr(imp, "resolved_symbol") and hasattr(imp.resolved_symbol, "file"): + elif hasattr(imp, "resolved_symbol") and hasattr( + imp.resolved_symbol, "file" + ): imported_file = imp.resolved_symbol.file - + if not imported_file: continue - + # Skip external modules if configured - if self.config.ignore_external and hasattr(imported_file, "is_external") and imported_file.is_external: + if ( + self.config.ignore_external + and hasattr(imported_file, "is_external") + and imported_file.is_external + ): continue - + # Add node for imported file if not already added - imported_path = str(imported_file.path) if hasattr(imported_file, "path") else "" - + imported_path = ( + str(imported_file.path) if hasattr(imported_file, "path") else "" + ) + if imported_path not in module_node_ids: imported_name = imported_path.split("/")[-1] imported_module = "/".join(imported_path.split("/")[:-1]) - + imported_id = self._add_node( imported_file, name=imported_name, module=imported_module, - color=self.config.color_palette.get("External" if imported_path.startswith(module_path) else "File"), - file_path=imported_path + color=self.config.color_palette.get( + "External" + if imported_path.startswith(module_path) + else "File" + ), + file_path=imported_path, ) - + module_node_ids[imported_path] = imported_id - + # Add edge for import relationship self._add_edge( file, imported_file, type="imports", - import_name=imp.name if hasattr(imp, "name") else "" + import_name=imp.name if hasattr(imp, "name") else "", ) - + # Generate visualization data if self.config.output_format == OutputFormat.JSON: data = self._convert_graph_to_json() - return self._save_visualization(VisualizationType.MODULE_DEPENDENCIES, module_path, data) + return self._save_visualization( + VisualizationType.MODULE_DEPENDENCIES, module_path, data + ) else: fig = self._plot_graph() - return self._save_visualization(VisualizationType.MODULE_DEPENDENCIES, module_path, fig) - - def visualize_dead_code(self, path_filter: Optional[str] = None): + return self._save_visualization( + VisualizationType.MODULE_DEPENDENCIES, module_path, fig + ) + + def visualize_dead_code(self, path_filter: str | None = None): """ Generate a visualization of dead (unused) code in the codebase. - + Args: path_filter: Optional path to filter files - + Returns: Visualization data or path to saved file """ self.current_visualization_type = VisualizationType.DEAD_CODE self.current_entity_name = path_filter or "codebase" - + # Initialize graph self._initialize_graph() - + # Initialize analyzer if needed if not self.analyzer: logger.info("Initializing analyzer for dead code detection") self.analyzer = CodebaseAnalyzer( codebase=self.codebase, - repo_path=self.context.base_path if hasattr(self.context, "base_path") else None + repo_path=self.context.base_path + if hasattr(self.context, "base_path") + else None, ) - + # Perform analysis if not already done if not hasattr(self.analyzer, "results") or not self.analyzer.results: logger.info("Running code analysis") self.analyzer.analyze(AnalysisType.CODEBASE) - + # Extract dead code information from analysis results if not hasattr(self.analyzer, "results"): logger.error("Analysis results not available") return None - + dead_code = {} - if "static_analysis" in self.analyzer.results and "dead_code" in self.analyzer.results["static_analysis"]: + if ( + "static_analysis" in self.analyzer.results + and "dead_code" in self.analyzer.results["static_analysis"] + ): dead_code = self.analyzer.results["static_analysis"]["dead_code"] - + if not dead_code: logger.warning("No dead code detected in analysis results") return None - + # Create file nodes for containing dead code file_nodes = {} - + # Process unused functions if "unused_functions" in dead_code: for unused_func in dead_code["unused_functions"]: file_path = unused_func.get("file", "") - + # Skip if path filter is specified and doesn't match if path_filter and not file_path.startswith(path_filter): continue - + # Add file node if not already added if file_path not in file_nodes: # Find file in codebase @@ -927,51 +1038,54 @@ def visualize_dead_code(self, path_filter: Optional[str] = None): if hasattr(file, "path") and str(file.path) == file_path: file_obj = file break - + if file_obj: file_name = file_path.split("/")[-1] - file_id = self._add_node( + self._add_node( file_obj, name=file_name, color=self.config.color_palette.get("File"), - file_path=file_path + file_path=file_path, ) - + file_nodes[file_path] = file_obj - + # Add unused function node func_name = unused_func.get("name", "") func_line = unused_func.get("line", None) - + # Create a placeholder for the function (we don't have the actual object) - func_obj = {"name": func_name, "file_path": file_path, "line": func_line, "type": "Function"} - - func_id = self._add_node( + func_obj = { + "name": func_name, + "file_path": file_path, + "line": func_line, + "type": "Function", + } + + self._add_node( func_obj, name=func_name, color=self.config.color_palette.get("Dead"), file_path=file_path, line=func_line, - is_dead=True + is_dead=True, ) - + # Add edge from file to function if file_path in file_nodes: self._add_edge( - file_nodes[file_path], - func_obj, - type="contains_dead" + file_nodes[file_path], func_obj, type="contains_dead" ) - + # Process unused variables if "unused_variables" in dead_code: for unused_var in dead_code["unused_variables"]: file_path = unused_var.get("file", "") - + # Skip if path filter is specified and doesn't match if path_filter and not file_path.startswith(path_filter): continue - + # Add file node if not already added if file_path not in file_nodes: # Find file in codebase @@ -980,272 +1094,320 @@ def visualize_dead_code(self, path_filter: Optional[str] = None): if hasattr(file, "path") and str(file.path) == file_path: file_obj = file break - + if file_obj: file_name = file_path.split("/")[-1] - file_id = self._add_node( + self._add_node( file_obj, name=file_name, color=self.config.color_palette.get("File"), - file_path=file_path + file_path=file_path, ) - + file_nodes[file_path] = file_obj - + # Add unused variable node var_name = unused_var.get("name", "") var_line = unused_var.get("line", None) - + # Create a placeholder for the variable - var_obj = {"name": var_name, "file_path": file_path, "line": var_line, "type": "Variable"} - - var_id = self._add_node( + var_obj = { + "name": var_name, + "file_path": file_path, + "line": var_line, + "type": "Variable", + } + + self._add_node( var_obj, name=var_name, color=self.config.color_palette.get("Dead"), file_path=file_path, line=var_line, - is_dead=True + is_dead=True, ) - + # Add edge from file to variable if file_path in file_nodes: - self._add_edge( - file_nodes[file_path], - var_obj, - type="contains_dead" - ) - + self._add_edge(file_nodes[file_path], var_obj, type="contains_dead") + # Generate visualization data if self.config.output_format == OutputFormat.JSON: data = self._convert_graph_to_json() - return self._save_visualization(VisualizationType.DEAD_CODE, self.current_entity_name, data) + return self._save_visualization( + VisualizationType.DEAD_CODE, self.current_entity_name, data + ) else: fig = self._plot_graph() - return self._save_visualization(VisualizationType.DEAD_CODE, self.current_entity_name, fig) - - def visualize_cyclomatic_complexity(self, path_filter: Optional[str] = None): + return self._save_visualization( + VisualizationType.DEAD_CODE, self.current_entity_name, fig + ) + + def visualize_cyclomatic_complexity(self, path_filter: str | None = None): """ Generate a heatmap visualization of cyclomatic complexity. - + Args: path_filter: Optional path to filter files - + Returns: Visualization data or path to saved file """ self.current_visualization_type = VisualizationType.CYCLOMATIC_COMPLEXITY self.current_entity_name = path_filter or "codebase" - + # Initialize analyzer if needed if not self.analyzer: logger.info("Initializing analyzer for complexity analysis") self.analyzer = CodebaseAnalyzer( codebase=self.codebase, - repo_path=self.context.base_path if hasattr(self.context, "base_path") else None + repo_path=self.context.base_path + if hasattr(self.context, "base_path") + else None, ) - + # Perform analysis if not already done if not hasattr(self.analyzer, "results") or not self.analyzer.results: logger.info("Running code analysis") self.analyzer.analyze(AnalysisType.CODEBASE) - + # Extract complexity information from analysis results if not hasattr(self.analyzer, "results"): logger.error("Analysis results not available") return None - + complexity_data = {} - if "static_analysis" in self.analyzer.results and "code_complexity" in self.analyzer.results["static_analysis"]: - complexity_data = self.analyzer.results["static_analysis"]["code_complexity"] - + if ( + "static_analysis" in self.analyzer.results + and "code_complexity" in self.analyzer.results["static_analysis"] + ): + complexity_data = self.analyzer.results["static_analysis"][ + "code_complexity" + ] + if not complexity_data: logger.warning("No complexity data found in analysis results") return None - + # Extract function complexities functions = [] if "function_complexity" in complexity_data: for func_data in complexity_data["function_complexity"]: # Skip if path filter is specified and doesn't match - if path_filter and not func_data.get("file", "").startswith(path_filter): + if path_filter and not func_data.get("file", "").startswith( + path_filter + ): continue - + functions.append({ "name": func_data.get("name", ""), "file": func_data.get("file", ""), "complexity": func_data.get("complexity", 1), - "line": func_data.get("line", None) + "line": func_data.get("line", None), }) - + # Sort functions by complexity (descending) functions.sort(key=lambda x: x.get("complexity", 0), reverse=True) - + # Generate heatmap visualization plt.figure(figsize=(12, 10)) - + # Extract data for heatmap - func_names = [f"{func['name']} ({func['file'].split('/')[-1]})" for func in functions[:30]] + func_names = [ + f"{func['name']} ({func['file'].split('/')[-1]})" for func in functions[:30] + ] complexities = [func.get("complexity", 0) for func in functions[:30]] - + # Create horizontal bar chart bars = plt.barh(func_names, complexities) - + # Color bars by complexity norm = plt.Normalize(1, max(10, max(complexities))) - cmap = plt.cm.get_cmap('YlOrRd') - + cmap = plt.cm.get_cmap("YlOrRd") + for i, bar in enumerate(bars): complexity = complexities[i] bar.set_color(cmap(norm(complexity))) - + # Add labels and title - plt.xlabel('Cyclomatic Complexity') - plt.title('Top Functions by Cyclomatic Complexity') - plt.grid(axis='x', linestyle='--', alpha=0.6) - + plt.xlabel("Cyclomatic Complexity") + plt.title("Top Functions by Cyclomatic Complexity") + plt.grid(axis="x", linestyle="--", alpha=0.6) + # Add colorbar - plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label='Complexity') - + plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label="Complexity") + # Save and return visualization - return self._save_visualization(VisualizationType.CYCLOMATIC_COMPLEXITY, self.current_entity_name, plt.gcf()) - - def visualize_issues_heatmap(self, severity: Optional[IssueSeverity] = None, path_filter: Optional[str] = None): + return self._save_visualization( + VisualizationType.CYCLOMATIC_COMPLEXITY, self.current_entity_name, plt.gcf() + ) + + def visualize_issues_heatmap( + self, + severity: IssueSeverity | None = None, + path_filter: str | None = None, + ): """ Generate a heatmap visualization of issues in the codebase. - + Args: severity: Optional severity level to filter issues path_filter: Optional path to filter files - + Returns: Visualization data or path to saved file """ self.current_visualization_type = VisualizationType.ISSUES_HEATMAP self.current_entity_name = f"{severity.value if severity else 'all'}_issues" - + # Initialize analyzer if needed if not self.analyzer: logger.info("Initializing analyzer for issues analysis") self.analyzer = CodebaseAnalyzer( codebase=self.codebase, - repo_path=self.context.base_path if hasattr(self.context, "base_path") else None + repo_path=self.context.base_path + if hasattr(self.context, "base_path") + else None, ) - + # Perform analysis if not already done if not hasattr(self.analyzer, "results") or not self.analyzer.results: logger.info("Running code analysis") self.analyzer.analyze(AnalysisType.CODEBASE) - + # Extract issues from analysis results - if not hasattr(self.analyzer, "results") or "issues" not in self.analyzer.results: + if ( + not hasattr(self.analyzer, "results") + or "issues" not in self.analyzer.results + ): logger.error("Issues not available in analysis results") return None - + issues = self.analyzer.results["issues"] - + # Filter issues by severity if specified if severity: issues = [issue for issue in issues if issue.get("severity") == severity] - + # Filter issues by path if specified if path_filter: - issues = [issue for issue in issues if issue.get("file", "").startswith(path_filter)] - + issues = [ + issue + for issue in issues + if issue.get("file", "").startswith(path_filter) + ] + if not issues: logger.warning("No issues found matching the criteria") return None - + # Group issues by file file_issues = {} for issue in issues: file_path = issue.get("file", "") if file_path not in file_issues: file_issues[file_path] = [] - + file_issues[file_path].append(issue) - + # Generate heatmap visualization plt.figure(figsize=(12, 10)) - + # Extract data for heatmap files = list(file_issues.keys()) file_names = [file_path.split("/")[-1] for file_path in files] issue_counts = [len(file_issues[file_path]) for file_path in files] - + # Sort by issue count - sorted_data = sorted(zip(file_names, issue_counts, files), key=lambda x: x[1], reverse=True) - file_names, issue_counts, files = zip(*sorted_data) - + sorted_data = sorted( + zip(file_names, issue_counts, files, strict=False), + key=lambda x: x[1], + reverse=True, + ) + file_names, issue_counts, files = zip(*sorted_data, strict=False) + # Create horizontal bar chart bars = plt.barh(file_names[:20], issue_counts[:20]) - + # Color bars by issue count norm = plt.Normalize(1, max(5, max(issue_counts[:20]))) - cmap = plt.cm.get_cmap('OrRd') - + cmap = plt.cm.get_cmap("OrRd") + for i, bar in enumerate(bars): count = issue_counts[i] bar.set_color(cmap(norm(count))) - + # Add labels and title - plt.xlabel('Number of Issues') + plt.xlabel("Number of Issues") severity_text = f" ({severity.value})" if severity else "" - plt.title(f'Files with the Most Issues{severity_text}') - plt.grid(axis='x', linestyle='--', alpha=0.6) - + plt.title(f"Files with the Most Issues{severity_text}") + plt.grid(axis="x", linestyle="--", alpha=0.6) + # Add colorbar - plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label='Issue Count') - + plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label="Issue Count") + # Save and return visualization - return self._save_visualization(VisualizationType.ISSUES_HEATMAP, self.current_entity_name, plt.gcf()) - + return self._save_visualization( + VisualizationType.ISSUES_HEATMAP, self.current_entity_name, plt.gcf() + ) + def visualize_pr_comparison(self): """ Generate a visualization comparing base branch with PR. - + Returns: Visualization data or path to saved file """ self.current_visualization_type = VisualizationType.PR_COMPARISON - + # Check if analyzer has PR data - if not self.analyzer or not self.analyzer.pr_codebase or not self.analyzer.base_codebase: + if ( + not self.analyzer + or not self.analyzer.pr_codebase + or not self.analyzer.base_codebase + ): logger.error("PR comparison requires analyzer with PR data") return None - - self.current_entity_name = f"pr_{self.analyzer.pr_number}" if self.analyzer.pr_number else "pr_comparison" - + + self.current_entity_name = ( + f"pr_{self.analyzer.pr_number}" + if self.analyzer.pr_number + else "pr_comparison" + ) + # Perform comparison analysis if not already done if not hasattr(self.analyzer, "results") or not self.analyzer.results: logger.info("Running PR comparison analysis") self.analyzer.analyze(AnalysisType.COMPARISON) - + # Extract comparison data from analysis results - if not hasattr(self.analyzer, "results") or "comparison" not in self.analyzer.results: + if ( + not hasattr(self.analyzer, "results") + or "comparison" not in self.analyzer.results + ): logger.error("Comparison data not available in analysis results") return None - + comparison = self.analyzer.results["comparison"] - + # Initialize graph self._initialize_graph() - + # Process symbol comparison data if "symbol_comparison" in comparison: for symbol_data in comparison["symbol_comparison"]: symbol_name = symbol_data.get("name", "") in_base = symbol_data.get("in_base", False) in_pr = symbol_data.get("in_pr", False) - + # Create a placeholder for the symbol symbol_obj = { "name": symbol_name, "in_base": in_base, "in_pr": in_pr, - "type": "Symbol" + "type": "Symbol", } - + # Determine node color based on presence in base and PR if in_base and in_pr: color = "#A5D6A7" # Light green (modified) @@ -1253,230 +1415,199 @@ def visualize_pr_comparison(self): color = "#EF9A9A" # Light red (removed) else: color = "#90CAF9" # Light blue (added) - + # Add node for symbol - symbol_id = self._add_node( + self._add_node( symbol_obj, name=symbol_name, color=color, in_base=in_base, - in_pr=in_pr + in_pr=in_pr, ) - + # Process parameter changes if available if "parameter_changes" in symbol_data: param_changes = symbol_data["parameter_changes"] - + # Process removed parameters for param in param_changes.get("removed", []): param_obj = { "name": param, "change_type": "removed", - "type": "Parameter" + "type": "Parameter", } - - param_id = self._add_node( + + self._add_node( param_obj, name=param, color="#EF9A9A", # Light red (removed) - change_type="removed" + change_type="removed", ) - - self._add_edge( - symbol_obj, - param_obj, - type="removed_parameter" - ) - + + self._add_edge(symbol_obj, param_obj, type="removed_parameter") + # Process added parameters for param in param_changes.get("added", []): param_obj = { "name": param, "change_type": "added", - "type": "Parameter" + "type": "Parameter", } - - param_id = self._add_node( + + self._add_node( param_obj, name=param, color="#90CAF9", # Light blue (added) - change_type="added" + change_type="added", ) - - self._add_edge( - symbol_obj, - param_obj, - type="added_parameter" - ) - + + self._add_edge(symbol_obj, param_obj, type="added_parameter") + # Process return type changes if available if "return_type_change" in symbol_data: return_type_change = symbol_data["return_type_change"] old_type = return_type_change.get("old", "None") new_type = return_type_change.get("new", "None") - + return_obj = { "name": f"{old_type} -> {new_type}", "old_type": old_type, "new_type": new_type, - "type": "ReturnType" + "type": "ReturnType", } - - return_id = self._add_node( + + self._add_node( return_obj, name=f"{old_type} -> {new_type}", color="#FFD54F", # Amber (changed) old_type=old_type, - new_type=new_type + new_type=new_type, ) - - self._add_edge( - symbol_obj, - return_obj, - type="return_type_change" - ) - + + self._add_edge(symbol_obj, return_obj, type="return_type_change") + # Process call site issues if available if "call_site_issues" in symbol_data: for issue in symbol_data["call_site_issues"]: issue_file = issue.get("file", "") issue_line = issue.get("line", None) issue_text = issue.get("issue", "") - + # Create a placeholder for the issue issue_obj = { "name": issue_text, "file": issue_file, "line": issue_line, - "type": "Issue" + "type": "Issue", } - - issue_id = self._add_node( + + self._add_node( issue_obj, name=f"{issue_file.split('/')[-1]}:{issue_line}", color="#EF5350", # Red (error) file_path=issue_file, line=issue_line, - issue_text=issue_text - ) - - self._add_edge( - symbol_obj, - issue_obj, - type="call_site_issue" + issue_text=issue_text, ) - + + self._add_edge(symbol_obj, issue_obj, type="call_site_issue") + # Generate visualization data if self.config.output_format == OutputFormat.JSON: data = self._convert_graph_to_json() - return self._save_visualization(VisualizationType.PR_COMPARISON, self.current_entity_name, data) + return self._save_visualization( + VisualizationType.PR_COMPARISON, self.current_entity_name, data + ) else: fig = self._plot_graph() - return self._save_visualization(VisualizationType.PR_COMPARISON, self.current_entity_name, fig) + return self._save_visualization( + VisualizationType.PR_COMPARISON, self.current_entity_name, fig + ) + # Command-line interface def main(): """ Command-line interface for the codebase visualizer. - + This function parses command-line arguments and generates visualizations based on the specified parameters. """ parser = argparse.ArgumentParser( description="Generate visualizations of codebase structure and analysis." ) - + # Repository options repo_group = parser.add_argument_group("Repository Options") + repo_group.add_argument("--repo-url", help="URL of the repository to analyze") repo_group.add_argument( - "--repo-url", - help="URL of the repository to analyze" - ) - repo_group.add_argument( - "--repo-path", - help="Local path to the repository to analyze" + "--repo-path", help="Local path to the repository to analyze" ) - repo_group.add_argument( - "--language", - help="Programming language of the codebase" - ) - + repo_group.add_argument("--language", help="Programming language of the codebase") + # Visualization options viz_group = parser.add_argument_group("Visualization Options") viz_group.add_argument( "--type", choices=[t.value for t in VisualizationType], required=True, - help="Type of visualization to generate" + help="Type of visualization to generate", ) viz_group.add_argument( - "--entity", - help="Name of the entity to visualize (function, class, file, etc.)" + "--entity", help="Name of the entity to visualize (function, class, file, etc.)" ) viz_group.add_argument( "--max-depth", type=int, default=5, - help="Maximum depth for recursive visualizations" + help="Maximum depth for recursive visualizations", ) viz_group.add_argument( - "--ignore-external", - action="store_true", - help="Ignore external dependencies" + "--ignore-external", action="store_true", help="Ignore external dependencies" ) viz_group.add_argument( "--severity", choices=[s.value for s in IssueSeverity], - help="Filter issues by severity" + help="Filter issues by severity", ) - viz_group.add_argument( - "--path-filter", - help="Filter by file path" - ) - + viz_group.add_argument("--path-filter", help="Filter by file path") + # PR options pr_group = parser.add_argument_group("PR Options") + pr_group.add_argument("--pr-number", type=int, help="PR number to analyze") pr_group.add_argument( - "--pr-number", - type=int, - help="PR number to analyze" - ) - pr_group.add_argument( - "--base-branch", - default="main", - help="Base branch for comparison" + "--base-branch", default="main", help="Base branch for comparison" ) - + # Output options output_group = parser.add_argument_group("Output Options") output_group.add_argument( "--output-format", choices=[f.value for f in OutputFormat], default="json", - help="Output format for the visualization" + help="Output format for the visualization", ) output_group.add_argument( - "--output-directory", - help="Directory to save visualizations" + "--output-directory", help="Directory to save visualizations" ) output_group.add_argument( "--layout", choices=["spring", "kamada_kawai", "spectral"], default="spring", - help="Layout algorithm for graph visualization" + help="Layout algorithm for graph visualization", ) - + args = parser.parse_args() - + # Create visualizer configuration config = VisualizationConfig( max_depth=args.max_depth, ignore_external=args.ignore_external, output_format=OutputFormat(args.output_format), output_directory=args.output_directory, - layout_algorithm=args.layout + layout_algorithm=args.layout, ) - + # Create codebase analyzer if needed for PR comparison analyzer = None if args.type == VisualizationType.PR_COMPARISON.value or args.pr_number: @@ -1485,71 +1616,68 @@ def main(): repo_path=args.repo_path, base_branch=args.base_branch, pr_number=args.pr_number, - language=args.language + language=args.language, ) - + # Create visualizer - visualizer = CodebaseVisualizer( - analyzer=analyzer, - config=config - ) - + visualizer = CodebaseVisualizer(analyzer=analyzer, config=config) + # Generate visualization based on type viz_type = VisualizationType(args.type) result = None - + if viz_type == VisualizationType.CALL_GRAPH: if not args.entity: logger.error("Entity name required for call graph visualization") sys.exit(1) - + result = visualizer.visualize_call_graph(args.entity) - + elif viz_type == VisualizationType.DEPENDENCY_GRAPH: if not args.entity: logger.error("Entity name required for dependency graph visualization") sys.exit(1) - + result = visualizer.visualize_dependency_graph(args.entity) - + elif viz_type == VisualizationType.BLAST_RADIUS: if not args.entity: logger.error("Entity name required for blast radius visualization") sys.exit(1) - + result = visualizer.visualize_blast_radius(args.entity) - + elif viz_type == VisualizationType.CLASS_METHODS: if not args.entity: logger.error("Class name required for class methods visualization") sys.exit(1) - + result = visualizer.visualize_class_methods(args.entity) - + elif viz_type == VisualizationType.MODULE_DEPENDENCIES: if not args.entity: logger.error("Module path required for module dependencies visualization") sys.exit(1) - + result = visualizer.visualize_module_dependencies(args.entity) - + elif viz_type == VisualizationType.DEAD_CODE: result = visualizer.visualize_dead_code(args.path_filter) - + elif viz_type == VisualizationType.CYCLOMATIC_COMPLEXITY: result = visualizer.visualize_cyclomatic_complexity(args.path_filter) - + elif viz_type == VisualizationType.ISSUES_HEATMAP: severity = IssueSeverity(args.severity) if args.severity else None result = visualizer.visualize_issues_heatmap(severity, args.path_filter) - + elif viz_type == VisualizationType.PR_COMPARISON: if not args.pr_number: logger.error("PR number required for PR comparison visualization") sys.exit(1) - + result = visualizer.visualize_pr_comparison() - + # Output result if result: logger.info(f"Visualization completed: {result}") @@ -1557,5 +1685,6 @@ def main(): logger.error("Failed to generate visualization") sys.exit(1) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/codegen-on-oss/codegen_on_oss/analyzers/context/__init__.py b/codegen-on-oss/codegen_on_oss/analyzers/context/__init__.py index 497fad744..e8ade95ca 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/context/__init__.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/context/__init__.py @@ -10,7 +10,7 @@ from codegen_on_oss.analyzers.context.function import FunctionContext __all__ = [ - 'CodebaseContext', - 'FileContext', - 'FunctionContext', -] \ No newline at end of file + "CodebaseContext", + "FileContext", + "FunctionContext", +] diff --git a/codegen-on-oss/codegen_on_oss/analyzers/context/codebase.py b/codegen-on-oss/codegen_on_oss/analyzers/context/codebase.py index 51e98c64e..3ce03415c 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/context/codebase.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/context/codebase.py @@ -7,22 +7,23 @@ code structure visualization, and PR comparison. """ -import os -import sys import logging -import networkx as nx -from typing import Dict, List, Set, Tuple, Any, Optional, Union, Callable, TypeVar, cast +import sys from enum import Enum -from pathlib import Path +from typing import Any + +import networkx as nx try: + from codegen.sdk.codebase.codebase_context import ( + CodebaseContext as SDKCodebaseContext, + ) + from codegen.sdk.core.class_definition import Class from codegen.sdk.core.codebase import Codebase - from codegen.sdk.codebase.codebase_context import CodebaseContext as SDKCodebaseContext - from codegen.sdk.core.file import SourceFile from codegen.sdk.core.directory import Directory - from codegen.sdk.core.symbol import Symbol + from codegen.sdk.core.file import SourceFile from codegen.sdk.core.function import Function - from codegen.sdk.core.class_definition import Class + from codegen.sdk.core.symbol import Symbol from codegen.sdk.enums import EdgeType, SymbolType except ImportError: print("Codegen SDK not found. Please install it first.") @@ -31,8 +32,8 @@ # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) @@ -53,8 +54,10 @@ ".vscode", ] + class NodeType(str, Enum): """Types of nodes in the graph.""" + FILE = "file" DIRECTORY = "directory" FUNCTION = "function" @@ -63,6 +66,7 @@ class NodeType(str, Enum): VARIABLE = "variable" UNKNOWN = "unknown" + def get_node_type(node: Any) -> NodeType: """Determine the type of a node.""" if isinstance(node, SourceFile): @@ -76,6 +80,7 @@ def get_node_type(node: Any) -> NodeType: else: return NodeType.UNKNOWN + def get_node_classes(): """Get a dictionary mapping node types to their classes.""" return { @@ -85,27 +90,28 @@ def get_node_classes(): NodeType.CLASS: Class, } + class CodebaseContext: """ Graph-based representation of a codebase for advanced analysis. - + This class provides a graph representation of a codebase, including files, directories, functions, classes, and their relationships. It supports advanced analysis capabilities such as dependency analysis, code structure visualization, and PR comparison. """ - + def __init__( self, codebase: Codebase, - base_path: Optional[str] = None, - pr_branch: Optional[str] = None, + base_path: str | None = None, + pr_branch: str | None = None, base_branch: str = "main", - file_ignore_list: Optional[List[str]] = None + file_ignore_list: list[str] | None = None, ): """ Initialize the CodebaseContext. - + Args: codebase: The codebase to analyze base_path: Base path of the codebase @@ -118,269 +124,292 @@ def __init__( self.pr_branch = pr_branch self.base_branch = base_branch self.file_ignore_list = file_ignore_list or GLOBAL_FILE_IGNORE_LIST - + # Initialize graph self._graph = nx.DiGraph() - + # Build the graph self._build_graph() - + def _build_graph(self): """Build the codebase graph.""" logger.info("Building codebase graph...") - + # Add nodes for files for file in self.codebase.files: # Skip ignored files if self._should_ignore_file(file): continue - + # Add file node - self._graph.add_node(file, - type=NodeType.FILE, - path=file.file_path if hasattr(file, 'file_path') else str(file)) - + self._graph.add_node( + file, + type=NodeType.FILE, + path=file.file_path if hasattr(file, "file_path") else str(file), + ) + # Add nodes for functions in the file - if hasattr(file, 'functions'): + if hasattr(file, "functions"): for func in file.functions: - self._graph.add_node(func, - type=NodeType.FUNCTION, - name=func.name if hasattr(func, 'name') else str(func), - file=file) - + self._graph.add_node( + func, + type=NodeType.FUNCTION, + name=func.name if hasattr(func, "name") else str(func), + file=file, + ) + # Add edge from file to function self._graph.add_edge(file, func, type=EdgeType.CONTAINS) - + # Add nodes for classes in the file - if hasattr(file, 'classes'): + if hasattr(file, "classes"): for cls in file.classes: - self._graph.add_node(cls, - type=NodeType.CLASS, - name=cls.name if hasattr(cls, 'name') else str(cls), - file=file) - + self._graph.add_node( + cls, + type=NodeType.CLASS, + name=cls.name if hasattr(cls, "name") else str(cls), + file=file, + ) + # Add edge from file to class self._graph.add_edge(file, cls, type=EdgeType.CONTAINS) - + # Add nodes for methods in the class - if hasattr(cls, 'methods'): + if hasattr(cls, "methods"): for method in cls.methods: - self._graph.add_node(method, - type=NodeType.FUNCTION, - name=method.name if hasattr(method, 'name') else str(method), - file=file, - class_name=cls.name if hasattr(cls, 'name') else str(cls)) - + self._graph.add_node( + method, + type=NodeType.FUNCTION, + name=method.name + if hasattr(method, "name") + else str(method), + file=file, + class_name=cls.name + if hasattr(cls, "name") + else str(cls), + ) + # Add edge from class to method self._graph.add_edge(cls, method, type=EdgeType.CONTAINS) - + # Add edges for imports for file in self.codebase.files: # Skip ignored files if self._should_ignore_file(file): continue - + # Add import edges - if hasattr(file, 'imports'): + if hasattr(file, "imports"): for imp in file.imports: # Get imported file imported_file = None - - if hasattr(imp, 'resolved_file'): + + if hasattr(imp, "resolved_file"): imported_file = imp.resolved_file - elif hasattr(imp, 'resolved_symbol') and hasattr(imp.resolved_symbol, 'file'): + elif hasattr(imp, "resolved_symbol") and hasattr( + imp.resolved_symbol, "file" + ): imported_file = imp.resolved_symbol.file - + if imported_file and imported_file in self._graph: # Add edge from file to imported file self._graph.add_edge(file, imported_file, type=EdgeType.IMPORTS) - + # Add edges for function calls - for func in [n for n in self._graph.nodes if get_node_type(n) == NodeType.FUNCTION]: - if hasattr(func, 'call_sites'): + for func in [ + n for n in self._graph.nodes if get_node_type(n) == NodeType.FUNCTION + ]: + if hasattr(func, "call_sites"): for call_site in func.call_sites: - if hasattr(call_site, 'called_function') and call_site.called_function in self._graph: + if ( + hasattr(call_site, "called_function") + and call_site.called_function in self._graph + ): # Add edge from function to called function - self._graph.add_edge(func, call_site.called_function, type=EdgeType.CALLS) - + self._graph.add_edge( + func, call_site.called_function, type=EdgeType.CALLS + ) + # Add edges for class inheritance for cls in [n for n in self._graph.nodes if get_node_type(n) == NodeType.CLASS]: - if hasattr(cls, 'superclasses'): + if hasattr(cls, "superclasses"): for superclass in cls.superclasses: if superclass in self._graph: # Add edge from class to superclass - self._graph.add_edge(cls, superclass, type=EdgeType.INHERITS_FROM) - - logger.info(f"Graph built with {len(self._graph.nodes)} nodes and {len(self._graph.edges)} edges") - + self._graph.add_edge( + cls, superclass, type=EdgeType.INHERITS_FROM + ) + + logger.info( + f"Graph built with {len(self._graph.nodes)} nodes and {len(self._graph.edges)} edges" + ) + def _should_ignore_file(self, file) -> bool: """Check if a file should be ignored.""" - if hasattr(file, 'is_binary') and file.is_binary: + if hasattr(file, "is_binary") and file.is_binary: return True - - file_path = file.file_path if hasattr(file, 'file_path') else str(file) - + + file_path = file.file_path if hasattr(file, "file_path") else str(file) + # Check against ignore list - for pattern in self.file_ignore_list: - if pattern in file_path: - return True - - return False - + return any(pattern in file_path for pattern in self.file_ignore_list) + @property def graph(self) -> nx.DiGraph: """Get the codebase graph.""" return self._graph - + @property - def nodes(self) -> List[Any]: + def nodes(self) -> list[Any]: """Get all nodes in the graph.""" return list(self._graph.nodes) - - def get_node(self, name: str) -> Optional[Any]: + + def get_node(self, name: str) -> Any | None: """ Get a node by name. - + Args: name: Name of the node to get - + Returns: The node, or None if not found """ for node in self._graph.nodes: - if (hasattr(node, 'name') and node.name == name) or str(node) == name: + if (hasattr(node, "name") and node.name == name) or str(node) == name: return node return None - - def predecessors(self, node: Any) -> List[Any]: + + def predecessors(self, node: Any) -> list[Any]: """ Get predecessors of a node. - + Args: node: Node to get predecessors for - + Returns: List of predecessor nodes """ return list(self._graph.predecessors(node)) - - def successors(self, node: Any) -> List[Any]: + + def successors(self, node: Any) -> list[Any]: """ Get successors of a node. - + Args: node: Node to get successors for - + Returns: List of successor nodes """ return list(self._graph.successors(node)) - - def in_edges(self, node: Any, data: bool = False) -> List[Any]: + + def in_edges(self, node: Any, data: bool = False) -> list[Any]: """ Get incoming edges of a node. - + Args: node: Node to get edges for data: Whether to include edge data - + Returns: List of incoming edges """ return list(self._graph.in_edges(node, data=data)) - - def out_edges(self, node: Any, data: bool = False) -> List[Any]: + + def out_edges(self, node: Any, data: bool = False) -> list[Any]: """ Get outgoing edges of a node. - + Args: node: Node to get edges for data: Whether to include edge data - + Returns: List of outgoing edges """ return list(self._graph.out_edges(node, data=data)) - - def edges(self, data: bool = False) -> List[Any]: + + def edges(self, data: bool = False) -> list[Any]: """ Get all edges in the graph. - + Args: data: Whether to include edge data - + Returns: List of edges """ return list(self._graph.edges(data=data)) - - def get_nodes_by_type(self, node_type: NodeType) -> List[Any]: + + def get_nodes_by_type(self, node_type: NodeType) -> list[Any]: """ Get nodes by type. - + Args: node_type: Type of nodes to get - + Returns: List of nodes of the specified type """ return [n for n in self._graph.nodes if get_node_type(n) == node_type] - - def build_subgraph(self, nodes: List[Any]) -> nx.DiGraph: + + def build_subgraph(self, nodes: list[Any]) -> nx.DiGraph: """ Build a subgraph from a list of nodes. - + Args: nodes: List of nodes to include in the subgraph - + Returns: Subgraph containing the specified nodes """ return self._graph.subgraph(nodes) - - def find_paths(self, source: Any, target: Any, cutoff: Optional[int] = None) -> List[List[Any]]: + + def find_paths( + self, source: Any, target: Any, cutoff: int | None = None + ) -> list[list[Any]]: """ Find all paths between two nodes. - + Args: source: Source node target: Target node cutoff: Maximum path length - + Returns: List of paths from source to target """ if source not in self._graph or target not in self._graph: return [] - + try: return list(nx.all_simple_paths(self._graph, source, target, cutoff=cutoff)) except nx.NetworkXError: return [] - - def find_shortest_path(self, source: Any, target: Any) -> Optional[List[Any]]: + + def find_shortest_path(self, source: Any, target: Any) -> list[Any] | None: """ Find the shortest path between two nodes. - + Args: source: Source node target: Target node - + Returns: Shortest path from source to target, or None if no path exists """ if source not in self._graph or target not in self._graph: return None - + try: return nx.shortest_path(self._graph, source, target) except nx.NetworkXNoPath: return None - - def find_cycles(self) -> List[List[Any]]: + + def find_cycles(self) -> list[list[Any]]: """ Find cycles in the graph. - + Returns: List of cycles in the graph """ @@ -388,47 +417,47 @@ def find_cycles(self) -> List[List[Any]]: return list(nx.simple_cycles(self._graph)) except nx.NetworkXNoCycle: return [] - - def get_files(self) -> List[SourceFile]: + + def get_files(self) -> list[SourceFile]: """ Get all files in the codebase. - + Returns: List of files """ return self.get_nodes_by_type(NodeType.FILE) - - def get_functions(self) -> List[Function]: + + def get_functions(self) -> list[Function]: """ Get all functions in the codebase. - + Returns: List of functions """ return self.get_nodes_by_type(NodeType.FUNCTION) - - def get_classes(self) -> List[Class]: + + def get_classes(self) -> list[Class]: """ Get all classes in the codebase. - + Returns: List of classes """ return self.get_nodes_by_type(NodeType.CLASS) - + def export_to_networkx(self) -> nx.DiGraph: """ Export the graph to a NetworkX graph. - + Returns: NetworkX graph representation of the codebase """ return self._graph.copy() - - def export_to_dict(self) -> Dict[str, Any]: + + def export_to_dict(self) -> dict[str, Any]: """ Export the graph to a dictionary. - + Returns: Dictionary representation of the codebase graph """ @@ -438,28 +467,29 @@ def export_to_dict(self) -> Dict[str, Any]: "id": str(id(node)), "type": get_node_type(node).value, } - - if hasattr(node, 'name'): + + if hasattr(node, "name"): node_data["name"] = node.name - - if hasattr(node, 'file') and hasattr(node.file, 'file_path'): + + if hasattr(node, "file") and hasattr(node.file, "file_path"): node_data["file"] = node.file.file_path - + nodes.append(node_data) - + edges = [] for source, target, data in self._graph.edges(data=True): edge_data = { "source": str(id(source)), "target": str(id(target)), } - + if "type" in data: - edge_data["type"] = data["type"].value if isinstance(data["type"], Enum) else str(data["type"]) - + edge_data["type"] = ( + data["type"].value + if isinstance(data["type"], Enum) + else str(data["type"]) + ) + edges.append(edge_data) - - return { - "nodes": nodes, - "edges": edges - } \ No newline at end of file + + return {"nodes": nodes, "edges": edges} diff --git a/codegen-on-oss/codegen_on_oss/analyzers/context/file.py b/codegen-on-oss/codegen_on_oss/analyzers/context/file.py index 191573b95..eeefcdd98 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/context/file.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/context/file.py @@ -6,17 +6,15 @@ including structure, imports, exports, and symbols within a file. """ -import os -import sys import logging -from typing import Dict, List, Set, Tuple, Any, Optional, Union, cast -from pathlib import Path +import sys +from typing import Any try: + from codegen.sdk.core.class_definition import Class from codegen.sdk.core.file import SourceFile - from codegen.sdk.core.symbol import Symbol from codegen.sdk.core.function import Function - from codegen.sdk.core.class_definition import Class + from codegen.sdk.core.symbol import Symbol from codegen.sdk.enums import EdgeType, SymbolType except ImportError: print("Codegen SDK not found. Please install it first.") @@ -25,161 +23,166 @@ # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) + class FileContext: """ Context for file-level analysis. - + This class provides specialized analysis capabilities for a single file, including structure analysis, import/export analysis, and symbol analysis. """ - + def __init__(self, file: SourceFile): """ Initialize the FileContext. - + Args: file: The file to analyze """ self.file = file - self.path = file.file_path if hasattr(file, 'file_path') else str(file) - self.functions = list(file.functions) if hasattr(file, 'functions') else [] - self.classes = list(file.classes) if hasattr(file, 'classes') else [] - self.imports = list(file.imports) if hasattr(file, 'imports') else [] - self.exports = list(file.exports) if hasattr(file, 'exports') else [] - + self.path = file.file_path if hasattr(file, "file_path") else str(file) + self.functions = list(file.functions) if hasattr(file, "functions") else [] + self.classes = list(file.classes) if hasattr(file, "classes") else [] + self.imports = list(file.imports) if hasattr(file, "imports") else [] + self.exports = list(file.exports) if hasattr(file, "exports") else [] + # Collect symbols - self.symbols: List[Symbol] = [] + self.symbols: list[Symbol] = [] self.symbols.extend(self.functions) self.symbols.extend(self.classes) - + # Add symbols from file.symbols if available - if hasattr(file, 'symbols'): + if hasattr(file, "symbols"): for symbol in file.symbols: if symbol not in self.symbols: self.symbols.append(symbol) - - def get_symbol(self, name: str) -> Optional[Symbol]: + + def get_symbol(self, name: str) -> Symbol | None: """ Get a symbol by name. - + Args: name: Name of the symbol to get - + Returns: The symbol, or None if not found """ for symbol in self.symbols: - if hasattr(symbol, 'name') and symbol.name == name: + if hasattr(symbol, "name") and symbol.name == name: return symbol return None - - def get_function(self, name: str) -> Optional[Function]: + + def get_function(self, name: str) -> Function | None: """ Get a function by name. - + Args: name: Name of the function to get - + Returns: The function, or None if not found """ for func in self.functions: - if hasattr(func, 'name') and func.name == name: + if hasattr(func, "name") and func.name == name: return func return None - - def get_class(self, name: str) -> Optional[Class]: + + def get_class(self, name: str) -> Class | None: """ Get a class by name. - + Args: name: Name of the class to get - + Returns: The class, or None if not found """ for cls in self.classes: - if hasattr(cls, 'name') and cls.name == name: + if hasattr(cls, "name") and cls.name == name: return cls return None - - def get_import(self, name: str) -> Optional[Any]: + + def get_import(self, name: str) -> Any | None: """ Get an import by name. - + Args: name: Name of the import to get - + Returns: The import, or None if not found """ for imp in self.imports: - if hasattr(imp, 'name') and imp.name == name: + if hasattr(imp, "name") and imp.name == name: return imp return None - - def get_export(self, name: str) -> Optional[Any]: + + def get_export(self, name: str) -> Any | None: """ Get an export by name. - + Args: name: Name of the export to get - + Returns: The export, or None if not found """ for exp in self.exports: - if hasattr(exp, 'name') and exp.name == name: + if hasattr(exp, "name") and exp.name == name: return exp return None - - def get_symbols_by_type(self, symbol_type: SymbolType) -> List[Symbol]: + + def get_symbols_by_type(self, symbol_type: SymbolType) -> list[Symbol]: """ Get symbols by type. - + Args: symbol_type: Type of symbols to get - + Returns: List of symbols of the specified type """ - return [s for s in self.symbols if hasattr(s, 'symbol_type') and s.symbol_type == symbol_type] - - def get_imported_modules(self) -> List[str]: + return [ + s + for s in self.symbols + if hasattr(s, "symbol_type") and s.symbol_type == symbol_type + ] + + def get_imported_modules(self) -> list[str]: """ Get imported module names. - + Returns: List of imported module names """ modules = [] for imp in self.imports: - if hasattr(imp, 'module_name'): + if hasattr(imp, "module_name"): modules.append(imp.module_name) return modules - - def get_exported_symbols(self) -> List[str]: + + def get_exported_symbols(self) -> list[str]: """ Get exported symbol names. - + Returns: List of exported symbol names """ symbols = [] for exp in self.exports: - if hasattr(exp, 'name'): + if hasattr(exp, "name"): symbols.append(exp.name) return symbols - - def analyze_complexity(self) -> Dict[str, Any]: + + def analyze_complexity(self) -> dict[str, Any]: """ Analyze code complexity in the file. - + Returns: Dictionary containing complexity metrics """ @@ -187,70 +190,74 @@ def analyze_complexity(self) -> Dict[str, Any]: "functions": {}, "average_complexity": 0, "max_complexity": 0, - "total_complexity": 0 + "total_complexity": 0, } - + total_complexity = 0 max_complexity = 0 function_count = 0 - + for func in self.functions: # Calculate cyclomatic complexity complexity = self._calculate_cyclomatic_complexity(func) - + # Update metrics total_complexity += complexity max_complexity = max(max_complexity, complexity) function_count += 1 - + # Add function metrics - func_name = func.name if hasattr(func, 'name') else str(func) + func_name = func.name if hasattr(func, "name") else str(func) result["functions"][func_name] = { "complexity": complexity, - "line_count": len(func.source.split('\n')) if hasattr(func, 'source') else 0 + "line_count": len(func.source.split("\n")) + if hasattr(func, "source") + else 0, } - + # Update summary metrics - result["average_complexity"] = total_complexity / function_count if function_count > 0 else 0 + result["average_complexity"] = ( + total_complexity / function_count if function_count > 0 else 0 + ) result["max_complexity"] = max_complexity result["total_complexity"] = total_complexity - + return result - + def _calculate_cyclomatic_complexity(self, function) -> int: """ Calculate cyclomatic complexity for a function. - + Args: function: Function to analyze - + Returns: Cyclomatic complexity score """ complexity = 1 # Base complexity - - if not hasattr(function, 'source'): + + if not hasattr(function, "source"): return complexity - + source = function.source - + # Count branching statements - complexity += source.count('if ') - complexity += source.count('elif ') - complexity += source.count('for ') - complexity += source.count('while ') - complexity += source.count('except:') - complexity += source.count('except ') - complexity += source.count(' and ') - complexity += source.count(' or ') - complexity += source.count('case ') - + complexity += source.count("if ") + complexity += source.count("elif ") + complexity += source.count("for ") + complexity += source.count("while ") + complexity += source.count("except:") + complexity += source.count("except ") + complexity += source.count(" and ") + complexity += source.count(" or ") + complexity += source.count("case ") + return complexity - - def analyze_imports(self) -> Dict[str, Any]: + + def analyze_imports(self) -> dict[str, Any]: """ Analyze imports in the file. - + Returns: Dictionary containing import analysis """ @@ -260,43 +267,42 @@ def analyze_imports(self) -> Dict[str, Any]: "unresolved_imports": 0, "external_imports": 0, "internal_imports": 0, - "import_details": [] + "import_details": [], } - + for imp in self.imports: import_info = { - "name": imp.name if hasattr(imp, 'name') else str(imp), - "module": imp.module_name if hasattr(imp, 'module_name') else "unknown", + "name": imp.name if hasattr(imp, "name") else str(imp), + "module": imp.module_name if hasattr(imp, "module_name") else "unknown", "is_resolved": False, - "is_external": False + "is_external": False, } - + # Check if import is resolved - if hasattr(imp, 'resolved_file') and imp.resolved_file: - import_info["is_resolved"] = True - result["resolved_imports"] += 1 - elif hasattr(imp, 'resolved_symbol') and imp.resolved_symbol: + if (hasattr(imp, "resolved_file") and imp.resolved_file) or ( + hasattr(imp, "resolved_symbol") and imp.resolved_symbol + ): import_info["is_resolved"] = True result["resolved_imports"] += 1 else: result["unresolved_imports"] += 1 - + # Check if import is external - if hasattr(imp, 'is_external'): + if hasattr(imp, "is_external"): import_info["is_external"] = imp.is_external if imp.is_external: result["external_imports"] += 1 else: result["internal_imports"] += 1 - + result["import_details"].append(import_info) - + return result - - def analyze_structure(self) -> Dict[str, Any]: + + def analyze_structure(self) -> dict[str, Any]: """ Analyze file structure. - + Returns: Dictionary containing structure analysis """ @@ -306,27 +312,38 @@ def analyze_structure(self) -> Dict[str, Any]: "function_count": len(self.functions), "class_count": len(self.classes), "import_count": len(self.imports), - "export_count": len(self.exports) + "export_count": len(self.exports), } - + # Count lines of code - if hasattr(self.file, 'content'): - result["line_count"] = len(self.file.content.split('\n')) - + if hasattr(self.file, "content"): + result["line_count"] = len(self.file.content.split("\n")) + return result - - def to_dict(self) -> Dict[str, Any]: + + def to_dict(self) -> dict[str, Any]: """ Convert the file context to a dictionary. - + Returns: Dictionary representation of the file context """ return { "path": self.path, - "functions": [func.name if hasattr(func, 'name') else str(func) for func in self.functions], - "classes": [cls.name if hasattr(cls, 'name') else str(cls) for cls in self.classes], - "imports": [imp.name if hasattr(imp, 'name') else str(imp) for imp in self.imports], - "exports": [exp.name if hasattr(exp, 'name') else str(exp) for exp in self.exports], - "symbols": [sym.name if hasattr(sym, 'name') else str(sym) for sym in self.symbols] - } \ No newline at end of file + "functions": [ + func.name if hasattr(func, "name") else str(func) + for func in self.functions + ], + "classes": [ + cls.name if hasattr(cls, "name") else str(cls) for cls in self.classes + ], + "imports": [ + imp.name if hasattr(imp, "name") else str(imp) for imp in self.imports + ], + "exports": [ + exp.name if hasattr(exp, "name") else str(exp) for exp in self.exports + ], + "symbols": [ + sym.name if hasattr(sym, "name") else str(sym) for sym in self.symbols + ], + } diff --git a/codegen-on-oss/codegen_on_oss/analyzers/context/function.py b/codegen-on-oss/codegen_on_oss/analyzers/context/function.py index 26b453f62..2a74e2a28 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/context/function.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/context/function.py @@ -6,12 +6,9 @@ including parameters, return types, complexity, and call relationships. """ -import os -import sys import logging -import re -from typing import Dict, List, Set, Tuple, Any, Optional, Union, cast -from pathlib import Path +import sys +from typing import Any try: from codegen.sdk.core.function import Function @@ -24,89 +21,102 @@ # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) + class FunctionContext: """ Context for function-level analysis. - + This class provides specialized analysis capabilities for a single function, including parameter analysis, return type analysis, complexity analysis, and call relationship analysis. """ - + def __init__(self, function: Function): """ Initialize the FunctionContext. - + Args: function: The function to analyze """ self.function = function - self.name = function.name if hasattr(function, 'name') else str(function) - self.file = function.file if hasattr(function, 'file') else None - self.file_path = function.file.file_path if hasattr(function, 'file') and hasattr(function.file, 'file_path') else "unknown" - self.line = function.line if hasattr(function, 'line') else None - self.parameters = list(function.parameters) if hasattr(function, 'parameters') else [] - self.return_type = function.return_type if hasattr(function, 'return_type') else None - self.is_async = function.is_async if hasattr(function, 'is_async') else False - self.source = function.source if hasattr(function, 'source') else "" - self.call_sites = list(function.call_sites) if hasattr(function, 'call_sites') else [] + self.name = function.name if hasattr(function, "name") else str(function) + self.file = function.file if hasattr(function, "file") else None + self.file_path = ( + function.file.file_path + if hasattr(function, "file") and hasattr(function.file, "file_path") + else "unknown" + ) + self.line = function.line if hasattr(function, "line") else None + self.parameters = ( + list(function.parameters) if hasattr(function, "parameters") else [] + ) + self.return_type = ( + function.return_type if hasattr(function, "return_type") else None + ) + self.is_async = function.is_async if hasattr(function, "is_async") else False + self.source = function.source if hasattr(function, "source") else "" + self.call_sites = ( + list(function.call_sites) if hasattr(function, "call_sites") else [] + ) self.locals = [] - + # Extract local variables if available - if hasattr(function, 'code_block') and hasattr(function.code_block, 'local_var_assignments'): + if hasattr(function, "code_block") and hasattr( + function.code_block, "local_var_assignments" + ): self.locals = list(function.code_block.local_var_assignments) - - def get_parameter(self, name: str) -> Optional[Any]: + + def get_parameter(self, name: str) -> Any | None: """ Get a parameter by name. - + Args: name: Name of the parameter to get - + Returns: The parameter, or None if not found """ for param in self.parameters: - if hasattr(param, 'name') and param.name == name: + if hasattr(param, "name") and param.name == name: return param return None - - def get_parameter_types(self) -> Dict[str, Any]: + + def get_parameter_types(self) -> dict[str, Any]: """ Get parameter types. - + Returns: Dictionary mapping parameter names to types """ result = {} for param in self.parameters: - if hasattr(param, 'name'): - param_type = param.type if hasattr(param, 'type') else None + if hasattr(param, "name"): + param_type = param.type if hasattr(param, "type") else None result[param.name] = str(param_type) if param_type else None return result - - def get_called_functions(self) -> List[Any]: + + def get_called_functions(self) -> list[Any]: """ Get functions called by this function. - + Returns: List of called functions """ result = [] for call_site in self.call_sites: - if hasattr(call_site, 'called_function'): + if hasattr(call_site, "called_function"): result.append(call_site.called_function) return result - - def analyze_complexity(self) -> Dict[str, Any]: + + def analyze_complexity(self) -> dict[str, Any]: """ Analyze function complexity. - + Returns: Dictionary containing complexity metrics """ @@ -115,65 +125,65 @@ def analyze_complexity(self) -> Dict[str, Any]: "file": self.file_path, "line": self.line, "cyclomatic_complexity": self._calculate_cyclomatic_complexity(), - "line_count": len(self.source.split('\n')) if self.source else 0, + "line_count": len(self.source.split("\n")) if self.source else 0, "parameter_count": len(self.parameters), - "nesting_depth": self._calculate_nesting_depth() + "nesting_depth": self._calculate_nesting_depth(), } - + return result - + def _calculate_cyclomatic_complexity(self) -> int: """ Calculate cyclomatic complexity of the function. - + Returns: Cyclomatic complexity score """ if not self.source: return 1 - + complexity = 1 # Base complexity - + # Count branching statements - complexity += self.source.count('if ') - complexity += self.source.count('elif ') - complexity += self.source.count('for ') - complexity += self.source.count('while ') - complexity += self.source.count('except:') - complexity += self.source.count('except ') - complexity += self.source.count(' and ') - complexity += self.source.count(' or ') - complexity += self.source.count('case ') - + complexity += self.source.count("if ") + complexity += self.source.count("elif ") + complexity += self.source.count("for ") + complexity += self.source.count("while ") + complexity += self.source.count("except:") + complexity += self.source.count("except ") + complexity += self.source.count(" and ") + complexity += self.source.count(" or ") + complexity += self.source.count("case ") + return complexity - + def _calculate_nesting_depth(self) -> int: """ Calculate the maximum nesting depth of the function. - + Returns: Maximum nesting depth """ if not self.source: return 0 - - lines = self.source.split('\n') + + lines = self.source.split("\n") max_indent = 0 - + for line in lines: if line.strip(): # Skip empty lines indent = len(line) - len(line.lstrip()) max_indent = max(max_indent, indent) - + # Estimate nesting depth (rough approximation) est_nesting_depth = max_indent // 4 # Assuming 4 spaces per indent level - + return est_nesting_depth - - def analyze_parameters(self) -> Dict[str, Any]: + + def analyze_parameters(self) -> dict[str, Any]: """ Analyze function parameters. - + Returns: Dictionary containing parameter analysis """ @@ -182,55 +192,59 @@ def analyze_parameters(self) -> Dict[str, Any]: "typed_parameters": 0, "untyped_parameters": 0, "default_parameters": 0, - "parameter_details": [] + "parameter_details": [], } - + for param in self.parameters: param_info = { - "name": param.name if hasattr(param, 'name') else str(param), - "type": str(param.type) if hasattr(param, 'type') and param.type else None, - "has_default": param.has_default if hasattr(param, 'has_default') else False, - "position": param.position if hasattr(param, 'position') else None + "name": param.name if hasattr(param, "name") else str(param), + "type": str(param.type) + if hasattr(param, "type") and param.type + else None, + "has_default": param.has_default + if hasattr(param, "has_default") + else False, + "position": param.position if hasattr(param, "position") else None, } - + # Update counts if param_info["type"]: result["typed_parameters"] += 1 else: result["untyped_parameters"] += 1 - + if param_info["has_default"]: result["default_parameters"] += 1 - + result["parameter_details"].append(param_info) - + return result - - def analyze_return_type(self) -> Dict[str, Any]: + + def analyze_return_type(self) -> dict[str, Any]: """ Analyze function return type. - + Returns: Dictionary containing return type analysis """ return { "has_return_type": self.return_type is not None, "return_type": str(self.return_type) if self.return_type else None, - "return_type_category": self._categorize_return_type() + "return_type_category": self._categorize_return_type(), } - + def _categorize_return_type(self) -> str: """ Categorize the return type. - + Returns: Category of the return type """ if not self.return_type: return "untyped" - + type_str = str(self.return_type).lower() - + if "none" in type_str: return "none" elif "bool" in type_str: @@ -251,73 +265,84 @@ def _categorize_return_type(self) -> str: return "callable" else: return "complex" - - def analyze_call_sites(self) -> Dict[str, Any]: + + def analyze_call_sites(self) -> dict[str, Any]: """ Analyze function call sites. - + Returns: Dictionary containing call site analysis """ result = { "total_call_sites": len(self.call_sites), "calls_by_function": {}, - "calls_by_file": {} + "calls_by_file": {}, } - + for call_site in self.call_sites: # Get called function called_function = None - if hasattr(call_site, 'called_function'): + if hasattr(call_site, "called_function"): called_function = call_site.called_function - + # Skip if no called function if not called_function: continue - + # Get function name - func_name = called_function.name if hasattr(called_function, 'name') else str(called_function) - + func_name = ( + called_function.name + if hasattr(called_function, "name") + else str(called_function) + ) + # Update calls by function if func_name not in result["calls_by_function"]: result["calls_by_function"][func_name] = 0 result["calls_by_function"][func_name] += 1 - + # Get file file_path = "unknown" - if hasattr(call_site, 'file') and hasattr(call_site.file, 'file_path'): + if hasattr(call_site, "file") and hasattr(call_site.file, "file_path"): file_path = call_site.file.file_path - + # Update calls by file if file_path not in result["calls_by_file"]: result["calls_by_file"][file_path] = 0 result["calls_by_file"][file_path] += 1 - + return result - - def analyze_usage_patterns(self) -> Dict[str, Any]: + + def analyze_usage_patterns(self) -> dict[str, Any]: """ Analyze function usage patterns. - + Returns: Dictionary containing usage pattern analysis """ result = { "uses_async_await": self.is_async or "await " in self.source, - "uses_exceptions": "try:" in self.source or "except:" in self.source or "except " in self.source, + "uses_exceptions": "try:" in self.source + or "except:" in self.source + or "except " in self.source, "uses_loops": "for " in self.source or "while " in self.source, - "uses_conditionals": "if " in self.source or "elif " in self.source or "else:" in self.source, - "uses_comprehensions": "[" in self.source and "for" in self.source and "]" in self.source, + "uses_conditionals": "if " in self.source + or "elif " in self.source + or "else:" in self.source, + "uses_comprehensions": "[" in self.source + and "for" in self.source + and "]" in self.source, "uses_generators": "yield " in self.source, - "uses_decorators": hasattr(self.function, 'decorators') and bool(self.function.decorators) + "uses_decorators": hasattr(self.function, "decorators") + and bool(self.function.decorators), } - + return result - - def to_dict(self) -> Dict[str, Any]: + + def to_dict(self) -> dict[str, Any]: """ Convert the function context to a dictionary. - + Returns: Dictionary representation of the function context """ @@ -326,11 +351,17 @@ def to_dict(self) -> Dict[str, Any]: "file_path": self.file_path, "line": self.line, "is_async": self.is_async, - "parameters": [param.name if hasattr(param, 'name') else str(param) for param in self.parameters], + "parameters": [ + param.name if hasattr(param, "name") else str(param) + for param in self.parameters + ], "return_type": str(self.return_type) if self.return_type else None, "complexity": self._calculate_cyclomatic_complexity(), - "line_count": len(self.source.split('\n')) if self.source else 0, + "line_count": len(self.source.split("\n")) if self.source else 0, "nesting_depth": self._calculate_nesting_depth(), - "local_variables": [local.name if hasattr(local, 'name') else str(local) for local in self.locals], - "call_sites_count": len(self.call_sites) - } \ No newline at end of file + "local_variables": [ + local.name if hasattr(local, "name") else str(local) + for local in self.locals + ], + "call_sites_count": len(self.call_sites), + } diff --git a/codegen-on-oss/codegen_on_oss/analyzers/context/graph/__init__.py b/codegen-on-oss/codegen_on_oss/analyzers/context/graph/__init__.py index 99d6cc83f..979afe76f 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/context/graph/__init__.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/context/graph/__init__.py @@ -5,40 +5,42 @@ of code, including building, traversing, exporting, and visualizing graphs. """ -from typing import Dict, List, Any, Optional +from typing import Any, Dict, List, Optional import networkx as nx -def build_dependency_graph(edges: List[Dict[str, Any]]) -> nx.DiGraph: + +def build_dependency_graph(edges: list[dict[str, Any]]) -> nx.DiGraph: """ Build a dependency graph from a list of edges. - + Args: edges: List of edges, where each edge is a dictionary with 'source', 'target', and optional 'type' keys - + Returns: NetworkX DiGraph representing the dependencies """ graph = nx.DiGraph() - + for edge in edges: - source = edge.get('source') - target = edge.get('target') - edge_type = edge.get('type', 'unknown') - + source = edge.get("source") + target = edge.get("target") + edge_type = edge.get("type", "unknown") + if source and target: graph.add_edge(source, target, type=edge_type) - + return graph -def find_circular_dependencies(graph: nx.DiGraph) -> List[List[str]]: + +def find_circular_dependencies(graph: nx.DiGraph) -> list[list[str]]: """ Find circular dependencies in a graph. - + Args: graph: NetworkX DiGraph to analyze - + Returns: List of cycles, where each cycle is a list of node names """ @@ -47,38 +49,42 @@ def find_circular_dependencies(graph: nx.DiGraph) -> List[List[str]]: except nx.NetworkXNoCycle: return [] -def find_hub_nodes(graph: nx.DiGraph, threshold: int = 5) -> List[str]: + +def find_hub_nodes(graph: nx.DiGraph, threshold: int = 5) -> list[str]: """ Find hub nodes in a graph (nodes with many connections). - + Args: graph: NetworkX DiGraph to analyze threshold: Minimum number of connections to be considered a hub - + Returns: List of hub node names """ hubs = [] - + for node in graph.nodes(): # Count both incoming and outgoing connections connection_count = graph.in_degree(node) + graph.out_degree(node) - + if connection_count >= threshold: hubs.append(node) - + # Sort by connection count in descending order - hubs.sort(key=lambda node: graph.in_degree(node) + graph.out_degree(node), reverse=True) - + hubs.sort( + key=lambda node: graph.in_degree(node) + graph.out_degree(node), reverse=True + ) + return hubs -def calculate_centrality(graph: nx.DiGraph) -> Dict[str, float]: + +def calculate_centrality(graph: nx.DiGraph) -> dict[str, float]: """ Calculate centrality for each node in the graph. - + Args: graph: NetworkX DiGraph to analyze - + Returns: Dictionary mapping node names to centrality scores """ @@ -88,14 +94,15 @@ def calculate_centrality(graph: nx.DiGraph) -> Dict[str, float]: # Fall back to degree centrality if betweenness fails return nx.degree_centrality(graph) -def export_to_dot(graph: nx.DiGraph, filename: Optional[str] = None) -> str: + +def export_to_dot(graph: nx.DiGraph, filename: str | None = None) -> str: """ Export a graph to DOT format. - + Args: graph: NetworkX DiGraph to export filename: File to write DOT to, or None to return as string - + Returns: DOT representation of the graph if filename is None, otherwise returns empty string @@ -103,7 +110,7 @@ def export_to_dot(graph: nx.DiGraph, filename: Optional[str] = None) -> str: try: import pydot from networkx.drawing.nx_pydot import write_dot - + if filename: write_dot(graph, filename) return "" @@ -111,69 +118,72 @@ def export_to_dot(graph: nx.DiGraph, filename: Optional[str] = None) -> str: # Convert to pydot pydot_graph = nx.nx_pydot.to_pydot(graph) return pydot_graph.to_string() - + except ImportError: # Fallback to basic DOT export if pydot is not available dot = ["digraph G {"] - + # Add nodes for node in graph.nodes(): dot.append(f' "{node}";') - + # Add edges for u, v, data in graph.edges(data=True): - edge_type = data.get('type', '') + edge_type = data.get("type", "") edge_str = f' "{u}" -> "{v}"' - + if edge_type: edge_str += f' [label="{edge_type}"]' - - edge_str += ';' + + edge_str += ";" dot.append(edge_str) - + dot.append("}") dot_str = "\n".join(dot) - + if filename: - with open(filename, 'w') as f: + with open(filename, "w") as f: f.write(dot_str) return "" else: return dot_str -def calculate_cohesion(graph: nx.DiGraph, module_nodes: Dict[str, List[str]]) -> Dict[str, float]: + +def calculate_cohesion( + graph: nx.DiGraph, module_nodes: dict[str, list[str]] +) -> dict[str, float]: """ Calculate cohesion for modules in the graph. - + Args: graph: NetworkX DiGraph to analyze module_nodes: Dictionary mapping module names to lists of node names - + Returns: Dictionary mapping module names to cohesion scores """ cohesion = {} - + for module, nodes in module_nodes.items(): if not nodes: cohesion[module] = 0.0 continue - + # Create subgraph for this module module_subgraph = graph.subgraph(nodes) - + # Count internal edges internal_edges = module_subgraph.number_of_edges() - + # Count external edges external_edges = 0 for node in nodes: for _, target in graph.out_edges(node): if target not in nodes: external_edges += 1 - + # Calculate cohesion as ratio of internal to total edges total_edges = internal_edges + external_edges cohesion[module] = internal_edges / total_edges if total_edges > 0 else 0.0 - - return cohesion \ No newline at end of file + + return cohesion diff --git a/codegen-on-oss/codegen_on_oss/analyzers/context_codebase.py b/codegen-on-oss/codegen_on_oss/analyzers/context_codebase.py index 935752aa0..f43b89ccc 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/context_codebase.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/context_codebase.py @@ -7,34 +7,35 @@ and codebase vs. PR comparisons. """ +import logging import os -import sys -import tempfile -import shutil import re -import logging +import sys +from collections.abc import Callable from pathlib import Path -from typing import Dict, List, Set, Tuple, Any, Optional, Union, TypeVar, cast, Callable -from enum import Enum +from typing import Any + import networkx as nx try: - from codegen.sdk.core.codebase import Codebase - from codegen.sdk.codebase.codebase_context import CodebaseContext as SDKCodebaseContext from codegen.configs.models.codebase import CodebaseConfig from codegen.configs.models.secrets import SecretsConfig - from codegen.sdk.codebase.config import ProjectConfig - from codegen.git.schemas.repo_config import RepoConfig from codegen.git.repo_operator.repo_operator import RepoOperator - from codegen.shared.enums.programming_language import ProgrammingLanguage - from codegen.sdk.core.file import SourceFile + from codegen.git.schemas.repo_config import RepoConfig + from codegen.sdk.codebase.codebase_context import ( + CodebaseContext as SDKCodebaseContext, + ) + from codegen.sdk.codebase.config import ProjectConfig + from codegen.sdk.codebase.transaction_manager import TransactionManager + from codegen.sdk.codebase.transactions import Transaction + from codegen.sdk.core.class_definition import Class + from codegen.sdk.core.codebase import Codebase from codegen.sdk.core.directory import Directory - from codegen.sdk.core.symbol import Symbol + from codegen.sdk.core.file import SourceFile from codegen.sdk.core.function import Function - from codegen.sdk.core.class_definition import Class + from codegen.sdk.core.symbol import Symbol from codegen.sdk.enums import EdgeType, SymbolType - from codegen.sdk.codebase.transactions import Transaction - from codegen.sdk.codebase.transaction_manager import TransactionManager + from codegen.shared.enums.programming_language import ProgrammingLanguage except ImportError: print("Codegen SDK not found. Please install it first.") sys.exit(1) @@ -42,8 +43,8 @@ # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) @@ -72,29 +73,31 @@ "*.log", ] + def get_node_classes(): """Return a tuple of classes that represent nodes in the codebase graph.""" return (Symbol, Function, Class, Directory, SourceFile) + class CodebaseContext: """ Enhanced context for codebase analysis, providing graph manipulation and codebase comparison capabilities. - + This class extends the functionality of the SDK's CodebaseContext with additional methods for PR analysis and codebase comparison. """ - + def __init__( self, codebase: Codebase, - base_path: Optional[str] = None, - pr_branch: Optional[str] = None, + base_path: str | None = None, + pr_branch: str | None = None, base_branch: str = "main", ): """ Initialize the CodebaseContext. - + Args: codebase: Codebase instance to analyze base_path: Base path of the codebase @@ -105,35 +108,35 @@ def __init__( self.base_path = base_path or "" self.pr_branch = pr_branch self.base_branch = base_branch - + # Graph for storing codebase structure self._graph = nx.DiGraph() - + # Transaction management self.transaction_manager = TransactionManager() - + # Cache for nodes and files self._node_cache = {} self._file_cache = {} self._directory_cache = {} - + # Initialize the graph self.build_graph() - + def __repr__(self) -> str: """String representation of the CodebaseContext.""" return f"CodebaseContext(nodes={len(self.nodes)}, edges={len(self.edges)}, files={len(self._file_cache)})" - + @property def _graph(self) -> nx.DiGraph: """Get the graph.""" return self.__graph - + @_graph.setter def _graph(self, graph: nx.DiGraph) -> None: """Set the graph.""" self.__graph = graph - + def build_graph(self) -> None: """Build the codebase graph.""" # Clear existing graph and caches @@ -141,191 +144,189 @@ def build_graph(self) -> None: self._node_cache = {} self._file_cache = {} self._directory_cache = {} - + # Add files to the graph for file in self.codebase.files: if any(re.match(pattern, file.path) for pattern in GLOBAL_FILE_IGNORE_LIST): continue - + self.add_node(file) - + # Cache file for faster access self._file_cache[str(file.path)] = file - + # Add symbols to the graph for symbol in self.codebase.symbols: self.add_node(symbol) - + # Connect symbol to its file - if hasattr(symbol, 'file') and symbol.file: + if hasattr(symbol, "file") and symbol.file: self.add_edge(symbol.file, symbol, EdgeType.CONTAINS) - + # Connect class members to their class - if hasattr(symbol, 'parent') and symbol.parent: + if hasattr(symbol, "parent") and symbol.parent: self.add_edge(symbol.parent, symbol, EdgeType.CONTAINS) - + # Build directory tree self.build_directory_tree() - + # Compute dependencies self._compute_dependencies() - - def apply_diffs(self, diffs: Dict[str, Any]) -> None: + + def apply_diffs(self, diffs: dict[str, Any]) -> None: """ Apply diffs to the codebase. - + Args: diffs: Dictionary of file paths to diff content """ for file_path, diff in diffs.items(): # Process each file's diff self._process_diff_files({file_path: diff}) - + # Rebuild the graph with the applied diffs self.build_graph() - + def _reset_files(self) -> None: """Reset any modified files to their original state.""" # Clear file cache self._file_cache = {} - + # Re-populate cache from codebase for file in self.codebase.files: self._file_cache[str(file.path)] = file - + def reset_codebase(self) -> None: """Reset the codebase to its original state.""" # Reset files self._reset_files() - + # Rebuild the graph self.build_graph() - + def undo_applied_diffs(self) -> None: """Undo all applied diffs.""" self._revert_diffs() self.build_graph() - + def _revert_diffs(self) -> None: """Revert any applied diffs.""" # Use transaction manager to revert all transactions self.transaction_manager.revert_all() - + # Reset files self._reset_files() - + def save_commit(self, message: str) -> str: """ Save changes as a commit. - + Args: message: Commit message - + Returns: Commit hash """ # Use repo operator to commit changes - if hasattr(self.codebase, 'repo_operator'): + if hasattr(self.codebase, "repo_operator"): return self.codebase.repo_operator.commit(message) return "" - + def prune_graph(self) -> None: """Remove any nodes that no longer exist in the codebase.""" nodes_to_remove = [] - + for node in self.nodes: - if hasattr(node, 'path'): + if hasattr(node, "path"): path = str(node.path) - + # Check if file still exists - if isinstance(node, SourceFile) and path not in self._file_cache: - nodes_to_remove.append(node) - - # Check if directory still exists - elif isinstance(node, Directory) and path not in self._directory_cache: + if (isinstance(node, SourceFile) and path not in self._file_cache) or ( + isinstance(node, Directory) and path not in self._directory_cache + ): nodes_to_remove.append(node) - + # Check if symbol's file still exists - elif hasattr(node, 'file') and node.file: + elif hasattr(node, "file") and node.file: file_path = str(node.file.path) if file_path not in self._file_cache: nodes_to_remove.append(node) - + # Remove nodes for node in nodes_to_remove: self.remove_node(node) - + def build_directory_tree(self) -> None: """Build the directory tree from the files.""" directories = {} - + for file in self._file_cache.values(): path = file.path parent_dir = path.parent - + # Create directory nodes current_dir = parent_dir while str(current_dir) != ".": dir_path = str(current_dir) - + if dir_path not in directories: dir_node = Directory(current_dir) directories[dir_path] = dir_node self.add_node(dir_node) self._directory_cache[dir_path] = dir_node - + # Connect to parent directory parent_path = str(current_dir.parent) if parent_path != "." and parent_path in directories: parent_node = directories[parent_path] self.add_edge(parent_node, dir_node, EdgeType.CONTAINS) - + # Connect file to directory if str(current_dir) == str(parent_dir): self.add_edge(directories[dir_path], file, EdgeType.CONTAINS) - + current_dir = current_dir.parent if str(current_dir) == ".": break - - def get_directory(self, path: Union[str, Path]) -> Optional[Directory]: + + def get_directory(self, path: str | Path) -> Directory | None: """ Get a directory node from the graph. - + Args: path: Directory path - + Returns: Directory node or None if not found """ path_str = str(path) - + # Check cache first if path_str in self._directory_cache: return self._directory_cache[path_str] - + # Search for the directory in the graph for node in self.nodes: if isinstance(node, Directory) and str(node.path) == path_str: self._directory_cache[path_str] = node return node - + return None - - def _process_diff_files(self, diff_files: Dict[str, Any]) -> None: + + def _process_diff_files(self, diff_files: dict[str, Any]) -> None: """ Process diff files and apply changes to the codebase. - + Args: diff_files: Dictionary mapping file paths to diff content """ for file_path, diff_content in diff_files.items(): file = self.get_file(file_path) - + if file: # Create a transaction for this change transaction = Transaction(file, diff_content) - + # Apply the transaction self.transaction_manager.apply(transaction) else: @@ -333,233 +334,245 @@ def _process_diff_files(self, diff_files: Dict[str, Any]) -> None: if isinstance(diff_content, str): # Create new file new_file = self.add_single_file(file_path, diff_content) - + if new_file: # Add to cache self._file_cache[file_path] = new_file - + def _compute_dependencies(self) -> None: """Compute dependencies between symbols.""" # Process imports to create dependency edges for file in self._file_cache.values(): - if hasattr(file, 'imports'): + if hasattr(file, "imports"): for import_item in file.imports: imported_symbol = None - + # Try to resolve the import - if hasattr(import_item, 'resolved_symbol') and import_item.resolved_symbol: + if ( + hasattr(import_item, "resolved_symbol") + and import_item.resolved_symbol + ): imported_symbol = import_item.resolved_symbol - elif hasattr(import_item, 'name'): + elif hasattr(import_item, "name"): # Try to find the symbol by name for symbol in self.codebase.symbols: - if hasattr(symbol, 'name') and symbol.name == import_item.name: + if ( + hasattr(symbol, "name") + and symbol.name == import_item.name + ): imported_symbol = symbol break - + if imported_symbol: # Create dependency edge self.add_edge(file, imported_symbol, EdgeType.IMPORTS) - + # Process function calls to create call edges for func in self.codebase.functions: - if hasattr(func, 'calls'): + if hasattr(func, "calls"): for call in func.calls: called_func = None - + # Try to resolve the call - if hasattr(call, 'resolved_symbol') and call.resolved_symbol: + if hasattr(call, "resolved_symbol") and call.resolved_symbol: called_func = call.resolved_symbol - elif hasattr(call, 'name'): + elif hasattr(call, "name"): # Try to find the function by name for other_func in self.codebase.functions: - if hasattr(other_func, 'name') and other_func.name == call.name: + if ( + hasattr(other_func, "name") + and other_func.name == call.name + ): called_func = other_func break - + if called_func: # Create call edge self.add_edge(func, called_func, EdgeType.CALLS) - - def build_subgraph(self, nodes: List[Any]) -> nx.DiGraph: + + def build_subgraph(self, nodes: list[Any]) -> nx.DiGraph: """ Build a subgraph containing only the specified nodes. - + Args: nodes: List of nodes to include in the subgraph - + Returns: Subgraph as a new DiGraph """ subgraph = nx.DiGraph() - + # Add nodes for node in nodes: if self.has_node(node): subgraph.add_node(node) - + # Add edges for u, v, data in self.edges(data=True): if subgraph.has_node(u) and subgraph.has_node(v): subgraph.add_edge(u, v, **data) - + return subgraph - - def get_node(self, id_or_obj: Any) -> Optional[Any]: + + def get_node(self, id_or_obj: Any) -> Any | None: """ Get a node from the graph by ID or object. - + Args: id_or_obj: Node ID or object - + Returns: Node or None if not found """ if self.has_node(id_or_obj): return id_or_obj - + # Check if it's a string path if isinstance(id_or_obj, str): # Try to find file or directory if id_or_obj in self._file_cache: return self._file_cache[id_or_obj] - + if id_or_obj in self._directory_cache: return self._directory_cache[id_or_obj] - + # Try to find by name for node in self.nodes: - if hasattr(node, 'name') and node.name == id_or_obj: + if hasattr(node, "name") and node.name == id_or_obj: return node - - if hasattr(node, 'path') and str(node.path) == id_or_obj: + + if hasattr(node, "path") and str(node.path) == id_or_obj: return node - + return None - - def get_nodes(self, node_type: Optional[Any] = None) -> List[Any]: + + def get_nodes(self, node_type: Any | None = None) -> list[Any]: """ Get all nodes of a specific type. - + Args: node_type: Type of nodes to return - + Returns: List of nodes """ if node_type is None: return list(self.nodes) - + return [node for node in self.nodes if isinstance(node, node_type)] - - def get_edges(self, edge_type: Optional[Any] = None) -> List[Tuple[Any, Any, Dict[str, Any]]]: + + def get_edges( + self, edge_type: Any | None = None + ) -> list[tuple[Any, Any, dict[str, Any]]]: """ Get all edges of a specific type. - + Args: edge_type: Type of edges to return - + Returns: List of edges as (u, v, data) tuples """ edges = list(self.edges(data=True)) - + if edge_type is None: return edges - + return [ - (u, v, data) for u, v, data in edges - if 'type' in data and data['type'] == edge_type + (u, v, data) + for u, v, data in edges + if "type" in data and data["type"] == edge_type ] - - def get_file(self, path: Union[str, Path]) -> Optional[SourceFile]: + + def get_file(self, path: str | Path) -> SourceFile | None: """ Get a file from the codebase. - + Args: path: File path - + Returns: SourceFile or None if not found """ path_str = str(path) - + # Check cache first if path_str in self._file_cache: return self._file_cache[path_str] - + # Try to get raw file file = self._get_raw_file_from_path(path_str) - + if file: self._file_cache[path_str] = file - + return file - - def _get_raw_file_from_path(self, path: str) -> Optional[SourceFile]: + + def _get_raw_file_from_path(self, path: str) -> SourceFile | None: """ Get a file from the codebase by its path. - + Args: path: File path - + Returns: SourceFile or None if not found """ # Try to get file from codebase - if hasattr(self.codebase, 'get_file'): + if hasattr(self.codebase, "get_file"): return self.codebase.get_file(path) - + # Fallback to searching in files for file in self.codebase.files: if str(file.path) == path: return file - + return None - - def get_external_module(self, name: str) -> Optional[Any]: + + def get_external_module(self, name: str) -> Any | None: """ Get an external module from the codebase. - + Args: name: Module name - + Returns: External module or None if not found """ - if hasattr(self.codebase, 'get_external_module'): + if hasattr(self.codebase, "get_external_module"): return self.codebase.get_external_module(name) - + # Fallback: search through external modules - if hasattr(self.codebase, 'external_modules'): + if hasattr(self.codebase, "external_modules"): for module in self.codebase.external_modules: - if hasattr(module, 'name') and module.name == name: + if hasattr(module, "name") and module.name == name: return module - + return None - + def add_node(self, node: Any) -> None: """ Add a node to the graph. - + Args: node: Node to add """ if not self.has_node(node): self._graph.add_node(node) - + # Add to cache if applicable - if hasattr(node, 'path'): + if hasattr(node, "path"): path_str = str(node.path) - + if isinstance(node, SourceFile): self._file_cache[path_str] = node elif isinstance(node, Directory): self._directory_cache[path_str] = node - - def add_child(self, parent: Any, child: Any, edge_type: Optional[Any] = None) -> None: + + def add_child(self, parent: Any, child: Any, edge_type: Any | None = None) -> None: """ Add a child node to a parent node. - + Args: parent: Parent node child: Child node @@ -567,42 +580,42 @@ def add_child(self, parent: Any, child: Any, edge_type: Optional[Any] = None) -> """ self.add_node(parent) self.add_node(child) - + edge_data = {} if edge_type is not None: - edge_data['type'] = edge_type - + edge_data["type"] = edge_type + self.add_edge(parent, child, edge_type) - + def has_node(self, node: Any) -> bool: """ Check if a node exists in the graph. - + Args: node: Node to check - + Returns: True if the node exists, False otherwise """ return self._graph.has_node(node) - + def has_edge(self, u: Any, v: Any) -> bool: """ Check if an edge exists in the graph. - + Args: u: Source node v: Target node - + Returns: True if the edge exists, False otherwise """ return self._graph.has_edge(u, v) - - def add_edge(self, u: Any, v: Any, edge_type: Optional[Any] = None) -> None: + + def add_edge(self, u: Any, v: Any, edge_type: Any | None = None) -> None: """ Add an edge to the graph. - + Args: u: Source node v: Target node @@ -610,303 +623,303 @@ def add_edge(self, u: Any, v: Any, edge_type: Optional[Any] = None) -> None: """ if not self.has_node(u): self.add_node(u) - + if not self.has_node(v): self.add_node(v) - + edge_data = {} if edge_type is not None: - edge_data['type'] = edge_type - + edge_data["type"] = edge_type + self._graph.add_edge(u, v, **edge_data) - - def add_edges(self, edge_list: List[Tuple[Any, Any, Dict[str, Any]]]) -> None: + + def add_edges(self, edge_list: list[tuple[Any, Any, dict[str, Any]]]) -> None: """ Add multiple edges to the graph. - + Args: edge_list: List of (u, v, data) tuples """ for u, v, data in edge_list: if not self.has_node(u): self.add_node(u) - + if not self.has_node(v): self.add_node(v) - + self._graph.add_edge(u, v, **data) - + @property - def nodes(self) -> List[Any]: + def nodes(self) -> list[Any]: """Get all nodes in the graph.""" return list(self._graph.nodes()) - + @property def edges(self) -> Callable: """Get all edges in the graph.""" return self._graph.edges - - def predecessor(self, node: Any) -> Optional[Any]: + + def predecessor(self, node: Any) -> Any | None: """ Get the predecessor of a node. - + Args: node: Node to get predecessor for - + Returns: Predecessor node or None if not found """ preds = list(self.predecessors(node)) return preds[0] if preds else None - - def predecessors(self, node: Any) -> List[Any]: + + def predecessors(self, node: Any) -> list[Any]: """ Get all predecessors of a node. - + Args: node: Node to get predecessors for - + Returns: List of predecessor nodes """ if not self.has_node(node): return [] - + return list(self._graph.predecessors(node)) - - def successors(self, node: Any) -> List[Any]: + + def successors(self, node: Any) -> list[Any]: """ Get all successors of a node. - + Args: node: Node to get successors for - + Returns: List of successor nodes """ if not self.has_node(node): return [] - + return list(self._graph.successors(node)) - - def get_edge_data(self, u: Any, v: Any) -> Dict[str, Any]: + + def get_edge_data(self, u: Any, v: Any) -> dict[str, Any]: """ Get the data for an edge. - + Args: u: Source node v: Target node - + Returns: Edge data dictionary """ if not self.has_edge(u, v): return {} - + return self._graph.get_edge_data(u, v) - - def in_edges(self, node: Any, data: bool = False) -> List[Any]: + + def in_edges(self, node: Any, data: bool = False) -> list[Any]: """ Get all incoming edges for a node. - + Args: node: Node to get incoming edges for data: Whether to include edge data - + Returns: List of incoming edges """ if not self.has_node(node): return [] - + return list(self._graph.in_edges(node, data=data)) - - def out_edges(self, node: Any, data: bool = False) -> List[Any]: + + def out_edges(self, node: Any, data: bool = False) -> list[Any]: """ Get all outgoing edges for a node. - + Args: node: Node to get outgoing edges for data: Whether to include edge data - + Returns: List of outgoing edges """ if not self.has_node(node): return [] - + return list(self._graph.out_edges(node, data=data)) - + def remove_node(self, node: Any) -> None: """ Remove a node from the graph. - + Args: node: Node to remove """ if self.has_node(node): self._graph.remove_node(node) - + # Remove from cache if applicable - if hasattr(node, 'path'): + if hasattr(node, "path"): path_str = str(node.path) - + if isinstance(node, SourceFile) and path_str in self._file_cache: del self._file_cache[path_str] elif isinstance(node, Directory) and path_str in self._directory_cache: del self._directory_cache[path_str] - + def remove_edge(self, u: Any, v: Any) -> None: """ Remove an edge from the graph. - + Args: u: Source node v: Target node """ if self.has_edge(u, v): self._graph.remove_edge(u, v) - - def to_absolute(self, path: Union[str, Path]) -> str: + + def to_absolute(self, path: str | Path) -> str: """ Convert a relative path to an absolute path. - + Args: path: Relative path - + Returns: Absolute path """ path_str = str(path) - + if os.path.isabs(path_str): return path_str - + return os.path.join(self.base_path, path_str) - - def to_relative(self, path: Union[str, Path]) -> str: + + def to_relative(self, path: str | Path) -> str: """ Convert an absolute path to a relative path. - + Args: path: Absolute path - + Returns: Relative path """ path_str = str(path) - + if not os.path.isabs(path_str): return path_str - + return os.path.relpath(path_str, self.base_path) - - def is_subdir(self, parent: Union[str, Path], child: Union[str, Path]) -> bool: + + def is_subdir(self, parent: str | Path, child: str | Path) -> bool: """ Check if a directory is a subdirectory of another. - + Args: parent: Parent directory child: Child directory - + Returns: True if child is a subdirectory of parent, False otherwise """ parent_str = str(parent) child_str = str(child) - + parent_abs = os.path.abspath(parent_str) child_abs = os.path.abspath(child_str) - + return child_abs.startswith(parent_abs) - + def commit_transactions(self, message: str) -> str: """ Commit all pending transactions. - + Args: message: Commit message - + Returns: Commit hash """ # Apply all transactions and commit self.transaction_manager.apply_all() - + return self.save_commit(message) - - def add_single_file(self, path: str, content: str) -> Optional[SourceFile]: + + def add_single_file(self, path: str, content: str) -> SourceFile | None: """ Add a single file to the codebase. - + Args: path: File path content: File content - + Returns: SourceFile or None if creation failed """ # Add file to the transaction manager transaction = Transaction.create_new_file(path, content) self.transaction_manager.add(transaction) - + # Initialize file in codebase - if hasattr(self.codebase, 'add_file'): + if hasattr(self.codebase, "add_file"): return self.codebase.add_file(path, content) - + return None - + @property def session(self) -> Any: """Get the transaction session.""" return self.transaction_manager.session - - def remove_directory(self, path: Union[str, Path]) -> None: + + def remove_directory(self, path: str | Path) -> None: """ Remove a directory and all its contents from the codebase. - + Args: path: Directory path """ path_str = str(path) dir_node = self.get_directory(path_str) - + if not dir_node: return - + # Get all files in the directory files_to_remove = [] for file in self._file_cache.values(): if self.is_subdir(path_str, file.path): files_to_remove.append(file) - + # Remove files for file in files_to_remove: file_path = str(file.path) - + # Create transaction for removal transaction = Transaction.delete_file(file_path) self.transaction_manager.add(transaction) - + # Remove from cache if file_path in self._file_cache: del self._file_cache[file_path] - + # Remove from graph if self.has_node(file): self.remove_node(file) - + # Remove directory from cache if path_str in self._directory_cache: del self._directory_cache[path_str] - + # Remove directory node from graph if self.has_node(dir_node): self.remove_node(dir_node) - + @property - def ts_declassify(self) -> Optional[Callable]: + def ts_declassify(self) -> Callable | None: """Get TypeScript declassify function if available.""" - if hasattr(self.codebase, 'ts_declassify'): + if hasattr(self.codebase, "ts_declassify"): return self.codebase.ts_declassify - return None \ No newline at end of file + return None diff --git a/codegen-on-oss/codegen_on_oss/analyzers/current_code_codebase.py b/codegen-on-oss/codegen_on_oss/analyzers/current_code_codebase.py index 137081efe..07eb5795f 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/current_code_codebase.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/current_code_codebase.py @@ -1,15 +1,21 @@ import importlib import os from pathlib import Path -from typing import Optional, TypedDict, Union, List +from typing import TypedDict -from codegen.shared.decorators.docs import DocumentedObject, apidoc_objects, no_apidoc_objects, py_apidoc_objects, ts_apidoc_objects -from codegen.sdk.core.codebase import Codebase, CodebaseType -from codegen.sdk.codebase.config import ProjectConfig from codegen.configs.models.codebase import CodebaseConfig from codegen.configs.models.secrets import SecretsConfig from codegen.git.repo_operator.repo_operator import RepoOperator from codegen.git.schemas.repo_config import RepoConfig +from codegen.sdk.codebase.config import ProjectConfig +from codegen.sdk.core.codebase import Codebase, CodebaseType +from codegen.shared.decorators.docs import ( + DocumentedObject, + apidoc_objects, + no_apidoc_objects, + py_apidoc_objects, + ts_apidoc_objects, +) from codegen.shared.enums.programming_language import ProgrammingLanguage from codegen.shared.logging.get_logger import get_logger @@ -35,15 +41,15 @@ def get_base_path(repo_path: str) -> str: def get_selected_codebase( - repo_path: Optional[str] = None, - base_path: Optional[str] = None, - config: Optional[CodebaseConfig] = None, - secrets: Optional[SecretsConfig] = None, - subdirectories: Optional[List[str]] = None, - programming_language: Optional[ProgrammingLanguage] = None + repo_path: str | None = None, + base_path: str | None = None, + config: CodebaseConfig | None = None, + secrets: SecretsConfig | None = None, + subdirectories: list[str] | None = None, + programming_language: ProgrammingLanguage | None = None, ) -> CodebaseType: """Returns a Codebase instance for the selected repository. - + Parameters: repo_path: Path to the repository base_path: Base directory within the repository where code is located @@ -51,42 +57,44 @@ def get_selected_codebase( secrets: SecretsConfig for any credentials needed subdirectories: List of subdirectories to include in the analysis programming_language: Primary programming language of the codebase - + Returns: A Codebase instance initialized with the provided parameters """ if not repo_path: repo_path = get_repo_path() - + if not base_path: base_path = get_base_path(repo_path) - - logger.info(f"Creating codebase from repo at: {repo_path} with base_path {base_path}") - + + logger.info( + f"Creating codebase from repo at: {repo_path} with base_path {base_path}" + ) + # Set up repository config repo_config = RepoConfig.from_repo_path(repo_path) repo_config.respect_gitignore = True # Respect gitignore by default op = RepoOperator(repo_config=repo_config, bot_commit=False) - + # Use provided config or create a new one config = (config or CodebaseConfig()).model_copy(update={"base_path": base_path}) - + # Determine the programming language if not provided if not programming_language: # Default to Python, but try to detect from files programming_language = ProgrammingLanguage.PYTHON # TODO: Add language detection logic if needed - + # Create project config projects = [ ProjectConfig( repo_operator=op, programming_language=programming_language, subdirectories=subdirectories, - base_path=base_path + base_path=base_path, ) ] - + # Create and return codebase codebase = Codebase(projects=projects, config=config, secrets=secrets) return codebase @@ -94,9 +102,9 @@ def get_selected_codebase( def import_modules_from_path(directory_path: str, package_prefix: str = ""): """Imports all Python modules from the given directory path. - + This is used to collect all documented objects from the modules. - + Parameters: directory_path: Path to the directory containing Python modules package_prefix: Prefix to use for module imports (e.g., 'mypackage.') @@ -105,25 +113,28 @@ def import_modules_from_path(directory_path: str, package_prefix: str = ""): if not directory.exists() or not directory.is_dir(): logger.warning(f"Directory does not exist: {directory_path}") return - + for file in directory.rglob("*.py"): if "__init__" in file.name or "braintrust_evaluator" in file.name: continue - + try: # Convert path to module name relative_path = file.relative_to(directory) - module_name = package_prefix + str(relative_path).replace("/", ".").removesuffix(".py") - + module_name = package_prefix + str(relative_path).replace( + "/", "." + ).removesuffix(".py") + # Import the module importlib.import_module(module_name) logger.debug(f"Successfully imported module: {module_name}") except Exception as e: - logger.error(f"Error importing {module_name}: {e}") + logger.exception(f"Error importing {module_name}: {e}") class DocumentedObjects(TypedDict): """Type definition for the documented objects collection.""" + apidoc: list[DocumentedObject] ts_apidoc: list[DocumentedObject] py_apidoc: list[DocumentedObject] @@ -131,71 +142,93 @@ class DocumentedObjects(TypedDict): def get_documented_objects( - repo_path: Optional[str] = None, + repo_path: str | None = None, package_prefix: str = "", - import_paths: Optional[List[str]] = None + import_paths: list[str] | None = None, ) -> DocumentedObjects: """Get all objects decorated with API documentation decorators. - + This function imports modules from the specified paths and collects objects decorated with apidoc, py_apidoc, ts_apidoc, and no_apidoc. - + Parameters: repo_path: Path to the repository root package_prefix: Prefix to use for importing modules import_paths: List of paths to import from - + Returns: A dictionary containing the collected documented objects """ if not repo_path: repo_path = get_repo_path() - + if not import_paths: # Default to importing from common directories base_path = get_base_path(repo_path) import_paths = [ os.path.join(repo_path, base_path), - os.path.join(repo_path, base_path, "codegen") if base_path else os.path.join(repo_path, "codegen"), - os.path.join(repo_path, base_path, "sdk") if base_path else os.path.join(repo_path, "sdk"), + os.path.join(repo_path, base_path, "codegen") + if base_path + else os.path.join(repo_path, "codegen"), + os.path.join(repo_path, base_path, "sdk") + if base_path + else os.path.join(repo_path, "sdk"), ] - + # Import all modules to populate the documented objects lists for path in import_paths: if os.path.exists(path) and os.path.isdir(path): import_modules_from_path(path, package_prefix) - + # Add core types if they aren't already added from codegen.sdk.core.codebase import CodebaseType, PyCodebaseType, TSCodebaseType - + if CodebaseType not in apidoc_objects: - apidoc_objects.append(DocumentedObject(name="CodebaseType", module="codegen.sdk.core.codebase", object=CodebaseType)) + apidoc_objects.append( + DocumentedObject( + name="CodebaseType", + module="codegen.sdk.core.codebase", + object=CodebaseType, + ) + ) if PyCodebaseType not in apidoc_objects: - apidoc_objects.append(DocumentedObject(name="PyCodebaseType", module="codegen.sdk.core.codebase", object=PyCodebaseType)) + apidoc_objects.append( + DocumentedObject( + name="PyCodebaseType", + module="codegen.sdk.core.codebase", + object=PyCodebaseType, + ) + ) if TSCodebaseType not in apidoc_objects: - apidoc_objects.append(DocumentedObject(name="TSCodebaseType", module="codegen.sdk.core.codebase", object=TSCodebaseType)) - + apidoc_objects.append( + DocumentedObject( + name="TSCodebaseType", + module="codegen.sdk.core.codebase", + object=TSCodebaseType, + ) + ) + # Return the collected objects return { "apidoc": apidoc_objects, "py_apidoc": py_apidoc_objects, "ts_apidoc": ts_apidoc_objects, - "no_apidoc": no_apidoc_objects + "no_apidoc": no_apidoc_objects, } def get_codebase_with_docs( - repo_path: Optional[str] = None, - base_path: Optional[str] = None, - config: Optional[CodebaseConfig] = None, - secrets: Optional[SecretsConfig] = None, - subdirectories: Optional[List[str]] = None, - programming_language: Optional[ProgrammingLanguage] = None, + repo_path: str | None = None, + base_path: str | None = None, + config: CodebaseConfig | None = None, + secrets: SecretsConfig | None = None, + subdirectories: list[str] | None = None, + programming_language: ProgrammingLanguage | None = None, package_prefix: str = "", - import_paths: Optional[List[str]] = None + import_paths: list[str] | None = None, ) -> tuple[CodebaseType, DocumentedObjects]: """Convenience function to get both a codebase and its documented objects. - + Parameters: repo_path: Path to the repository base_path: Base directory within the repository @@ -205,26 +238,24 @@ def get_codebase_with_docs( programming_language: Primary programming language of the codebase package_prefix: Prefix for importing modules import_paths: List of paths to import from - + Returns: A tuple containing the Codebase instance and the documented objects """ if not repo_path: repo_path = get_repo_path() - + codebase = get_selected_codebase( repo_path=repo_path, base_path=base_path, config=config, secrets=secrets, subdirectories=subdirectories, - programming_language=programming_language + programming_language=programming_language, ) - + documented_objects = get_documented_objects( - repo_path=repo_path, - package_prefix=package_prefix, - import_paths=import_paths + repo_path=repo_path, package_prefix=package_prefix, import_paths=import_paths ) - - return codebase, documented_objects \ No newline at end of file + + return codebase, documented_objects diff --git a/codegen-on-oss/codegen_on_oss/analyzers/dependencies.py b/codegen-on-oss/codegen_on_oss/analyzers/dependencies.py index f16e43718..82c65806e 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/dependencies.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/dependencies.py @@ -7,26 +7,34 @@ dependencies analysis. """ -import os -import sys import logging -import networkx as nx -from datetime import datetime -from typing import Dict, List, Set, Tuple, Any, Optional, Union, TypeVar, cast -from pathlib import Path +import sys from dataclasses import dataclass, field +from typing import Any + +import networkx as nx try: from codegen.sdk.core.codebase import Codebase from codegen.sdk.core.file import SourceFile - from codegen.sdk.core.symbol import Symbol from codegen.sdk.core.function import Function + from codegen.sdk.core.symbol import Symbol from codegen.sdk.enums import EdgeType, SymbolType - - # Import from our own modules - from codegen_on_oss.analyzers.issues import Issue, IssueCollection, IssueSeverity, IssueCategory, CodeLocation - from codegen_on_oss.analyzers.models.analysis_result import AnalysisResult, DependencyResult + from codegen_on_oss.analyzers.codebase_context import CodebaseContext + + # Import from our own modules + from codegen_on_oss.analyzers.issues import ( + CodeLocation, + Issue, + IssueCategory, + IssueCollection, + IssueSeverity, + ) + from codegen_on_oss.analyzers.models.analysis_result import ( + AnalysisResult, + DependencyResult, + ) except ImportError: print("Codegen SDK or required modules not found.") sys.exit(1) @@ -34,54 +42,65 @@ # Configure logging logger = logging.getLogger(__name__) + @dataclass class ImportDependency: """Represents an import dependency between files or modules.""" + source: str target: str - import_name: Optional[str] = None + import_name: str | None = None is_external: bool = False is_relative: bool = False - line_number: Optional[int] = None + line_number: int | None = None + @dataclass class ModuleDependency: """Represents a dependency between modules.""" + source_module: str target_module: str imports_count: int = 1 is_circular: bool = False + @dataclass class CircularDependency: """Represents a circular dependency in the codebase.""" - files: List[str] - modules: List[str] + + files: list[str] + modules: list[str] length: int cycle_type: str = "import" # Either "import" or "function_call" + @dataclass class ModuleCoupling: """Represents coupling metrics for a module.""" + module: str file_count: int - imported_modules: List[str] + imported_modules: list[str] import_count: int coupling_ratio: float - exported_symbols: List[str] = field(default_factory=list) + exported_symbols: list[str] = field(default_factory=list) + @dataclass class ExternalDependency: """Represents an external dependency.""" + module_name: str usage_count: int - importing_files: List[str] = field(default_factory=list) - imported_symbols: List[str] = field(default_factory=list) + importing_files: list[str] = field(default_factory=list) + imported_symbols: list[str] = field(default_factory=list) + class DependencyAnalyzer: """ Analyzer for codebase dependencies. - + This analyzer provides comprehensive dependency analysis, including: 1. Import dependencies analysis 2. Circular dependencies detection @@ -89,16 +108,16 @@ class DependencyAnalyzer: 4. External dependencies analysis 5. Call graph analysis """ - + def __init__( self, - codebase: Optional[Codebase] = None, - context: Optional[CodebaseContext] = None, - issue_collection: Optional[IssueCollection] = None + codebase: Codebase | None = None, + context: CodebaseContext | None = None, + issue_collection: IssueCollection | None = None, ): """ Initialize the DependencyAnalyzer. - + Args: codebase: Codebase instance to analyze context: CodebaseContext for advanced graph analysis @@ -107,31 +126,31 @@ def __init__( self.codebase = codebase self.context = context self.issues = issue_collection or IssueCollection() - + # Analysis results - self.import_dependencies: List[ImportDependency] = [] - self.module_dependencies: List[ModuleDependency] = [] - self.circular_dependencies: List[CircularDependency] = [] - self.module_coupling: Dict[str, ModuleCoupling] = {} - self.external_dependencies: Dict[str, ExternalDependency] = {} - + self.import_dependencies: list[ImportDependency] = [] + self.module_dependencies: list[ModuleDependency] = [] + self.circular_dependencies: list[CircularDependency] = [] + self.module_coupling: dict[str, ModuleCoupling] = {} + self.external_dependencies: dict[str, ExternalDependency] = {} + # Analysis graphs self.import_graph = nx.DiGraph() self.module_graph = nx.DiGraph() self.call_graph = nx.DiGraph() self.class_hierarchy_graph = nx.DiGraph() - + # Initialize context if needed if self.codebase and not self.context: try: self.context = CodebaseContext(codebase=self.codebase) except Exception as e: - logger.error(f"Error initializing context: {e}") - + logger.exception(f"Error initializing context: {e}") + def analyze(self) -> DependencyResult: """ Perform comprehensive dependency analysis on the codebase. - + Returns: DependencyResult containing all dependency analysis results """ @@ -141,13 +160,13 @@ def analyze(self) -> DependencyResult: self.circular_dependencies = [] self.module_coupling = {} self.external_dependencies = {} - + # Initialize graphs self.import_graph = nx.DiGraph() self.module_graph = nx.DiGraph() self.call_graph = nx.DiGraph() self.class_hierarchy_graph = nx.DiGraph() - + # Perform analysis self._analyze_import_dependencies() self._find_circular_dependencies() @@ -155,10 +174,10 @@ def analyze(self) -> DependencyResult: self._analyze_external_dependencies() self._analyze_call_graph() self._analyze_class_hierarchy() - + # Return structured results return self._create_result() - + def _create_result(self) -> DependencyResult: """Create a structured result object from the analysis results.""" # Organize import dependencies @@ -170,7 +189,7 @@ def _create_result(self) -> DependencyResult: "import_name": dep.import_name, "is_external": dep.is_external, "is_relative": dep.is_relative, - "line_number": dep.line_number + "line_number": dep.line_number, } for dep in self.import_dependencies ], @@ -179,18 +198,24 @@ def _create_result(self) -> DependencyResult: "source_module": dep.source_module, "target_module": dep.target_module, "imports_count": dep.imports_count, - "is_circular": dep.is_circular + "is_circular": dep.is_circular, } for dep in self.module_dependencies ], "stats": { "total_imports": len(self.import_dependencies), - "internal_imports": sum(1 for dep in self.import_dependencies if not dep.is_external), - "external_imports": sum(1 for dep in self.import_dependencies if dep.is_external), - "relative_imports": sum(1 for dep in self.import_dependencies if dep.is_relative) - } + "internal_imports": sum( + 1 for dep in self.import_dependencies if not dep.is_external + ), + "external_imports": sum( + 1 for dep in self.import_dependencies if dep.is_external + ), + "relative_imports": sum( + 1 for dep in self.import_dependencies if dep.is_relative + ), + }, } - + # Organize circular dependencies circular_deps = { "circular_imports": [ @@ -198,18 +223,16 @@ def _create_result(self) -> DependencyResult: "files": dep.files, "modules": dep.modules, "length": dep.length, - "cycle_type": dep.cycle_type + "cycle_type": dep.cycle_type, } for dep in self.circular_dependencies ], "circular_dependencies_count": len(self.circular_dependencies), - "affected_modules": list(set( - module - for dep in self.circular_dependencies - for module in dep.modules - )) + "affected_modules": list({ + module for dep in self.circular_dependencies for module in dep.modules + }), } - + # Organize module coupling coupling = { "high_coupling_modules": [ @@ -218,7 +241,7 @@ def _create_result(self) -> DependencyResult: "coupling_ratio": data.coupling_ratio, "import_count": data.import_count, "file_count": data.file_count, - "imported_modules": data.imported_modules + "imported_modules": data.imported_modules, } for module, data in self.module_coupling.items() if data.coupling_ratio > 3 # Threshold for high coupling @@ -229,17 +252,20 @@ def _create_result(self) -> DependencyResult: "coupling_ratio": data.coupling_ratio, "import_count": data.import_count, "file_count": data.file_count, - "imported_modules": data.imported_modules + "imported_modules": data.imported_modules, } for module, data in self.module_coupling.items() - if data.coupling_ratio < 0.5 and data.file_count > 1 # Threshold for low coupling + if data.coupling_ratio < 0.5 + and data.file_count > 1 # Threshold for low coupling ], "average_coupling": ( - sum(data.coupling_ratio for data in self.module_coupling.values()) / - len(self.module_coupling) if self.module_coupling else 0 - ) + sum(data.coupling_ratio for data in self.module_coupling.values()) + / len(self.module_coupling) + if self.module_coupling + else 0 + ), } - + # Organize external dependencies external_deps = { "external_modules": list(self.external_dependencies.keys()), @@ -247,17 +273,17 @@ def _create_result(self) -> DependencyResult: { "module": module, "usage_count": data.usage_count, - "importing_files": data.importing_files[:10] # Limit to 10 files + "importing_files": data.importing_files[:10], # Limit to 10 files } for module, data in sorted( self.external_dependencies.items(), key=lambda x: x[1].usage_count, - reverse=True + reverse=True, )[:10] # Top 10 most used ], - "total_external_modules": len(self.external_dependencies) + "total_external_modules": len(self.external_dependencies), } - + # Create result object return DependencyResult( import_dependencies=import_deps, @@ -265,307 +291,361 @@ def _create_result(self) -> DependencyResult: module_coupling=coupling, external_dependencies=external_deps, call_graph=self._export_call_graph(), - class_hierarchy=self._export_class_hierarchy() + class_hierarchy=self._export_class_hierarchy(), ) - + def _analyze_import_dependencies(self) -> None: """Analyze import dependencies in the codebase.""" if not self.codebase: logger.error("Codebase not initialized") return - + # Process all files to extract import information for file in self.codebase.files: # Skip if no imports - if not hasattr(file, 'imports') or not file.imports: + if not hasattr(file, "imports") or not file.imports: continue - + # Get file path - file_path = str(file.file_path if hasattr(file, 'file_path') else - file.path if hasattr(file, 'path') else file) - + file_path = str( + file.file_path + if hasattr(file, "file_path") + else file.path + if hasattr(file, "path") + else file + ) + # Extract module name from file path - file_parts = file_path.split('/') - module_name = '/'.join(file_parts[:-1]) if len(file_parts) > 1 else file_parts[0] - + file_parts = file_path.split("/") + module_name = ( + "/".join(file_parts[:-1]) if len(file_parts) > 1 else file_parts[0] + ) + # Initialize module info in module graph if not self.module_graph.has_node(module_name): - self.module_graph.add_node(module_name, files=set([file_path])) + self.module_graph.add_node(module_name, files={file_path}) else: - self.module_graph.nodes[module_name]['files'].add(file_path) - + self.module_graph.nodes[module_name]["files"].add(file_path) + # Process imports for imp in file.imports: # Get import information - import_name = imp.name if hasattr(imp, 'name') else "unknown" - line_number = imp.line if hasattr(imp, 'line') else None - is_relative = hasattr(imp, 'is_relative') and imp.is_relative - + import_name = imp.name if hasattr(imp, "name") else "unknown" + line_number = imp.line if hasattr(imp, "line") else None + is_relative = hasattr(imp, "is_relative") and imp.is_relative + # Try to get imported file imported_file = None - if hasattr(imp, 'resolved_file'): + if hasattr(imp, "resolved_file"): imported_file = imp.resolved_file - elif hasattr(imp, 'resolved_symbol') and hasattr(imp.resolved_symbol, 'file'): + elif hasattr(imp, "resolved_symbol") and hasattr( + imp.resolved_symbol, "file" + ): imported_file = imp.resolved_symbol.file - + # Get imported file path and module if imported_file: # Get imported file path - imported_path = str(imported_file.file_path if hasattr(imported_file, 'file_path') else - imported_file.path if hasattr(imported_file, 'path') else imported_file) - + imported_path = str( + imported_file.file_path + if hasattr(imported_file, "file_path") + else imported_file.path + if hasattr(imported_file, "path") + else imported_file + ) + # Extract imported module name - imported_parts = imported_path.split('/') - imported_module = '/'.join(imported_parts[:-1]) if len(imported_parts) > 1 else imported_parts[0] - + imported_parts = imported_path.split("/") + imported_module = ( + "/".join(imported_parts[:-1]) + if len(imported_parts) > 1 + else imported_parts[0] + ) + # Check if external - is_external = hasattr(imported_file, 'is_external') and imported_file.is_external - + is_external = ( + hasattr(imported_file, "is_external") + and imported_file.is_external + ) + # Add to import dependencies - self.import_dependencies.append(ImportDependency( - source=file_path, - target=imported_path, - import_name=import_name, - is_external=is_external, - is_relative=is_relative, - line_number=line_number - )) - + self.import_dependencies.append( + ImportDependency( + source=file_path, + target=imported_path, + import_name=import_name, + is_external=is_external, + is_relative=is_relative, + line_number=line_number, + ) + ) + # Add to import graph - self.import_graph.add_edge(file_path, imported_path, - name=import_name, - external=is_external, - relative=is_relative) - + self.import_graph.add_edge( + file_path, + imported_path, + name=import_name, + external=is_external, + relative=is_relative, + ) + # Add to module graph if not is_external: # Initialize imported module if needed if not self.module_graph.has_node(imported_module): - self.module_graph.add_node(imported_module, files=set([imported_path])) + self.module_graph.add_node( + imported_module, files={imported_path} + ) else: - self.module_graph.nodes[imported_module]['files'].add(imported_path) - + self.module_graph.nodes[imported_module]["files"].add( + imported_path + ) + # Add module dependency if module_name != imported_module: # Skip self-imports if self.module_graph.has_edge(module_name, imported_module): # Increment count for existing edge - self.module_graph[module_name][imported_module]['count'] += 1 + self.module_graph[module_name][imported_module][ + "count" + ] += 1 else: # Add new edge - self.module_graph.add_edge(module_name, imported_module, count=1) + self.module_graph.add_edge( + module_name, imported_module, count=1 + ) else: # Handle external import that couldn't be resolved # Extract module name from import - if hasattr(imp, 'module_name') and imp.module_name: + if hasattr(imp, "module_name") and imp.module_name: external_module = imp.module_name is_external = True - + # Add to import dependencies - self.import_dependencies.append(ImportDependency( - source=file_path, - target=external_module, - import_name=import_name, - is_external=True, - is_relative=is_relative, - line_number=line_number - )) - + self.import_dependencies.append( + ImportDependency( + source=file_path, + target=external_module, + import_name=import_name, + is_external=True, + is_relative=is_relative, + line_number=line_number, + ) + ) + # Track external dependency - self._track_external_dependency(external_module, file_path, import_name) - + self._track_external_dependency( + external_module, file_path, import_name + ) + # Extract module dependencies from module graph for source, target, data in self.module_graph.edges(data=True): - self.module_dependencies.append(ModuleDependency( - source_module=source, - target_module=target, - imports_count=data.get('count', 1) - )) - + self.module_dependencies.append( + ModuleDependency( + source_module=source, + target_module=target, + imports_count=data.get("count", 1), + ) + ) + def _find_circular_dependencies(self) -> None: """Find circular dependencies in the codebase.""" # Find circular dependencies at the file level try: file_cycles = list(nx.simple_cycles(self.import_graph)) - + for cycle in file_cycles: if len(cycle) < 2: continue - + # Get the modules involved in the cycle modules = [] for file_path in cycle: - parts = file_path.split('/') - module = '/'.join(parts[:-1]) if len(parts) > 1 else parts[0] + parts = file_path.split("/") + module = "/".join(parts[:-1]) if len(parts) > 1 else parts[0] modules.append(module) - + # Create circular dependency circular_dep = CircularDependency( - files=cycle, - modules=modules, - length=len(cycle), - cycle_type="import" + files=cycle, modules=modules, length=len(cycle), cycle_type="import" ) - + self.circular_dependencies.append(circular_dep) - + # Create issue for this circular dependency - self.issues.add(Issue( - message=f"Circular import dependency detected between {len(cycle)} files", - severity=IssueSeverity.ERROR, - category=IssueCategory.DEPENDENCY_CYCLE, - location=CodeLocation( - file=cycle[0], - line=None - ), - suggestion="Refactor the code to break the circular dependency, potentially by extracting shared code into a separate module" - )) - + self.issues.add( + Issue( + message=f"Circular import dependency detected between {len(cycle)} files", + severity=IssueSeverity.ERROR, + category=IssueCategory.DEPENDENCY_CYCLE, + location=CodeLocation(file=cycle[0], line=None), + suggestion="Refactor the code to break the circular dependency, potentially by extracting shared code into a separate module", + ) + ) + # Mark modules as circular in module dependencies for i in range(len(modules)): source = modules[i] - target = modules[(i+1) % len(modules)] - + target = modules[(i + 1) % len(modules)] + for dep in self.module_dependencies: if dep.source_module == source and dep.target_module == target: dep.is_circular = True - + except Exception as e: - logger.error(f"Error finding circular dependencies: {e}") - + logger.exception(f"Error finding circular dependencies: {e}") + # Find circular dependencies at the module level try: module_cycles = list(nx.simple_cycles(self.module_graph)) - + for cycle in module_cycles: if len(cycle) < 2: continue - + # Find files for these modules files = [] for module in cycle: - if self.module_graph.has_node(module) and 'files' in self.module_graph.nodes[module]: - module_files = self.module_graph.nodes[module]['files'] + if ( + self.module_graph.has_node(module) + and "files" in self.module_graph.nodes[module] + ): + module_files = self.module_graph.nodes[module]["files"] if module_files: files.append(next(iter(module_files))) # Take first file - + # Only add if we haven't already found this cycle at the file level - if not any(set(cycle) == set(dep.modules) for dep in self.circular_dependencies): + if not any( + set(cycle) == set(dep.modules) for dep in self.circular_dependencies + ): circular_dep = CircularDependency( files=files, modules=cycle, length=len(cycle), - cycle_type="import" + cycle_type="import", ) - + self.circular_dependencies.append(circular_dep) - + # Create issue for this circular dependency - self.issues.add(Issue( - message=f"Circular dependency detected between modules: {', '.join(cycle)}", - severity=IssueSeverity.ERROR, - category=IssueCategory.DEPENDENCY_CYCLE, - location=CodeLocation( - file=files[0] if files else cycle[0], - line=None - ), - suggestion="Refactor the code to break the circular dependency" - )) - + self.issues.add( + Issue( + message=f"Circular dependency detected between modules: {', '.join(cycle)}", + severity=IssueSeverity.ERROR, + category=IssueCategory.DEPENDENCY_CYCLE, + location=CodeLocation( + file=files[0] if files else cycle[0], line=None + ), + suggestion="Refactor the code to break the circular dependency", + ) + ) + except Exception as e: - logger.error(f"Error finding circular module dependencies: {e}") - + logger.exception(f"Error finding circular module dependencies: {e}") + # If we have context, also find circular function call dependencies - if self.context and hasattr(self.context, '_graph'): + if self.context and hasattr(self.context, "_graph"): try: # Try to find function call cycles - function_nodes = [node for node in self.context.nodes if isinstance(node, Function)] - + function_nodes = [ + node for node in self.context.nodes if isinstance(node, Function) + ] + # Build function call graph call_graph = nx.DiGraph() - + for func in function_nodes: call_graph.add_node(func) - + # Add call edges for _, target, data in self.context.out_edges(func, data=True): - if isinstance(target, Function) and data.get('type') == EdgeType.CALLS: + if ( + isinstance(target, Function) + and data.get("type") == EdgeType.CALLS + ): call_graph.add_edge(func, target) - + # Find cycles func_cycles = list(nx.simple_cycles(call_graph)) - + for cycle in func_cycles: if len(cycle) < 2: continue - + # Get files and function names files = [] function_names = [] - + for func in cycle: - function_names.append(func.name if hasattr(func, 'name') else str(func)) - if hasattr(func, 'file') and hasattr(func.file, 'file_path'): + function_names.append( + func.name if hasattr(func, "name") else str(func) + ) + if hasattr(func, "file") and hasattr(func.file, "file_path"): files.append(str(func.file.file_path)) - + # Get modules modules = [] for file_path in files: - parts = file_path.split('/') - module = '/'.join(parts[:-1]) if len(parts) > 1 else parts[0] + parts = file_path.split("/") + module = "/".join(parts[:-1]) if len(parts) > 1 else parts[0] modules.append(module) - + # Create circular dependency circular_dep = CircularDependency( files=files, modules=modules, length=len(cycle), - cycle_type="function_call" + cycle_type="function_call", ) - + self.circular_dependencies.append(circular_dep) - + # Create issue for this circular dependency - self.issues.add(Issue( - message=f"Circular function call dependency detected: {' -> '.join(function_names)}", - severity=IssueSeverity.ERROR if len(cycle) > 2 else IssueSeverity.WARNING, - category=IssueCategory.DEPENDENCY_CYCLE, - location=CodeLocation( - file=files[0] if files else "unknown", - line=None - ), - suggestion="Refactor the code to eliminate the circular function calls" - )) - + self.issues.add( + Issue( + message=f"Circular function call dependency detected: {' -> '.join(function_names)}", + severity=IssueSeverity.ERROR + if len(cycle) > 2 + else IssueSeverity.WARNING, + category=IssueCategory.DEPENDENCY_CYCLE, + location=CodeLocation( + file=files[0] if files else "unknown", line=None + ), + suggestion="Refactor the code to eliminate the circular function calls", + ) + ) + except Exception as e: - logger.error(f"Error finding circular function call dependencies: {e}") - + logger.exception( + f"Error finding circular function call dependencies: {e}" + ) + def _analyze_module_coupling(self) -> None: """Analyze module coupling in the codebase.""" # Use module graph to calculate coupling metrics for module in self.module_graph.nodes(): # Get files in this module - files = self.module_graph.nodes[module].get('files', set()) + files = self.module_graph.nodes[module].get("files", set()) file_count = len(files) - + # Get imported modules imported_modules = [] for _, target in self.module_graph.out_edges(module): imported_modules.append(target) - + # Calculate metrics import_count = len(imported_modules) coupling_ratio = import_count / file_count if file_count > 0 else 0 - + # Find exported symbols if we have the context exported_symbols = [] if self.context: for file_path in files: file = self.context.get_file(file_path) - if file and hasattr(file, 'exports'): + if file and hasattr(file, "exports"): for export in file.exports: - if hasattr(export, 'name'): + if hasattr(export, "name"): exported_symbols.append(export.name) - + # Create module coupling data self.module_coupling[module] = ModuleCoupling( module=module, @@ -573,22 +653,23 @@ def _analyze_module_coupling(self) -> None: imported_modules=imported_modules, import_count=import_count, coupling_ratio=coupling_ratio, - exported_symbols=exported_symbols + exported_symbols=exported_symbols, ) - + # Check for high coupling if coupling_ratio > 3 and file_count > 1: # Threshold for high coupling - self.issues.add(Issue( - message=f"High module coupling: {module} has a coupling ratio of {coupling_ratio:.2f}", - severity=IssueSeverity.WARNING, - category=IssueCategory.DEPENDENCY_CYCLE, - location=CodeLocation( - file=next(iter(files)) if files else module, - line=None - ), - suggestion="Consider refactoring to reduce the number of dependencies" - )) - + self.issues.add( + Issue( + message=f"High module coupling: {module} has a coupling ratio of {coupling_ratio:.2f}", + severity=IssueSeverity.WARNING, + category=IssueCategory.DEPENDENCY_CYCLE, + location=CodeLocation( + file=next(iter(files)) if files else module, line=None + ), + suggestion="Consider refactoring to reduce the number of dependencies", + ) + ) + def _analyze_external_dependencies(self) -> None: """Analyze external dependencies in the codebase.""" # Collect external dependencies from import dependencies @@ -597,264 +678,331 @@ def _analyze_external_dependencies(self) -> None: external_name = dep.target import_name = dep.import_name file_path = dep.source - + self._track_external_dependency(external_name, file_path, import_name) - - def _track_external_dependency(self, module_name: str, file_path: str, import_name: Optional[str] = None) -> None: + + def _track_external_dependency( + self, module_name: str, file_path: str, import_name: str | None = None + ) -> None: """Track an external dependency.""" if module_name not in self.external_dependencies: self.external_dependencies[module_name] = ExternalDependency( module_name=module_name, usage_count=1, importing_files=[file_path], - imported_symbols=[import_name] if import_name else [] + imported_symbols=[import_name] if import_name else [], ) else: # Update existing dependency self.external_dependencies[module_name].usage_count += 1 - + if file_path not in self.external_dependencies[module_name].importing_files: - self.external_dependencies[module_name].importing_files.append(file_path) - - if import_name and import_name not in self.external_dependencies[module_name].imported_symbols: - self.external_dependencies[module_name].imported_symbols.append(import_name) - + self.external_dependencies[module_name].importing_files.append( + file_path + ) + + if ( + import_name + and import_name + not in self.external_dependencies[module_name].imported_symbols + ): + self.external_dependencies[module_name].imported_symbols.append( + import_name + ) + def _analyze_call_graph(self) -> None: """Analyze function call relationships.""" # Skip if we don't have context if not self.context: return - + # Find all functions functions = [node for node in self.context.nodes if isinstance(node, Function)] - + # Build call graph for func in functions: - func_name = func.name if hasattr(func, 'name') else str(func) - func_path = str(func.file.file_path) if hasattr(func, 'file') and hasattr(func.file, 'file_path') else "unknown" - + func_name = func.name if hasattr(func, "name") else str(func) + func_path = ( + str(func.file.file_path) + if hasattr(func, "file") and hasattr(func.file, "file_path") + else "unknown" + ) + # Add node to call graph if not self.call_graph.has_node(func_name): self.call_graph.add_node(func_name, path=func_path, function=func) - + # Process outgoing calls - if hasattr(func, 'calls'): + if hasattr(func, "calls"): for call in func.calls: called_func = None - + # Try to resolve the call - if hasattr(call, 'resolved_symbol') and call.resolved_symbol: + if hasattr(call, "resolved_symbol") and call.resolved_symbol: called_func = call.resolved_symbol - elif hasattr(call, 'name'): + elif hasattr(call, "name"): # Try to find by name for other_func in functions: - if hasattr(other_func, 'name') and other_func.name == call.name: + if ( + hasattr(other_func, "name") + and other_func.name == call.name + ): called_func = other_func break - + if called_func: - called_name = called_func.name if hasattr(called_func, 'name') else str(called_func) - called_path = str(called_func.file.file_path) if hasattr(called_func, 'file') and hasattr(called_func.file, 'file_path') else "unknown" - + called_name = ( + called_func.name + if hasattr(called_func, "name") + else str(called_func) + ) + called_path = ( + str(called_func.file.file_path) + if hasattr(called_func, "file") + and hasattr(called_func.file, "file_path") + else "unknown" + ) + # Add target node if needed if not self.call_graph.has_node(called_name): - self.call_graph.add_node(called_name, path=called_path, function=called_func) - + self.call_graph.add_node( + called_name, path=called_path, function=called_func + ) + # Add edge to call graph - self.call_graph.add_edge(func_name, called_name, source_path=func_path, target_path=called_path) - + self.call_graph.add_edge( + func_name, + called_name, + source_path=func_path, + target_path=called_path, + ) + # Check for recursive calls if self.call_graph.has_edge(func_name, func_name): - self.issues.add(Issue( - message=f"Recursive function: {func_name}", - severity=IssueSeverity.INFO, - category=IssueCategory.DEPENDENCY_CYCLE, - location=CodeLocation( - file=func_path, - line=func.line if hasattr(func, 'line') else None - ), - symbol=func_name - )) - + self.issues.add( + Issue( + message=f"Recursive function: {func_name}", + severity=IssueSeverity.INFO, + category=IssueCategory.DEPENDENCY_CYCLE, + location=CodeLocation( + file=func_path, + line=func.line if hasattr(func, "line") else None, + ), + symbol=func_name, + ) + ) + # Analyze call chains self._analyze_deep_call_chains() - + def _analyze_deep_call_chains(self) -> None: """Analyze deep call chains in the call graph.""" # Find entry points (functions not called by others) - entry_points = [node for node in self.call_graph.nodes() - if self.call_graph.in_degree(node) == 0] - + entry_points = [ + node + for node in self.call_graph.nodes() + if self.call_graph.in_degree(node) == 0 + ] + # Find leaf functions (functions that don't call others) - leaf_functions = [node for node in self.call_graph.nodes() - if self.call_graph.out_degree(node) == 0] - + leaf_functions = [ + node + for node in self.call_graph.nodes() + if self.call_graph.out_degree(node) == 0 + ] + # Look for long paths long_chains = [] - + for entry in entry_points: for leaf in leaf_functions: try: if nx.has_path(self.call_graph, entry, leaf): path = nx.shortest_path(self.call_graph, entry, leaf) - + if len(path) > 5: # Threshold for "deep" call chains long_chains.append({ "entry_point": entry, "length": len(path), - "path": path + "path": path, }) - + # Create issue for very deep call chains if len(path) > 8: # Threshold for concerning depth - entry_path = self.call_graph.nodes[entry].get('path', 'unknown') - - self.issues.add(Issue( - message=f"Deep call chain starting from {entry} ({len(path)} levels deep)", - severity=IssueSeverity.WARNING, - category=IssueCategory.COMPLEXITY, - location=CodeLocation( - file=entry_path, - line=None - ), - suggestion="Consider refactoring to reduce call depth" - )) + entry_path = self.call_graph.nodes[entry].get( + "path", "unknown" + ) + + self.issues.add( + Issue( + message=f"Deep call chain starting from {entry} ({len(path)} levels deep)", + severity=IssueSeverity.WARNING, + category=IssueCategory.COMPLEXITY, + location=CodeLocation( + file=entry_path, line=None + ), + suggestion="Consider refactoring to reduce call depth", + ) + ) except nx.NetworkXNoPath: pass - + # Sort chains by length - long_chains.sort(key=lambda x: x['length'], reverse=True) - + long_chains.sort(key=lambda x: x["length"], reverse=True) + # Store top 10 longest chains self.long_call_chains = long_chains[:10] - + def _analyze_class_hierarchy(self) -> None: """Analyze class inheritance hierarchy.""" # Skip if we don't have context if not self.context: return - + # Find all classes classes = [node for node in self.context.nodes if isinstance(node, Class)] - + # Build inheritance graph for cls in classes: - cls_name = cls.name if hasattr(cls, 'name') else str(cls) - cls_path = str(cls.file.file_path) if hasattr(cls, 'file') and hasattr(cls.file, 'file_path') else "unknown" - + cls_name = cls.name if hasattr(cls, "name") else str(cls) + cls_path = ( + str(cls.file.file_path) + if hasattr(cls, "file") and hasattr(cls.file, "file_path") + else "unknown" + ) + # Add node to class graph if not self.class_hierarchy_graph.has_node(cls_name): - self.class_hierarchy_graph.add_node(cls_name, path=cls_path, class_obj=cls) - + self.class_hierarchy_graph.add_node( + cls_name, path=cls_path, class_obj=cls + ) + # Process superclasses - if hasattr(cls, 'superclasses'): + if hasattr(cls, "superclasses"): for superclass in cls.superclasses: - super_name = superclass.name if hasattr(superclass, 'name') else str(superclass) - super_path = str(superclass.file.file_path) if hasattr(superclass, 'file') and hasattr(superclass.file, 'file_path') else "unknown" - + super_name = ( + superclass.name + if hasattr(superclass, "name") + else str(superclass) + ) + super_path = ( + str(superclass.file.file_path) + if hasattr(superclass, "file") + and hasattr(superclass.file, "file_path") + else "unknown" + ) + # Add superclass node if needed if not self.class_hierarchy_graph.has_node(super_name): - self.class_hierarchy_graph.add_node(super_name, path=super_path, class_obj=superclass) - + self.class_hierarchy_graph.add_node( + super_name, path=super_path, class_obj=superclass + ) + # Add inheritance edge self.class_hierarchy_graph.add_edge(cls_name, super_name) - + # Check for deep inheritance for cls_name in self.class_hierarchy_graph.nodes(): # Calculate inheritance depth depth = 0 current = cls_name - + while self.class_hierarchy_graph.out_degree(current) > 0: depth += 1 successors = list(self.class_hierarchy_graph.successors(current)) if not successors: break current = successors[0] # Follow first superclass - + # Check if depth exceeds threshold if depth > 3: # Threshold for deep inheritance - cls_path = self.class_hierarchy_graph.nodes[cls_name].get('path', 'unknown') - - self.issues.add(Issue( - message=f"Deep inheritance: {cls_name} has an inheritance depth of {depth}", - severity=IssueSeverity.WARNING, - category=IssueCategory.DEPENDENCY_CYCLE, - location=CodeLocation( - file=cls_path, - line=None - ), - suggestion="Consider using composition instead of deep inheritance" - )) - - def _export_call_graph(self) -> Dict[str, Any]: + cls_path = self.class_hierarchy_graph.nodes[cls_name].get( + "path", "unknown" + ) + + self.issues.add( + Issue( + message=f"Deep inheritance: {cls_name} has an inheritance depth of {depth}", + severity=IssueSeverity.WARNING, + category=IssueCategory.DEPENDENCY_CYCLE, + location=CodeLocation(file=cls_path, line=None), + suggestion="Consider using composition instead of deep inheritance", + ) + ) + + def _export_call_graph(self) -> dict[str, Any]: """Export the call graph for the analysis result.""" nodes = [] edges = [] - + # Add nodes for node in self.call_graph.nodes(): node_data = self.call_graph.nodes[node] - nodes.append({ - "id": node, - "path": node_data.get('path', 'unknown') - }) - + nodes.append({"id": node, "path": node_data.get("path", "unknown")}) + # Add edges for source, target in self.call_graph.edges(): edge_data = self.call_graph.get_edge_data(source, target) edges.append({ "source": source, "target": target, - "source_path": edge_data.get('source_path', 'unknown'), - "target_path": edge_data.get('target_path', 'unknown') + "source_path": edge_data.get("source_path", "unknown"), + "target_path": edge_data.get("target_path", "unknown"), }) - + # Find entry points and leaf functions - entry_points = [node for node in self.call_graph.nodes() - if self.call_graph.in_degree(node) == 0] - - leaf_functions = [node for node in self.call_graph.nodes() - if self.call_graph.out_degree(node) == 0] - + entry_points = [ + node + for node in self.call_graph.nodes() + if self.call_graph.in_degree(node) == 0 + ] + + leaf_functions = [ + node + for node in self.call_graph.nodes() + if self.call_graph.out_degree(node) == 0 + ] + return { "nodes": nodes, "edges": edges, "entry_points": entry_points, "leaf_functions": leaf_functions, - "deep_call_chains": self.long_call_chains if hasattr(self, 'long_call_chains') else [] + "deep_call_chains": self.long_call_chains + if hasattr(self, "long_call_chains") + else [], } - - def _export_class_hierarchy(self) -> Dict[str, Any]: + + def _export_class_hierarchy(self) -> dict[str, Any]: """Export the class hierarchy for the analysis result.""" nodes = [] edges = [] - + # Add nodes for node in self.class_hierarchy_graph.nodes(): node_data = self.class_hierarchy_graph.nodes[node] - nodes.append({ - "id": node, - "path": node_data.get('path', 'unknown') - }) - + nodes.append({"id": node, "path": node_data.get("path", "unknown")}) + # Add edges for source, target in self.class_hierarchy_graph.edges(): - edges.append({ - "source": source, - "target": target - }) - + edges.append({"source": source, "target": target}) + # Find root classes (no superclasses) and leaf classes (no subclasses) - root_classes = [node for node in self.class_hierarchy_graph.nodes() - if self.class_hierarchy_graph.out_degree(node) == 0] - - leaf_classes = [node for node in self.class_hierarchy_graph.nodes() - if self.class_hierarchy_graph.in_degree(node) == 0] - + root_classes = [ + node + for node in self.class_hierarchy_graph.nodes() + if self.class_hierarchy_graph.out_degree(node) == 0 + ] + + leaf_classes = [ + node + for node in self.class_hierarchy_graph.nodes() + if self.class_hierarchy_graph.in_degree(node) == 0 + ] + return { "nodes": nodes, "edges": edges, "root_classes": root_classes, - "leaf_classes": leaf_classes - } \ No newline at end of file + "leaf_classes": leaf_classes, + } diff --git a/codegen-on-oss/codegen_on_oss/analyzers/dependency_analyzer.py b/codegen-on-oss/codegen_on_oss/analyzers/dependency_analyzer.py index 56eff1440..a18c23d2d 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/dependency_analyzer.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/dependency_analyzer.py @@ -6,52 +6,61 @@ import relationships, circular dependencies, and module coupling. """ -import os -import sys import logging +from typing import Any + import networkx as nx -from typing import Dict, List, Set, Tuple, Any, Optional, Union from codegen_on_oss.analyzers.base_analyzer import BaseCodeAnalyzer -from codegen_on_oss.analyzers.issue_types import Issue, IssueSeverity, AnalysisType, IssueCategory +from codegen_on_oss.analyzers.issue_types import ( + AnalysisType, + Issue, + IssueCategory, + IssueSeverity, +) # Configure logging logger = logging.getLogger(__name__) + class DependencyAnalyzer(BaseCodeAnalyzer): """ Analyzer for codebase dependencies. - + This analyzer detects issues related to dependencies, including import relationships, circular dependencies, and module coupling. """ - - def analyze(self, analysis_type: AnalysisType = AnalysisType.DEPENDENCY) -> Dict[str, Any]: + + def analyze( + self, analysis_type: AnalysisType = AnalysisType.DEPENDENCY + ) -> dict[str, Any]: """ Perform dependency analysis on the codebase. - + Args: analysis_type: Type of analysis to perform - + Returns: Dictionary containing analysis results """ if not self.base_codebase: raise ValueError("Codebase not initialized") - + result = { "metadata": { "analysis_time": str(datetime.now()), "analysis_type": analysis_type, - "repo_name": getattr(self.base_codebase.ctx, 'repo_name', None), - "language": str(getattr(self.base_codebase.ctx, 'programming_language', None)), + "repo_name": getattr(self.base_codebase.ctx, "repo_name", None), + "language": str( + getattr(self.base_codebase.ctx, "programming_language", None) + ), }, "summary": {}, } - + # Reset issues list self.issues = [] - + # Perform appropriate analysis based on type if analysis_type == AnalysisType.DEPENDENCY: # Run all dependency checks @@ -59,33 +68,47 @@ def analyze(self, analysis_type: AnalysisType = AnalysisType.DEPENDENCY) -> Dict result["circular_dependencies"] = self._find_circular_dependencies() result["module_coupling"] = self._analyze_module_coupling() result["external_dependencies"] = self._analyze_external_dependencies() - + # Add issues to the result result["issues"] = [issue.to_dict() for issue in self.issues] result["issue_counts"] = { "total": len(self.issues), "by_severity": { - "critical": sum(1 for issue in self.issues if issue.severity == IssueSeverity.CRITICAL), - "error": sum(1 for issue in self.issues if issue.severity == IssueSeverity.ERROR), - "warning": sum(1 for issue in self.issues if issue.severity == IssueSeverity.WARNING), - "info": sum(1 for issue in self.issues if issue.severity == IssueSeverity.INFO), + "critical": sum( + 1 + for issue in self.issues + if issue.severity == IssueSeverity.CRITICAL + ), + "error": sum( + 1 for issue in self.issues if issue.severity == IssueSeverity.ERROR + ), + "warning": sum( + 1 + for issue in self.issues + if issue.severity == IssueSeverity.WARNING + ), + "info": sum( + 1 for issue in self.issues if issue.severity == IssueSeverity.INFO + ), }, "by_category": { - category.value: sum(1 for issue in self.issues if issue.category == category) + category.value: sum( + 1 for issue in self.issues if issue.category == category + ) for category in IssueCategory if any(issue.category == category for issue in self.issues) - } + }, } - + # Store results self.results = result - + return result - - def _analyze_import_dependencies(self) -> Dict[str, Any]: + + def _analyze_import_dependencies(self) -> dict[str, Any]: """ Analyze import dependencies in the codebase. - + Returns: Dictionary containing import dependencies analysis results """ @@ -98,210 +121,251 @@ def _analyze_import_dependencies(self) -> Dict[str, Any]: "total_imports": 0, "internal_imports": 0, "external_imports": 0, - "relative_imports": 0 - } + "relative_imports": 0, + }, } - + # Create a directed graph for module dependencies G = nx.DiGraph() - + # Track import counts module_imports = {} # modules importing others module_imported = {} # modules being imported - + # Process all files to extract import information for file in self.base_codebase.files: # Skip if no imports - if not hasattr(file, 'imports') or not file.imports: + if not hasattr(file, "imports") or not file.imports: continue - + # Get file path - file_path = file.filepath if hasattr(file, 'filepath') else str(file.path) if hasattr(file, 'path') else str(file) - + file_path = ( + file.filepath + if hasattr(file, "filepath") + else str(file.path) + if hasattr(file, "path") + else str(file) + ) + # Extract module name from file path - file_parts = file_path.split('/') - module_name = '/'.join(file_parts[:-1]) if len(file_parts) > 1 else file_parts[0] - + file_parts = file_path.split("/") + module_name = ( + "/".join(file_parts[:-1]) if len(file_parts) > 1 else file_parts[0] + ) + # Initialize import counts if module_name not in module_imports: module_imports[module_name] = 0 - + # Process imports for imp in file.imports: import_deps["dependency_stats"]["total_imports"] += 1 - + # Get imported module information imported_file = None imported_module = "unknown" is_external = False - - if hasattr(imp, 'resolved_file'): + + if hasattr(imp, "resolved_file"): imported_file = imp.resolved_file - elif hasattr(imp, 'resolved_symbol') and hasattr(imp.resolved_symbol, 'file'): + elif hasattr(imp, "resolved_symbol") and hasattr( + imp.resolved_symbol, "file" + ): imported_file = imp.resolved_symbol.file - + if imported_file: # Get imported file path - imported_path = imported_file.filepath if hasattr(imported_file, 'filepath') else str(imported_file.path) if hasattr(imported_file, 'path') else str(imported_file) - + imported_path = ( + imported_file.filepath + if hasattr(imported_file, "filepath") + else str(imported_file.path) + if hasattr(imported_file, "path") + else str(imported_file) + ) + # Extract imported module name - imported_parts = imported_path.split('/') - imported_module = '/'.join(imported_parts[:-1]) if len(imported_parts) > 1 else imported_parts[0] - + imported_parts = imported_path.split("/") + imported_module = ( + "/".join(imported_parts[:-1]) + if len(imported_parts) > 1 + else imported_parts[0] + ) + # Check if external - is_external = hasattr(imported_file, 'is_external') and imported_file.is_external + is_external = ( + hasattr(imported_file, "is_external") + and imported_file.is_external + ) else: # If we couldn't resolve the import, use the import name - imported_module = imp.name if hasattr(imp, 'name') else "unknown" - + imported_module = imp.name if hasattr(imp, "name") else "unknown" + # Assume external if we couldn't resolve is_external = True - + # Update import type counts if is_external: import_deps["dependency_stats"]["external_imports"] += 1 else: import_deps["dependency_stats"]["internal_imports"] += 1 - + # Check if relative import - if hasattr(imp, 'is_relative') and imp.is_relative: + if hasattr(imp, "is_relative") and imp.is_relative: import_deps["dependency_stats"]["relative_imports"] += 1 - + # Update module import counts module_imports[module_name] += 1 - + if imported_module not in module_imported: module_imported[imported_module] = 0 module_imported[imported_module] += 1 - + # Add to dependency graph if module_name != imported_module: # Skip self-imports G.add_edge(module_name, imported_module) - + # Add to file dependencies list import_deps["file_dependencies"].append({ "source_file": file_path, "target_file": imported_path if imported_file else "unknown", - "import_name": imp.name if hasattr(imp, 'name') else "unknown", - "is_external": is_external + "import_name": imp.name if hasattr(imp, "name") else "unknown", + "is_external": is_external, }) - + # Extract module dependencies from graph for source, target in G.edges(): import_deps["module_dependencies"].append({ "source_module": source, - "target_module": target + "target_module": target, }) - + # Find most imported modules most_imported = sorted( - [(module, count) for module, count in module_imported.items()], + module_imported.items(), key=lambda x: x[1], - reverse=True + reverse=True, ) - + for module, count in most_imported[:10]: # Top 10 import_deps["most_imported_modules"].append({ "module": module, - "import_count": count + "import_count": count, }) - + # Find modules that import the most most_importing = sorted( - [(module, count) for module, count in module_imports.items()], + module_imports.items(), key=lambda x: x[1], - reverse=True + reverse=True, ) - + for module, count in most_importing[:10]: # Top 10 import_deps["most_importing_modules"].append({ "module": module, - "import_count": count + "import_count": count, }) - + return import_deps - - def _find_circular_dependencies(self) -> Dict[str, Any]: + + def _find_circular_dependencies(self) -> dict[str, Any]: """ Find circular dependencies in the codebase. - + Returns: Dictionary containing circular dependencies analysis results """ circular_deps = { "circular_imports": [], "circular_dependencies_count": 0, - "affected_modules": set() + "affected_modules": set(), } - + # Create dependency graph if not already available G = nx.DiGraph() - + # Process all files to build dependency graph for file in self.base_codebase.files: # Skip if no imports - if not hasattr(file, 'imports') or not file.imports: + if not hasattr(file, "imports") or not file.imports: continue - + # Get file path - file_path = file.filepath if hasattr(file, 'filepath') else str(file.path) if hasattr(file, 'path') else str(file) - + file_path = ( + file.filepath + if hasattr(file, "filepath") + else str(file.path) + if hasattr(file, "path") + else str(file) + ) + # Process imports for imp in file.imports: # Get imported file imported_file = None - - if hasattr(imp, 'resolved_file'): + + if hasattr(imp, "resolved_file"): imported_file = imp.resolved_file - elif hasattr(imp, 'resolved_symbol') and hasattr(imp.resolved_symbol, 'file'): + elif hasattr(imp, "resolved_symbol") and hasattr( + imp.resolved_symbol, "file" + ): imported_file = imp.resolved_symbol.file - + if imported_file: # Get imported file path - imported_path = imported_file.filepath if hasattr(imported_file, 'filepath') else str(imported_file.path) if hasattr(imported_file, 'path') else str(imported_file) - + imported_path = ( + imported_file.filepath + if hasattr(imported_file, "filepath") + else str(imported_file.path) + if hasattr(imported_file, "path") + else str(imported_file) + ) + # Add edge to graph G.add_edge(file_path, imported_path) - + # Find cycles in the graph try: cycles = list(nx.simple_cycles(G)) - + for cycle in cycles: circular_deps["circular_imports"].append({ "files": cycle, - "length": len(cycle) + "length": len(cycle), }) - + # Add affected modules to set for file_path in cycle: - module_path = '/'.join(file_path.split('/')[:-1]) + module_path = "/".join(file_path.split("/")[:-1]) circular_deps["affected_modules"].add(module_path) - + # Add issue if len(cycle) >= 2: - self.add_issue(Issue( - file=cycle[0], - line=None, - message=f"Circular dependency detected between {len(cycle)} files", - severity=IssueSeverity.ERROR, - category=IssueCategory.DEPENDENCY_CYCLE, - suggestion="Break the circular dependency by refactoring the code" - )) - + self.add_issue( + Issue( + file=cycle[0], + line=None, + message=f"Circular dependency detected between {len(cycle)} files", + severity=IssueSeverity.ERROR, + category=IssueCategory.DEPENDENCY_CYCLE, + suggestion="Break the circular dependency by refactoring the code", + ) + ) + except Exception as e: - logger.error(f"Error finding circular dependencies: {e}") - + logger.exception(f"Error finding circular dependencies: {e}") + # Update cycle count - circular_deps["circular_dependencies_count"] = len(circular_deps["circular_imports"]) + circular_deps["circular_dependencies_count"] = len( + circular_deps["circular_imports"] + ) circular_deps["affected_modules"] = list(circular_deps["affected_modules"]) - + return circular_deps - - def _analyze_module_coupling(self) -> Dict[str, Any]: + + def _analyze_module_coupling(self) -> dict[str, Any]: """ Analyze module coupling in the codebase. - + Returns: Dictionary containing module coupling analysis results """ @@ -309,176 +373,212 @@ def _analyze_module_coupling(self) -> Dict[str, Any]: "high_coupling_modules": [], "low_coupling_modules": [], "coupling_metrics": {}, - "average_coupling": 0.0 + "average_coupling": 0.0, } - + # Create module dependency graphs modules = {} # Module name -> set of imported modules module_files = {} # Module name -> list of files - + # Process all files to extract module information for file in self.base_codebase.files: # Get file path - file_path = file.filepath if hasattr(file, 'filepath') else str(file.path) if hasattr(file, 'path') else str(file) - + file_path = ( + file.filepath + if hasattr(file, "filepath") + else str(file.path) + if hasattr(file, "path") + else str(file) + ) + # Extract module name from file path - module_parts = file_path.split('/') - module_name = '/'.join(module_parts[:-1]) if len(module_parts) > 1 else module_parts[0] - + module_parts = file_path.split("/") + module_name = ( + "/".join(module_parts[:-1]) + if len(module_parts) > 1 + else module_parts[0] + ) + # Initialize module structures if module_name not in modules: modules[module_name] = set() module_files[module_name] = [] - + module_files[module_name].append(file_path) - + # Skip if no imports - if not hasattr(file, 'imports') or not file.imports: + if not hasattr(file, "imports") or not file.imports: continue - + # Process imports for imp in file.imports: # Get imported file imported_file = None - - if hasattr(imp, 'resolved_file'): + + if hasattr(imp, "resolved_file"): imported_file = imp.resolved_file - elif hasattr(imp, 'resolved_symbol') and hasattr(imp.resolved_symbol, 'file'): + elif hasattr(imp, "resolved_symbol") and hasattr( + imp.resolved_symbol, "file" + ): imported_file = imp.resolved_symbol.file - + if imported_file: # Get imported file path - imported_path = imported_file.filepath if hasattr(imported_file, 'filepath') else str(imported_file.path) if hasattr(imported_file, 'path') else str(imported_file) - + imported_path = ( + imported_file.filepath + if hasattr(imported_file, "filepath") + else str(imported_file.path) + if hasattr(imported_file, "path") + else str(imported_file) + ) + # Extract imported module name - imported_parts = imported_path.split('/') - imported_module = '/'.join(imported_parts[:-1]) if len(imported_parts) > 1 else imported_parts[0] - + imported_parts = imported_path.split("/") + imported_module = ( + "/".join(imported_parts[:-1]) + if len(imported_parts) > 1 + else imported_parts[0] + ) + # Skip self-imports if imported_module != module_name: modules[module_name].add(imported_module) - + # Calculate coupling metrics for each module total_coupling = 0.0 module_count = 0 - + for module_name, imported_modules in modules.items(): # Calculate metrics file_count = len(module_files[module_name]) import_count = len(imported_modules) - + # Calculate coupling ratio (imports per file) coupling_ratio = import_count / file_count if file_count > 0 else 0 - + # Add to metrics coupling["coupling_metrics"][module_name] = { "files": file_count, "imported_modules": list(imported_modules), "import_count": import_count, - "coupling_ratio": coupling_ratio + "coupling_ratio": coupling_ratio, } - + # Track total for average total_coupling += coupling_ratio module_count += 1 - + # Categorize coupling if coupling_ratio > 3: # Threshold for "high coupling" coupling["high_coupling_modules"].append({ "module": module_name, "coupling_ratio": coupling_ratio, "import_count": import_count, - "file_count": file_count + "file_count": file_count, }) - + # Add issue - self.add_issue(Issue( - file=module_files[module_name][0] if module_files[module_name] else module_name, - line=None, - message=f"High module coupling: {coupling_ratio:.2f} imports per file", - severity=IssueSeverity.WARNING, - category=IssueCategory.DEPENDENCY_CYCLE, - suggestion="Consider refactoring to reduce coupling between modules" - )) - elif coupling_ratio < 0.5 and file_count > 1: # Threshold for "low coupling" + self.add_issue( + Issue( + file=module_files[module_name][0] + if module_files[module_name] + else module_name, + line=None, + message=f"High module coupling: {coupling_ratio:.2f} imports per file", + severity=IssueSeverity.WARNING, + category=IssueCategory.DEPENDENCY_CYCLE, + suggestion="Consider refactoring to reduce coupling between modules", + ) + ) + elif ( + coupling_ratio < 0.5 and file_count > 1 + ): # Threshold for "low coupling" coupling["low_coupling_modules"].append({ "module": module_name, "coupling_ratio": coupling_ratio, "import_count": import_count, - "file_count": file_count + "file_count": file_count, }) - + # Calculate average coupling - coupling["average_coupling"] = total_coupling / module_count if module_count > 0 else 0.0 - + coupling["average_coupling"] = ( + total_coupling / module_count if module_count > 0 else 0.0 + ) + # Sort coupling lists - coupling["high_coupling_modules"].sort(key=lambda x: x["coupling_ratio"], reverse=True) + coupling["high_coupling_modules"].sort( + key=lambda x: x["coupling_ratio"], reverse=True + ) coupling["low_coupling_modules"].sort(key=lambda x: x["coupling_ratio"]) - + return coupling - - def _analyze_external_dependencies(self) -> Dict[str, Any]: + + def _analyze_external_dependencies(self) -> dict[str, Any]: """ Analyze external dependencies in the codebase. - + Returns: Dictionary containing external dependencies analysis results """ external_deps = { "external_modules": [], "external_module_usage": {}, - "most_used_external_modules": [] + "most_used_external_modules": [], } - + # Track external module usage external_usage = {} # Module name -> usage count - + # Process all imports to find external dependencies for file in self.base_codebase.files: # Skip if no imports - if not hasattr(file, 'imports') or not file.imports: + if not hasattr(file, "imports") or not file.imports: continue - + # Process imports for imp in file.imports: # Check if external import is_external = False external_name = None - - if hasattr(imp, 'module_name'): + + if hasattr(imp, "module_name"): external_name = imp.module_name - + # Check if this is an external module - if hasattr(imp, 'is_external'): + if hasattr(imp, "is_external"): is_external = imp.is_external - elif external_name and '.' not in external_name and '/' not in external_name: + elif ( + external_name + and "." not in external_name + and "/" not in external_name + ): # Simple heuristic: single-word module names without dots or slashes # are likely external modules is_external = True - + if is_external and external_name: # Add to external modules list if not already there if external_name not in external_usage: external_usage[external_name] = 0 external_deps["external_modules"].append(external_name) - + external_usage[external_name] += 1 - + # Add usage counts for module, count in external_usage.items(): external_deps["external_module_usage"][module] = count - + # Find most used external modules most_used = sorted( - [(module, count) for module, count in external_usage.items()], + external_usage.items(), key=lambda x: x[1], - reverse=True + reverse=True, ) - + for module, count in most_used[:10]: # Top 10 external_deps["most_used_external_modules"].append({ "module": module, - "usage_count": count + "usage_count": count, }) - - return external_deps \ No newline at end of file + + return external_deps diff --git a/codegen-on-oss/codegen_on_oss/analyzers/diff_lite.py b/codegen-on-oss/codegen_on_oss/analyzers/diff_lite.py new file mode 100644 index 000000000..934b68d70 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/diff_lite.py @@ -0,0 +1,163 @@ +from enum import IntEnum, auto +from os import PathLike +from pathlib import Path +from typing import NamedTuple, Self + +from git import Diff +from watchfiles import Change + + +class ChangeType(IntEnum): + """ + Enumeration of change types for tracking file modifications. + + Attributes: + Modified: File content has been modified + Removed: File has been deleted + Renamed: File has been renamed + Added: New file has been added + """ + + Modified = auto() + Removed = auto() + Renamed = auto() + Added = auto() + + @staticmethod + def from_watch_change_type(change_type: Change) -> "ChangeType": + """ + Convert watchfiles Change type to ChangeType. + + Args: + change_type: The watchfiles Change enum value + + Returns: + Corresponding ChangeType enum value + """ + if change_type is Change.added: + return ChangeType.Added + elif change_type is Change.deleted: + return ChangeType.Removed + elif change_type is Change.modified: + return ChangeType.Modified + + msg = f"Unsupported watch change type: {change_type}" + raise ValueError(msg) + + @staticmethod + def from_git_change_type(change_type: str | None) -> "ChangeType": + """ + Convert git change type string to ChangeType. + + Args: + change_type: Git change type string ('M', 'D', 'R', 'A') + + Returns: + Corresponding ChangeType enum value + + Raises: + ValueError: If the change type is not supported + """ + if change_type == "M": + return ChangeType.Modified + if change_type == "D": + return ChangeType.Removed + if change_type == "R": + return ChangeType.Renamed + if change_type == "A": + return ChangeType.Added + + msg = f"Invalid git change type: {change_type}" + raise ValueError(msg) + + +class DiffLite(NamedTuple): + """ + Simple diff implementation for tracking file changes during code analysis. + + This lightweight diff implementation provides support for tracking file changes, + including modifications, removals, renames, and additions. + + Attributes: + change_type: Type of change (Modified, Removed, Renamed, Added) + path: Path to the file + rename_from: Original path for renamed files (None for non-renamed files) + rename_to: New path for renamed files (None for non-renamed files) + old_content: Previous content of the file (None if not available) + """ + + change_type: ChangeType + path: Path + rename_from: Path | None = None + rename_to: Path | None = None + old_content: bytes | None = None + + @classmethod + def from_watch_change(cls, change: Change, path: PathLike) -> Self: + """ + Create a DiffLite instance from a watchfiles Change. + + Args: + change: The watchfiles Change enum value + path: Path to the file + + Returns: + DiffLite instance representing the change + """ + return cls( + change_type=ChangeType.from_watch_change_type(change), + path=Path(path), + ) + + @classmethod + def from_git_diff(cls, git_diff: Diff) -> Self: + """ + Create a DiffLite instance from a git Diff object. + + Args: + git_diff: Git Diff object + + Returns: + DiffLite instance representing the git diff + """ + old = None + if git_diff.a_blob: + old = git_diff.a_blob.data_stream.read() + + return cls( + change_type=ChangeType.from_git_change_type(git_diff.change_type), + path=Path(git_diff.a_path) if git_diff.a_path else None, + rename_from=Path(git_diff.rename_from) if git_diff.rename_from else None, + rename_to=Path(git_diff.rename_to) if git_diff.rename_to else None, + old_content=old, + ) + + @classmethod + def from_reverse_diff(cls, diff_lite: "DiffLite") -> Self: + """ + Create a DiffLite instance that represents the reverse of another DiffLite. + + This is useful for undoing changes or representing the opposite operation. + + Args: + diff_lite: Original DiffLite instance + + Returns: + DiffLite instance representing the reverse change + """ + if diff_lite.change_type == ChangeType.Added: + change_type = ChangeType.Removed + elif diff_lite.change_type == ChangeType.Removed: + change_type = ChangeType.Added + else: + change_type = diff_lite.change_type + + if diff_lite.change_type == ChangeType.Renamed: + return cls( + change_type=change_type, + path=diff_lite.path, + rename_from=diff_lite.rename_to, + rename_to=diff_lite.rename_from, + ) + + return cls(change_type=change_type, path=diff_lite.path) diff --git a/codegen-on-oss/codegen_on_oss/analyzers/error_analyzer.py b/codegen-on-oss/codegen_on_oss/analyzers/error_analyzer.py index 104b72633..051ce2e15 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/error_analyzer.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/error_analyzer.py @@ -11,34 +11,41 @@ - codegen_on_oss.analyzers.dependency_analyzer.DependencyAnalyzer """ -import os -import sys import json import logging +import sys import warnings -from typing import Dict, List, Set, Tuple, Any, Optional, Union # Import from our new analyzers try: from codegen_on_oss.analyzers.base_analyzer import BaseCodeAnalyzer from codegen_on_oss.analyzers.code_quality_analyzer import CodeQualityAnalyzer from codegen_on_oss.analyzers.dependency_analyzer import DependencyAnalyzer - from codegen_on_oss.analyzers.issue_types import Issue, IssueSeverity, AnalysisType, IssueCategory - from codegen_on_oss.codebase_visualizer import CodebaseVisualizer, VisualizationType, OutputFormat + from codegen_on_oss.analyzers.issue_types import ( + AnalysisType, + Issue, + IssueCategory, + IssueSeverity, + ) + from codegen_on_oss.codebase_visualizer import ( + CodebaseVisualizer, + OutputFormat, + VisualizationType, + ) except ImportError: print("Error loading analyzer modules. Please make sure they are installed.") sys.exit(1) # Import codegen SDK try: - from codegen.sdk.core.codebase import Codebase from codegen.configs.models.codebase import CodebaseConfig from codegen.configs.models.secrets import SecretsConfig - from codegen.sdk.codebase.config import ProjectConfig - from codegen.git.schemas.repo_config import RepoConfig from codegen.git.repo_operator.repo_operator import RepoOperator - from codegen.shared.enums.programming_language import ProgrammingLanguage + from codegen.git.schemas.repo_config import RepoConfig from codegen.sdk.codebase.codebase_analysis import get_codebase_summary + from codegen.sdk.codebase.config import ProjectConfig + from codegen.sdk.core.codebase import Codebase + from codegen.shared.enums.programming_language import ProgrammingLanguage except ImportError: print("Codegen SDK not found. Please install it first.") sys.exit(1) @@ -46,8 +53,8 @@ # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) @@ -55,26 +62,27 @@ warnings.warn( "error_analyzer.py is deprecated. Please use analyzers directly from codegen_on_oss.analyzers package.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) + class CodebaseAnalyzer: """ Legacy interface to the new analyzer modules. - + This class provides backwards compatibility with code that used the old CodebaseAnalyzer class from error_analyzer.py. """ - + def __init__( self, - repo_url: Optional[str] = None, - repo_path: Optional[str] = None, - language: Optional[str] = None + repo_url: str | None = None, + repo_path: str | None = None, + language: str | None = None, ): """ Initialize the CodebaseAnalyzer. - + Args: repo_url: URL of the repository to analyze repo_path: Local path to the repository to analyze @@ -82,44 +90,45 @@ def __init__( """ # Create instances of the new analyzers self.quality_analyzer = CodeQualityAnalyzer( - repo_url=repo_url, - repo_path=repo_path, - language=language + repo_url=repo_url, repo_path=repo_path, language=language ) - + self.dependency_analyzer = DependencyAnalyzer( - repo_url=repo_url, - repo_path=repo_path, - language=language + repo_url=repo_url, repo_path=repo_path, language=language ) - + # Set up legacy attributes self.repo_url = repo_url self.repo_path = repo_path self.language = language self.codebase = self.quality_analyzer.base_codebase self.results = {} - + # Initialize visualizer - self.visualizer = CodebaseVisualizer( - codebase=self.codebase - ) - - def analyze(self, categories: List[str] = None, output_format: str = "json", output_file: Optional[str] = None): + self.visualizer = CodebaseVisualizer(codebase=self.codebase) + + def analyze( + self, + categories: list[str] | None = None, + output_format: str = "json", + output_file: str | None = None, + ): """ Perform a comprehensive analysis of the codebase. - + Args: categories: List of categories to analyze. If None, all categories are analyzed. output_format: Format of the output (json, html, console) output_file: Path to the output file - + Returns: Dict containing the analysis results """ if not self.codebase: - raise ValueError("Codebase not initialized. Please initialize the codebase first.") - + raise ValueError( + "Codebase not initialized. Please initialize the codebase first." + ) + # Map old category names to new analyzers category_map = { "codebase_structure": "dependency", @@ -128,20 +137,22 @@ def analyze(self, categories: List[str] = None, output_format: str = "json", out "code_quality": "code_quality", "visualization": "visualization", "language_specific": "code_quality", - "code_metrics": "code_quality" + "code_metrics": "code_quality", } - + # Initialize results with metadata self.results = { "metadata": { - "repo_name": getattr(self.codebase.ctx, 'repo_name', None), + "repo_name": getattr(self.codebase.ctx, "repo_name", None), "analysis_time": str(datetime.now()), - "language": str(getattr(self.codebase.ctx, 'programming_language', None)), - "codebase_summary": get_codebase_summary(self.codebase) + "language": str( + getattr(self.codebase.ctx, "programming_language", None) + ), + "codebase_summary": get_codebase_summary(self.codebase), }, - "categories": {} + "categories": {}, } - + # Determine categories to analyze if not categories: # If no categories are specified, run all analysis types @@ -152,65 +163,102 @@ def analyze(self, categories: List[str] = None, output_format: str = "json", out for category in categories: if category in category_map: analysis_types.add(category_map[category]) - + # Run each analysis type if "code_quality" in analysis_types: quality_results = self.quality_analyzer.analyze(AnalysisType.CODE_QUALITY) - + # Add results to the legacy format - for category in ["code_quality", "symbol_level", "language_specific", "code_metrics"]: + for category in [ + "code_quality", + "symbol_level", + "language_specific", + "code_metrics", + ]: if category in categories or not categories: self.results["categories"][category] = {} - + # Map new results to old category structure if category == "code_quality": self.results["categories"][category].update({ - "unused_functions": quality_results.get("dead_code", {}).get("unused_functions", []), - "unused_classes": quality_results.get("dead_code", {}).get("unused_classes", []), - "unused_variables": quality_results.get("dead_code", {}).get("unused_variables", []), - "unused_imports": quality_results.get("dead_code", {}).get("unused_imports", []), - "cyclomatic_complexity": quality_results.get("complexity", {}), - "cognitive_complexity": quality_results.get("complexity", {}), - "function_size_metrics": quality_results.get("style_issues", {}).get("long_functions", []) + "unused_functions": quality_results.get( + "dead_code", {} + ).get("unused_functions", []), + "unused_classes": quality_results.get("dead_code", {}).get( + "unused_classes", [] + ), + "unused_variables": quality_results.get( + "dead_code", {} + ).get("unused_variables", []), + "unused_imports": quality_results.get("dead_code", {}).get( + "unused_imports", [] + ), + "cyclomatic_complexity": quality_results.get( + "complexity", {} + ), + "cognitive_complexity": quality_results.get( + "complexity", {} + ), + "function_size_metrics": quality_results.get( + "style_issues", {} + ).get("long_functions", []), }) elif category == "symbol_level": self.results["categories"][category].update({ "function_parameter_analysis": [], - "function_complexity_metrics": quality_results.get("complexity", {}).get("function_complexity", []) + "function_complexity_metrics": quality_results.get( + "complexity", {} + ).get("function_complexity", []), }) elif category == "code_metrics": self.results["categories"][category].update({ - "calculate_cyclomatic_complexity": quality_results.get("complexity", {}), - "calculate_maintainability_index": quality_results.get("maintainability", {}) + "calculate_cyclomatic_complexity": quality_results.get( + "complexity", {} + ), + "calculate_maintainability_index": quality_results.get( + "maintainability", {} + ), }) - + if "dependency" in analysis_types: - dependency_results = self.dependency_analyzer.analyze(AnalysisType.DEPENDENCY) - + dependency_results = self.dependency_analyzer.analyze( + AnalysisType.DEPENDENCY + ) + # Add results to the legacy format for category in ["codebase_structure", "dependency_flow"]: if category in categories or not categories: self.results["categories"][category] = {} - + # Map new results to old category structure if category == "codebase_structure": self.results["categories"][category].update({ - "import_dependency_map": dependency_results.get("import_dependencies", {}).get("module_dependencies", []), - "circular_imports": dependency_results.get("circular_dependencies", {}).get("circular_imports", []), - "module_coupling_metrics": dependency_results.get("module_coupling", {}), - "module_dependency_graph": dependency_results.get("import_dependencies", {}).get("module_dependencies", []) + "import_dependency_map": dependency_results.get( + "import_dependencies", {} + ).get("module_dependencies", []), + "circular_imports": dependency_results.get( + "circular_dependencies", {} + ).get("circular_imports", []), + "module_coupling_metrics": dependency_results.get( + "module_coupling", {} + ), + "module_dependency_graph": dependency_results.get( + "import_dependencies", {} + ).get("module_dependencies", []), }) elif category == "dependency_flow": self.results["categories"][category].update({ "function_call_relationships": [], "entry_point_analysis": [], - "dead_code_detection": quality_results.get("dead_code", {}) if "code_quality" in analysis_types else {} + "dead_code_detection": quality_results.get("dead_code", {}) + if "code_quality" in analysis_types + else {}, }) - + # Output the results if output_format == "json": if output_file: - with open(output_file, 'w') as f: + with open(output_file, "w") as f: json.dump(self.results, f, indent=2) logger.info(f"Results saved to {output_file}") else: @@ -219,13 +267,13 @@ def analyze(self, categories: List[str] = None, output_format: str = "json", out self._generate_html_report(output_file) elif output_format == "console": self._print_console_report() - + return self.results - - def _generate_html_report(self, output_file: Optional[str] = None): + + def _generate_html_report(self, output_file: str | None = None): """ Generate an HTML report of the analysis results. - + Args: output_file: Path to the output file """ @@ -250,34 +298,36 @@ def _generate_html_report(self, output_file: Optional[str] = None):

    Codebase Analysis Report

    Metadata

    -

    Repository: {self.results['metadata'].get('repo_name', 'Unknown')}

    -

    Analysis Time: {self.results['metadata'].get('analysis_time', 'Unknown')}

    -

    Language: {self.results['metadata'].get('language', 'Unknown')}

    +

    Repository: {self.results["metadata"].get("repo_name", "Unknown")}

    +

    Analysis Time: {self.results["metadata"].get("analysis_time", "Unknown")}

    +

    Language: {self.results["metadata"].get("language", "Unknown")}

    """ - + # Add issues section html_content += """

    Issues

    """ - + # Collect all issues all_issues = [] - if hasattr(self.quality_analyzer, 'issues'): + if hasattr(self.quality_analyzer, "issues"): all_issues.extend(self.quality_analyzer.issues) - if hasattr(self.dependency_analyzer, 'issues'): + if hasattr(self.dependency_analyzer, "issues"): all_issues.extend(self.dependency_analyzer.issues) - + # Sort issues by severity - all_issues.sort(key=lambda x: { - IssueSeverity.CRITICAL: 0, - IssueSeverity.ERROR: 1, - IssueSeverity.WARNING: 2, - IssueSeverity.INFO: 3 - }.get(x.severity, 4)) - + all_issues.sort( + key=lambda x: { + IssueSeverity.CRITICAL: 0, + IssueSeverity.ERROR: 1, + IssueSeverity.WARNING: 2, + IssueSeverity.INFO: 3, + }.get(x.severity, 4) + ) + # Add issues to HTML for issue in all_issues: severity_class = issue.severity.value @@ -285,134 +335,147 @@ def _generate_html_report(self, output_file: Optional[str] = None):

    {issue.severity.value.upper()}: {issue.message}

    File: {issue.file} {f"(Line {issue.line})" if issue.line else ""}

    -

    Symbol: {issue.symbol or 'N/A'}

    -

    Suggestion: {issue.suggestion or 'N/A'}

    +

    Symbol: {issue.symbol or "N/A"}

    +

    Suggestion: {issue.suggestion or "N/A"}

    """ - + html_content += """
    """ - + # Add summary of results html_content += """

    Analysis Results

    """ - - for category, results in self.results.get('categories', {}).items(): + + for category, results in self.results.get("categories", {}).items(): html_content += f"""

    {category}

    {json.dumps(results, indent=2)}
    """ - + html_content += """
    """ - + # Save HTML to file or print to console if output_file: - with open(output_file, 'w') as f: + with open(output_file, "w") as f: f.write(html_content) logger.info(f"HTML report saved to {output_file}") else: print(html_content) - + def _print_console_report(self): """Print a summary of the analysis results to the console.""" print("\nšŸ“Š Codebase Analysis Report šŸ“Š") print("=" * 50) - + # Print metadata - print(f"\nšŸ“Œ Repository: {self.results['metadata'].get('repo_name', 'Unknown')}") - print(f"šŸ“† Analysis Time: {self.results['metadata'].get('analysis_time', 'Unknown')}") + print( + f"\nšŸ“Œ Repository: {self.results['metadata'].get('repo_name', 'Unknown')}" + ) + print( + f"šŸ“† Analysis Time: {self.results['metadata'].get('analysis_time', 'Unknown')}" + ) print(f"šŸ”¤ Language: {self.results['metadata'].get('language', 'Unknown')}") - + # Print summary of issues print("\n🚨 Issues Summary") print("-" * 50) - + # Collect all issues all_issues = [] - if hasattr(self.quality_analyzer, 'issues'): + if hasattr(self.quality_analyzer, "issues"): all_issues.extend(self.quality_analyzer.issues) - if hasattr(self.dependency_analyzer, 'issues'): + if hasattr(self.dependency_analyzer, "issues"): all_issues.extend(self.dependency_analyzer.issues) - + # Print issue counts by severity severity_counts = { IssueSeverity.CRITICAL: 0, IssueSeverity.ERROR: 0, IssueSeverity.WARNING: 0, - IssueSeverity.INFO: 0 + IssueSeverity.INFO: 0, } - + for issue in all_issues: severity_counts[issue.severity] += 1 - + print(f"Critical: {severity_counts[IssueSeverity.CRITICAL]}") print(f"Errors: {severity_counts[IssueSeverity.ERROR]}") print(f"Warnings: {severity_counts[IssueSeverity.WARNING]}") print(f"Info: {severity_counts[IssueSeverity.INFO]}") print(f"Total: {len(all_issues)}") - + # Print top issues by severity if all_issues: print("\nšŸ” Top Issues") print("-" * 50) - + # Sort issues by severity - all_issues.sort(key=lambda x: { - IssueSeverity.CRITICAL: 0, - IssueSeverity.ERROR: 1, - IssueSeverity.WARNING: 2, - IssueSeverity.INFO: 3 - }.get(x.severity, 4)) - + all_issues.sort( + key=lambda x: { + IssueSeverity.CRITICAL: 0, + IssueSeverity.ERROR: 1, + IssueSeverity.WARNING: 2, + IssueSeverity.INFO: 3, + }.get(x.severity, 4) + ) + # Print top 10 issues for i, issue in enumerate(all_issues[:10]): - print(f"{i+1}. [{issue.severity.value.upper()}] {issue.message}") - print(f" File: {issue.file} {f'(Line {issue.line})' if issue.line else ''}") + print(f"{i + 1}. [{issue.severity.value.upper()}] {issue.message}") + print( + f" File: {issue.file} {f'(Line {issue.line})' if issue.line else ''}" + ) print(f" Symbol: {issue.symbol or 'N/A'}") print(f" Suggestion: {issue.suggestion or 'N/A'}") print() - + # Print summary of results by category - for category, results in self.results.get('categories', {}).items(): + for category, results in self.results.get("categories", {}).items(): print(f"\nšŸ“‹ {category.replace('_', ' ').title()}") print("-" * 50) - + # Print key statistics for each category if category == "code_quality": unused_funcs = len(results.get("unused_functions", [])) unused_vars = len(results.get("unused_variables", [])) print(f"Unused Functions: {unused_funcs}") print(f"Unused Variables: {unused_vars}") - + # Print complexity stats if available complexity = results.get("cyclomatic_complexity", {}) if "function_complexity" in complexity: - high_complexity = [f for f in complexity["function_complexity"] if f.get("complexity", 0) > 10] + high_complexity = [ + f + for f in complexity["function_complexity"] + if f.get("complexity", 0) > 10 + ] print(f"High Complexity Functions: {len(high_complexity)}") - + elif category == "codebase_structure": circular_imports = len(results.get("circular_imports", [])) print(f"Circular Imports: {circular_imports}") - + module_deps = results.get("module_dependency_graph", []) print(f"Module Dependencies: {len(module_deps)}") - + elif category == "dependency_flow": dead_code = results.get("dead_code_detection", {}) total_dead = ( - len(dead_code.get("unused_functions", [])) + - len(dead_code.get("unused_classes", [])) + - len(dead_code.get("unused_variables", [])) + len(dead_code.get("unused_functions", [])) + + len(dead_code.get("unused_classes", [])) + + len(dead_code.get("unused_variables", [])) ) print(f"Dead Code Items: {total_dead}") + # For backwards compatibility, expose the CodebaseAnalyzer class as the main interface -__all__ = ['CodebaseAnalyzer'] \ No newline at end of file +__all__ = ["CodebaseAnalyzer"] diff --git a/codegen-on-oss/codegen_on_oss/analyzers/issue_analyzer.py b/codegen-on-oss/codegen_on_oss/analyzers/issue_analyzer.py index 213db9bb0..d1eb8a6c7 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/issue_analyzer.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/issue_analyzer.py @@ -2,33 +2,38 @@ """ Issue Analyzer Module -This module provides common functionality for detecting and tracking issues +This module provides common functionality for detecting and tracking issues across different types of code analyzers. It provides standardized issue handling and categorization to ensure consistent issue reporting. """ -import os import logging -from typing import Dict, List, Set, Any, Optional, Union, Callable +from collections.abc import Callable +from typing import Any from codegen_on_oss.analyzers.base_analyzer import BaseCodeAnalyzer -from codegen_on_oss.analyzers.issue_types import Issue, IssueSeverity, AnalysisType, IssueCategory +from codegen_on_oss.analyzers.issue_types import ( + Issue, + IssueCategory, + IssueSeverity, +) # Configure logging logger = logging.getLogger(__name__) + class IssueAnalyzer(BaseCodeAnalyzer): """ Base class for analyzers that detect and report issues. - + This class builds on the BaseCodeAnalyzer to add standardized issue tracking, categorization, and reporting capabilities. """ - + def __init__(self, **kwargs): """ Initialize the issue analyzer. - + Args: **kwargs: Arguments to pass to the BaseCodeAnalyzer """ @@ -37,45 +42,48 @@ def __init__(self, **kwargs): self.issue_handlers = {} self.issue_categories = set() self.register_default_filters() - + def register_default_filters(self): """Register default issue filters.""" # Filter out issues in test files by default - self.add_issue_filter(lambda issue: "test" in issue.file.lower(), - "Skip issues in test files") - + self.add_issue_filter( + lambda issue: "test" in issue.file.lower(), "Skip issues in test files" + ) + # Filter out issues in generated files by default - self.add_issue_filter(lambda issue: "generated" in issue.file.lower(), - "Skip issues in generated files") - + self.add_issue_filter( + lambda issue: "generated" in issue.file.lower(), + "Skip issues in generated files", + ) + def add_issue_filter(self, filter_func: Callable[[Issue], bool], description: str): """ Add a filter function that determines if an issue should be skipped. - + Args: filter_func: Function that returns True if issue should be skipped description: Description of the filter """ self.issue_filters.append((filter_func, description)) - + def register_issue_handler(self, category: IssueCategory, handler: Callable): """ Register a handler function for a specific issue category. - + Args: category: Issue category to handle handler: Function that will detect issues of this category """ self.issue_handlers[category] = handler self.issue_categories.add(category) - + def should_skip_issue(self, issue: Issue) -> bool: """ Check if an issue should be skipped based on registered filters. - + Args: issue: Issue to check - + Returns: True if the issue should be skipped, False otherwise """ @@ -85,129 +93,154 @@ def should_skip_issue(self, issue: Issue) -> bool: return True except Exception as e: logger.debug(f"Error applying issue filter: {e}") - + return False - + def add_issue(self, issue: Issue): """ Add an issue to the list if it passes all filters. - + Args: issue: Issue to add """ if self.should_skip_issue(issue): return - + super().add_issue(issue) - - def detect_issues(self, categories: Optional[List[IssueCategory]] = None) -> Dict[IssueCategory, List[Issue]]: + + def detect_issues( + self, categories: list[IssueCategory] | None = None + ) -> dict[IssueCategory, list[Issue]]: """ Detect issues across specified categories. - + Args: categories: Categories of issues to detect (defaults to all registered categories) - + Returns: Dictionary mapping categories to lists of issues """ result = {} - + # Use all registered categories if none specified if not categories: categories = list(self.issue_categories) - + # Process each requested category for category in categories: if category in self.issue_handlers: # Clear existing issues of this category self.issues = [i for i in self.issues if i.category != category] - + # Run the handler to detect issues try: handler = self.issue_handlers[category] handler_result = handler() result[category] = handler_result except Exception as e: - logger.error(f"Error detecting issues for category {category}: {e}") + logger.exception( + f"Error detecting issues for category {category}: {e}" + ) result[category] = [] else: logger.warning(f"No handler registered for issue category: {category}") result[category] = [] - + return result - - def get_issues_by_category(self) -> Dict[IssueCategory, List[Issue]]: + + def get_issues_by_category(self) -> dict[IssueCategory, list[Issue]]: """ Group issues by category. - + Returns: Dictionary mapping categories to lists of issues """ result = {} - + for issue in self.issues: if issue.category: if issue.category not in result: result[issue.category] = [] result[issue.category].append(issue) - + return result - - def get_issue_statistics(self) -> Dict[str, Any]: + + def get_issue_statistics(self) -> dict[str, Any]: """ Get statistics about detected issues. - + Returns: Dictionary with issue statistics """ issues_by_category = self.get_issues_by_category() - + return { "total": len(self.issues), "by_severity": { - "critical": sum(1 for issue in self.issues if issue.severity == IssueSeverity.CRITICAL), - "error": sum(1 for issue in self.issues if issue.severity == IssueSeverity.ERROR), - "warning": sum(1 for issue in self.issues if issue.severity == IssueSeverity.WARNING), - "info": sum(1 for issue in self.issues if issue.severity == IssueSeverity.INFO), + "critical": sum( + 1 + for issue in self.issues + if issue.severity == IssueSeverity.CRITICAL + ), + "error": sum( + 1 for issue in self.issues if issue.severity == IssueSeverity.ERROR + ), + "warning": sum( + 1 + for issue in self.issues + if issue.severity == IssueSeverity.WARNING + ), + "info": sum( + 1 for issue in self.issues if issue.severity == IssueSeverity.INFO + ), }, "by_category": { category.value: len(issues) for category, issues in issues_by_category.items() - } + }, } - + def format_issues_report(self) -> str: """ Format issues as a readable report. - + Returns: Formatted string with issue report """ report_lines = [ "==== Issues Report ====", f"Total issues: {len(self.issues)}", - "" + "", ] - + # Group by severity issues_by_severity = {} for issue in self.issues: if issue.severity not in issues_by_severity: issues_by_severity[issue.severity] = [] issues_by_severity[issue.severity].append(issue) - + # Add severity sections - for severity in [IssueSeverity.CRITICAL, IssueSeverity.ERROR, IssueSeverity.WARNING, IssueSeverity.INFO]: + for severity in [ + IssueSeverity.CRITICAL, + IssueSeverity.ERROR, + IssueSeverity.WARNING, + IssueSeverity.INFO, + ]: if severity in issues_by_severity: - report_lines.append(f"==== {severity.value.upper()} ({len(issues_by_severity[severity])}) ====") - + report_lines.append( + f"==== {severity.value.upper()} ({len(issues_by_severity[severity])}) ====" + ) + for issue in issues_by_severity[severity]: - location = f"{issue.file}:{issue.line}" if issue.line else issue.file + location = ( + f"{issue.file}:{issue.line}" if issue.line else issue.file + ) category = f"[{issue.category.value}]" if issue.category else "" report_lines.append(f"{location} {category} {issue.message}") if issue.suggestion: report_lines.append(f" Suggestion: {issue.suggestion}") - + report_lines.append("") - - return "\n".join(report_lines) \ No newline at end of file + + return "\n".join(report_lines) diff --git a/codegen-on-oss/codegen_on_oss/analyzers/issue_types.py b/codegen-on-oss/codegen_on_oss/analyzers/issue_types.py index a474d5f74..07b8877d8 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/issue_types.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/issue_types.py @@ -8,10 +8,12 @@ from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Set, Tuple, Any, Optional, Union +from typing import Any + class AnalysisType(str, Enum): """Types of analysis that can be performed.""" + CODEBASE = "codebase" PR = "pr" COMPARISON = "comparison" @@ -21,15 +23,19 @@ class AnalysisType(str, Enum): DEPENDENCY = "dependency" TYPE_CHECKING = "type_checking" + class IssueSeverity(str, Enum): """Severity levels for issues.""" + CRITICAL = "critical" ERROR = "error" WARNING = "warning" INFO = "info" + class IssueCategory(str, Enum): """Categories of issues that can be detected.""" + DEAD_CODE = "dead_code" COMPLEXITY = "complexity" TYPE_ERROR = "type_error" @@ -42,19 +48,21 @@ class IssueCategory(str, Enum): STYLE_ISSUE = "style_issue" DOCUMENTATION = "documentation" + @dataclass class Issue: """Represents an issue found during analysis.""" + file: str - line: Optional[int] + line: int | None message: str severity: IssueSeverity - category: Optional[IssueCategory] = None - symbol: Optional[str] = None - code: Optional[str] = None - suggestion: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: + category: IssueCategory | None = None + symbol: str | None = None + code: str | None = None + suggestion: str | None = None + + def to_dict(self) -> dict[str, Any]: """Convert issue to dictionary representation.""" return { "file": self.file, @@ -64,11 +72,11 @@ def to_dict(self) -> Dict[str, Any]: "category": self.category, "symbol": self.symbol, "code": self.code, - "suggestion": self.suggestion + "suggestion": self.suggestion, } - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'Issue': + def from_dict(cls, data: dict[str, Any]) -> "Issue": """Create an issue from a dictionary representation.""" return cls( file=data["file"], @@ -78,5 +86,5 @@ def from_dict(cls, data: Dict[str, Any]) -> 'Issue': category=IssueCategory(data["category"]) if "category" in data else None, symbol=data.get("symbol"), code=data.get("code"), - suggestion=data.get("suggestion") - ) \ No newline at end of file + suggestion=data.get("suggestion"), + ) diff --git a/codegen-on-oss/codegen_on_oss/analyzers/issues.py b/codegen-on-oss/codegen_on_oss/analyzers/issues.py index f7880126c..c20ddc3ea 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/issues.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/issues.py @@ -6,23 +6,25 @@ It provides a standardized way to represent and manage issues across different analyzers. """ -import os import json import logging -from dataclasses import dataclass, field, asdict +from collections.abc import Callable +from dataclasses import asdict, dataclass, field from enum import Enum -from typing import Dict, List, Set, Tuple, Any, Optional, Union, Callable +from typing import Any # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) + class AnalysisType(str, Enum): """Types of analysis that can be performed.""" + CODEBASE = "codebase" PR = "pr" COMPARISON = "comparison" @@ -32,71 +34,79 @@ class AnalysisType(str, Enum): PERFORMANCE = "performance" TYPE_CHECKING = "type_checking" + class IssueSeverity(str, Enum): """Severity levels for issues.""" + CRITICAL = "critical" # Must be fixed immediately, blocks functionality - ERROR = "error" # Must be fixed, causes errors or undefined behavior - WARNING = "warning" # Should be fixed, may cause problems in future - INFO = "info" # Informational, could be improved but not critical + ERROR = "error" # Must be fixed, causes errors or undefined behavior + WARNING = "warning" # Should be fixed, may cause problems in future + INFO = "info" # Informational, could be improved but not critical + class IssueCategory(str, Enum): """Categories of issues that can be detected.""" + # Code Quality Issues - DEAD_CODE = "dead_code" # Unused variables, functions, etc. - COMPLEXITY = "complexity" # Code too complex, needs refactoring - STYLE_ISSUE = "style_issue" # Code style issues (line length, etc.) - DOCUMENTATION = "documentation" # Missing or incomplete documentation - + DEAD_CODE = "dead_code" # Unused variables, functions, etc. + COMPLEXITY = "complexity" # Code too complex, needs refactoring + STYLE_ISSUE = "style_issue" # Code style issues (line length, etc.) + DOCUMENTATION = "documentation" # Missing or incomplete documentation + # Type and Parameter Issues - TYPE_ERROR = "type_error" # Type errors or inconsistencies + TYPE_ERROR = "type_error" # Type errors or inconsistencies PARAMETER_MISMATCH = "parameter_mismatch" # Parameter type or count mismatch - RETURN_TYPE_ERROR = "return_type_error" # Return type error or mismatch - + RETURN_TYPE_ERROR = "return_type_error" # Return type error or mismatch + # Implementation Issues IMPLEMENTATION_ERROR = "implementation_error" # Incorrect implementation MISSING_IMPLEMENTATION = "missing_implementation" # Missing implementation - + # Dependency Issues - IMPORT_ERROR = "import_error" # Import errors or issues + IMPORT_ERROR = "import_error" # Import errors or issues DEPENDENCY_CYCLE = "dependency_cycle" # Circular dependency - MODULE_COUPLING = "module_coupling" # High coupling between modules - + MODULE_COUPLING = "module_coupling" # High coupling between modules + # API Issues - API_CHANGE = "api_change" # API has changed in a breaking way - API_USAGE_ERROR = "api_usage_error" # Incorrect API usage - + API_CHANGE = "api_change" # API has changed in a breaking way + API_USAGE_ERROR = "api_usage_error" # Incorrect API usage + # Security Issues SECURITY_VULNERABILITY = "security_vulnerability" # Security vulnerability - + # Performance Issues PERFORMANCE_ISSUE = "performance_issue" # Performance issue + class IssueStatus(str, Enum): """Status of an issue.""" - OPEN = "open" # Issue is open and needs to be fixed - FIXED = "fixed" # Issue has been fixed + + OPEN = "open" # Issue is open and needs to be fixed + FIXED = "fixed" # Issue has been fixed WONTFIX = "wontfix" # Issue will not be fixed INVALID = "invalid" # Issue is invalid or not applicable DUPLICATE = "duplicate" # Issue is a duplicate of another + @dataclass class CodeLocation: """Location of an issue in code.""" + file: str - line: Optional[int] = None - column: Optional[int] = None - end_line: Optional[int] = None - end_column: Optional[int] = None - - def to_dict(self) -> Dict[str, Any]: + line: int | None = None + column: int | None = None + end_line: int | None = None + end_column: int | None = None + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary representation.""" return {k: v for k, v in asdict(self).items() if v is not None} - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'CodeLocation': + def from_dict(cls, data: dict[str, Any]) -> "CodeLocation": """Create from dictionary representation.""" return cls(**{k: v for k, v in data.items() if k in cls.__annotations__}) - + def __str__(self) -> str: """Convert to string representation.""" if self.line is not None: @@ -105,51 +115,54 @@ def __str__(self) -> str: return f"{self.file}:{self.line}" return self.file + @dataclass class Issue: """Represents an issue found during analysis.""" + # Core fields message: str severity: IssueSeverity location: CodeLocation - + # Classification fields - category: Optional[IssueCategory] = None - analysis_type: Optional[AnalysisType] = None + category: IssueCategory | None = None + analysis_type: AnalysisType | None = None status: IssueStatus = IssueStatus.OPEN - + # Context fields - symbol: Optional[str] = None - code: Optional[str] = None - suggestion: Optional[str] = None - related_symbols: List[str] = field(default_factory=list) - related_locations: List[CodeLocation] = field(default_factory=list) - + symbol: str | None = None + code: str | None = None + suggestion: str | None = None + related_symbols: list[str] = field(default_factory=list) + related_locations: list[CodeLocation] = field(default_factory=list) + # Metadata fields - id: Optional[str] = None - hash: Optional[str] = None - metadata: Dict[str, Any] = field(default_factory=dict) - + id: str | None = None + hash: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + def __post_init__(self): """Initialize derived fields.""" # Generate an ID if not provided if self.id is None: import hashlib + # Create a hash based on location and message hash_input = f"{self.location.file}:{self.location.line}:{self.message}" self.id = hashlib.md5(hash_input.encode()).hexdigest()[:12] - + @property def file(self) -> str: """Get the file path.""" return self.location.file - + @property - def line(self) -> Optional[int]: + def line(self) -> int | None: """Get the line number.""" return self.location.line - - def to_dict(self) -> Dict[str, Any]: + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary representation.""" result = { "id": self.id, @@ -158,203 +171,211 @@ def to_dict(self) -> Dict[str, Any]: "location": self.location.to_dict(), "status": self.status.value, } - + # Add optional fields if present if self.category: result["category"] = self.category.value - + if self.analysis_type: result["analysis_type"] = self.analysis_type.value - + if self.symbol: result["symbol"] = self.symbol - + if self.code: result["code"] = self.code - + if self.suggestion: result["suggestion"] = self.suggestion - + if self.related_symbols: result["related_symbols"] = self.related_symbols - + if self.related_locations: - result["related_locations"] = [loc.to_dict() for loc in self.related_locations] - + result["related_locations"] = [ + loc.to_dict() for loc in self.related_locations + ] + if self.metadata: result["metadata"] = self.metadata - + return result - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'Issue': + def from_dict(cls, data: dict[str, Any]) -> "Issue": """Create from dictionary representation.""" # Convert string enums to actual enum values if "severity" in data and isinstance(data["severity"], str): data["severity"] = IssueSeverity(data["severity"]) - + if "category" in data and isinstance(data["category"], str): data["category"] = IssueCategory(data["category"]) - + if "analysis_type" in data and isinstance(data["analysis_type"], str): data["analysis_type"] = AnalysisType(data["analysis_type"]) - + if "status" in data and isinstance(data["status"], str): data["status"] = IssueStatus(data["status"]) - + # Convert location dict to CodeLocation if "location" in data and isinstance(data["location"], dict): data["location"] = CodeLocation.from_dict(data["location"]) - + # Convert related_locations dicts to CodeLocation objects if "related_locations" in data and isinstance(data["related_locations"], list): data["related_locations"] = [ CodeLocation.from_dict(loc) if isinstance(loc, dict) else loc for loc in data["related_locations"] ] - + return cls(**{k: v for k, v in data.items() if k in cls.__annotations__}) + class IssueCollection: """Collection of issues with filtering and grouping capabilities.""" - - def __init__(self, issues: Optional[List[Issue]] = None): + + def __init__(self, issues: list[Issue] | None = None): """ Initialize the issue collection. - + Args: issues: Initial list of issues """ self.issues = issues or [] self._filters = [] - + def add_issue(self, issue: Issue): """ Add an issue to the collection. - + Args: issue: Issue to add """ self.issues.append(issue) - - def add_issues(self, issues: List[Issue]): + + def add_issues(self, issues: list[Issue]): """ Add multiple issues to the collection. - + Args: issues: Issues to add """ self.issues.extend(issues) - + def add_filter(self, filter_func: Callable[[Issue], bool], description: str = ""): """ Add a filter function. - + Args: filter_func: Function that returns True if issue should be included description: Description of the filter """ self._filters.append((filter_func, description)) - + def get_issues( self, - severity: Optional[IssueSeverity] = None, - category: Optional[IssueCategory] = None, - status: Optional[IssueStatus] = None, - file_path: Optional[str] = None, - symbol: Optional[str] = None - ) -> List[Issue]: + severity: IssueSeverity | None = None, + category: IssueCategory | None = None, + status: IssueStatus | None = None, + file_path: str | None = None, + symbol: str | None = None, + ) -> list[Issue]: """ Get issues matching the specified criteria. - + Args: severity: Severity to filter by category: Category to filter by status: Status to filter by file_path: File path to filter by symbol: Symbol name to filter by - + Returns: List of matching issues """ filtered_issues = self.issues - + # Apply custom filters for filter_func, _ in self._filters: filtered_issues = [i for i in filtered_issues if filter_func(i)] - + # Apply standard filters if severity: filtered_issues = [i for i in filtered_issues if i.severity == severity] - + if category: filtered_issues = [i for i in filtered_issues if i.category == category] - + if status: filtered_issues = [i for i in filtered_issues if i.status == status] - + if file_path: - filtered_issues = [i for i in filtered_issues if i.location.file == file_path] - + filtered_issues = [ + i for i in filtered_issues if i.location.file == file_path + ] + if symbol: filtered_issues = [ - i for i in filtered_issues - if (i.symbol == symbol or - (i.related_symbols and symbol in i.related_symbols)) + i + for i in filtered_issues + if ( + i.symbol == symbol + or (i.related_symbols and symbol in i.related_symbols) + ) ] - + return filtered_issues - - def group_by_severity(self) -> Dict[IssueSeverity, List[Issue]]: + + def group_by_severity(self) -> dict[IssueSeverity, list[Issue]]: """ Group issues by severity. - + Returns: Dictionary mapping severities to lists of issues """ result = {severity: [] for severity in IssueSeverity} - + for issue in self.issues: result[issue.severity].append(issue) - + return result - - def group_by_category(self) -> Dict[IssueCategory, List[Issue]]: + + def group_by_category(self) -> dict[IssueCategory, list[Issue]]: """ Group issues by category. - + Returns: Dictionary mapping categories to lists of issues """ result = {category: [] for category in IssueCategory} - + for issue in self.issues: if issue.category: result[issue.category].append(issue) - + return result - - def group_by_file(self) -> Dict[str, List[Issue]]: + + def group_by_file(self) -> dict[str, list[Issue]]: """ Group issues by file. - + Returns: Dictionary mapping file paths to lists of issues """ result = {} - + for issue in self.issues: if issue.location.file not in result: result[issue.location.file] = [] - + result[issue.location.file].append(issue) - + return result - - def statistics(self) -> Dict[str, Any]: + + def statistics(self) -> dict[str, Any]: """ Get statistics about the issues. - + Returns: Dictionary with issue statistics """ @@ -363,12 +384,11 @@ def statistics(self) -> Dict[str, Any]: by_status = {status: [] for status in IssueStatus} for issue in self.issues: by_status[issue.status].append(issue) - + return { "total": len(self.issues), "by_severity": { - severity.value: len(issues) - for severity, issues in by_severity.items() + severity.value: len(issues) for severity, issues in by_severity.items() }, "by_category": { category.value: len(issues) @@ -376,50 +396,49 @@ def statistics(self) -> Dict[str, Any]: if len(issues) > 0 # Only include non-empty categories }, "by_status": { - status.value: len(issues) - for status, issues in by_status.items() + status.value: len(issues) for status, issues in by_status.items() }, - "file_count": len(self.group_by_file()) + "file_count": len(self.group_by_file()), } - - def to_dict(self) -> Dict[str, Any]: + + def to_dict(self) -> dict[str, Any]: """ Convert to dictionary representation. - + Returns: Dictionary representation of the issue collection """ return { "issues": [issue.to_dict() for issue in self.issues], "statistics": self.statistics(), - "filters": [desc for _, desc in self._filters if desc] + "filters": [desc for _, desc in self._filters if desc], } - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'IssueCollection': + def from_dict(cls, data: dict[str, Any]) -> "IssueCollection": """ Create from dictionary representation. - + Args: data: Dictionary representation - + Returns: Issue collection """ collection = cls() - + if "issues" in data and isinstance(data["issues"], list): collection.add_issues([ Issue.from_dict(issue) if isinstance(issue, dict) else issue for issue in data["issues"] ]) - + return collection - + def save_to_file(self, file_path: str, format: str = "json"): """ Save to file. - + Args: file_path: Path to save to format: Format to save in @@ -429,36 +448,36 @@ def save_to_file(self, file_path: str, format: str = "json"): json.dump(self.to_dict(), f, indent=2) else: raise ValueError(f"Unsupported format: {format}") - + @classmethod - def load_from_file(cls, file_path: str) -> 'IssueCollection': + def load_from_file(cls, file_path: str) -> "IssueCollection": """ Load from file. - + Args: file_path: Path to load from - + Returns: Issue collection """ - with open(file_path, "r") as f: + with open(file_path) as f: data = json.load(f) - + return cls.from_dict(data) def create_issue( message: str, - severity: Union[str, IssueSeverity], + severity: str | IssueSeverity, file: str, - line: Optional[int] = None, - category: Optional[Union[str, IssueCategory]] = None, - symbol: Optional[str] = None, - suggestion: Optional[str] = None + line: int | None = None, + category: str | IssueCategory | None = None, + symbol: str | None = None, + suggestion: str | None = None, ) -> Issue: """ Create an issue with simplified parameters. - + Args: message: Issue message severity: Issue severity @@ -467,21 +486,21 @@ def create_issue( category: Issue category symbol: Symbol name suggestion: Suggested fix - + Returns: Issue object """ # Convert string severity to enum if isinstance(severity, str): severity = IssueSeverity(severity) - + # Convert string category to enum if isinstance(category, str) and category: category = IssueCategory(category) - + # Create location location = CodeLocation(file=file, line=line) - + # Create issue return Issue( message=message, @@ -489,5 +508,5 @@ def create_issue( location=location, category=category, symbol=symbol, - suggestion=suggestion - ) \ No newline at end of file + suggestion=suggestion, + ) diff --git a/codegen-on-oss/codegen_on_oss/analyzers/models/analysis_result.py b/codegen-on-oss/codegen_on_oss/analyzers/models/analysis_result.py index 0cd012609..140bb73ae 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/models/analysis_result.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/models/analysis_result.py @@ -7,122 +7,137 @@ """ import json -from dataclasses import dataclass, field, asdict -from enum import Enum -from typing import Dict, List, Set, Any, Optional, Union +from dataclasses import asdict, dataclass, field from datetime import datetime +from typing import Any from codegen_on_oss.analyzers.issues import AnalysisType, IssueCollection + @dataclass class AnalysisSummary: """Summary statistics for an analysis.""" + total_files: int = 0 total_functions: int = 0 total_classes: int = 0 total_issues: int = 0 analysis_time: str = field(default_factory=lambda: datetime.now().isoformat()) - analysis_duration_ms: Optional[int] = None - - def to_dict(self) -> Dict[str, Any]: + analysis_duration_ms: int | None = None + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary representation.""" return {k: v for k, v in asdict(self).items() if v is not None} + @dataclass class CodeQualityResult: """Results of code quality analysis.""" - dead_code: Dict[str, Any] = field(default_factory=dict) - complexity: Dict[str, Any] = field(default_factory=dict) - parameter_issues: Dict[str, Any] = field(default_factory=dict) - style_issues: Dict[str, Any] = field(default_factory=dict) - implementation_issues: Dict[str, Any] = field(default_factory=dict) - maintainability: Dict[str, Any] = field(default_factory=dict) - - def to_dict(self) -> Dict[str, Any]: + + dead_code: dict[str, Any] = field(default_factory=dict) + complexity: dict[str, Any] = field(default_factory=dict) + parameter_issues: dict[str, Any] = field(default_factory=dict) + style_issues: dict[str, Any] = field(default_factory=dict) + implementation_issues: dict[str, Any] = field(default_factory=dict) + maintainability: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary representation.""" - return {k: v for k, v in asdict(self).items()} + return dict(asdict(self).items()) + @dataclass class DependencyResult: """Results of dependency analysis.""" - import_dependencies: Dict[str, Any] = field(default_factory=dict) - circular_dependencies: Dict[str, Any] = field(default_factory=dict) - module_coupling: Dict[str, Any] = field(default_factory=dict) - external_dependencies: Dict[str, Any] = field(default_factory=dict) - call_graph: Dict[str, Any] = field(default_factory=dict) - class_hierarchy: Dict[str, Any] = field(default_factory=dict) - - def to_dict(self) -> Dict[str, Any]: + + import_dependencies: dict[str, Any] = field(default_factory=dict) + circular_dependencies: dict[str, Any] = field(default_factory=dict) + module_coupling: dict[str, Any] = field(default_factory=dict) + external_dependencies: dict[str, Any] = field(default_factory=dict) + call_graph: dict[str, Any] = field(default_factory=dict) + class_hierarchy: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary representation.""" - return {k: v for k, v in asdict(self).items()} + return dict(asdict(self).items()) + @dataclass class PrAnalysisResult: """Results of PR analysis.""" - modified_symbols: List[Dict[str, Any]] = field(default_factory=list) - added_symbols: List[Dict[str, Any]] = field(default_factory=list) - removed_symbols: List[Dict[str, Any]] = field(default_factory=list) - signature_changes: List[Dict[str, Any]] = field(default_factory=list) - impact: Dict[str, Any] = field(default_factory=dict) - - def to_dict(self) -> Dict[str, Any]: + + modified_symbols: list[dict[str, Any]] = field(default_factory=list) + added_symbols: list[dict[str, Any]] = field(default_factory=list) + removed_symbols: list[dict[str, Any]] = field(default_factory=list) + signature_changes: list[dict[str, Any]] = field(default_factory=list) + impact: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary representation.""" - return {k: v for k, v in asdict(self).items()} + return dict(asdict(self).items()) + @dataclass class SecurityResult: """Results of security analysis.""" - vulnerabilities: List[Dict[str, Any]] = field(default_factory=list) - secrets: List[Dict[str, Any]] = field(default_factory=list) - injection_risks: List[Dict[str, Any]] = field(default_factory=list) - - def to_dict(self) -> Dict[str, Any]: + + vulnerabilities: list[dict[str, Any]] = field(default_factory=list) + secrets: list[dict[str, Any]] = field(default_factory=list) + injection_risks: list[dict[str, Any]] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary representation.""" - return {k: v for k, v in asdict(self).items()} + return dict(asdict(self).items()) + @dataclass class PerformanceResult: """Results of performance analysis.""" - bottlenecks: List[Dict[str, Any]] = field(default_factory=list) - optimization_opportunities: List[Dict[str, Any]] = field(default_factory=list) - memory_issues: List[Dict[str, Any]] = field(default_factory=list) - - def to_dict(self) -> Dict[str, Any]: + + bottlenecks: list[dict[str, Any]] = field(default_factory=list) + optimization_opportunities: list[dict[str, Any]] = field(default_factory=list) + memory_issues: list[dict[str, Any]] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary representation.""" - return {k: v for k, v in asdict(self).items()} + return dict(asdict(self).items()) + @dataclass class MetadataEntry: """Metadata about an analysis.""" + key: str value: Any - - def to_dict(self) -> Dict[str, Any]: + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary representation.""" return {"key": self.key, "value": self.value} + @dataclass class AnalysisResult: """Comprehensive analysis result.""" + # Core data - analysis_types: List[AnalysisType] + analysis_types: list[AnalysisType] summary: AnalysisSummary = field(default_factory=AnalysisSummary) issues: IssueCollection = field(default_factory=IssueCollection) - + # Analysis results - code_quality: Optional[CodeQualityResult] = None - dependencies: Optional[DependencyResult] = None - pr_analysis: Optional[PrAnalysisResult] = None - security: Optional[SecurityResult] = None - performance: Optional[PerformanceResult] = None - + code_quality: CodeQualityResult | None = None + dependencies: DependencyResult | None = None + pr_analysis: PrAnalysisResult | None = None + security: SecurityResult | None = None + performance: PerformanceResult | None = None + # Metadata - metadata: Dict[str, Any] = field(default_factory=dict) - repo_name: Optional[str] = None - repo_path: Optional[str] = None - language: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: + metadata: dict[str, Any] = field(default_factory=dict) + repo_name: str | None = None + repo_path: str | None = None + language: str | None = None + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary representation.""" result = { "analysis_types": [at.value for at in self.analysis_types], @@ -130,54 +145,54 @@ def to_dict(self) -> Dict[str, Any]: "issues": self.issues.to_dict(), "metadata": self.metadata, } - + # Add optional sections if present if self.repo_name: result["repo_name"] = self.repo_name - + if self.repo_path: result["repo_path"] = self.repo_path - + if self.language: result["language"] = self.language - + # Add analysis results if present if self.code_quality: result["code_quality"] = self.code_quality.to_dict() - + if self.dependencies: result["dependencies"] = self.dependencies.to_dict() - + if self.pr_analysis: result["pr_analysis"] = self.pr_analysis.to_dict() - + if self.security: result["security"] = self.security.to_dict() - + if self.performance: result["performance"] = self.performance.to_dict() - + return result - + def save_to_file(self, file_path: str, indent: int = 2): """ Save analysis result to a file. - + Args: file_path: Path to save to indent: JSON indentation level """ - with open(file_path, 'w') as f: + with open(file_path, "w") as f: json.dump(self.to_dict(), f, indent=indent) - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'AnalysisResult': + def from_dict(cls, data: dict[str, Any]) -> "AnalysisResult": """ Create analysis result from dictionary. - + Args: data: Dictionary representation - + Returns: Analysis result object """ @@ -186,13 +201,21 @@ def from_dict(cls, data: Dict[str, Any]) -> 'AnalysisResult': AnalysisType(at) if isinstance(at, str) else at for at in data.get("analysis_types", []) ] - + # Create summary - summary = AnalysisSummary(**data.get("summary", {})) if "summary" in data else AnalysisSummary() - + summary = ( + AnalysisSummary(**data.get("summary", {})) + if "summary" in data + else AnalysisSummary() + ) + # Create issues collection - issues = IssueCollection.from_dict(data.get("issues", {})) if "issues" in data else IssueCollection() - + issues = ( + IssueCollection.from_dict(data.get("issues", {})) + if "issues" in data + else IssueCollection() + ) + # Create result object result = cls( analysis_types=analysis_types, @@ -201,79 +224,91 @@ def from_dict(cls, data: Dict[str, Any]) -> 'AnalysisResult': repo_name=data.get("repo_name"), repo_path=data.get("repo_path"), language=data.get("language"), - metadata=data.get("metadata", {}) + metadata=data.get("metadata", {}), ) - + # Add analysis results if present if "code_quality" in data: result.code_quality = CodeQualityResult(**data["code_quality"]) - + if "dependencies" in data: result.dependencies = DependencyResult(**data["dependencies"]) - + if "pr_analysis" in data: result.pr_analysis = PrAnalysisResult(**data["pr_analysis"]) - + if "security" in data: result.security = SecurityResult(**data["security"]) - + if "performance" in data: result.performance = PerformanceResult(**data["performance"]) - + return result - + @classmethod - def load_from_file(cls, file_path: str) -> 'AnalysisResult': + def load_from_file(cls, file_path: str) -> "AnalysisResult": """ Load analysis result from file. - + Args: file_path: Path to load from - + Returns: Analysis result object """ - with open(file_path, 'r') as f: + with open(file_path) as f: data = json.load(f) - + return cls.from_dict(data) - - def get_issue_count(self, severity: Optional[str] = None, category: Optional[str] = None) -> int: + + def get_issue_count( + self, severity: str | None = None, category: str | None = None + ) -> int: """ Get count of issues matching criteria. - + Args: severity: Optional severity to filter by category: Optional category to filter by - + Returns: Count of matching issues """ issues_dict = self.issues.to_dict() - + if severity and category: # Count issues with specific severity and category return sum( - 1 for issue in issues_dict.get("issues", []) - if issue.get("severity") == severity and issue.get("category") == category + 1 + for issue in issues_dict.get("issues", []) + if issue.get("severity") == severity + and issue.get("category") == category ) elif severity: # Count issues with specific severity - return issues_dict.get("statistics", {}).get("by_severity", {}).get(severity, 0) + return ( + issues_dict.get("statistics", {}) + .get("by_severity", {}) + .get(severity, 0) + ) elif category: # Count issues with specific category - return issues_dict.get("statistics", {}).get("by_category", {}).get(category, 0) + return ( + issues_dict.get("statistics", {}) + .get("by_category", {}) + .get(category, 0) + ) else: # Total issues return issues_dict.get("statistics", {}).get("total", 0) - - def merge(self, other: 'AnalysisResult') -> 'AnalysisResult': + + def merge(self, other: "AnalysisResult") -> "AnalysisResult": """ Merge with another analysis result. - + Args: other: Analysis result to merge with - + Returns: New merged analysis result """ @@ -284,28 +319,30 @@ def merge(self, other: 'AnalysisResult') -> 'AnalysisResult': repo_path=self.repo_path or other.repo_path, language=self.language or other.language, ) - + # Merge issues merged.issues.add_issues(self.issues.issues) merged.issues.add_issues(other.issues.issues) - + # Merge metadata merged.metadata = {**self.metadata, **other.metadata} - + # Merge analysis results (take non-None values) merged.code_quality = self.code_quality or other.code_quality merged.dependencies = self.dependencies or other.dependencies merged.pr_analysis = self.pr_analysis or other.pr_analysis merged.security = self.security or other.security merged.performance = self.performance or other.performance - + # Update summary merged.summary = AnalysisSummary( total_files=max(self.summary.total_files, other.summary.total_files), - total_functions=max(self.summary.total_functions, other.summary.total_functions), + total_functions=max( + self.summary.total_functions, other.summary.total_functions + ), total_classes=max(self.summary.total_classes, other.summary.total_classes), total_issues=len(merged.issues.issues), - analysis_time=datetime.now().isoformat() + analysis_time=datetime.now().isoformat(), ) - - return merged \ No newline at end of file + + return merged diff --git a/codegen-on-oss/codegen_on_oss/analyzers/resolution/resolution_manager.py b/codegen-on-oss/codegen_on_oss/analyzers/resolution/resolution_manager.py index 96f583358..05cbdfc59 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/resolution/resolution_manager.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/resolution/resolution_manager.py @@ -7,54 +7,57 @@ apply automated fixes and track issue resolution. """ -import os import logging -import sys +import os from enum import Enum -from typing import Dict, List, Set, Tuple, Any, Optional, Union, Callable # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) + class ResolutionStrategy(str, Enum): """Strategies for resolving issues.""" + AUTO_FIX = "auto_fix" SUGGESTION = "suggestion" MANUAL = "manual" IGNORE = "ignore" + class ResolutionStatus(str, Enum): """Status of resolution attempts.""" + PENDING = "pending" IN_PROGRESS = "in_progress" RESOLVED = "resolved" FAILED = "failed" IGNORED = "ignored" + class ResolutionManager: """ Manager for resolving code issues identified during analysis. - + This class provides functionality to track, apply, and validate resolutions for issues found in the codebase. """ - + def __init__( self, analyzer=None, codebase=None, context=None, auto_apply: bool = False, - strategies: Optional[Dict[str, ResolutionStrategy]] = None + strategies: dict[str, ResolutionStrategy] | None = None, ): """ Initialize the ResolutionManager. - + Args: analyzer: Optional analyzer with analysis results codebase: Optional codebase to resolve issues for @@ -69,11 +72,11 @@ def __init__( self.strategies = strategies or {} self.resolutions = {} self.resolution_history = [] - + # Initialize strategies if not provided if not self.strategies: self._init_default_strategies() - + def _init_default_strategies(self): """Initialize default resolution strategies for common issue types.""" self.strategies = { @@ -87,114 +90,122 @@ def _init_default_strategies(self): "dead_code": ResolutionStrategy.SUGGESTION, "security_issue": ResolutionStrategy.MANUAL, } - + def load_issues(self): """ Load issues from the analyzer. - + Returns: List of issues """ if not self.analyzer: logger.error("No analyzer available") return [] - + if not hasattr(self.analyzer, "results") or not self.analyzer.results: logger.error("No analysis results available") return [] - + if "issues" not in self.analyzer.results: logger.error("No issues found in analysis results") return [] - + issues = self.analyzer.results["issues"] - + # Initialize resolutions tracking for issue in issues: issue_id = issue.get("id") if not issue_id: continue - + self.resolutions[issue_id] = { "issue": issue, "status": ResolutionStatus.PENDING, - "strategy": self.strategies.get(issue.get("type"), ResolutionStrategy.MANUAL), + "strategy": self.strategies.get( + issue.get("type"), ResolutionStrategy.MANUAL + ), "resolution_data": None, "applied": False, - "validation_result": None + "validation_result": None, } - + return issues - - def get_resolution_candidates(self, filter_strategy: Optional[ResolutionStrategy] = None): + + def get_resolution_candidates( + self, filter_strategy: ResolutionStrategy | None = None + ): """ Get issues that can be resolved with the specified strategy. - + Args: filter_strategy: Optional strategy to filter issues by - + Returns: List of issues that can be resolved with the specified strategy """ candidates = [] - - for issue_id, resolution in self.resolutions.items(): + + for _issue_id, resolution in self.resolutions.items(): if filter_strategy and resolution["strategy"] != filter_strategy: continue - + if resolution["status"] == ResolutionStatus.PENDING: candidates.append(resolution["issue"]) - + return candidates - + def generate_resolutions(self): """ Generate resolutions for all pending issues. - + Returns: Number of resolutions generated """ count = 0 - + # Process auto-fix issues first - auto_fix_candidates = self.get_resolution_candidates(ResolutionStrategy.AUTO_FIX) + auto_fix_candidates = self.get_resolution_candidates( + ResolutionStrategy.AUTO_FIX + ) for issue in auto_fix_candidates: if self._generate_resolution(issue): count += 1 - + # Process suggestion issues next - suggestion_candidates = self.get_resolution_candidates(ResolutionStrategy.SUGGESTION) + suggestion_candidates = self.get_resolution_candidates( + ResolutionStrategy.SUGGESTION + ) for issue in suggestion_candidates: if self._generate_resolution(issue): count += 1 - + # Skip manual issues as they require human intervention - + return count - + def _generate_resolution(self, issue): """ Generate a resolution for a specific issue. - + Args: issue: Issue to generate a resolution for - + Returns: True if a resolution was generated, False otherwise """ issue_id = issue.get("id") if not issue_id or issue_id not in self.resolutions: return False - + resolution = self.resolutions[issue_id] resolution["status"] = ResolutionStatus.IN_PROGRESS - + try: # Generate resolution based on issue type issue_type = issue.get("type") - issue_file = issue.get("file") - issue_line = issue.get("line") - + issue.get("file") + issue.get("line") + # Special handling for common issue types if issue_type == "unused_import": resolution_data = self._resolve_unused_import(issue) @@ -216,126 +227,135 @@ def _generate_resolution(self, issue): # No specific handler for this issue type resolution["status"] = ResolutionStatus.PENDING return False - + if not resolution_data: resolution["status"] = ResolutionStatus.FAILED return False - + resolution["resolution_data"] = resolution_data resolution["status"] = ResolutionStatus.RESOLVED - + # Auto-apply if configured - if self.auto_apply and resolution["strategy"] == ResolutionStrategy.AUTO_FIX: + if ( + self.auto_apply + and resolution["strategy"] == ResolutionStrategy.AUTO_FIX + ): self.apply_resolution(issue_id) - + return True except Exception as e: - logger.error(f"Error generating resolution for issue {issue_id}: {str(e)}") + logger.exception(f"Error generating resolution for issue {issue_id}: {e!s}") resolution["status"] = ResolutionStatus.FAILED return False - + def apply_resolution(self, issue_id): """ Apply a resolution to the codebase. - + Args: issue_id: ID of the issue to apply the resolution for - + Returns: True if the resolution was applied, False otherwise """ if issue_id not in self.resolutions: logger.error(f"Issue {issue_id} not found") return False - + resolution = self.resolutions[issue_id] if resolution["status"] != ResolutionStatus.RESOLVED: logger.error(f"Resolution for issue {issue_id} is not ready to apply") return False - + if resolution["applied"]: logger.warning(f"Resolution for issue {issue_id} already applied") return True - + try: # Apply the resolution issue = resolution["issue"] resolution_data = resolution["resolution_data"] - + issue_type = issue.get("type") issue_file = issue.get("file") - + if not issue_file or not os.path.isfile(issue_file): logger.error(f"Issue file not found: {issue_file}") return False - + # Special handling based on issue type - if issue_type == "unused_import" or issue_type == "unused_variable" or issue_type == "unused_function": + if ( + ( + issue_type == "unused_import" + or issue_type == "unused_variable" + or issue_type == "unused_function" + ) + or issue_type == "missing_return_type" + or issue_type == "parameter_type_mismatch" + ): if "code_changes" in resolution_data: - self._apply_code_changes(issue_file, resolution_data["code_changes"]) - elif issue_type == "missing_return_type": - if "code_changes" in resolution_data: - self._apply_code_changes(issue_file, resolution_data["code_changes"]) - elif issue_type == "parameter_type_mismatch": - if "code_changes" in resolution_data: - self._apply_code_changes(issue_file, resolution_data["code_changes"]) + self._apply_code_changes( + issue_file, resolution_data["code_changes"] + ) elif issue_type == "circular_dependency": if "code_changes" in resolution_data: for file_path, changes in resolution_data["code_changes"].items(): self._apply_code_changes(file_path, changes) else: - logger.warning(f"No implementation for applying resolution of type {issue_type}") + logger.warning( + f"No implementation for applying resolution of type {issue_type}" + ) return False - + # Record the application resolution["applied"] = True self.resolution_history.append({ "issue_id": issue_id, "timestamp": datetime.now().isoformat(), "action": "apply", - "success": True + "success": True, }) - + return True except Exception as e: - logger.error(f"Error applying resolution for issue {issue_id}: {str(e)}") + logger.exception(f"Error applying resolution for issue {issue_id}: {e!s}") self.resolution_history.append({ "issue_id": issue_id, "timestamp": datetime.now().isoformat(), "action": "apply", "success": False, - "error": str(e) + "error": str(e), }) return False - + def validate_resolution(self, issue_id): """ Validate a resolution after it has been applied. - + Args: issue_id: ID of the issue to validate the resolution for - + Returns: True if the resolution is valid, False otherwise """ if issue_id not in self.resolutions: logger.error(f"Issue {issue_id} not found") return False - + resolution = self.resolutions[issue_id] if not resolution["applied"]: logger.error(f"Resolution for issue {issue_id} has not been applied") return False - + try: # Validate the resolution - issue = resolution["issue"] - resolution_data = resolution["resolution_data"] - + resolution["issue"] + resolution["resolution_data"] + # Rerun the analyzer to check if the issue is fixed if self.analyzer: self.analyzer.analyze() - + # Check if the issue still exists if "issues" in self.analyzer.results: for current_issue in self.analyzer.results["issues"]: @@ -343,55 +363,53 @@ def validate_resolution(self, issue_id): # Issue still exists, resolution is invalid resolution["validation_result"] = { "valid": False, - "reason": "Issue still exists after resolution" + "reason": "Issue still exists after resolution", } return False - + # Issue no longer exists, resolution is valid - resolution["validation_result"] = { - "valid": True - } + resolution["validation_result"] = {"valid": True} return True else: logger.warning("No analyzer available for validation") return True except Exception as e: - logger.error(f"Error validating resolution for issue {issue_id}: {str(e)}") + logger.exception(f"Error validating resolution for issue {issue_id}: {e!s}") resolution["validation_result"] = { "valid": False, - "reason": f"Error during validation: {str(e)}" + "reason": f"Error during validation: {e!s}", } return False - + def rollback_resolution(self, issue_id): """ Rollback a resolution that has been applied. - + Args: issue_id: ID of the issue to rollback the resolution for - + Returns: True if the resolution was rolled back, False otherwise """ if issue_id not in self.resolutions: logger.error(f"Issue {issue_id} not found") return False - + resolution = self.resolutions[issue_id] if not resolution["applied"]: logger.error(f"Resolution for issue {issue_id} has not been applied") return False - + try: # Rollback the resolution issue = resolution["issue"] resolution_data = resolution["resolution_data"] - + if "original_code" in resolution_data: issue_file = issue.get("file") with open(issue_file, "w") as f: f.write(resolution_data["original_code"]) - + # Record the rollback resolution["applied"] = False resolution["validation_result"] = None @@ -399,59 +417,61 @@ def rollback_resolution(self, issue_id): "issue_id": issue_id, "timestamp": datetime.now().isoformat(), "action": "rollback", - "success": True + "success": True, }) - + return True except Exception as e: - logger.error(f"Error rolling back resolution for issue {issue_id}: {str(e)}") + logger.exception( + f"Error rolling back resolution for issue {issue_id}: {e!s}" + ) self.resolution_history.append({ "issue_id": issue_id, "timestamp": datetime.now().isoformat(), "action": "rollback", "success": False, - "error": str(e) + "error": str(e), }) return False - + def ignore_issue(self, issue_id, reason: str = ""): """ Mark an issue as ignored. - + Args: issue_id: ID of the issue to ignore reason: Reason for ignoring the issue - + Returns: True if the issue was marked as ignored, False otherwise """ if issue_id not in self.resolutions: logger.error(f"Issue {issue_id} not found") return False - + resolution = self.resolutions[issue_id] resolution["status"] = ResolutionStatus.IGNORED resolution["resolution_data"] = { "reason": reason, - "timestamp": datetime.now().isoformat() + "timestamp": datetime.now().isoformat(), } - + self.resolution_history.append({ "issue_id": issue_id, "timestamp": datetime.now().isoformat(), "action": "ignore", - "reason": reason + "reason": reason, }) - + return True - + def get_resolution_status(self, issue_id=None): """ Get the status of resolutions. - + Args: issue_id: Optional ID of the issue to get the status for - + Returns: Resolution status information """ @@ -459,7 +479,7 @@ def get_resolution_status(self, issue_id=None): if issue_id not in self.resolutions: logger.error(f"Issue {issue_id} not found") return None - + return self.resolutions[issue_id] else: # Get summary of all resolutions @@ -472,9 +492,9 @@ def get_resolution_status(self, issue_id=None): "failed": 0, "ignored": 0, "valid": 0, - "invalid": 0 + "invalid": 0, } - + for resolution in self.resolutions.values(): if resolution["status"] == ResolutionStatus.PENDING: summary["pending"] += 1 @@ -484,7 +504,9 @@ def get_resolution_status(self, issue_id=None): summary["resolved"] += 1 if resolution["applied"]: summary["applied"] += 1 - if resolution["validation_result"] and resolution["validation_result"].get("valid"): + if resolution["validation_result"] and resolution[ + "validation_result" + ].get("valid"): summary["valid"] += 1 elif resolution["validation_result"]: summary["invalid"] += 1 @@ -492,30 +514,30 @@ def get_resolution_status(self, issue_id=None): summary["failed"] += 1 elif resolution["status"] == ResolutionStatus.IGNORED: summary["ignored"] += 1 - + return summary - + def _apply_code_changes(self, file_path, changes): """ Apply code changes to a file. - + Args: file_path: Path to the file to apply changes to changes: List of changes to apply - + Returns: True if changes were applied, False otherwise """ try: # Read the file - with open(file_path, "r") as f: + with open(file_path) as f: lines = f.readlines() - + # Apply the changes for change in changes: if "line" in change and "action" in change: line_idx = change["line"] - 1 # Convert to 0-indexed - + if change["action"] == "remove": if 0 <= line_idx < len(lines): lines[line_idx] = "" @@ -525,24 +547,24 @@ def _apply_code_changes(self, file_path, changes): elif change["action"] == "insert" and "new_text" in change: if 0 <= line_idx <= len(lines): lines.insert(line_idx, change["new_text"] + "\n") - + # Write the changes back to the file with open(file_path, "w") as f: f.writelines(lines) - + return True except Exception as e: - logger.error(f"Error applying code changes to {file_path}: {str(e)}") + logger.exception(f"Error applying code changes to {file_path}: {e!s}") return False - + # Resolution generators for specific issue types def _resolve_unused_import(self, issue): """ Generate a resolution for an unused import issue. - + Args: issue: Issue to generate a resolution for - + Returns: Resolution data or None if no resolution could be generated """ @@ -550,148 +572,152 @@ def _resolve_unused_import(self, issue): issue_file = issue.get("file") issue_line = issue.get("line") import_name = issue.get("symbol") - - if not issue_file or not os.path.isfile(issue_file) or not issue_line or not import_name: + + if ( + not issue_file + or not os.path.isfile(issue_file) + or not issue_line + or not import_name + ): return None - + # Read the file - with open(issue_file, "r") as f: + with open(issue_file) as f: lines = f.readlines() original_code = "".join(lines) - + # Find the import line if 0 <= issue_line - 1 < len(lines): import_line = lines[issue_line - 1] - + # Check if it's a single import or part of a multi-import - if f"import {import_name}" in import_line or f"from " in import_line and f" import {import_name}" in import_line: + if f"import {import_name}" in import_line or ( + "from " in import_line and f" import {import_name}" in import_line + ): # Generate change return { "original_code": original_code, - "code_changes": [ - { - "line": issue_line, - "action": "remove" - } - ] + "code_changes": [{"line": issue_line, "action": "remove"}], } - + return None except Exception as e: - logger.error(f"Error resolving unused import: {str(e)}") + logger.exception(f"Error resolving unused import: {e!s}") return None - + def _resolve_unused_variable(self, issue): """Resolution generator for unused variable issues.""" try: issue_file = issue.get("file") issue_line = issue.get("line") var_name = issue.get("symbol") - - if not issue_file or not os.path.isfile(issue_file) or not issue_line or not var_name: + + if ( + not issue_file + or not os.path.isfile(issue_file) + or not issue_line + or not var_name + ): return None - + # Read the file - with open(issue_file, "r") as f: + with open(issue_file) as f: lines = f.readlines() original_code = "".join(lines) - + # Find the variable declaration line if 0 <= issue_line - 1 < len(lines): var_line = lines[issue_line - 1] - + # Check if it's a variable assignment if f"{var_name} =" in var_line or f"{var_name}=" in var_line: # Generate change return { "original_code": original_code, - "code_changes": [ - { - "line": issue_line, - "action": "remove" - } - ] + "code_changes": [{"line": issue_line, "action": "remove"}], } - + return None except Exception as e: - logger.error(f"Error resolving unused variable: {str(e)}") + logger.exception(f"Error resolving unused variable: {e!s}") return None - + def _resolve_unused_function(self, issue): """Resolution generator for unused function issues.""" try: issue_file = issue.get("file") issue_line = issue.get("line") func_name = issue.get("symbol") - - if not issue_file or not os.path.isfile(issue_file) or not issue_line or not func_name: + + if ( + not issue_file + or not os.path.isfile(issue_file) + or not issue_line + or not func_name + ): return None - + # Read the file - with open(issue_file, "r") as f: + with open(issue_file) as f: lines = f.readlines() original_code = "".join(lines) - + # Find the function declaration line if 0 <= issue_line - 1 < len(lines): func_line = lines[issue_line - 1] - + # Check if it's a function declaration if f"def {func_name}" in func_line: # Find the end of the function end_line = issue_line indent_level = None - + # Get indentation level of the function for i, char in enumerate(func_line): if char != " " and char != "\t": indent_level = i break - + if indent_level is None: return None - + # Find all lines of the function function_lines = [] for i in range(issue_line - 1, len(lines)): # Skip empty lines if not lines[i].strip(): continue - + # Check indentation current_indent = 0 for j, char in enumerate(lines[i]): if char != " " and char != "\t": current_indent = j break - + # If indentation is less than function, we've reached the end if current_indent <= indent_level and i > issue_line - 1: end_line = i break - + function_lines.append(lines[i]) - + # Generate change changes = [] for i in range(issue_line - 1, end_line): - changes.append({ - "line": i + 1, - "action": "remove" - }) - + changes.append({"line": i + 1, "action": "remove"}) + return { "original_code": original_code, "code_changes": changes, - "function_text": "".join(function_lines) + "function_text": "".join(function_lines), } - + return None except Exception as e: - logger.error(f"Error resolving unused function: {str(e)}") + logger.exception(f"Error resolving unused function: {e!s}") return None - + def _resolve_missing_return_type(self, issue): """Resolution generator for missing return type issues.""" try: @@ -699,29 +725,42 @@ def _resolve_missing_return_type(self, issue): issue_line = issue.get("line") func_name = issue.get("symbol") suggested_type = issue.get("suggested_type", "Any") - - if not issue_file or not os.path.isfile(issue_file) or not issue_line or not func_name: + + if ( + not issue_file + or not os.path.isfile(issue_file) + or not issue_line + or not func_name + ): return None - + # Read the file - with open(issue_file, "r") as f: + with open(issue_file) as f: lines = f.readlines() original_code = "".join(lines) - + # Find the function declaration line if 0 <= issue_line - 1 < len(lines): func_line = lines[issue_line - 1] - + # Check if it's a function declaration and doesn't have a return type if f"def {func_name}" in func_line and "->" not in func_line: # Find the closing parenthesis close_paren_idx = func_line.rfind(")") colon_idx = func_line.rfind(":") - - if close_paren_idx != -1 and colon_idx != -1 and close_paren_idx < colon_idx: + + if ( + close_paren_idx != -1 + and colon_idx != -1 + and close_paren_idx < colon_idx + ): # Insert return type - new_line = func_line[:close_paren_idx + 1] + f" -> {suggested_type}" + func_line[close_paren_idx + 1:] - + new_line = ( + func_line[: close_paren_idx + 1] + + f" -> {suggested_type}" + + func_line[close_paren_idx + 1 :] + ) + # Generate change return { "original_code": original_code, @@ -729,33 +768,33 @@ def _resolve_missing_return_type(self, issue): { "line": issue_line, "action": "replace", - "new_text": new_line.rstrip() + "new_text": new_line.rstrip(), } - ] + ], } - + return None except Exception as e: - logger.error(f"Error resolving missing return type: {str(e)}") + logger.exception(f"Error resolving missing return type: {e!s}") return None - + def _resolve_parameter_type_mismatch(self, issue): """Resolution generator for parameter type mismatch issues.""" # Implementation would depend on the specific issue structure return None - + def _resolve_circular_dependency(self, issue): """Resolution generator for circular dependency issues.""" # Implementation would involve analyzing the dependency graph # and suggesting module reorganization return None - + def _resolve_complex_function(self, issue): """Resolution generator for complex function issues.""" # Implementation would involve suggesting function refactoring return None - + def _resolve_dead_code(self, issue): """Resolution generator for dead code issues.""" # Similar to unused function resolution - return None \ No newline at end of file + return None diff --git a/codegen-on-oss/codegen_on_oss/analyzers/snapshot/snapshot_manager.py b/codegen-on-oss/codegen_on_oss/analyzers/snapshot/snapshot_manager.py index adb9c82b4..a09a54a81 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/snapshot/snapshot_manager.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/snapshot/snapshot_manager.py @@ -7,77 +7,79 @@ consistency between versions. """ -import os -import sys +import hashlib import json import logging +import os import tempfile -import shutil -import hashlib -from typing import Dict, List, Set, Tuple, Any, Optional, Union -from datetime import datetime -from pathlib import Path from dataclasses import dataclass, field +from datetime import datetime +from typing import Any # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) + @dataclass class SnapshotMetadata: """Metadata for a codebase snapshot.""" + snapshot_id: str timestamp: str description: str creator: str base_path: str - commit_hash: Optional[str] = None - branch: Optional[str] = None - tag: Optional[str] = None + commit_hash: str | None = None + branch: str | None = None + tag: str | None = None file_count: int = 0 total_lines: int = 0 - language_stats: Dict[str, int] = field(default_factory=dict) - extra: Dict[str, Any] = field(default_factory=dict) + language_stats: dict[str, int] = field(default_factory=dict) + extra: dict[str, Any] = field(default_factory=dict) + @dataclass class FileSnapshot: """Snapshot of a file in the codebase.""" + path: str relative_path: str hash: str size: int lines: int - language: Optional[str] = None - content_hash: Optional[str] = None - ast_hash: Optional[str] = None - last_modified: Optional[str] = None - metadata: Dict[str, Any] = field(default_factory=dict) + language: str | None = None + content_hash: str | None = None + ast_hash: str | None = None + last_modified: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + class CodebaseSnapshot: """ Codebase snapshot representation. - + This class stores a complete snapshot of a codebase at a point in time, including all files and their metadata. """ - + def __init__( self, base_path: str, description: str = "", creator: str = "snapshot_manager", - include_patterns: List[str] = None, - exclude_patterns: List[str] = None, - snapshot_id: Optional[str] = None, - store_content: bool = False + include_patterns: list[str] | None = None, + exclude_patterns: list[str] | None = None, + snapshot_id: str | None = None, + store_content: bool = False, ): """ Initialize a codebase snapshot. - + Args: base_path: Base path of the codebase description: Description of the snapshot @@ -95,112 +97,112 @@ def __init__( self.snapshot_id = snapshot_id or self._generate_id() self.store_content = store_content self.timestamp = datetime.now().isoformat() - + # Initialize data structures - self.files: Dict[str, FileSnapshot] = {} - self.content: Dict[str, str] = {} - self.language_stats: Dict[str, int] = {} - + self.files: dict[str, FileSnapshot] = {} + self.content: dict[str, str] = {} + self.language_stats: dict[str, int] = {} + # Get git information if available self.commit_hash = self._get_git_commit_hash() self.branch = self._get_git_branch() self.tag = self._get_git_tag() - + def _generate_id(self) -> str: """ Generate a unique ID for the snapshot. - + Returns: Generated ID """ timestamp = datetime.now().strftime("%Y%m%d%H%M%S") random_suffix = hashlib.md5(os.urandom(16)).hexdigest()[:8] return f"snapshot_{timestamp}_{random_suffix}" - - def _get_git_commit_hash(self) -> Optional[str]: + + def _get_git_commit_hash(self) -> str | None: """ Get the current Git commit hash. - + Returns: Commit hash if available, None otherwise """ try: import subprocess + result = subprocess.run( ["git", "rev-parse", "HEAD"], cwd=self.base_path, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + capture_output=True, text=True, - check=False + check=False, ) - + if result.returncode == 0: return result.stdout.strip() return None except Exception: return None - - def _get_git_branch(self) -> Optional[str]: + + def _get_git_branch(self) -> str | None: """ Get the current Git branch. - + Returns: Branch name if available, None otherwise """ try: import subprocess + result = subprocess.run( ["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=self.base_path, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + capture_output=True, text=True, - check=False + check=False, ) - + if result.returncode == 0: return result.stdout.strip() return None except Exception: return None - - def _get_git_tag(self) -> Optional[str]: + + def _get_git_tag(self) -> str | None: """ Get the current Git tag. - + Returns: Tag name if available, None otherwise """ try: import subprocess + result = subprocess.run( ["git", "describe", "--tags", "--exact-match"], cwd=self.base_path, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + capture_output=True, text=True, - check=False + check=False, ) - + if result.returncode == 0: return result.stdout.strip() return None except Exception: return None - - def _get_file_language(self, file_path: str) -> Optional[str]: + + def _get_file_language(self, file_path: str) -> str | None: """ Determine the programming language of a file based on its extension. - + Args: file_path: Path to the file - + Returns: Language name if recognized, None otherwise """ extension = os.path.splitext(file_path)[1].lower() - + language_map = { ".py": "Python", ".js": "JavaScript", @@ -234,43 +236,43 @@ def _get_file_language(self, file_path: str) -> Optional[str]: ".bat": "Batch", ".ps1": "PowerShell", } - + return language_map.get(extension) - + def _should_include_file(self, file_path: str) -> bool: """ Check if a file should be included in the snapshot. - + Args: file_path: Path to the file - + Returns: True if the file should be included, False otherwise """ import fnmatch - + # Convert to relative path rel_path = os.path.relpath(file_path, self.base_path) - + # Check exclude patterns first for pattern in self.exclude_patterns: if fnmatch.fnmatch(rel_path, pattern): return False - + # Then check include patterns for pattern in self.include_patterns: if fnmatch.fnmatch(rel_path, pattern): return True - + return False - + def _compute_file_hash(self, file_path: str) -> str: """ Compute a hash of a file's content. - + Args: file_path: Path to the file - + Returns: Hash of the file content """ @@ -279,70 +281,72 @@ def _compute_file_hash(self, file_path: str) -> str: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest() - + def _count_lines(self, file_path: str) -> int: """ Count the number of lines in a file. - + Args: file_path: Path to the file - + Returns: Number of lines in the file """ try: - with open(file_path, "r", encoding="utf-8", errors="ignore") as f: + with open(file_path, encoding="utf-8", errors="ignore") as f: return sum(1 for _ in f) except Exception: # Fallback for binary files return 0 - + def create(self): """ Create a snapshot of the codebase. - + This method scans the codebase, collects file metadata, and optionally stores file content. """ if not os.path.isdir(self.base_path): logger.error(f"Base path not found: {self.base_path}") return - + # Reset data structures self.files = {} self.content = {} self.language_stats = {} - + total_files = 0 total_lines = 0 - + # Walk the directory tree for root, _, files in os.walk(self.base_path): for file in files: file_path = os.path.join(root, file) - + # Skip if file should not be included if not self._should_include_file(file_path): continue - + try: # Get file stats file_stats = os.stat(file_path) file_size = file_stats.st_size - file_modified = datetime.fromtimestamp(file_stats.st_mtime).isoformat() - + file_modified = datetime.fromtimestamp( + file_stats.st_mtime + ).isoformat() + # Get file language language = self._get_file_language(file_path) - + # Count lines line_count = self._count_lines(file_path) - + # Compute hash file_hash = self._compute_file_hash(file_path) - + # Get relative path rel_path = os.path.relpath(file_path, self.base_path) - + # Create file snapshot file_snapshot = FileSnapshot( path=file_path, @@ -351,37 +355,45 @@ def create(self): size=file_size, lines=line_count, language=language, - last_modified=file_modified + last_modified=file_modified, ) - + # Store file content if requested if self.store_content: try: - with open(file_path, "r", encoding="utf-8", errors="ignore") as f: + with open( + file_path, encoding="utf-8", errors="ignore" + ) as f: file_content = f.read() self.content[rel_path] = file_content except Exception as e: - logger.warning(f"Could not read content of {file_path}: {str(e)}") - + logger.warning( + f"Could not read content of {file_path}: {e!s}" + ) + # Store file snapshot self.files[rel_path] = file_snapshot - + # Update language stats if language: - self.language_stats[language] = self.language_stats.get(language, 0) + 1 - + self.language_stats[language] = ( + self.language_stats.get(language, 0) + 1 + ) + # Update totals total_files += 1 total_lines += line_count except Exception as e: - logger.warning(f"Error processing file {file_path}: {str(e)}") - - logger.info(f"Created snapshot with {total_files} files and {total_lines} lines") - + logger.warning(f"Error processing file {file_path}: {e!s}") + + logger.info( + f"Created snapshot with {total_files} files and {total_lines} lines" + ) + def get_metadata(self) -> SnapshotMetadata: """ Get metadata for the snapshot. - + Returns: Snapshot metadata """ @@ -396,16 +408,16 @@ def get_metadata(self) -> SnapshotMetadata: tag=self.tag, file_count=len(self.files), total_lines=sum(file.lines for file in self.files.values()), - language_stats=self.language_stats + language_stats=self.language_stats, ) - - def save(self, output_path: Optional[str] = None) -> str: + + def save(self, output_path: str | None = None) -> str: """ Save the snapshot to disk. - + Args: output_path: Optional path to save the snapshot to - + Returns: Path to the saved snapshot """ @@ -413,57 +425,57 @@ def save(self, output_path: Optional[str] = None) -> str: if not output_path: output_dir = tempfile.mkdtemp(prefix="codebase_snapshot_") output_path = os.path.join(output_dir, f"{self.snapshot_id}.json") - + # Create output directory if it doesn't exist os.makedirs(os.path.dirname(output_path), exist_ok=True) - + # Convert snapshot to JSON snapshot_data = { "metadata": self.get_metadata().__dict__, "files": {rel_path: file.__dict__ for rel_path, file in self.files.items()}, - "content": self.content if self.store_content else {} + "content": self.content if self.store_content else {}, } - + # Save to disk with open(output_path, "w") as f: json.dump(snapshot_data, f, indent=2) - + logger.info(f"Saved snapshot to {output_path}") return output_path - + @classmethod - def load(cls, snapshot_path: str) -> 'CodebaseSnapshot': + def load(cls, snapshot_path: str) -> "CodebaseSnapshot": """ Load a snapshot from disk. - + Args: snapshot_path: Path to the snapshot file - + Returns: Loaded snapshot """ - with open(snapshot_path, "r") as f: + with open(snapshot_path) as f: snapshot_data = json.load(f) - + # Extract metadata metadata = snapshot_data["metadata"] - + # Create snapshot instance snapshot = cls( base_path=metadata["base_path"], description=metadata["description"], creator=metadata["creator"], - snapshot_id=metadata["snapshot_id"] + snapshot_id=metadata["snapshot_id"], ) - + # Set timestamp snapshot.timestamp = metadata["timestamp"] - + # Set Git information snapshot.commit_hash = metadata.get("commit_hash") snapshot.branch = metadata.get("branch") snapshot.tag = metadata.get("tag") - + # Load files snapshot.files = {} for rel_path, file_data in snapshot_data["files"].items(): @@ -475,47 +487,47 @@ def load(cls, snapshot_path: str) -> 'CodebaseSnapshot': lines=file_data["lines"], language=file_data.get("language"), last_modified=file_data.get("last_modified"), - metadata=file_data.get("metadata", {}) + metadata=file_data.get("metadata", {}), ) - + # Load content if available snapshot.content = snapshot_data.get("content", {}) snapshot.store_content = bool(snapshot.content) - + # Load language stats snapshot.language_stats = metadata.get("language_stats", {}) - + logger.info(f"Loaded snapshot from {snapshot_path}") return snapshot - - def diff(self, other: 'CodebaseSnapshot') -> Dict[str, Any]: + + def diff(self, other: "CodebaseSnapshot") -> dict[str, Any]: """ Compare this snapshot with another snapshot. - + Args: other: Snapshot to compare with - + Returns: Diff between the snapshots """ # Get sets of file paths self_files = set(self.files.keys()) other_files = set(other.files.keys()) - + # Find added, deleted, and common files added_files = other_files - self_files deleted_files = self_files - other_files common_files = self_files & other_files - + # Find modified files modified_files = [] for file_path in common_files: self_file = self.files[file_path] other_file = other.files[file_path] - + if self_file.hash != other_file.hash: modified_files.append(file_path) - + # Calculate content diff for modified files if content is available content_diff = {} if self.store_content and other.store_content: @@ -524,36 +536,47 @@ def diff(self, other: 'CodebaseSnapshot') -> Dict[str, Any]: try: # Use difflib to generate unified diff import difflib + diff = difflib.unified_diff( self.content[file_path].splitlines(keepends=True), other.content[file_path].splitlines(keepends=True), fromfile=f"a/{file_path}", - tofile=f"b/{file_path}" + tofile=f"b/{file_path}", ) content_diff[file_path] = "".join(diff) except Exception as e: - logger.warning(f"Error generating diff for {file_path}: {str(e)}") - + logger.warning(f"Error generating diff for {file_path}: {e!s}") + # Calculate statistics diff_stats = { "files_added": len(added_files), "files_deleted": len(deleted_files), "files_modified": len(modified_files), "files_unchanged": len(common_files) - len(modified_files), - "lines_added": sum(other.files[file_path].lines for file_path in added_files), - "lines_deleted": sum(self.files[file_path].lines for file_path in deleted_files), - "lines_modified": sum(other.files[file_path].lines - self.files[file_path].lines for file_path in modified_files if file_path in other.files and file_path in self.files), + "lines_added": sum( + other.files[file_path].lines for file_path in added_files + ), + "lines_deleted": sum( + self.files[file_path].lines for file_path in deleted_files + ), + "lines_modified": sum( + other.files[file_path].lines - self.files[file_path].lines + for file_path in modified_files + if file_path in other.files and file_path in self.files + ), } - + # Calculate language stats diff language_diff = {} - for language in set(self.language_stats.keys()) | set(other.language_stats.keys()): + for language in set(self.language_stats.keys()) | set( + other.language_stats.keys() + ): self_count = self.language_stats.get(language, 0) other_count = other.language_stats.get(language, 0) - + if self_count != other_count: language_diff[language] = other_count - self_count - + return { "added_files": list(added_files), "deleted_files": list(deleted_files), @@ -563,70 +586,77 @@ def diff(self, other: 'CodebaseSnapshot') -> Dict[str, Any]: "content_diff": content_diff, "from_snapshot": self.snapshot_id, "to_snapshot": other.snapshot_id, - "timestamp": datetime.now().isoformat() + "timestamp": datetime.now().isoformat(), } + class SnapshotManager: """ Manager for codebase snapshots. - + This class provides functionality to create, store, load, and compare codebase snapshots. """ - - def __init__(self, storage_dir: Optional[str] = None): + + def __init__(self, storage_dir: str | None = None): """ Initialize the snapshot manager. - + Args: storage_dir: Directory to store snapshots in """ - self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "codebase_snapshots") + self.storage_dir = storage_dir or os.path.join( + tempfile.gettempdir(), "codebase_snapshots" + ) os.makedirs(self.storage_dir, exist_ok=True) - + # Initialize data structures - self.snapshots: Dict[str, SnapshotMetadata] = {} + self.snapshots: dict[str, SnapshotMetadata] = {} self.load_index() - + def load_index(self): """Load the snapshot index.""" index_path = os.path.join(self.storage_dir, "index.json") - + if os.path.isfile(index_path): try: - with open(index_path, "r") as f: + with open(index_path) as f: data = json.load(f) - + self.snapshots = {} for snapshot_id, metadata in data.items(): self.snapshots[snapshot_id] = SnapshotMetadata(**metadata) except Exception as e: - logger.error(f"Error loading snapshot index: {str(e)}") + logger.exception(f"Error loading snapshot index: {e!s}") self.snapshots = {} - + def save_index(self): """Save the snapshot index.""" index_path = os.path.join(self.storage_dir, "index.json") - + try: with open(index_path, "w") as f: - json.dump({id: metadata.__dict__ for id, metadata in self.snapshots.items()}, f, indent=2) + json.dump( + {id: metadata.__dict__ for id, metadata in self.snapshots.items()}, + f, + indent=2, + ) except Exception as e: - logger.error(f"Error saving snapshot index: {str(e)}") - + logger.exception(f"Error saving snapshot index: {e!s}") + def create_snapshot( self, base_path: str, description: str = "", creator: str = "snapshot_manager", - include_patterns: List[str] = None, - exclude_patterns: List[str] = None, - snapshot_id: Optional[str] = None, - store_content: bool = False + include_patterns: list[str] | None = None, + exclude_patterns: list[str] | None = None, + snapshot_id: str | None = None, + store_content: bool = False, ) -> str: """ Create a new snapshot of a codebase. - + Args: base_path: Base path of the codebase description: Description of the snapshot @@ -635,7 +665,7 @@ def create_snapshot( exclude_patterns: Patterns of files to exclude snapshot_id: Optional ID for the snapshot store_content: Whether to store file content - + Returns: ID of the created snapshot """ @@ -647,134 +677,140 @@ def create_snapshot( include_patterns=include_patterns, exclude_patterns=exclude_patterns, snapshot_id=snapshot_id, - store_content=store_content + store_content=store_content, ) - + # Generate the snapshot snapshot.create() - + # Save the snapshot snapshot_path = os.path.join(self.storage_dir, f"{snapshot.snapshot_id}.json") snapshot.save(snapshot_path) - + # Update the index self.snapshots[snapshot.snapshot_id] = snapshot.get_metadata() self.save_index() - + return snapshot.snapshot_id - - def get_snapshot(self, snapshot_id: str) -> Optional[CodebaseSnapshot]: + + def get_snapshot(self, snapshot_id: str) -> CodebaseSnapshot | None: """ Get a snapshot by ID. - + Args: snapshot_id: ID of the snapshot - + Returns: Snapshot if found, None otherwise """ if snapshot_id not in self.snapshots: logger.error(f"Snapshot not found: {snapshot_id}") return None - + snapshot_path = os.path.join(self.storage_dir, f"{snapshot_id}.json") - + if not os.path.isfile(snapshot_path): logger.error(f"Snapshot file not found: {snapshot_path}") return None - + return CodebaseSnapshot.load(snapshot_path) - + def delete_snapshot(self, snapshot_id: str) -> bool: """ Delete a snapshot. - + Args: snapshot_id: ID of the snapshot - + Returns: True if the snapshot was deleted, False otherwise """ if snapshot_id not in self.snapshots: logger.error(f"Snapshot not found: {snapshot_id}") return False - + snapshot_path = os.path.join(self.storage_dir, f"{snapshot_id}.json") - + if os.path.isfile(snapshot_path): try: os.remove(snapshot_path) except Exception as e: - logger.error(f"Error deleting snapshot file: {str(e)}") + logger.exception(f"Error deleting snapshot file: {e!s}") return False - + # Update the index del self.snapshots[snapshot_id] self.save_index() - + return True - - def compare_snapshots(self, snapshot_id1: str, snapshot_id2: str) -> Optional[Dict[str, Any]]: + + def compare_snapshots( + self, snapshot_id1: str, snapshot_id2: str + ) -> dict[str, Any] | None: """ Compare two snapshots. - + Args: snapshot_id1: ID of the first snapshot snapshot_id2: ID of the second snapshot - + Returns: Diff between the snapshots if both exist, None otherwise """ snapshot1 = self.get_snapshot(snapshot_id1) snapshot2 = self.get_snapshot(snapshot_id2) - + if not snapshot1 or not snapshot2: return None - + return snapshot1.diff(snapshot2) - - def get_latest_snapshot(self, base_path: Optional[str] = None) -> Optional[str]: + + def get_latest_snapshot(self, base_path: str | None = None) -> str | None: """ Get the latest snapshot ID. - + Args: base_path: Optional base path to filter snapshots - + Returns: ID of the latest snapshot if any exist, None otherwise """ if not self.snapshots: return None - + filtered_snapshots = self.snapshots - + if base_path: filtered_snapshots = { - id: metadata for id, metadata in self.snapshots.items() + id: metadata + for id, metadata in self.snapshots.items() if metadata.base_path == base_path } - + if not filtered_snapshots: return None - + # Sort by timestamp and get the latest - latest_id = max(filtered_snapshots.keys(), key=lambda id: filtered_snapshots[id].timestamp) + latest_id = max( + filtered_snapshots.keys(), key=lambda id: filtered_snapshots[id].timestamp + ) return latest_id - - def list_snapshots(self, base_path: Optional[str] = None) -> List[SnapshotMetadata]: + + def list_snapshots(self, base_path: str | None = None) -> list[SnapshotMetadata]: """ List all snapshots. - + Args: base_path: Optional base path to filter snapshots - + Returns: List of snapshot metadata """ if base_path: return [ - metadata for metadata in self.snapshots.values() + metadata + for metadata in self.snapshots.values() if metadata.base_path == base_path ] else: - return list(self.snapshots.values()) \ No newline at end of file + return list(self.snapshots.values()) diff --git a/codegen-on-oss/codegen_on_oss/analyzers/unified_analyzer.py b/codegen-on-oss/codegen_on_oss/analyzers/unified_analyzer.py index bf204f042..55c8bac16 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/unified_analyzer.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/unified_analyzer.py @@ -4,40 +4,51 @@ This module consolidates various analyzer functionalities into a cohesive architecture, reducing code duplication and providing a standard interface for all types of codebase analysis. -It enables comprehensive analysis of codebases including code quality, dependencies, +It enables comprehensive analysis of codebases including code quality, dependencies, structural patterns, and issue detection. """ -import os -import sys import json import logging +import sys import tempfile -import networkx as nx from datetime import datetime -from pathlib import Path -from typing import Dict, List, Set, Tuple, Any, Optional, Union, Type -from enum import Enum +from typing import Any + +import networkx as nx try: - from codegen.sdk.core.codebase import Codebase from codegen.configs.models.codebase import CodebaseConfig from codegen.configs.models.secrets import SecretsConfig - from codegen.sdk.codebase.config import ProjectConfig - from codegen.git.schemas.repo_config import RepoConfig from codegen.git.repo_operator.repo_operator import RepoOperator - from codegen.shared.enums.programming_language import ProgrammingLanguage - from codegen.sdk.codebase.codebase_analysis import get_codebase_summary, get_file_summary + from codegen.git.schemas.repo_config import RepoConfig + from codegen.sdk.codebase.codebase_analysis import ( + get_codebase_summary, + get_file_summary, + ) + from codegen.sdk.codebase.config import ProjectConfig + from codegen.sdk.core.class_definition import Class + from codegen.sdk.core.codebase import Codebase from codegen.sdk.core.file import SourceFile - from codegen.sdk.enums import EdgeType, SymbolType - from codegen.sdk.core.symbol import Symbol from codegen.sdk.core.function import Function - from codegen.sdk.core.class_definition import Class - + from codegen.sdk.core.symbol import Symbol + from codegen.sdk.enums import EdgeType, SymbolType + from codegen.shared.enums.programming_language import ProgrammingLanguage + + from codegen_on_oss.analyzers.issue_types import ( + AnalysisType, + Issue, + IssueCategory, + IssueSeverity, + ) + # Import from our own modules - from codegen_on_oss.context_codebase import CodebaseContext, get_node_classes, GLOBAL_FILE_IGNORE_LIST + from codegen_on_oss.context_codebase import ( + GLOBAL_FILE_IGNORE_LIST, + CodebaseContext, + get_node_classes, + ) from codegen_on_oss.current_code_codebase import get_selected_codebase - from codegen_on_oss.analyzers.issue_types import Issue, IssueSeverity, AnalysisType, IssueCategory except ImportError: print("Codegen SDK or required modules not found.") sys.exit(1) @@ -45,205 +56,227 @@ # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) + class AnalyzerRegistry: """ Registry of analyzer plugins. - + This singleton maintains a registry of all analyzer plugins and their associated analysis types. """ + _instance = None - + def __new__(cls): if cls._instance is None: - cls._instance = super(AnalyzerRegistry, cls).__new__(cls) + cls._instance = super().__new__(cls) cls._instance._analyzers = {} return cls._instance - - def register(self, analysis_type: AnalysisType, analyzer_class: Type['AnalyzerPlugin']): + + def register( + self, analysis_type: AnalysisType, analyzer_class: type["AnalyzerPlugin"] + ): """ Register an analyzer plugin for a specific analysis type. - + Args: analysis_type: Type of analysis the plugin handles analyzer_class: Class of the analyzer plugin """ self._analyzers[analysis_type] = analyzer_class - - def get_analyzer(self, analysis_type: AnalysisType) -> Optional[Type['AnalyzerPlugin']]: + + def get_analyzer( + self, analysis_type: AnalysisType + ) -> type["AnalyzerPlugin"] | None: """ Get the analyzer plugin for a specific analysis type. - + Args: analysis_type: Type of analysis to get plugin for - + Returns: The analyzer plugin class, or None if not found """ return self._analyzers.get(analysis_type) - - def list_analyzers(self) -> Dict[AnalysisType, Type['AnalyzerPlugin']]: + + def list_analyzers(self) -> dict[AnalysisType, type["AnalyzerPlugin"]]: """ Get all registered analyzers. - + Returns: Dictionary mapping analysis types to analyzer plugin classes """ return self._analyzers.copy() + class AnalyzerPlugin: """ Base class for analyzer plugins. - + Analyzer plugins implement specific analysis functionality for different types of codebase analysis. """ - - def __init__(self, analyzer: 'UnifiedCodeAnalyzer'): + + def __init__(self, analyzer: "UnifiedCodeAnalyzer"): """ Initialize the analyzer plugin. - + Args: analyzer: Parent analyzer that owns this plugin """ self.analyzer = analyzer self.issues = [] - - def analyze(self) -> Dict[str, Any]: + + def analyze(self) -> dict[str, Any]: """ Perform analysis using this plugin. - + Returns: Dictionary containing analysis results """ raise NotImplementedError("Analyzer plugins must implement analyze()") - + def add_issue(self, issue: Issue): """ Add an issue to the list. - + Args: issue: Issue to add """ self.analyzer.add_issue(issue) self.issues.append(issue) + class CodeQualityAnalyzerPlugin(AnalyzerPlugin): """ Plugin for code quality analysis. - + This plugin detects issues related to code quality, including dead code, complexity, style, and maintainability. """ - - def analyze(self) -> Dict[str, Any]: + + def analyze(self) -> dict[str, Any]: """ Perform code quality analysis. - + Returns: Dictionary containing code quality analysis results """ result = {} - + # Perform code quality checks result["dead_code"] = self._find_dead_code() result["complexity"] = self._analyze_code_complexity() result["style_issues"] = self._check_style_issues() result["maintainability"] = self._calculate_maintainability() - + return result - - def _find_dead_code(self) -> Dict[str, Any]: + + def _find_dead_code(self) -> dict[str, Any]: """Find unused code (dead code) in the codebase.""" codebase = self.analyzer.base_codebase - + dead_code = { "unused_functions": [], "unused_classes": [], "unused_variables": [], - "unused_imports": [] + "unused_imports": [], } - + # Find unused functions - if hasattr(codebase, 'functions'): + if hasattr(codebase, "functions"): for func in codebase.functions: # Skip if function should be excluded if self.analyzer.should_skip_symbol(func): continue - + # Skip decorated functions (as they might be used indirectly) - if hasattr(func, 'decorators') and func.decorators: + if hasattr(func, "decorators") and func.decorators: continue - + # Check if function has no call sites or usages - has_call_sites = hasattr(func, 'call_sites') and len(func.call_sites) > 0 - has_usages = hasattr(func, 'usages') and len(func.usages) > 0 - + has_call_sites = ( + hasattr(func, "call_sites") and len(func.call_sites) > 0 + ) + has_usages = hasattr(func, "usages") and len(func.usages) > 0 + if not has_call_sites and not has_usages: # Get file path and name safely - file_path = func.file.file_path if hasattr(func, 'file') and hasattr(func.file, 'file_path') else "unknown" - func_name = func.name if hasattr(func, 'name') else str(func) - + file_path = ( + func.file.file_path + if hasattr(func, "file") and hasattr(func.file, "file_path") + else "unknown" + ) + func_name = func.name if hasattr(func, "name") else str(func) + # Skip main entry points if func_name in ["main", "__main__"]: continue - + # Add to dead code list dead_code["unused_functions"].append({ "name": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None + "line": func.line if hasattr(func, "line") else None, }) - + # Add issue - self.add_issue(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Unused function: {func_name}", - severity=IssueSeverity.WARNING, - category=IssueCategory.DEAD_CODE, - symbol=func_name, - suggestion="Consider removing this unused function or documenting why it's needed" - )) - + self.add_issue( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"Unused function: {func_name}", + severity=IssueSeverity.WARNING, + category=IssueCategory.DEAD_CODE, + symbol=func_name, + suggestion="Consider removing this unused function or documenting why it's needed", + ) + ) + # Find unused classes - if hasattr(codebase, 'classes'): + if hasattr(codebase, "classes"): for cls in codebase.classes: # Skip if class should be excluded if self.analyzer.should_skip_symbol(cls): continue - + # Check if class has no usages - has_usages = hasattr(cls, 'usages') and len(cls.usages) > 0 - + has_usages = hasattr(cls, "usages") and len(cls.usages) > 0 + if not has_usages: # Get file path and name safely - file_path = cls.file.file_path if hasattr(cls, 'file') and hasattr(cls.file, 'file_path') else "unknown" - cls_name = cls.name if hasattr(cls, 'name') else str(cls) - + file_path = ( + cls.file.file_path + if hasattr(cls, "file") and hasattr(cls.file, "file_path") + else "unknown" + ) + cls_name = cls.name if hasattr(cls, "name") else str(cls) + # Add to dead code list dead_code["unused_classes"].append({ "name": cls_name, "file": file_path, - "line": cls.line if hasattr(cls, 'line') else None + "line": cls.line if hasattr(cls, "line") else None, }) - + # Add issue - self.add_issue(Issue( - file=file_path, - line=cls.line if hasattr(cls, 'line') else None, - message=f"Unused class: {cls_name}", - severity=IssueSeverity.WARNING, - category=IssueCategory.DEAD_CODE, - symbol=cls_name, - suggestion="Consider removing this unused class or documenting why it's needed" - )) - + self.add_issue( + Issue( + file=file_path, + line=cls.line if hasattr(cls, "line") else None, + message=f"Unused class: {cls_name}", + severity=IssueSeverity.WARNING, + category=IssueCategory.DEAD_CODE, + symbol=cls_name, + suggestion="Consider removing this unused class or documenting why it's needed", + ) + ) + # Summarize findings dead_code["summary"] = { "unused_functions_count": len(dead_code["unused_functions"]), @@ -251,19 +284,19 @@ def _find_dead_code(self) -> Dict[str, Any]: "unused_variables_count": len(dead_code["unused_variables"]), "unused_imports_count": len(dead_code["unused_imports"]), "total_dead_code_count": ( - len(dead_code["unused_functions"]) + - len(dead_code["unused_classes"]) + - len(dead_code["unused_variables"]) + - len(dead_code["unused_imports"]) - ) + len(dead_code["unused_functions"]) + + len(dead_code["unused_classes"]) + + len(dead_code["unused_variables"]) + + len(dead_code["unused_imports"]) + ), } - + return dead_code - - def _analyze_code_complexity(self) -> Dict[str, Any]: + + def _analyze_code_complexity(self) -> dict[str, Any]: """Analyze code complexity.""" codebase = self.analyzer.base_codebase - + complexity_result = { "function_complexity": [], "high_complexity_functions": [], @@ -272,43 +305,47 @@ def _analyze_code_complexity(self) -> Dict[str, Any]: "low": 0, "medium": 0, "high": 0, - "very_high": 0 - } + "very_high": 0, + }, } - + # Process all functions to calculate complexity total_complexity = 0 function_count = 0 - - if hasattr(codebase, 'functions'): + + if hasattr(codebase, "functions"): for func in codebase.functions: # Skip if function should be excluded if self.analyzer.should_skip_symbol(func): continue - + # Skip if no code block - if not hasattr(func, 'code_block'): + if not hasattr(func, "code_block"): continue - + # Calculate cyclomatic complexity complexity = self._calculate_cyclomatic_complexity(func) - + # Get file path and name safely - file_path = func.file.file_path if hasattr(func, 'file') and hasattr(func.file, 'file_path') else "unknown" - func_name = func.name if hasattr(func, 'name') else str(func) - + file_path = ( + func.file.file_path + if hasattr(func, "file") and hasattr(func.file, "file_path") + else "unknown" + ) + func_name = func.name if hasattr(func, "name") else str(func) + # Add to complexity list complexity_result["function_complexity"].append({ "name": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, - "complexity": complexity + "line": func.line if hasattr(func, "line") else None, + "complexity": complexity, }) - + # Track total complexity total_complexity += complexity function_count += 1 - + # Categorize complexity if complexity <= 5: complexity_result["complexity_distribution"]["low"] += 1 @@ -318,84 +355,96 @@ def _analyze_code_complexity(self) -> Dict[str, Any]: complexity_result["complexity_distribution"]["high"] += 1 else: complexity_result["complexity_distribution"]["very_high"] += 1 - + # Flag high complexity functions if complexity > 10: complexity_result["high_complexity_functions"].append({ "name": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, - "complexity": complexity + "line": func.line if hasattr(func, "line") else None, + "complexity": complexity, }) - + # Add issue - severity = IssueSeverity.WARNING if complexity <= 15 else IssueSeverity.ERROR - self.add_issue(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"High cyclomatic complexity: {complexity}", - severity=severity, - category=IssueCategory.COMPLEXITY, - symbol=func_name, - suggestion="Consider refactoring this function to reduce complexity" - )) - + severity = ( + IssueSeverity.WARNING + if complexity <= 15 + else IssueSeverity.ERROR + ) + self.add_issue( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"High cyclomatic complexity: {complexity}", + severity=severity, + category=IssueCategory.COMPLEXITY, + symbol=func_name, + suggestion="Consider refactoring this function to reduce complexity", + ) + ) + # Calculate average complexity - complexity_result["average_complexity"] = total_complexity / function_count if function_count > 0 else 0.0 - + complexity_result["average_complexity"] = ( + total_complexity / function_count if function_count > 0 else 0.0 + ) + # Sort high complexity functions by complexity - complexity_result["high_complexity_functions"].sort(key=lambda x: x["complexity"], reverse=True) - + complexity_result["high_complexity_functions"].sort( + key=lambda x: x["complexity"], reverse=True + ) + return complexity_result - + def _calculate_cyclomatic_complexity(self, function) -> int: """Calculate cyclomatic complexity for a function.""" complexity = 1 # Base complexity - + def analyze_statement(statement): nonlocal complexity - + # Check for if statements (including elif branches) - if hasattr(statement, 'if_clause'): + if hasattr(statement, "if_clause"): complexity += 1 - + # Count elif branches - if hasattr(statement, 'elif_statements'): + if hasattr(statement, "elif_statements"): complexity += len(statement.elif_statements) - + # Count else branches - if hasattr(statement, 'else_clause') and statement.else_clause: + if hasattr(statement, "else_clause") and statement.else_clause: complexity += 1 - + # Count for loops - if hasattr(statement, 'is_for_loop') and statement.is_for_loop: + if hasattr(statement, "is_for_loop") and statement.is_for_loop: complexity += 1 - + # Count while loops - if hasattr(statement, 'is_while_loop') and statement.is_while_loop: + if hasattr(statement, "is_while_loop") and statement.is_while_loop: complexity += 1 - + # Count try/except blocks (each except adds a path) - if hasattr(statement, 'is_try_block') and statement.is_try_block: - if hasattr(statement, 'except_clauses'): + if hasattr(statement, "is_try_block") and statement.is_try_block: + if hasattr(statement, "except_clauses"): complexity += len(statement.except_clauses) - + # Recursively process nested statements - if hasattr(statement, 'statements'): + if hasattr(statement, "statements"): for nested_stmt in statement.statements: analyze_statement(nested_stmt) - + # Process all statements in the function's code block - if hasattr(function, 'code_block') and hasattr(function.code_block, 'statements'): + if hasattr(function, "code_block") and hasattr( + function.code_block, "statements" + ): for statement in function.code_block.statements: analyze_statement(statement) - + return complexity - - def _check_style_issues(self) -> Dict[str, Any]: + + def _check_style_issues(self) -> dict[str, Any]: """Check for code style issues.""" codebase = self.analyzer.base_codebase - + style_result = { "long_functions": [], "long_lines": [], @@ -403,201 +452,230 @@ def _check_style_issues(self) -> Dict[str, Any]: "summary": { "long_functions_count": 0, "long_lines_count": 0, - "inconsistent_naming_count": 0 - } + "inconsistent_naming_count": 0, + }, } - + # Check for long functions (too many lines) - if hasattr(codebase, 'functions'): + if hasattr(codebase, "functions"): for func in codebase.functions: # Skip if function should be excluded if self.analyzer.should_skip_symbol(func): continue - + # Get function code - if hasattr(func, 'source'): + if hasattr(func, "source"): code = func.source - lines = code.split('\n') - + lines = code.split("\n") + # Check function length if len(lines) > 50: # Threshold for "too long" # Get file path and name safely - file_path = func.file.file_path if hasattr(func, 'file') and hasattr(func.file, 'file_path') else "unknown" - func_name = func.name if hasattr(func, 'name') else str(func) - + file_path = ( + func.file.file_path + if hasattr(func, "file") and hasattr(func.file, "file_path") + else "unknown" + ) + func_name = func.name if hasattr(func, "name") else str(func) + # Add to long functions list style_result["long_functions"].append({ "name": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, - "line_count": len(lines) + "line": func.line if hasattr(func, "line") else None, + "line_count": len(lines), }) - + # Add issue - self.add_issue(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Long function: {len(lines)} lines", - severity=IssueSeverity.INFO, - category=IssueCategory.STYLE_ISSUE, - symbol=func_name, - suggestion="Consider breaking this function into smaller, more focused functions" - )) - + self.add_issue( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"Long function: {len(lines)} lines", + severity=IssueSeverity.INFO, + category=IssueCategory.STYLE_ISSUE, + symbol=func_name, + suggestion="Consider breaking this function into smaller, more focused functions", + ) + ) + # Update summary - style_result["summary"]["long_functions_count"] = len(style_result["long_functions"]) + style_result["summary"]["long_functions_count"] = len( + style_result["long_functions"] + ) style_result["summary"]["long_lines_count"] = len(style_result["long_lines"]) - style_result["summary"]["inconsistent_naming_count"] = len(style_result["inconsistent_naming"]) - + style_result["summary"]["inconsistent_naming_count"] = len( + style_result["inconsistent_naming"] + ) + return style_result - - def _calculate_maintainability(self) -> Dict[str, Any]: + + def _calculate_maintainability(self) -> dict[str, Any]: """Calculate maintainability metrics.""" import math + codebase = self.analyzer.base_codebase - + maintainability_result = { "function_maintainability": [], "low_maintainability_functions": [], "average_maintainability": 0.0, - "maintainability_distribution": { - "high": 0, - "medium": 0, - "low": 0 - } + "maintainability_distribution": {"high": 0, "medium": 0, "low": 0}, } - + # Process all functions to calculate maintainability total_maintainability = 0 function_count = 0 - - if hasattr(codebase, 'functions'): + + if hasattr(codebase, "functions"): for func in codebase.functions: # Skip if function should be excluded if self.analyzer.should_skip_symbol(func): continue - + # Skip if no code block - if not hasattr(func, 'code_block'): + if not hasattr(func, "code_block"): continue - + # Calculate metrics complexity = self._calculate_cyclomatic_complexity(func) - + # Calculate Halstead volume (approximation) operators = 0 operands = 0 - - if hasattr(func, 'source'): + + if hasattr(func, "source"): code = func.source # Simple approximation of operators and operands - operators = len([c for c in code if c in '+-*/=<>!&|^~%']) + operators = len([c for c in code if c in "+-*/=<>!&|^~%"]) # Counting words as potential operands import re - operands = len(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', code)) - - halstead_volume = operators * operands * math.log2(operators + operands) if operators + operands > 0 else 0 - + + operands = len(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", code)) + + halstead_volume = ( + operators * operands * math.log2(operators + operands) + if operators + operands > 0 + else 0 + ) + # Count lines of code - loc = len(func.source.split('\n')) if hasattr(func, 'source') else 0 - + loc = len(func.source.split("\n")) if hasattr(func, "source") else 0 + # Calculate maintainability index # Formula: 171 - 5.2 * ln(Halstead Volume) - 0.23 * (Cyclomatic Complexity) - 16.2 * ln(LOC) - halstead_term = 5.2 * math.log(max(1, halstead_volume)) if halstead_volume > 0 else 0 + halstead_term = ( + 5.2 * math.log(max(1, halstead_volume)) + if halstead_volume > 0 + else 0 + ) complexity_term = 0.23 * complexity loc_term = 16.2 * math.log(max(1, loc)) if loc > 0 else 0 - + maintainability = 171 - halstead_term - complexity_term - loc_term - + # Normalize to 0-100 scale maintainability = max(0, min(100, maintainability * 100 / 171)) - + # Get file path and name safely - file_path = func.file.file_path if hasattr(func, 'file') and hasattr(func.file, 'file_path') else "unknown" - func_name = func.name if hasattr(func, 'name') else str(func) - + file_path = ( + func.file.file_path + if hasattr(func, "file") and hasattr(func.file, "file_path") + else "unknown" + ) + func_name = func.name if hasattr(func, "name") else str(func) + # Add to maintainability list maintainability_result["function_maintainability"].append({ "name": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, + "line": func.line if hasattr(func, "line") else None, "maintainability": maintainability, "complexity": complexity, "halstead_volume": halstead_volume, - "loc": loc + "loc": loc, }) - + # Track total maintainability total_maintainability += maintainability function_count += 1 - + # Categorize maintainability if maintainability >= 70: maintainability_result["maintainability_distribution"]["high"] += 1 elif maintainability >= 50: - maintainability_result["maintainability_distribution"]["medium"] += 1 + maintainability_result["maintainability_distribution"][ + "medium" + ] += 1 else: maintainability_result["maintainability_distribution"]["low"] += 1 - + # Flag low maintainability functions maintainability_result["low_maintainability_functions"].append({ "name": func_name, "file": file_path, - "line": func.line if hasattr(func, 'line') else None, + "line": func.line if hasattr(func, "line") else None, "maintainability": maintainability, "complexity": complexity, "halstead_volume": halstead_volume, - "loc": loc + "loc": loc, }) - + # Add issue - self.add_issue(Issue( - file=file_path, - line=func.line if hasattr(func, 'line') else None, - message=f"Low maintainability index: {maintainability:.1f}", - severity=IssueSeverity.WARNING, - category=IssueCategory.COMPLEXITY, - symbol=func_name, - suggestion="Consider refactoring this function to improve maintainability" - )) - + self.add_issue( + Issue( + file=file_path, + line=func.line if hasattr(func, "line") else None, + message=f"Low maintainability index: {maintainability:.1f}", + severity=IssueSeverity.WARNING, + category=IssueCategory.COMPLEXITY, + symbol=func_name, + suggestion="Consider refactoring this function to improve maintainability", + ) + ) + # Calculate average maintainability - maintainability_result["average_maintainability"] = total_maintainability / function_count if function_count > 0 else 0.0 - + maintainability_result["average_maintainability"] = ( + total_maintainability / function_count if function_count > 0 else 0.0 + ) + # Sort low maintainability functions - maintainability_result["low_maintainability_functions"].sort(key=lambda x: x["maintainability"]) - + maintainability_result["low_maintainability_functions"].sort( + key=lambda x: x["maintainability"] + ) + return maintainability_result + class DependencyAnalyzerPlugin(AnalyzerPlugin): """ Plugin for dependency analysis. - + This plugin detects issues related to dependencies, including import relationships, circular dependencies, and module coupling. """ - - def analyze(self) -> Dict[str, Any]: + + def analyze(self) -> dict[str, Any]: """ Perform dependency analysis. - + Returns: Dictionary containing dependency analysis results """ result = {} - + # Perform dependency checks result["import_dependencies"] = self._analyze_import_dependencies() result["circular_dependencies"] = self._find_circular_dependencies() result["module_coupling"] = self._analyze_module_coupling() result["external_dependencies"] = self._analyze_external_dependencies() - + return result - - def _analyze_import_dependencies(self) -> Dict[str, Any]: + + def _analyze_import_dependencies(self) -> dict[str, Any]: """Analyze import dependencies in the codebase.""" codebase = self.analyzer.base_codebase - + import_deps = { "module_dependencies": [], "file_dependencies": [], @@ -607,419 +685,497 @@ def _analyze_import_dependencies(self) -> Dict[str, Any]: "total_imports": 0, "internal_imports": 0, "external_imports": 0, - "relative_imports": 0 - } + "relative_imports": 0, + }, } - + # Create a directed graph for module dependencies G = nx.DiGraph() - + # Track import counts module_imports = {} # modules importing others module_imported = {} # modules being imported - + # Process all files to extract import information for file in codebase.files: # Skip if no imports - if not hasattr(file, 'imports') or not file.imports: + if not hasattr(file, "imports") or not file.imports: continue - + # Skip if file should be excluded if self.analyzer.should_skip_file(file): continue - + # Get file path - file_path = file.file_path if hasattr(file, 'file_path') else str(file.path) if hasattr(file, 'path') else str(file) - + file_path = ( + file.file_path + if hasattr(file, "file_path") + else str(file.path) + if hasattr(file, "path") + else str(file) + ) + # Extract module name from file path - file_parts = file_path.split('/') - module_name = '/'.join(file_parts[:-1]) if len(file_parts) > 1 else file_parts[0] - + file_parts = file_path.split("/") + module_name = ( + "/".join(file_parts[:-1]) if len(file_parts) > 1 else file_parts[0] + ) + # Initialize import counts if module_name not in module_imports: module_imports[module_name] = 0 - + # Process imports for imp in file.imports: import_deps["dependency_stats"]["total_imports"] += 1 - + # Get imported module information imported_file = None imported_module = "unknown" is_external = False - - if hasattr(imp, 'resolved_file'): + + if hasattr(imp, "resolved_file"): imported_file = imp.resolved_file - elif hasattr(imp, 'resolved_symbol') and hasattr(imp.resolved_symbol, 'file'): + elif hasattr(imp, "resolved_symbol") and hasattr( + imp.resolved_symbol, "file" + ): imported_file = imp.resolved_symbol.file - + if imported_file: # Get imported file path - imported_path = imported_file.file_path if hasattr(imported_file, 'file_path') else str(imported_file.path) if hasattr(imported_file, 'path') else str(imported_file) - + imported_path = ( + imported_file.file_path + if hasattr(imported_file, "file_path") + else str(imported_file.path) + if hasattr(imported_file, "path") + else str(imported_file) + ) + # Extract imported module name - imported_parts = imported_path.split('/') - imported_module = '/'.join(imported_parts[:-1]) if len(imported_parts) > 1 else imported_parts[0] - + imported_parts = imported_path.split("/") + imported_module = ( + "/".join(imported_parts[:-1]) + if len(imported_parts) > 1 + else imported_parts[0] + ) + # Check if external - is_external = hasattr(imported_file, 'is_external') and imported_file.is_external + is_external = ( + hasattr(imported_file, "is_external") + and imported_file.is_external + ) else: # If we couldn't resolve the import, use the import name - imported_module = imp.name if hasattr(imp, 'name') else "unknown" - + imported_module = imp.name if hasattr(imp, "name") else "unknown" + # Assume external if we couldn't resolve is_external = True - + # Update import type counts if is_external: import_deps["dependency_stats"]["external_imports"] += 1 else: import_deps["dependency_stats"]["internal_imports"] += 1 - + # Check if relative import - if hasattr(imp, 'is_relative') and imp.is_relative: + if hasattr(imp, "is_relative") and imp.is_relative: import_deps["dependency_stats"]["relative_imports"] += 1 - + # Update module import counts module_imports[module_name] += 1 - + if imported_module not in module_imported: module_imported[imported_module] = 0 module_imported[imported_module] += 1 - + # Add to dependency graph if module_name != imported_module: # Skip self-imports G.add_edge(module_name, imported_module) - + # Add to file dependencies list import_deps["file_dependencies"].append({ "source_file": file_path, "target_file": imported_path if imported_file else "unknown", - "import_name": imp.name if hasattr(imp, 'name') else "unknown", - "is_external": is_external + "import_name": imp.name if hasattr(imp, "name") else "unknown", + "is_external": is_external, }) - + # Extract module dependencies from graph for source, target in G.edges(): import_deps["module_dependencies"].append({ "source_module": source, - "target_module": target + "target_module": target, }) - + # Find most imported modules most_imported = sorted( - [(module, count) for module, count in module_imported.items()], + module_imported.items(), key=lambda x: x[1], - reverse=True + reverse=True, ) - + for module, count in most_imported[:10]: # Top 10 import_deps["most_imported_modules"].append({ "module": module, - "import_count": count + "import_count": count, }) - + # Find modules that import the most most_importing = sorted( - [(module, count) for module, count in module_imports.items()], + module_imports.items(), key=lambda x: x[1], - reverse=True + reverse=True, ) - + for module, count in most_importing[:10]: # Top 10 import_deps["most_importing_modules"].append({ "module": module, - "import_count": count + "import_count": count, }) - + return import_deps - - def _find_circular_dependencies(self) -> Dict[str, Any]: + + def _find_circular_dependencies(self) -> dict[str, Any]: """Find circular dependencies in the codebase.""" codebase = self.analyzer.base_codebase - + circular_deps = { "circular_imports": [], "circular_dependencies_count": 0, - "affected_modules": set() + "affected_modules": set(), } - + # Create dependency graph if not already available G = nx.DiGraph() - + # Process all files to build dependency graph for file in codebase.files: # Skip if no imports - if not hasattr(file, 'imports') or not file.imports: + if not hasattr(file, "imports") or not file.imports: continue - + # Skip if file should be excluded if self.analyzer.should_skip_file(file): continue - + # Get file path - file_path = file.file_path if hasattr(file, 'file_path') else str(file.path) if hasattr(file, 'path') else str(file) - + file_path = ( + file.file_path + if hasattr(file, "file_path") + else str(file.path) + if hasattr(file, "path") + else str(file) + ) + # Process imports for imp in file.imports: # Get imported file imported_file = None - - if hasattr(imp, 'resolved_file'): + + if hasattr(imp, "resolved_file"): imported_file = imp.resolved_file - elif hasattr(imp, 'resolved_symbol') and hasattr(imp.resolved_symbol, 'file'): + elif hasattr(imp, "resolved_symbol") and hasattr( + imp.resolved_symbol, "file" + ): imported_file = imp.resolved_symbol.file - + if imported_file: # Get imported file path - imported_path = imported_file.file_path if hasattr(imported_file, 'file_path') else str(imported_file.path) if hasattr(imported_file, 'path') else str(imported_file) - + imported_path = ( + imported_file.file_path + if hasattr(imported_file, "file_path") + else str(imported_file.path) + if hasattr(imported_file, "path") + else str(imported_file) + ) + # Add edge to graph G.add_edge(file_path, imported_path) - + # Find cycles in the graph try: cycles = list(nx.simple_cycles(G)) - + for cycle in cycles: circular_deps["circular_imports"].append({ "files": cycle, - "length": len(cycle) + "length": len(cycle), }) - + # Add affected modules to set for file_path in cycle: - module_path = '/'.join(file_path.split('/')[:-1]) + module_path = "/".join(file_path.split("/")[:-1]) circular_deps["affected_modules"].add(module_path) - + # Add issue if len(cycle) >= 2: - self.add_issue(Issue( - file=cycle[0], - line=None, - message=f"Circular dependency detected between {len(cycle)} files", - severity=IssueSeverity.ERROR, - category=IssueCategory.DEPENDENCY_CYCLE, - suggestion="Break the circular dependency by refactoring the code" - )) - + self.add_issue( + Issue( + file=cycle[0], + line=None, + message=f"Circular dependency detected between {len(cycle)} files", + severity=IssueSeverity.ERROR, + category=IssueCategory.DEPENDENCY_CYCLE, + suggestion="Break the circular dependency by refactoring the code", + ) + ) + except Exception as e: - logger.error(f"Error finding circular dependencies: {e}") - + logger.exception(f"Error finding circular dependencies: {e}") + # Update cycle count - circular_deps["circular_dependencies_count"] = len(circular_deps["circular_imports"]) + circular_deps["circular_dependencies_count"] = len( + circular_deps["circular_imports"] + ) circular_deps["affected_modules"] = list(circular_deps["affected_modules"]) - + return circular_deps - - def _analyze_module_coupling(self) -> Dict[str, Any]: + + def _analyze_module_coupling(self) -> dict[str, Any]: """Analyze module coupling in the codebase.""" codebase = self.analyzer.base_codebase - + coupling = { "high_coupling_modules": [], "low_coupling_modules": [], "coupling_metrics": {}, - "average_coupling": 0.0 + "average_coupling": 0.0, } - + # Create module dependency graphs modules = {} # Module name -> set of imported modules module_files = {} # Module name -> list of files - + # Process all files to extract module information for file in codebase.files: # Skip if file should be excluded if self.analyzer.should_skip_file(file): continue - + # Get file path - file_path = file.file_path if hasattr(file, 'file_path') else str(file.path) if hasattr(file, 'path') else str(file) - + file_path = ( + file.file_path + if hasattr(file, "file_path") + else str(file.path) + if hasattr(file, "path") + else str(file) + ) + # Extract module name from file path - module_parts = file_path.split('/') - module_name = '/'.join(module_parts[:-1]) if len(module_parts) > 1 else module_parts[0] - + module_parts = file_path.split("/") + module_name = ( + "/".join(module_parts[:-1]) + if len(module_parts) > 1 + else module_parts[0] + ) + # Initialize module structures if module_name not in modules: modules[module_name] = set() module_files[module_name] = [] - + module_files[module_name].append(file_path) - + # Skip if no imports - if not hasattr(file, 'imports') or not file.imports: + if not hasattr(file, "imports") or not file.imports: continue - + # Process imports for imp in file.imports: # Get imported file imported_file = None - - if hasattr(imp, 'resolved_file'): + + if hasattr(imp, "resolved_file"): imported_file = imp.resolved_file - elif hasattr(imp, 'resolved_symbol') and hasattr(imp.resolved_symbol, 'file'): + elif hasattr(imp, "resolved_symbol") and hasattr( + imp.resolved_symbol, "file" + ): imported_file = imp.resolved_symbol.file - + if imported_file: # Get imported file path - imported_path = imported_file.file_path if hasattr(imported_file, 'file_path') else str(imported_file.path) if hasattr(imported_file, 'path') else str(imported_file) - + imported_path = ( + imported_file.file_path + if hasattr(imported_file, "file_path") + else str(imported_file.path) + if hasattr(imported_file, "path") + else str(imported_file) + ) + # Extract imported module name - imported_parts = imported_path.split('/') - imported_module = '/'.join(imported_parts[:-1]) if len(imported_parts) > 1 else imported_parts[0] - + imported_parts = imported_path.split("/") + imported_module = ( + "/".join(imported_parts[:-1]) + if len(imported_parts) > 1 + else imported_parts[0] + ) + # Skip self-imports if imported_module != module_name: modules[module_name].add(imported_module) - + # Calculate coupling metrics for each module total_coupling = 0.0 module_count = 0 - + for module_name, imported_modules in modules.items(): # Calculate metrics file_count = len(module_files[module_name]) import_count = len(imported_modules) - + # Calculate coupling ratio (imports per file) coupling_ratio = import_count / file_count if file_count > 0 else 0 - + # Add to metrics coupling["coupling_metrics"][module_name] = { "files": file_count, "imported_modules": list(imported_modules), "import_count": import_count, - "coupling_ratio": coupling_ratio + "coupling_ratio": coupling_ratio, } - + # Track total for average total_coupling += coupling_ratio module_count += 1 - + # Categorize coupling if coupling_ratio > 3: # Threshold for "high coupling" coupling["high_coupling_modules"].append({ "module": module_name, "coupling_ratio": coupling_ratio, "import_count": import_count, - "file_count": file_count + "file_count": file_count, }) - + # Add issue - self.add_issue(Issue( - file=module_files[module_name][0] if module_files[module_name] else module_name, - line=None, - message=f"High module coupling: {coupling_ratio:.2f} imports per file", - severity=IssueSeverity.WARNING, - category=IssueCategory.DEPENDENCY_CYCLE, - suggestion="Consider refactoring to reduce coupling between modules" - )) - elif coupling_ratio < 0.5 and file_count > 1: # Threshold for "low coupling" + self.add_issue( + Issue( + file=module_files[module_name][0] + if module_files[module_name] + else module_name, + line=None, + message=f"High module coupling: {coupling_ratio:.2f} imports per file", + severity=IssueSeverity.WARNING, + category=IssueCategory.DEPENDENCY_CYCLE, + suggestion="Consider refactoring to reduce coupling between modules", + ) + ) + elif ( + coupling_ratio < 0.5 and file_count > 1 + ): # Threshold for "low coupling" coupling["low_coupling_modules"].append({ "module": module_name, "coupling_ratio": coupling_ratio, "import_count": import_count, - "file_count": file_count + "file_count": file_count, }) - + # Calculate average coupling - coupling["average_coupling"] = total_coupling / module_count if module_count > 0 else 0.0 - + coupling["average_coupling"] = ( + total_coupling / module_count if module_count > 0 else 0.0 + ) + # Sort coupling lists - coupling["high_coupling_modules"].sort(key=lambda x: x["coupling_ratio"], reverse=True) + coupling["high_coupling_modules"].sort( + key=lambda x: x["coupling_ratio"], reverse=True + ) coupling["low_coupling_modules"].sort(key=lambda x: x["coupling_ratio"]) - + return coupling - - def _analyze_external_dependencies(self) -> Dict[str, Any]: + + def _analyze_external_dependencies(self) -> dict[str, Any]: """Analyze external dependencies in the codebase.""" codebase = self.analyzer.base_codebase - + external_deps = { "external_modules": [], "external_module_usage": {}, - "most_used_external_modules": [] + "most_used_external_modules": [], } - + # Track external module usage external_usage = {} # Module name -> usage count - + # Process all imports to find external dependencies for file in codebase.files: # Skip if no imports - if not hasattr(file, 'imports') or not file.imports: + if not hasattr(file, "imports") or not file.imports: continue - + # Skip if file should be excluded if self.analyzer.should_skip_file(file): continue - + # Process imports for imp in file.imports: # Check if external import is_external = False external_name = None - - if hasattr(imp, 'module_name'): + + if hasattr(imp, "module_name"): external_name = imp.module_name - + # Check if this is an external module - if hasattr(imp, 'is_external'): + if hasattr(imp, "is_external"): is_external = imp.is_external - elif external_name and '.' not in external_name and '/' not in external_name: + elif ( + external_name + and "." not in external_name + and "/" not in external_name + ): # Simple heuristic: single-word module names without dots or slashes # are likely external modules is_external = True - + if is_external and external_name: # Add to external modules list if not already there if external_name not in external_usage: external_usage[external_name] = 0 external_deps["external_modules"].append(external_name) - + external_usage[external_name] += 1 - + # Add usage counts for module, count in external_usage.items(): external_deps["external_module_usage"][module] = count - + # Find most used external modules most_used = sorted( - [(module, count) for module, count in external_usage.items()], + external_usage.items(), key=lambda x: x[1], - reverse=True + reverse=True, ) - + for module, count in most_used[:10]: # Top 10 external_deps["most_used_external_modules"].append({ "module": module, - "usage_count": count + "usage_count": count, }) - + return external_deps + class UnifiedCodeAnalyzer: """ Unified Codebase Analyzer. - + This class provides a comprehensive framework for analyzing codebases, with support for pluggable analyzers for different types of analysis. """ - + def __init__( self, - repo_url: Optional[str] = None, - repo_path: Optional[str] = None, + repo_url: str | None = None, + repo_path: str | None = None, base_branch: str = "main", - pr_number: Optional[int] = None, - language: Optional[str] = None, - file_ignore_list: Optional[List[str]] = None, - config: Optional[Dict[str, Any]] = None + pr_number: int | None = None, + language: str | None = None, + file_ignore_list: list[str] | None = None, + config: dict[str, Any] | None = None, ): """ Initialize the unified analyzer. - + Args: repo_url: URL of the repository to analyze repo_path: Local path to the repository to analyze @@ -1034,101 +1190,101 @@ def __init__( self.base_branch = base_branch self.pr_number = pr_number self.language = language - + # Use custom ignore list or default global list self.file_ignore_list = file_ignore_list or GLOBAL_FILE_IGNORE_LIST - + # Configuration options self.config = config or {} - + # Codebase and context objects self.base_codebase = None self.pr_codebase = None self.base_context = None self.pr_context = None - + # Analysis results self.issues = [] self.results = {} - + # PR comparison data self.pr_diff = None self.commit_shas = None self.modified_symbols = None self.pr_branch = None - + # Initialize codebase(s) based on provided parameters if repo_url: self._init_from_url(repo_url, language) elif repo_path: self._init_from_path(repo_path, language) - + # If PR number is provided, initialize PR-specific data if self.pr_number is not None and self.base_codebase is not None: self._init_pr_data(self.pr_number) - + # Initialize contexts self._init_contexts() - + # Initialize analyzers self._init_analyzers() - - def _init_from_url(self, repo_url: str, language: Optional[str] = None): + + def _init_from_url(self, repo_url: str, language: str | None = None): """ Initialize codebase from a repository URL. - + Args: repo_url: URL of the repository language: Programming language of the codebase """ try: # Extract repository information - if repo_url.endswith('.git'): + if repo_url.endswith(".git"): repo_url = repo_url[:-4] - - parts = repo_url.rstrip('/').split('/') + + parts = repo_url.rstrip("/").split("/") repo_name = parts[-1] owner = parts[-2] repo_full_name = f"{owner}/{repo_name}" - + # Create temporary directory for cloning tmp_dir = tempfile.mkdtemp(prefix="analyzer_") - + # Set up configuration config = CodebaseConfig( debug=False, allow_external=True, py_resolve_syspath=True, ) - + secrets = SecretsConfig() - + # Determine programming language prog_lang = None if language: prog_lang = ProgrammingLanguage(language.upper()) - + # Initialize the codebase logger.info(f"Initializing codebase from {repo_url}") - + self.base_codebase = Codebase.from_github( repo_full_name=repo_full_name, tmp_dir=tmp_dir, language=prog_lang, config=config, - secrets=secrets + secrets=secrets, ) - + logger.info(f"Successfully initialized codebase from {repo_url}") - + except Exception as e: - logger.error(f"Error initializing codebase from URL: {e}") + logger.exception(f"Error initializing codebase from URL: {e}") raise - - def _init_from_path(self, repo_path: str, language: Optional[str] = None): + + def _init_from_path(self, repo_path: str, language: str | None = None): """ Initialize codebase from a local repository path. - + Args: repo_path: Path to the repository language: Programming language of the codebase @@ -1140,95 +1296,93 @@ def _init_from_path(self, repo_path: str, language: Optional[str] = None): allow_external=True, py_resolve_syspath=True, ) - + secrets = SecretsConfig() - + # Initialize the codebase logger.info(f"Initializing codebase from {repo_path}") - + # Determine programming language prog_lang = None if language: prog_lang = ProgrammingLanguage(language.upper()) - + # Set up repository configuration repo_config = RepoConfig.from_repo_path(repo_path) repo_config.respect_gitignore = False repo_operator = RepoOperator(repo_config=repo_config, bot_commit=False) - + # Create project configuration project_config = ProjectConfig( repo_operator=repo_operator, - programming_language=prog_lang if prog_lang else None + programming_language=prog_lang if prog_lang else None, ) - + # Initialize codebase self.base_codebase = Codebase( - projects=[project_config], - config=config, - secrets=secrets + projects=[project_config], config=config, secrets=secrets ) - + logger.info(f"Successfully initialized codebase from {repo_path}") - + except Exception as e: - logger.error(f"Error initializing codebase from path: {e}") + logger.exception(f"Error initializing codebase from path: {e}") raise - + def _init_pr_data(self, pr_number: int): """ Initialize PR-specific data. - + Args: pr_number: PR number to analyze """ try: logger.info(f"Fetching PR #{pr_number} data") result = self.base_codebase.get_modified_symbols_in_pr(pr_number) - + # Unpack the result tuple if len(result) >= 3: self.pr_diff, self.commit_shas, self.modified_symbols = result[:3] if len(result) >= 4: self.pr_branch = result[3] - + logger.info(f"Found {len(self.modified_symbols)} modified symbols in PR") - + # Initialize PR codebase self._init_pr_codebase() - + except Exception as e: - logger.error(f"Error initializing PR data: {e}") + logger.exception(f"Error initializing PR data: {e}") raise - + def _init_pr_codebase(self): """Initialize PR codebase by checking out the PR branch.""" if not self.base_codebase or not self.pr_number: logger.error("Base codebase or PR number not initialized") return - + try: # Get PR data if not already fetched if not self.pr_branch: self._init_pr_data(self.pr_number) - + if not self.pr_branch: logger.error("Failed to get PR branch") return - + # Clone the base codebase self.pr_codebase = self.base_codebase - + # Checkout PR branch logger.info(f"Checking out PR branch: {self.pr_branch}") self.pr_codebase.checkout(self.pr_branch) - + logger.info("Successfully initialized PR codebase") - + except Exception as e: - logger.error(f"Error initializing PR codebase: {e}") + logger.exception(f"Error initializing PR codebase: {e}") raise - + def _init_contexts(self): """Initialize CodebaseContext objects for both base and PR codebases.""" if self.base_codebase: @@ -1237,225 +1391,249 @@ def _init_contexts(self): codebase=self.base_codebase, base_path=self.repo_path, pr_branch=None, - base_branch=self.base_branch + base_branch=self.base_branch, ) logger.info("Successfully initialized base context") except Exception as e: - logger.error(f"Error initializing base context: {e}") - + logger.exception(f"Error initializing base context: {e}") + if self.pr_codebase: try: self.pr_context = CodebaseContext( codebase=self.pr_codebase, base_path=self.repo_path, pr_branch=self.pr_branch, - base_branch=self.base_branch + base_branch=self.base_branch, ) logger.info("Successfully initialized PR context") except Exception as e: - logger.error(f"Error initializing PR context: {e}") - + logger.exception(f"Error initializing PR context: {e}") + def _init_analyzers(self): """Initialize analyzer plugins.""" # Register default analyzers registry = AnalyzerRegistry() registry.register(AnalysisType.CODE_QUALITY, CodeQualityAnalyzerPlugin) registry.register(AnalysisType.DEPENDENCY, DependencyAnalyzerPlugin) - + def add_issue(self, issue: Issue): """ Add an issue to the list of detected issues. - + Args: issue: Issue to add """ # Check if issue should be skipped if self.should_skip_issue(issue): return - + self.issues.append(issue) - + def should_skip_issue(self, issue: Issue) -> bool: """ Check if an issue should be skipped based on file patterns. - + Args: issue: Issue to check - + Returns: True if the issue should be skipped, False otherwise """ # Skip issues in ignored files file_path = issue.file - + # Check against ignore list for pattern in self.file_ignore_list: if pattern in file_path: return True - + # Check if the file is a test file if "test" in file_path.lower() or "tests" in file_path.lower(): # Skip low-severity issues in test files if issue.severity in [IssueSeverity.INFO, IssueSeverity.WARNING]: return True - + return False - + def should_skip_file(self, file) -> bool: """ Check if a file should be skipped during analysis. - + Args: file: File to check - + Returns: True if the file should be skipped, False otherwise """ # Skip binary files - if hasattr(file, 'is_binary') and file.is_binary: + if hasattr(file, "is_binary") and file.is_binary: return True - + # Get file path - file_path = file.file_path if hasattr(file, 'file_path') else str(file.path) if hasattr(file, 'path') else str(file) - + file_path = ( + file.file_path + if hasattr(file, "file_path") + else str(file.path) + if hasattr(file, "path") + else str(file) + ) + # Check against ignore list - for pattern in self.file_ignore_list: - if pattern in file_path: - return True - - return False - + return any(pattern in file_path for pattern in self.file_ignore_list) + def should_skip_symbol(self, symbol) -> bool: """ Check if a symbol should be skipped during analysis. - + Args: symbol: Symbol to check - + Returns: True if the symbol should be skipped, False otherwise """ # Skip symbols without a file - if not hasattr(symbol, 'file'): + if not hasattr(symbol, "file"): return True - + # Skip symbols in skipped files return self.should_skip_file(symbol.file) - - def get_issues(self, severity: Optional[IssueSeverity] = None, category: Optional[IssueCategory] = None) -> List[Issue]: + + def get_issues( + self, + severity: IssueSeverity | None = None, + category: IssueCategory | None = None, + ) -> list[Issue]: """ Get all issues matching the specified criteria. - + Args: severity: Optional severity level to filter by category: Optional category to filter by - + Returns: List of matching issues """ filtered_issues = self.issues - + if severity: filtered_issues = [i for i in filtered_issues if i.severity == severity] - + if category: filtered_issues = [i for i in filtered_issues if i.category == category] - + return filtered_issues - - def analyze(self, analysis_types: Optional[List[AnalysisType]] = None) -> Dict[str, Any]: + + def analyze( + self, analysis_types: list[AnalysisType] | None = None + ) -> dict[str, Any]: """ Perform analysis on the codebase. - + Args: analysis_types: List of analysis types to perform. If None, performs CODE_QUALITY and DEPENDENCY analysis. - + Returns: Dictionary containing analysis results """ if not self.base_codebase: raise ValueError("Codebase not initialized") - + # Default to code quality and dependency analysis if analysis_types is None: analysis_types = [AnalysisType.CODE_QUALITY, AnalysisType.DEPENDENCY] - + # Initialize results self.results = { "metadata": { "analysis_time": datetime.now().isoformat(), "analysis_types": [t.value for t in analysis_types], - "repo_name": getattr(self.base_codebase.ctx, 'repo_name', None), - "language": str(getattr(self.base_codebase.ctx, 'programming_language', None)), + "repo_name": getattr(self.base_codebase.ctx, "repo_name", None), + "language": str( + getattr(self.base_codebase.ctx, "programming_language", None) + ), }, "summary": get_codebase_summary(self.base_codebase), - "results": {} + "results": {}, } - + # Clear issues self.issues = [] - + # Run each analyzer registry = AnalyzerRegistry() - + for analysis_type in analysis_types: analyzer_class = registry.get_analyzer(analysis_type) - + if analyzer_class: logger.info(f"Running {analysis_type.value} analysis") analyzer = analyzer_class(self) analysis_result = analyzer.analyze() - + # Add results to unified results self.results["results"][analysis_type.value] = analysis_result else: logger.warning(f"No analyzer found for {analysis_type.value}") - + # Add issues to results self.results["issues"] = [issue.to_dict() for issue in self.issues] - + # Add issue statistics self.results["issue_stats"] = { "total": len(self.issues), "by_severity": { - "critical": sum(1 for issue in self.issues if issue.severity == IssueSeverity.CRITICAL), - "error": sum(1 for issue in self.issues if issue.severity == IssueSeverity.ERROR), - "warning": sum(1 for issue in self.issues if issue.severity == IssueSeverity.WARNING), - "info": sum(1 for issue in self.issues if issue.severity == IssueSeverity.INFO), + "critical": sum( + 1 + for issue in self.issues + if issue.severity == IssueSeverity.CRITICAL + ), + "error": sum( + 1 for issue in self.issues if issue.severity == IssueSeverity.ERROR + ), + "warning": sum( + 1 + for issue in self.issues + if issue.severity == IssueSeverity.WARNING + ), + "info": sum( + 1 for issue in self.issues if issue.severity == IssueSeverity.INFO + ), }, "by_category": { - category.value: sum(1 for issue in self.issues if issue.category == category) + category.value: sum( + 1 for issue in self.issues if issue.category == category + ) for category in IssueCategory if any(issue.category == category for issue in self.issues) - } + }, } - + return self.results - + def save_results(self, output_file: str, format: str = "json"): """ Save analysis results to a file. - + Args: output_file: Path to the output file format: Output format (json, html, or console) """ if format == "json": - with open(output_file, 'w') as f: + with open(output_file, "w") as f: json.dump(self.results, f, indent=2) elif format == "html": self._generate_html_report(output_file) else: # Default to JSON - with open(output_file, 'w') as f: + with open(output_file, "w") as f: json.dump(self.results, f, indent=2) - + logger.info(f"Results saved to {output_file}") - + def _generate_html_report(self, output_file: str): """ Generate an HTML report of the analysis results. - + Args: output_file: Path to the output file """ @@ -1483,99 +1661,123 @@ def _generate_html_report(self, output_file: str):

    Codebase Analysis Report

    Summary

    -

    Repository: {self.results['metadata'].get('repo_name', 'Unknown')}

    -

    Language: {self.results['metadata'].get('language', 'Unknown')}

    -

    Analysis Time: {self.results['metadata'].get('analysis_time', 'Unknown')}

    -

    Analysis Types: {', '.join(self.results['metadata'].get('analysis_types', []))}

    +

    Repository: {self.results["metadata"].get("repo_name", "Unknown")}

    +

    Language: {self.results["metadata"].get("language", "Unknown")}

    +

    Analysis Time: {self.results["metadata"].get("analysis_time", "Unknown")}

    +

    Analysis Types: {", ".join(self.results["metadata"].get("analysis_types", []))}

    Total Issues: {len(self.issues)}

    - +

    Issues

    """ - + # Add issues grouped by severity - for severity in [IssueSeverity.CRITICAL, IssueSeverity.ERROR, IssueSeverity.WARNING, IssueSeverity.INFO]: - severity_issues = [issue for issue in self.issues if issue.severity == severity] - + for severity in [ + IssueSeverity.CRITICAL, + IssueSeverity.ERROR, + IssueSeverity.WARNING, + IssueSeverity.INFO, + ]: + severity_issues = [ + issue for issue in self.issues if issue.severity == severity + ] + if severity_issues: html_content += f"""

    {severity.value.upper()} Issues ({len(severity_issues)})

    """ - + for issue in severity_issues: - location = f"{issue.file}:{issue.line}" if issue.line else issue.file + location = ( + f"{issue.file}:{issue.line}" if issue.line else issue.file + ) category = f"[{issue.category.value}]" if issue.category else "" - + html_content += f"""

    {location} {category} {issue.message}

    {issue.suggestion}

    """ - + html_content += """
    """ - + # Add detailed analysis sections html_content += """

    Detailed Analysis

    """ - - for analysis_type, results in self.results.get('results', {}).items(): + + for analysis_type, results in self.results.get("results", {}).items(): html_content += f"""

    {analysis_type}

    {json.dumps(results, indent=2)}
    """ - + html_content += """
    """ - - with open(output_file, 'w') as f: + + with open(output_file, "w") as f: f.write(html_content) + def main(): """Command-line entry point for the unified analyzer.""" import argparse - + parser = argparse.ArgumentParser(description="Unified Codebase Analyzer") - + # Repository source options source_group = parser.add_mutually_exclusive_group(required=True) source_group.add_argument("--repo-url", help="URL of the repository to analyze") - source_group.add_argument("--repo-path", help="Local path to the repository to analyze") - + source_group.add_argument( + "--repo-path", help="Local path to the repository to analyze" + ) + # Analysis options - parser.add_argument("--analysis-types", nargs="+", choices=[at.value for at in AnalysisType], - default=["code_quality", "dependency"], - help="Types of analysis to perform") - parser.add_argument("--language", choices=["python", "typescript"], - help="Programming language (auto-detected if not provided)") - parser.add_argument("--base-branch", default="main", - help="Base branch for PR comparison (default: main)") - parser.add_argument("--pr-number", type=int, - help="PR number to analyze") - + parser.add_argument( + "--analysis-types", + nargs="+", + choices=[at.value for at in AnalysisType], + default=["code_quality", "dependency"], + help="Types of analysis to perform", + ) + parser.add_argument( + "--language", + choices=["python", "typescript"], + help="Programming language (auto-detected if not provided)", + ) + parser.add_argument( + "--base-branch", + default="main", + help="Base branch for PR comparison (default: main)", + ) + parser.add_argument("--pr-number", type=int, help="PR number to analyze") + # Output options - parser.add_argument("--output-format", choices=["json", "html", "console"], default="json", - help="Output format") - parser.add_argument("--output-file", - help="Path to the output file") - + parser.add_argument( + "--output-format", + choices=["json", "html", "console"], + default="json", + help="Output format", + ) + parser.add_argument("--output-file", help="Path to the output file") + args = parser.parse_args() - + try: # Initialize the analyzer analyzer = UnifiedCodeAnalyzer( @@ -1583,13 +1785,13 @@ def main(): repo_path=args.repo_path, base_branch=args.base_branch, pr_number=args.pr_number, - language=args.language + language=args.language, ) - + # Perform the analysis analysis_types = [AnalysisType(at) for at in args.analysis_types] results = analyzer.analyze(analysis_types) - + # Output the results if args.output_format == "json": if args.output_file: @@ -1601,33 +1803,43 @@ def main(): analyzer.save_results(output_file, "html") elif args.output_format == "console": # Print summary to console - print(f"\n===== Codebase Analysis Report =====") + print("\n===== Codebase Analysis Report =====") print(f"Repository: {results['metadata'].get('repo_name', 'Unknown')}") print(f"Language: {results['metadata'].get('language', 'Unknown')}") - print(f"Analysis Time: {results['metadata'].get('analysis_time', 'Unknown')}") - print(f"Analysis Types: {', '.join(results['metadata'].get('analysis_types', []))}") - - print(f"\n===== Issues Summary =====") + print( + f"Analysis Time: {results['metadata'].get('analysis_time', 'Unknown')}" + ) + print( + f"Analysis Types: {', '.join(results['metadata'].get('analysis_types', []))}" + ) + + print("\n===== Issues Summary =====") print(f"Total: {results['issue_stats']['total']}") - print(f"Critical: {results['issue_stats']['by_severity'].get('critical', 0)}") + print( + f"Critical: {results['issue_stats']['by_severity'].get('critical', 0)}" + ) print(f"Errors: {results['issue_stats']['by_severity'].get('error', 0)}") - print(f"Warnings: {results['issue_stats']['by_severity'].get('warning', 0)}") + print( + f"Warnings: {results['issue_stats']['by_severity'].get('warning', 0)}" + ) print(f"Info: {results['issue_stats']['by_severity'].get('info', 0)}") - - print(f"\n===== Top Issues =====") + + print("\n===== Top Issues =====") for i, issue in enumerate(analyzer.issues[:10]): severity = issue.severity.value.upper() location = f"{issue.file}:{issue.line}" if issue.line else issue.file category = f"[{issue.category.value}]" if issue.category else "" - print(f"{i+1}. [{severity}] {location} {category} {issue.message}") + print(f"{i + 1}. [{severity}] {location} {category} {issue.message}") print(f" Suggestion: {issue.suggestion}") print() - + except Exception as e: import traceback + print(f"Error: {e}") traceback.print_exc() sys.exit(1) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/analysis_visualizer.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/analysis_visualizer.py index 3d7ea333a..a7c2a3f77 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/visualization/analysis_visualizer.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/analysis_visualizer.py @@ -7,32 +7,33 @@ """ import logging -from typing import Dict, List, Optional, Any, Union -from enum import Enum -from .visualizer import BaseVisualizer, VisualizationType, OutputFormat +from .visualizer import BaseVisualizer, OutputFormat, VisualizationType try: - import networkx as nx import matplotlib.pyplot as plt + import networkx as nx from matplotlib.colors import LinearSegmentedColormap except ImportError: - logging.warning("Visualization dependencies not found. Please install them with: pip install networkx matplotlib") + logging.warning( + "Visualization dependencies not found. Please install them with: pip install networkx matplotlib" + ) logger = logging.getLogger(__name__) + class AnalysisVisualizer(BaseVisualizer): """ Visualizer for code analysis results. - + This class provides methods to visualize analysis results such as dead code detection, cyclomatic complexity, and issue heatmaps. """ - + def __init__(self, analyzer=None, codebase=None, context=None, **kwargs): """ Initialize the AnalysisVisualizer. - + Args: analyzer: Analyzer with analysis results codebase: Codebase instance to visualize @@ -43,53 +44,56 @@ def __init__(self, analyzer=None, codebase=None, context=None, **kwargs): self.analyzer = analyzer self.codebase = codebase or (analyzer.base_codebase if analyzer else None) self.context = context or (analyzer.base_context if analyzer else None) - - def visualize_dead_code(self, path_filter: Optional[str] = None): + + def visualize_dead_code(self, path_filter: str | None = None): """ Generate a visualization of dead (unused) code in the codebase. - + Args: path_filter: Optional path to filter files - + Returns: Visualization data or path to saved file """ entity_name = path_filter or "codebase" - + # Initialize graph self._initialize_graph() - + # Check for analyzer if not self.analyzer: logger.error("Analyzer required for dead code visualization") return None - + # Check for analysis results if not hasattr(self.analyzer, "results") or not self.analyzer.results: logger.error("Analysis results not available") return None - + # Extract dead code information from analysis results dead_code = {} - if "static_analysis" in self.analyzer.results and "dead_code" in self.analyzer.results["static_analysis"]: + if ( + "static_analysis" in self.analyzer.results + and "dead_code" in self.analyzer.results["static_analysis"] + ): dead_code = self.analyzer.results["static_analysis"]["dead_code"] - + if not dead_code: logger.warning("No dead code detected in analysis results") return None - + # Create file nodes for containing dead code file_nodes = {} - + # Process unused functions if "unused_functions" in dead_code: for unused_func in dead_code["unused_functions"]: file_path = unused_func.get("file", "") - + # Skip if path filter is specified and doesn't match if path_filter and not file_path.startswith(path_filter): continue - + # Add file node if not already added if file_path not in file_nodes: # Find file in codebase @@ -98,51 +102,54 @@ def visualize_dead_code(self, path_filter: Optional[str] = None): if hasattr(file, "path") and str(file.path) == file_path: file_obj = file break - + if file_obj: file_name = file_path.split("/")[-1] - file_id = self._add_node( + self._add_node( file_obj, name=file_name, color=self.config.color_palette.get("File"), - file_path=file_path + file_path=file_path, ) - + file_nodes[file_path] = file_obj - + # Add unused function node func_name = unused_func.get("name", "") func_line = unused_func.get("line", None) - + # Create a placeholder for the function (we don't have the actual object) - func_obj = {"name": func_name, "file_path": file_path, "line": func_line, "type": "Function"} - - func_id = self._add_node( + func_obj = { + "name": func_name, + "file_path": file_path, + "line": func_line, + "type": "Function", + } + + self._add_node( func_obj, name=func_name, color=self.config.color_palette.get("Dead"), file_path=file_path, line=func_line, - is_dead=True + is_dead=True, ) - + # Add edge from file to function if file_path in file_nodes: self._add_edge( - file_nodes[file_path], - func_obj, - type="contains_dead" + file_nodes[file_path], func_obj, type="contains_dead" ) - + # Process unused variables if "unused_variables" in dead_code: for unused_var in dead_code["unused_variables"]: file_path = unused_var.get("file", "") - + # Skip if path filter is specified and doesn't match if path_filter and not file_path.startswith(path_filter): continue - + # Add file node if not already added if file_path not in file_nodes: # Find file in codebase @@ -151,248 +158,289 @@ def visualize_dead_code(self, path_filter: Optional[str] = None): if hasattr(file, "path") and str(file.path) == file_path: file_obj = file break - + if file_obj: file_name = file_path.split("/")[-1] - file_id = self._add_node( + self._add_node( file_obj, name=file_name, color=self.config.color_palette.get("File"), - file_path=file_path + file_path=file_path, ) - + file_nodes[file_path] = file_obj - + # Add unused variable node var_name = unused_var.get("name", "") var_line = unused_var.get("line", None) - + # Create a placeholder for the variable - var_obj = {"name": var_name, "file_path": file_path, "line": var_line, "type": "Variable"} - - var_id = self._add_node( + var_obj = { + "name": var_name, + "file_path": file_path, + "line": var_line, + "type": "Variable", + } + + self._add_node( var_obj, name=var_name, color=self.config.color_palette.get("Dead"), file_path=file_path, line=var_line, - is_dead=True + is_dead=True, ) - + # Add edge from file to variable if file_path in file_nodes: - self._add_edge( - file_nodes[file_path], - var_obj, - type="contains_dead" - ) - + self._add_edge(file_nodes[file_path], var_obj, type="contains_dead") + # Generate visualization data if self.config.output_format == OutputFormat.JSON: data = self._convert_graph_to_json() - return self._save_visualization(VisualizationType.DEAD_CODE, entity_name, data) + return self._save_visualization( + VisualizationType.DEAD_CODE, entity_name, data + ) else: fig = self._plot_graph() - return self._save_visualization(VisualizationType.DEAD_CODE, entity_name, fig) - - def visualize_cyclomatic_complexity(self, path_filter: Optional[str] = None): + return self._save_visualization( + VisualizationType.DEAD_CODE, entity_name, fig + ) + + def visualize_cyclomatic_complexity(self, path_filter: str | None = None): """ Generate a heatmap visualization of cyclomatic complexity. - + Args: path_filter: Optional path to filter files - + Returns: Visualization data or path to saved file """ entity_name = path_filter or "codebase" - + # Check for analyzer if not self.analyzer: logger.error("Analyzer required for complexity visualization") return None - + # Check for analysis results if not hasattr(self.analyzer, "results") or not self.analyzer.results: logger.error("Analysis results not available") return None - + # Extract complexity information from analysis results complexity_data = {} - if "static_analysis" in self.analyzer.results and "code_complexity" in self.analyzer.results["static_analysis"]: - complexity_data = self.analyzer.results["static_analysis"]["code_complexity"] - + if ( + "static_analysis" in self.analyzer.results + and "code_complexity" in self.analyzer.results["static_analysis"] + ): + complexity_data = self.analyzer.results["static_analysis"][ + "code_complexity" + ] + if not complexity_data: logger.warning("No complexity data found in analysis results") return None - + # Extract function complexities functions = [] if "function_complexity" in complexity_data: for func_data in complexity_data["function_complexity"]: # Skip if path filter is specified and doesn't match - if path_filter and not func_data.get("file", "").startswith(path_filter): + if path_filter and not func_data.get("file", "").startswith( + path_filter + ): continue - + functions.append({ "name": func_data.get("name", ""), "file": func_data.get("file", ""), "complexity": func_data.get("complexity", 1), - "line": func_data.get("line", None) + "line": func_data.get("line", None), }) - + # Sort functions by complexity (descending) functions.sort(key=lambda x: x.get("complexity", 0), reverse=True) - + # Generate heatmap visualization plt.figure(figsize=(12, 10)) - + # Extract data for heatmap - func_names = [f"{func['name']} ({func['file'].split('/')[-1]})" for func in functions[:30]] + func_names = [ + f"{func['name']} ({func['file'].split('/')[-1]})" for func in functions[:30] + ] complexities = [func.get("complexity", 0) for func in functions[:30]] - + # Create horizontal bar chart bars = plt.barh(func_names, complexities) - + # Color bars by complexity norm = plt.Normalize(1, max(10, max(complexities))) - cmap = plt.cm.get_cmap('YlOrRd') - + cmap = plt.cm.get_cmap("YlOrRd") + for i, bar in enumerate(bars): complexity = complexities[i] bar.set_color(cmap(norm(complexity))) - + # Add labels and title - plt.xlabel('Cyclomatic Complexity') - plt.title('Top Functions by Cyclomatic Complexity') - plt.grid(axis='x', linestyle='--', alpha=0.6) - + plt.xlabel("Cyclomatic Complexity") + plt.title("Top Functions by Cyclomatic Complexity") + plt.grid(axis="x", linestyle="--", alpha=0.6) + # Add colorbar - plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label='Complexity') - + plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label="Complexity") + # Save and return visualization - return self._save_visualization(VisualizationType.CYCLOMATIC_COMPLEXITY, entity_name, plt.gcf()) - - def visualize_issues_heatmap(self, severity=None, path_filter: Optional[str] = None): + return self._save_visualization( + VisualizationType.CYCLOMATIC_COMPLEXITY, entity_name, plt.gcf() + ) + + def visualize_issues_heatmap(self, severity=None, path_filter: str | None = None): """ Generate a heatmap visualization of issues in the codebase. - + Args: severity: Optional severity level to filter issues path_filter: Optional path to filter files - + Returns: Visualization data or path to saved file """ entity_name = f"{severity.value if severity else 'all'}_issues" - + # Check for analyzer if not self.analyzer: logger.error("Analyzer required for issues visualization") return None - + # Check for analysis results - if not hasattr(self.analyzer, "results") or "issues" not in self.analyzer.results: + if ( + not hasattr(self.analyzer, "results") + or "issues" not in self.analyzer.results + ): logger.error("Issues not available in analysis results") return None - + issues = self.analyzer.results["issues"] - + # Filter issues by severity if specified if severity: issues = [issue for issue in issues if issue.get("severity") == severity] - + # Filter issues by path if specified if path_filter: - issues = [issue for issue in issues if issue.get("file", "").startswith(path_filter)] - + issues = [ + issue + for issue in issues + if issue.get("file", "").startswith(path_filter) + ] + if not issues: logger.warning("No issues found matching the criteria") return None - + # Group issues by file file_issues = {} for issue in issues: file_path = issue.get("file", "") if file_path not in file_issues: file_issues[file_path] = [] - + file_issues[file_path].append(issue) - + # Generate heatmap visualization plt.figure(figsize=(12, 10)) - + # Extract data for heatmap files = list(file_issues.keys()) file_names = [file_path.split("/")[-1] for file_path in files] issue_counts = [len(file_issues[file_path]) for file_path in files] - + # Sort by issue count - sorted_data = sorted(zip(file_names, issue_counts, files), key=lambda x: x[1], reverse=True) - file_names, issue_counts, files = zip(*sorted_data) - + sorted_data = sorted( + zip(file_names, issue_counts, files, strict=False), + key=lambda x: x[1], + reverse=True, + ) + file_names, issue_counts, files = zip(*sorted_data, strict=False) + # Create horizontal bar chart bars = plt.barh(file_names[:20], issue_counts[:20]) - + # Color bars by issue count norm = plt.Normalize(1, max(5, max(issue_counts[:20]))) - cmap = plt.cm.get_cmap('OrRd') - + cmap = plt.cm.get_cmap("OrRd") + for i, bar in enumerate(bars): count = issue_counts[i] bar.set_color(cmap(norm(count))) - + # Add labels and title - plt.xlabel('Number of Issues') + plt.xlabel("Number of Issues") severity_text = f" ({severity.value})" if severity else "" - plt.title(f'Files with the Most Issues{severity_text}') - plt.grid(axis='x', linestyle='--', alpha=0.6) - + plt.title(f"Files with the Most Issues{severity_text}") + plt.grid(axis="x", linestyle="--", alpha=0.6) + # Add colorbar - plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label='Issue Count') - + plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label="Issue Count") + # Save and return visualization - return self._save_visualization(VisualizationType.ISSUES_HEATMAP, entity_name, plt.gcf()) - + return self._save_visualization( + VisualizationType.ISSUES_HEATMAP, entity_name, plt.gcf() + ) + def visualize_pr_comparison(self): """ Generate a visualization comparing base branch with PR. - + Returns: Visualization data or path to saved file """ # Check for analyzer with PR data - if not self.analyzer or not hasattr(self.analyzer, "pr_codebase") or not self.analyzer.pr_codebase or not self.analyzer.base_codebase: + if ( + not self.analyzer + or not hasattr(self.analyzer, "pr_codebase") + or not self.analyzer.pr_codebase + or not self.analyzer.base_codebase + ): logger.error("PR comparison requires analyzer with PR data") return None - - entity_name = f"pr_{self.analyzer.pr_number}" if hasattr(self.analyzer, "pr_number") and self.analyzer.pr_number else "pr_comparison" - + + entity_name = ( + f"pr_{self.analyzer.pr_number}" + if hasattr(self.analyzer, "pr_number") and self.analyzer.pr_number + else "pr_comparison" + ) + # Check for analysis results - if not hasattr(self.analyzer, "results") or "comparison" not in self.analyzer.results: + if ( + not hasattr(self.analyzer, "results") + or "comparison" not in self.analyzer.results + ): logger.error("Comparison data not available in analysis results") return None - + comparison = self.analyzer.results["comparison"] - + # Initialize graph self._initialize_graph() - + # Process symbol comparison data if "symbol_comparison" in comparison: for symbol_data in comparison["symbol_comparison"]: symbol_name = symbol_data.get("name", "") in_base = symbol_data.get("in_base", False) in_pr = symbol_data.get("in_pr", False) - + # Create a placeholder for the symbol symbol_obj = { "name": symbol_name, "in_base": in_base, "in_pr": in_pr, - "type": "Symbol" + "type": "Symbol", } - + # Determine node color based on presence in base and PR if in_base and in_pr: color = "#A5D6A7" # Light green (modified) @@ -400,123 +448,111 @@ def visualize_pr_comparison(self): color = "#EF9A9A" # Light red (removed) else: color = "#90CAF9" # Light blue (added) - + # Add node for symbol - symbol_id = self._add_node( + self._add_node( symbol_obj, name=symbol_name, color=color, in_base=in_base, - in_pr=in_pr + in_pr=in_pr, ) - + # Process parameter changes if available if "parameter_changes" in symbol_data: param_changes = symbol_data["parameter_changes"] - + # Process removed parameters for param in param_changes.get("removed", []): param_obj = { "name": param, "change_type": "removed", - "type": "Parameter" + "type": "Parameter", } - - param_id = self._add_node( + + self._add_node( param_obj, name=param, color="#EF9A9A", # Light red (removed) - change_type="removed" - ) - - self._add_edge( - symbol_obj, - param_obj, - type="removed_parameter" + change_type="removed", ) - + + self._add_edge(symbol_obj, param_obj, type="removed_parameter") + # Process added parameters for param in param_changes.get("added", []): param_obj = { "name": param, "change_type": "added", - "type": "Parameter" + "type": "Parameter", } - - param_id = self._add_node( + + self._add_node( param_obj, name=param, color="#90CAF9", # Light blue (added) - change_type="added" + change_type="added", ) - - self._add_edge( - symbol_obj, - param_obj, - type="added_parameter" - ) - + + self._add_edge(symbol_obj, param_obj, type="added_parameter") + # Process return type changes if available if "return_type_change" in symbol_data: return_type_change = symbol_data["return_type_change"] old_type = return_type_change.get("old", "None") new_type = return_type_change.get("new", "None") - + return_obj = { "name": f"{old_type} -> {new_type}", "old_type": old_type, "new_type": new_type, - "type": "ReturnType" + "type": "ReturnType", } - - return_id = self._add_node( + + self._add_node( return_obj, name=f"{old_type} -> {new_type}", color="#FFD54F", # Amber (changed) old_type=old_type, - new_type=new_type + new_type=new_type, ) - - self._add_edge( - symbol_obj, - return_obj, - type="return_type_change" - ) - + + self._add_edge(symbol_obj, return_obj, type="return_type_change") + # Process call site issues if available if "call_site_issues" in symbol_data: for issue in symbol_data["call_site_issues"]: issue_file = issue.get("file", "") issue_line = issue.get("line", None) issue_text = issue.get("issue", "") - + # Create a placeholder for the issue issue_obj = { "name": issue_text, "file": issue_file, "line": issue_line, - "type": "Issue" + "type": "Issue", } - - issue_id = self._add_node( + + self._add_node( issue_obj, name=f"{issue_file.split('/')[-1]}:{issue_line}", color="#EF5350", # Red (error) file_path=issue_file, line=issue_line, - issue_text=issue_text - ) - - self._add_edge( - symbol_obj, - issue_obj, - type="call_site_issue" + issue_text=issue_text, ) - + + self._add_edge(symbol_obj, issue_obj, type="call_site_issue") + # Generate visualization data if self.config.output_format == OutputFormat.JSON: data = self._convert_graph_to_json() - return self._save_visualization(VisualizationType.PR_COMPARISON, entity_name, data) + return self._save_visualization( + VisualizationType.PR_COMPARISON, entity_name, data + ) else: fig = self._plot_graph() - return self._save_visualization(VisualizationType.PR_COMPARISON, entity_name, fig) \ No newline at end of file + return self._save_visualization( + VisualizationType.PR_COMPARISON, entity_name, fig + ) diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/code_visualizer.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/code_visualizer.py index b6b196b7a..98c462643 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/visualization/code_visualizer.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/code_visualizer.py @@ -7,30 +7,32 @@ """ import logging -from typing import Dict, List, Set, Tuple, Any, Optional, Union -from .visualizer import BaseVisualizer, VisualizationType, OutputFormat +from .visualizer import BaseVisualizer, OutputFormat, VisualizationType try: - import networkx as nx import matplotlib.pyplot as plt + import networkx as nx except ImportError: - logging.warning("Visualization dependencies not found. Please install them with: pip install networkx matplotlib") + logging.warning( + "Visualization dependencies not found. Please install them with: pip install networkx matplotlib" + ) logger = logging.getLogger(__name__) + class CodeVisualizer(BaseVisualizer): """ Visualizer for code structures such as call graphs and dependencies. - + This class provides methods to visualize relationships between code entities including functions, classes, and modules. """ - + def __init__(self, codebase=None, context=None, **kwargs): """ Initialize the CodeVisualizer. - + Args: codebase: Codebase instance to visualize context: Context providing graph representation @@ -39,466 +41,528 @@ def __init__(self, codebase=None, context=None, **kwargs): super().__init__(**kwargs) self.codebase = codebase self.context = context - + # Initialize codebase if needed - if not self.codebase and not self.context and 'analyzer' in kwargs: - self.codebase = kwargs['analyzer'].base_codebase - self.context = kwargs['analyzer'].base_context - - def visualize_call_graph(self, function_name: str, max_depth: Optional[int] = None): + if not self.codebase and not self.context and "analyzer" in kwargs: + self.codebase = kwargs["analyzer"].base_codebase + self.context = kwargs["analyzer"].base_context + + def visualize_call_graph(self, function_name: str, max_depth: int | None = None): """ Generate a call graph visualization for a function. - + Args: function_name: Name of the function to visualize max_depth: Maximum depth of the call graph (overrides config) - + Returns: Visualization data or path to saved file """ # Set max depth - current_max_depth = max_depth if max_depth is not None else self.config.max_depth - + current_max_depth = ( + max_depth if max_depth is not None else self.config.max_depth + ) + # Initialize graph self._initialize_graph() - + # Find the function in the codebase function = None for func in self.codebase.functions: if func.name == function_name: function = func break - + if not function: logger.error(f"Function {function_name} not found in codebase") return None - + # Add root node - root_id = self._add_node( + self._add_node( function, name=function_name, color=self.config.color_palette.get("Root"), - is_root=True + is_root=True, ) - + # Recursively add call relationships - visited = set([function]) - + visited = {function} + def add_calls(func, depth=0): if depth >= current_max_depth: return - + # Skip if no function calls attribute if not hasattr(func, "function_calls"): return - + for call in func.function_calls: # Skip recursive calls if call.name == func.name: continue - + # Get the called function called_func = call.function_definition if not called_func: continue - + # Skip external modules if configured - if self.config.ignore_external and hasattr(called_func, "is_external") and called_func.is_external: + if ( + self.config.ignore_external + and hasattr(called_func, "is_external") + and called_func.is_external + ): continue - + # Generate name for display - if hasattr(called_func, "is_method") and called_func.is_method and hasattr(called_func, "parent_class"): + if ( + hasattr(called_func, "is_method") + and called_func.is_method + and hasattr(called_func, "parent_class") + ): called_name = f"{called_func.parent_class.name}.{called_func.name}" else: called_name = called_func.name - + # Add node for called function - called_id = self._add_node( + self._add_node( called_func, name=called_name, color=self.config.color_palette.get("Function"), - file_path=called_func.file.path if hasattr(called_func, "file") and hasattr(called_func.file, "path") else None + file_path=called_func.file.path + if hasattr(called_func, "file") + and hasattr(called_func.file, "path") + else None, ) - + # Add edge for call relationship self._add_edge( function, called_func, type="call", file_path=call.filepath if hasattr(call, "filepath") else None, - line=call.line if hasattr(call, "line") else None + line=call.line if hasattr(call, "line") else None, ) - + # Recursively process called function if called_func not in visited: visited.add(called_func) add_calls(called_func, depth + 1) - + # Start from the root function add_calls(function) - + # Generate visualization data if self.config.output_format == OutputFormat.JSON: data = self._convert_graph_to_json() - return self._save_visualization(VisualizationType.CALL_GRAPH, function_name, data) + return self._save_visualization( + VisualizationType.CALL_GRAPH, function_name, data + ) else: fig = self._plot_graph() - return self._save_visualization(VisualizationType.CALL_GRAPH, function_name, fig) - - def visualize_dependency_graph(self, symbol_name: str, max_depth: Optional[int] = None): + return self._save_visualization( + VisualizationType.CALL_GRAPH, function_name, fig + ) + + def visualize_dependency_graph( + self, symbol_name: str, max_depth: int | None = None + ): """ Generate a dependency graph visualization for a symbol. - + Args: symbol_name: Name of the symbol to visualize max_depth: Maximum depth of the dependency graph (overrides config) - + Returns: Visualization data or path to saved file """ # Set max depth - current_max_depth = max_depth if max_depth is not None else self.config.max_depth - + current_max_depth = ( + max_depth if max_depth is not None else self.config.max_depth + ) + # Initialize graph self._initialize_graph() - + # Find the symbol in the codebase symbol = None for sym in self.codebase.symbols: if hasattr(sym, "name") and sym.name == symbol_name: symbol = sym break - + if not symbol: logger.error(f"Symbol {symbol_name} not found in codebase") return None - + # Add root node - root_id = self._add_node( + self._add_node( symbol, name=symbol_name, color=self.config.color_palette.get("Root"), - is_root=True + is_root=True, ) - + # Recursively add dependencies - visited = set([symbol]) - + visited = {symbol} + def add_dependencies(sym, depth=0): if depth >= current_max_depth: return - + # Skip if no dependencies attribute if not hasattr(sym, "dependencies"): return - + for dep in sym.dependencies: dep_symbol = None - + if hasattr(dep, "__class__") and dep.__class__.__name__ == "Symbol": dep_symbol = dep elif hasattr(dep, "resolved_symbol"): dep_symbol = dep.resolved_symbol - + if not dep_symbol: continue - + # Skip external modules if configured - if self.config.ignore_external and hasattr(dep_symbol, "is_external") and dep_symbol.is_external: + if ( + self.config.ignore_external + and hasattr(dep_symbol, "is_external") + and dep_symbol.is_external + ): continue - + # Add node for dependency - dep_id = self._add_node( + self._add_node( dep_symbol, - name=dep_symbol.name if hasattr(dep_symbol, "name") else str(dep_symbol), - color=self.config.color_palette.get(dep_symbol.__class__.__name__, "#BBBBBB"), - file_path=dep_symbol.file.path if hasattr(dep_symbol, "file") and hasattr(dep_symbol.file, "path") else None + name=dep_symbol.name + if hasattr(dep_symbol, "name") + else str(dep_symbol), + color=self.config.color_palette.get( + dep_symbol.__class__.__name__, "#BBBBBB" + ), + file_path=dep_symbol.file.path + if hasattr(dep_symbol, "file") and hasattr(dep_symbol.file, "path") + else None, ) - + # Add edge for dependency relationship - self._add_edge( - sym, - dep_symbol, - type="depends_on" - ) - + self._add_edge(sym, dep_symbol, type="depends_on") + # Recursively process dependency if dep_symbol not in visited: visited.add(dep_symbol) add_dependencies(dep_symbol, depth + 1) - + # Start from the root symbol add_dependencies(symbol) - + # Generate visualization data if self.config.output_format == OutputFormat.JSON: data = self._convert_graph_to_json() - return self._save_visualization(VisualizationType.DEPENDENCY_GRAPH, symbol_name, data) + return self._save_visualization( + VisualizationType.DEPENDENCY_GRAPH, symbol_name, data + ) else: fig = self._plot_graph() - return self._save_visualization(VisualizationType.DEPENDENCY_GRAPH, symbol_name, fig) - - def visualize_blast_radius(self, symbol_name: str, max_depth: Optional[int] = None): + return self._save_visualization( + VisualizationType.DEPENDENCY_GRAPH, symbol_name, fig + ) + + def visualize_blast_radius(self, symbol_name: str, max_depth: int | None = None): """ Generate a blast radius visualization for a symbol. - + Args: symbol_name: Name of the symbol to visualize max_depth: Maximum depth of the blast radius (overrides config) - + Returns: Visualization data or path to saved file """ # Set max depth - current_max_depth = max_depth if max_depth is not None else self.config.max_depth - + current_max_depth = ( + max_depth if max_depth is not None else self.config.max_depth + ) + # Initialize graph self._initialize_graph() - + # Find the symbol in the codebase symbol = None for sym in self.codebase.symbols: if hasattr(sym, "name") and sym.name == symbol_name: symbol = sym break - + if not symbol: logger.error(f"Symbol {symbol_name} not found in codebase") return None - + # Add root node - root_id = self._add_node( + self._add_node( symbol, name=symbol_name, color=self.config.color_palette.get("Root"), - is_root=True + is_root=True, ) - + # Recursively add usages (reverse dependencies) - visited = set([symbol]) - + visited = {symbol} + def add_usages(sym, depth=0): if depth >= current_max_depth: return - + # Skip if no usages attribute if not hasattr(sym, "usages"): return - + for usage in sym.usages: # Skip if no usage symbol if not hasattr(usage, "usage_symbol"): continue - + usage_symbol = usage.usage_symbol - + # Skip external modules if configured - if self.config.ignore_external and hasattr(usage_symbol, "is_external") and usage_symbol.is_external: + if ( + self.config.ignore_external + and hasattr(usage_symbol, "is_external") + and usage_symbol.is_external + ): continue - + # Add node for usage - usage_id = self._add_node( + self._add_node( usage_symbol, - name=usage_symbol.name if hasattr(usage_symbol, "name") else str(usage_symbol), - color=self.config.color_palette.get(usage_symbol.__class__.__name__, "#BBBBBB"), - file_path=usage_symbol.file.path if hasattr(usage_symbol, "file") and hasattr(usage_symbol.file, "path") else None + name=usage_symbol.name + if hasattr(usage_symbol, "name") + else str(usage_symbol), + color=self.config.color_palette.get( + usage_symbol.__class__.__name__, "#BBBBBB" + ), + file_path=usage_symbol.file.path + if hasattr(usage_symbol, "file") + and hasattr(usage_symbol.file, "path") + else None, ) - + # Add edge for usage relationship - self._add_edge( - sym, - usage_symbol, - type="used_by" - ) - + self._add_edge(sym, usage_symbol, type="used_by") + # Recursively process usage if usage_symbol not in visited: visited.add(usage_symbol) add_usages(usage_symbol, depth + 1) - + # Start from the root symbol add_usages(symbol) - + # Generate visualization data if self.config.output_format == OutputFormat.JSON: data = self._convert_graph_to_json() - return self._save_visualization(VisualizationType.BLAST_RADIUS, symbol_name, data) + return self._save_visualization( + VisualizationType.BLAST_RADIUS, symbol_name, data + ) else: fig = self._plot_graph() - return self._save_visualization(VisualizationType.BLAST_RADIUS, symbol_name, fig) - + return self._save_visualization( + VisualizationType.BLAST_RADIUS, symbol_name, fig + ) + def visualize_class_methods(self, class_name: str): """ Generate a class methods visualization. - + Args: class_name: Name of the class to visualize - + Returns: Visualization data or path to saved file """ # Initialize graph self._initialize_graph() - + # Find the class in the codebase class_obj = None for cls in self.codebase.classes: if cls.name == class_name: class_obj = cls break - + if not class_obj: logger.error(f"Class {class_name} not found in codebase") return None - + # Add class node - class_id = self._add_node( + self._add_node( class_obj, name=class_name, color=self.config.color_palette.get("Class"), - is_root=True + is_root=True, ) - + # Skip if no methods attribute if not hasattr(class_obj, "methods"): logger.error(f"Class {class_name} has no methods attribute") return None - + # Add method nodes and connections method_ids = {} for method in class_obj.methods: method_name = f"{class_name}.{method.name}" - + # Add method node method_id = self._add_node( method, name=method_name, color=self.config.color_palette.get("Function"), - file_path=method.file.path if hasattr(method, "file") and hasattr(method.file, "path") else None + file_path=method.file.path + if hasattr(method, "file") and hasattr(method.file, "path") + else None, ) - + method_ids[method.name] = method_id - + # Add edge from class to method - self._add_edge( - class_obj, - method, - type="contains" - ) - + self._add_edge(class_obj, method, type="contains") + # Add call relationships between methods for method in class_obj.methods: # Skip if no function calls attribute if not hasattr(method, "function_calls"): continue - + for call in method.function_calls: # Get the called function called_func = call.function_definition if not called_func: continue - + # Only add edges between methods of this class - if hasattr(called_func, "is_method") and called_func.is_method and \ - hasattr(called_func, "parent_class") and called_func.parent_class == class_obj: + if ( + hasattr(called_func, "is_method") + and called_func.is_method + and hasattr(called_func, "parent_class") + and called_func.parent_class == class_obj + ): self._add_edge( method, called_func, type="calls", - line=call.line if hasattr(call, "line") else None + line=call.line if hasattr(call, "line") else None, ) - + # Generate visualization data if self.config.output_format == OutputFormat.JSON: data = self._convert_graph_to_json() - return self._save_visualization(VisualizationType.CLASS_METHODS, class_name, data) + return self._save_visualization( + VisualizationType.CLASS_METHODS, class_name, data + ) else: fig = self._plot_graph() - return self._save_visualization(VisualizationType.CLASS_METHODS, class_name, fig) - + return self._save_visualization( + VisualizationType.CLASS_METHODS, class_name, fig + ) + def visualize_module_dependencies(self, module_path: str): """ Generate a module dependencies visualization. - + Args: module_path: Path to the module to visualize - + Returns: Visualization data or path to saved file """ # Initialize graph self._initialize_graph() - + # Get all files in the module module_files = [] for file in self.codebase.files: if hasattr(file, "path") and str(file.path).startswith(module_path): module_files.append(file) - + if not module_files: logger.error(f"No files found in module {module_path}") return None - + # Add file nodes module_node_ids = {} for file in module_files: file_name = str(file.path).split("/")[-1] file_module = "/".join(str(file.path).split("/")[:-1]) - + # Add file node file_id = self._add_node( file, name=file_name, module=file_module, color=self.config.color_palette.get("File"), - file_path=str(file.path) + file_path=str(file.path), ) - + module_node_ids[str(file.path)] = file_id - + # Add import relationships for file in module_files: # Skip if no imports attribute if not hasattr(file, "imports"): continue - + for imp in file.imports: imported_file = None - + # Try to get imported file if hasattr(imp, "resolved_file"): imported_file = imp.resolved_file - elif hasattr(imp, "resolved_symbol") and hasattr(imp.resolved_symbol, "file"): + elif hasattr(imp, "resolved_symbol") and hasattr( + imp.resolved_symbol, "file" + ): imported_file = imp.resolved_symbol.file - + if not imported_file: continue - + # Skip external modules if configured - if self.config.ignore_external and hasattr(imported_file, "is_external") and imported_file.is_external: + if ( + self.config.ignore_external + and hasattr(imported_file, "is_external") + and imported_file.is_external + ): continue - + # Add node for imported file if not already added - imported_path = str(imported_file.path) if hasattr(imported_file, "path") else "" - + imported_path = ( + str(imported_file.path) if hasattr(imported_file, "path") else "" + ) + if imported_path not in module_node_ids: imported_name = imported_path.split("/")[-1] imported_module = "/".join(imported_path.split("/")[:-1]) - + imported_id = self._add_node( imported_file, name=imported_name, module=imported_module, - color=self.config.color_palette.get("External" if imported_path.startswith(module_path) else "File"), - file_path=imported_path + color=self.config.color_palette.get( + "External" + if imported_path.startswith(module_path) + else "File" + ), + file_path=imported_path, ) - + module_node_ids[imported_path] = imported_id - + # Add edge for import relationship self._add_edge( file, imported_file, type="imports", - import_name=imp.name if hasattr(imp, "name") else "" - ) \ No newline at end of file + import_name=imp.name if hasattr(imp, "name") else "", + ) diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/codebase_visualizer.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/codebase_visualizer.py index a7198f9a3..52f77eade 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/visualization/codebase_visualizer.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/codebase_visualizer.py @@ -7,42 +7,40 @@ easy-to-use API for generating various types of visualizations. """ +import argparse +import logging import os import sys -import logging -import argparse -from typing import Dict, List, Optional, Any, Union -from .visualizer import BaseVisualizer, VisualizationType, OutputFormat, VisualizationConfig -from .code_visualizer import CodeVisualizer from .analysis_visualizer import AnalysisVisualizer +from .code_visualizer import CodeVisualizer +from .visualizer import ( + OutputFormat, + VisualizationConfig, + VisualizationType, +) # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) + class CodebaseVisualizer: """ Main visualizer class providing a unified interface to all visualization capabilities. - + This class acts as a facade to the specialized visualizers, simplifying the generation of different types of visualizations for codebases. """ - - def __init__( - self, - analyzer=None, - codebase=None, - context=None, - config=None - ): + + def __init__(self, analyzer=None, codebase=None, context=None, config=None): """ Initialize the CodebaseVisualizer. - + Args: analyzer: Optional analyzer with analysis results codebase: Optional codebase to visualize @@ -53,55 +51,58 @@ def __init__( self.codebase = codebase or (analyzer.base_codebase if analyzer else None) self.context = context or (analyzer.base_context if analyzer else None) self.config = config or VisualizationConfig() - + # Initialize specialized visualizers self.code_visualizer = CodeVisualizer( analyzer=analyzer, codebase=self.codebase, context=self.context, - config=self.config + config=self.config, ) - + self.analysis_visualizer = AnalysisVisualizer( analyzer=analyzer, codebase=self.codebase, context=self.context, - config=self.config + config=self.config, ) - + # Create visualization directory if specified if self.config.output_directory: os.makedirs(self.config.output_directory, exist_ok=True) - + # Initialize codebase if needed if not self.codebase and not self.context: try: - from codegen_on_oss.current_code_codebase import get_selected_codebase from codegen_on_oss.analyzers.context_codebase import CodebaseContext - - logger.info("No codebase or context provided, initializing from current directory") + from codegen_on_oss.current_code_codebase import get_selected_codebase + + logger.info( + "No codebase or context provided, initializing from current directory" + ) self.codebase = get_selected_codebase() self.context = CodebaseContext( - codebase=self.codebase, - base_path=os.getcwd() + codebase=self.codebase, base_path=os.getcwd() ) - + # Update specialized visualizers self.code_visualizer.codebase = self.codebase self.code_visualizer.context = self.context self.analysis_visualizer.codebase = self.codebase self.analysis_visualizer.context = self.context except ImportError: - logger.error("Could not automatically initialize codebase. Please provide a codebase or context.") - + logger.exception( + "Could not automatically initialize codebase. Please provide a codebase or context." + ) + def visualize(self, visualization_type: VisualizationType, **kwargs): """ Generate a visualization of the specified type. - + Args: visualization_type: Type of visualization to generate **kwargs: Additional arguments for the specific visualization - + Returns: Visualization data or path to saved file """ @@ -111,7 +112,7 @@ def visualize(self, visualization_type: VisualizationType, **kwargs): VisualizationType.DEPENDENCY_GRAPH, VisualizationType.BLAST_RADIUS, VisualizationType.CLASS_METHODS, - VisualizationType.MODULE_DEPENDENCIES + VisualizationType.MODULE_DEPENDENCIES, ]: # Code structure visualizations return self._visualize_code_structure(visualization_type, **kwargs) @@ -119,39 +120,38 @@ def visualize(self, visualization_type: VisualizationType, **kwargs): VisualizationType.DEAD_CODE, VisualizationType.CYCLOMATIC_COMPLEXITY, VisualizationType.ISSUES_HEATMAP, - VisualizationType.PR_COMPARISON + VisualizationType.PR_COMPARISON, ]: # Analysis result visualizations return self._visualize_analysis_results(visualization_type, **kwargs) else: logger.error(f"Unsupported visualization type: {visualization_type}") return None - - def _visualize_code_structure(self, visualization_type: VisualizationType, **kwargs): + + def _visualize_code_structure( + self, visualization_type: VisualizationType, **kwargs + ): """ Generate a code structure visualization. - + Args: visualization_type: Type of visualization to generate **kwargs: Additional arguments for the specific visualization - + Returns: Visualization data or path to saved file """ if visualization_type == VisualizationType.CALL_GRAPH: return self.code_visualizer.visualize_call_graph( - function_name=kwargs.get("entity"), - max_depth=kwargs.get("max_depth") + function_name=kwargs.get("entity"), max_depth=kwargs.get("max_depth") ) elif visualization_type == VisualizationType.DEPENDENCY_GRAPH: return self.code_visualizer.visualize_dependency_graph( - symbol_name=kwargs.get("entity"), - max_depth=kwargs.get("max_depth") + symbol_name=kwargs.get("entity"), max_depth=kwargs.get("max_depth") ) elif visualization_type == VisualizationType.BLAST_RADIUS: return self.code_visualizer.visualize_blast_radius( - symbol_name=kwargs.get("entity"), - max_depth=kwargs.get("max_depth") + symbol_name=kwargs.get("entity"), max_depth=kwargs.get("max_depth") ) elif visualization_type == VisualizationType.CLASS_METHODS: return self.code_visualizer.visualize_class_methods( @@ -161,22 +161,24 @@ def _visualize_code_structure(self, visualization_type: VisualizationType, **kwa return self.code_visualizer.visualize_module_dependencies( module_path=kwargs.get("entity") ) - - def _visualize_analysis_results(self, visualization_type: VisualizationType, **kwargs): + + def _visualize_analysis_results( + self, visualization_type: VisualizationType, **kwargs + ): """ Generate an analysis results visualization. - + Args: visualization_type: Type of visualization to generate **kwargs: Additional arguments for the specific visualization - + Returns: Visualization data or path to saved file """ if not self.analyzer: logger.error(f"Analyzer required for {visualization_type} visualization") return None - + if visualization_type == VisualizationType.DEAD_CODE: return self.analysis_visualizer.visualize_dead_code( path_filter=kwargs.get("path_filter") @@ -187,234 +189,204 @@ def _visualize_analysis_results(self, visualization_type: VisualizationType, **k ) elif visualization_type == VisualizationType.ISSUES_HEATMAP: return self.analysis_visualizer.visualize_issues_heatmap( - severity=kwargs.get("severity"), - path_filter=kwargs.get("path_filter") + severity=kwargs.get("severity"), path_filter=kwargs.get("path_filter") ) elif visualization_type == VisualizationType.PR_COMPARISON: return self.analysis_visualizer.visualize_pr_comparison() - + # Convenience methods for common visualizations - def visualize_call_graph(self, function_name: str, max_depth: Optional[int] = None): + def visualize_call_graph(self, function_name: str, max_depth: int | None = None): """Convenience method for call graph visualization.""" return self.visualize( - VisualizationType.CALL_GRAPH, - entity=function_name, - max_depth=max_depth + VisualizationType.CALL_GRAPH, entity=function_name, max_depth=max_depth ) - - def visualize_dependency_graph(self, symbol_name: str, max_depth: Optional[int] = None): + + def visualize_dependency_graph( + self, symbol_name: str, max_depth: int | None = None + ): """Convenience method for dependency graph visualization.""" return self.visualize( - VisualizationType.DEPENDENCY_GRAPH, - entity=symbol_name, - max_depth=max_depth + VisualizationType.DEPENDENCY_GRAPH, entity=symbol_name, max_depth=max_depth ) - - def visualize_blast_radius(self, symbol_name: str, max_depth: Optional[int] = None): + + def visualize_blast_radius(self, symbol_name: str, max_depth: int | None = None): """Convenience method for blast radius visualization.""" return self.visualize( - VisualizationType.BLAST_RADIUS, - entity=symbol_name, - max_depth=max_depth + VisualizationType.BLAST_RADIUS, entity=symbol_name, max_depth=max_depth ) - + def visualize_class_methods(self, class_name: str): """Convenience method for class methods visualization.""" - return self.visualize( - VisualizationType.CLASS_METHODS, - entity=class_name - ) - + return self.visualize(VisualizationType.CLASS_METHODS, entity=class_name) + def visualize_module_dependencies(self, module_path: str): """Convenience method for module dependencies visualization.""" - return self.visualize( - VisualizationType.MODULE_DEPENDENCIES, - entity=module_path - ) - - def visualize_dead_code(self, path_filter: Optional[str] = None): + return self.visualize(VisualizationType.MODULE_DEPENDENCIES, entity=module_path) + + def visualize_dead_code(self, path_filter: str | None = None): """Convenience method for dead code visualization.""" - return self.visualize( - VisualizationType.DEAD_CODE, - path_filter=path_filter - ) - - def visualize_cyclomatic_complexity(self, path_filter: Optional[str] = None): + return self.visualize(VisualizationType.DEAD_CODE, path_filter=path_filter) + + def visualize_cyclomatic_complexity(self, path_filter: str | None = None): """Convenience method for cyclomatic complexity visualization.""" return self.visualize( - VisualizationType.CYCLOMATIC_COMPLEXITY, - path_filter=path_filter + VisualizationType.CYCLOMATIC_COMPLEXITY, path_filter=path_filter ) - - def visualize_issues_heatmap(self, severity=None, path_filter: Optional[str] = None): + + def visualize_issues_heatmap(self, severity=None, path_filter: str | None = None): """Convenience method for issues heatmap visualization.""" return self.visualize( - VisualizationType.ISSUES_HEATMAP, - severity=severity, - path_filter=path_filter + VisualizationType.ISSUES_HEATMAP, severity=severity, path_filter=path_filter ) - + def visualize_pr_comparison(self): """Convenience method for PR comparison visualization.""" - return self.visualize( - VisualizationType.PR_COMPARISON - ) + return self.visualize(VisualizationType.PR_COMPARISON) + # Command-line interface def main(): """ Command-line interface for the codebase visualizer. - + This function parses command-line arguments and generates visualizations based on the specified parameters. """ parser = argparse.ArgumentParser( description="Generate visualizations of codebase structure and analysis." ) - + # Repository options repo_group = parser.add_argument_group("Repository Options") + repo_group.add_argument("--repo-url", help="URL of the repository to analyze") repo_group.add_argument( - "--repo-url", - help="URL of the repository to analyze" - ) - repo_group.add_argument( - "--repo-path", - help="Local path to the repository to analyze" - ) - repo_group.add_argument( - "--language", - help="Programming language of the codebase" + "--repo-path", help="Local path to the repository to analyze" ) - + repo_group.add_argument("--language", help="Programming language of the codebase") + # Visualization options viz_group = parser.add_argument_group("Visualization Options") viz_group.add_argument( "--type", choices=[t.value for t in VisualizationType], required=True, - help="Type of visualization to generate" + help="Type of visualization to generate", ) viz_group.add_argument( - "--entity", - help="Name of the entity to visualize (function, class, file, etc.)" + "--entity", help="Name of the entity to visualize (function, class, file, etc.)" ) viz_group.add_argument( "--max-depth", type=int, default=5, - help="Maximum depth for recursive visualizations" - ) - viz_group.add_argument( - "--ignore-external", - action="store_true", - help="Ignore external dependencies" + help="Maximum depth for recursive visualizations", ) viz_group.add_argument( - "--severity", - help="Filter issues by severity" + "--ignore-external", action="store_true", help="Ignore external dependencies" ) - viz_group.add_argument( - "--path-filter", - help="Filter by file path" - ) - + viz_group.add_argument("--severity", help="Filter issues by severity") + viz_group.add_argument("--path-filter", help="Filter by file path") + # PR options pr_group = parser.add_argument_group("PR Options") + pr_group.add_argument("--pr-number", type=int, help="PR number to analyze") pr_group.add_argument( - "--pr-number", - type=int, - help="PR number to analyze" - ) - pr_group.add_argument( - "--base-branch", - default="main", - help="Base branch for comparison" + "--base-branch", default="main", help="Base branch for comparison" ) - + # Output options output_group = parser.add_argument_group("Output Options") output_group.add_argument( "--output-format", choices=[f.value for f in OutputFormat], default="json", - help="Output format for the visualization" + help="Output format for the visualization", ) output_group.add_argument( - "--output-directory", - help="Directory to save visualizations" + "--output-directory", help="Directory to save visualizations" ) output_group.add_argument( "--layout", choices=["spring", "kamada_kawai", "spectral"], default="spring", - help="Layout algorithm for graph visualization" + help="Layout algorithm for graph visualization", ) - + args = parser.parse_args() - + # Create visualizer configuration config = VisualizationConfig( max_depth=args.max_depth, ignore_external=args.ignore_external, output_format=OutputFormat(args.output_format), output_directory=args.output_directory, - layout_algorithm=args.layout + layout_algorithm=args.layout, ) - + try: # Import analyzer only if needed - if args.type in ["pr_comparison", "dead_code", "cyclomatic_complexity", "issues_heatmap"] or args.pr_number: + if ( + args.type + in ["pr_comparison", "dead_code", "cyclomatic_complexity", "issues_heatmap"] + or args.pr_number + ): from codegen_on_oss.analyzers.codebase_analyzer import CodebaseAnalyzer - + # Create analyzer analyzer = CodebaseAnalyzer( repo_url=args.repo_url, repo_path=args.repo_path, base_branch=args.base_branch, pr_number=args.pr_number, - language=args.language + language=args.language, ) else: analyzer = None except ImportError: - logger.warning("CodebaseAnalyzer not available. Some visualizations may not work.") + logger.warning( + "CodebaseAnalyzer not available. Some visualizations may not work." + ) analyzer = None - + # Create visualizer - visualizer = CodebaseVisualizer( - analyzer=analyzer, - config=config - ) - + visualizer = CodebaseVisualizer(analyzer=analyzer, config=config) + # Generate visualization based on type viz_type = VisualizationType(args.type) result = None - + # Process specific requirements for each visualization type - if viz_type in [ - VisualizationType.CALL_GRAPH, - VisualizationType.DEPENDENCY_GRAPH, - VisualizationType.BLAST_RADIUS, - VisualizationType.CLASS_METHODS, - VisualizationType.MODULE_DEPENDENCIES - ] and not args.entity: + if ( + viz_type + in [ + VisualizationType.CALL_GRAPH, + VisualizationType.DEPENDENCY_GRAPH, + VisualizationType.BLAST_RADIUS, + VisualizationType.CLASS_METHODS, + VisualizationType.MODULE_DEPENDENCIES, + ] + and not args.entity + ): logger.error(f"Entity name required for {viz_type} visualization") sys.exit(1) - - if viz_type == VisualizationType.PR_COMPARISON and not args.pr_number and not (analyzer and hasattr(analyzer, "pr_number")): + + if ( + viz_type == VisualizationType.PR_COMPARISON + and not args.pr_number + and not (analyzer and hasattr(analyzer, "pr_number")) + ): logger.error("PR number required for PR comparison visualization") sys.exit(1) - + # Generate visualization result = visualizer.visualize( viz_type, entity=args.entity, max_depth=args.max_depth, severity=args.severity, - path_filter=args.path_filter + path_filter=args.path_filter, ) - + # Output result if result: logger.info(f"Visualization completed: {result}") @@ -422,5 +394,6 @@ def main(): logger.error("Failed to generate visualization") sys.exit(1) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/codegen-on-oss/codegen_on_oss/analyzers/visualization/visualizer.py b/codegen-on-oss/codegen_on_oss/analyzers/visualization/visualizer.py index 7614dfaf5..81f4f61be 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/visualization/visualizer.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/visualization/visualizer.py @@ -7,25 +7,27 @@ of code structure, dependencies, and issues. """ -import os -import sys import json import logging -from enum import Enum -from pathlib import Path -from typing import Dict, List, Set, Tuple, Any, Optional, Union, TypeVar, cast, Callable -from datetime import datetime +import os from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any try: - import networkx as nx import matplotlib.pyplot as plt + import networkx as nx from matplotlib.colors import LinearSegmentedColormap except ImportError: - logging.warning("Visualization dependencies not found. Please install them with: pip install networkx matplotlib") + logging.warning( + "Visualization dependencies not found. Please install them with: pip install networkx matplotlib" + ) + class VisualizationType(str, Enum): """Types of visualizations supported by this module.""" + CALL_GRAPH = "call_graph" DEPENDENCY_GRAPH = "dependency_graph" BLAST_RADIUS = "blast_radius" @@ -36,81 +38,85 @@ class VisualizationType(str, Enum): ISSUES_HEATMAP = "issues_heatmap" PR_COMPARISON = "pr_comparison" + class OutputFormat(str, Enum): """Output formats for visualizations.""" + JSON = "json" PNG = "png" SVG = "svg" HTML = "html" DOT = "dot" + @dataclass class VisualizationConfig: """Configuration for visualization generation.""" + max_depth: int = 5 ignore_external: bool = True ignore_tests: bool = True node_size_base: int = 300 edge_width_base: float = 1.0 - filename_filter: Optional[List[str]] = None - symbol_filter: Optional[List[str]] = None + filename_filter: list[str] | None = None + symbol_filter: list[str] | None = None output_format: OutputFormat = OutputFormat.JSON - output_directory: Optional[str] = None + output_directory: str | None = None layout_algorithm: str = "spring" - highlight_nodes: List[str] = field(default_factory=list) + highlight_nodes: list[str] = field(default_factory=list) highlight_color: str = "#ff5555" - color_palette: Dict[str, str] = field(default_factory=lambda: { - "Function": "#a277ff", # Purple - "Class": "#ffca85", # Orange - "File": "#80CBC4", # Teal - "Module": "#81D4FA", # Light Blue - "Variable": "#B39DDB", # Light Purple - "Root": "#ef5350", # Red - "Warning": "#FFCA28", # Amber - "Error": "#EF5350", # Red - "Dead": "#78909C", # Gray - "External": "#B0BEC5", # Light Gray - }) + color_palette: dict[str, str] = field( + default_factory=lambda: { + "Function": "#a277ff", # Purple + "Class": "#ffca85", # Orange + "File": "#80CBC4", # Teal + "Module": "#81D4FA", # Light Blue + "Variable": "#B39DDB", # Light Purple + "Root": "#ef5350", # Red + "Warning": "#FFCA28", # Amber + "Error": "#EF5350", # Red + "Dead": "#78909C", # Gray + "External": "#B0BEC5", # Light Gray + } + ) + class BaseVisualizer: """ Base visualizer providing common functionality for different visualization types. - + This class implements the core operations needed for visualization, including graph creation, node and edge management, and output generation. """ - - def __init__( - self, - config: Optional[VisualizationConfig] = None - ): + + def __init__(self, config: VisualizationConfig | None = None): """ Initialize the BaseVisualizer. - + Args: config: Visualization configuration options """ self.config = config or VisualizationConfig() - + # Create visualization directory if specified if self.config.output_directory: os.makedirs(self.config.output_directory, exist_ok=True) - + # Initialize graph for visualization self.graph = nx.DiGraph() - + # Tracking current visualization self.current_visualization_type = None self.current_entity_name = None - + def _initialize_graph(self): """Initialize a fresh graph for visualization.""" self.graph = nx.DiGraph() - + def _add_node(self, node: Any, **attrs): """ Add a node to the visualization graph with attributes. - + Args: node: Node object to add **attrs: Node attributes @@ -118,10 +124,10 @@ def _add_node(self, node: Any, **attrs): # Skip if node already exists if self.graph.has_node(node): return - + # Generate node ID (memory address for unique identification) node_id = id(node) - + # Get node name if "name" in attrs: node_name = attrs["name"] @@ -131,11 +137,11 @@ def _add_node(self, node: Any, **attrs): node_name = str(node.path).split("/")[-1] else: node_name = str(node) - + # Determine node type and color node_type = node.__class__.__name__ color = attrs.get("color", self.config.color_palette.get(node_type, "#BBBBBB")) - + # Add node with attributes self.graph.add_node( node_id, @@ -143,15 +149,15 @@ def _add_node(self, node: Any, **attrs): name=node_name, type=node_type, color=color, - **attrs + **attrs, ) - + return node_id - + def _add_edge(self, source: Any, target: Any, **attrs): """ Add an edge to the visualization graph with attributes. - + Args: source: Source node target: Target node @@ -160,74 +166,81 @@ def _add_edge(self, source: Any, target: Any, **attrs): # Get node IDs source_id = id(source) target_id = id(target) - + # Add edge with attributes - self.graph.add_edge( - source_id, - target_id, - **attrs - ) - - def _generate_filename(self, visualization_type: VisualizationType, entity_name: str): + self.graph.add_edge(source_id, target_id, **attrs) + + def _generate_filename( + self, visualization_type: VisualizationType, entity_name: str + ): """ Generate a filename for the visualization. - + Args: visualization_type: Type of visualization entity_name: Name of the entity being visualized - + Returns: Generated filename """ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - sanitized_name = entity_name.replace("/", "_").replace("\\", "_").replace(".", "_") + sanitized_name = ( + entity_name.replace("/", "_").replace("\\", "_").replace(".", "_") + ) return f"{visualization_type.value}_{sanitized_name}_{timestamp}.{self.config.output_format.value}" - - def _save_visualization(self, visualization_type: VisualizationType, entity_name: str, data: Any): + + def _save_visualization( + self, visualization_type: VisualizationType, entity_name: str, data: Any + ): """ Save a visualization to file or return it. - + Args: visualization_type: Type of visualization entity_name: Name of the entity being visualized data: Visualization data to save - + Returns: Path to saved file or visualization data """ self.current_visualization_type = visualization_type self.current_entity_name = entity_name - + filename = self._generate_filename(visualization_type, entity_name) - + if self.config.output_directory: filepath = os.path.join(self.config.output_directory, filename) else: filepath = filename - + if self.config.output_format == OutputFormat.JSON: - with open(filepath, 'w') as f: + with open(filepath, "w") as f: json.dump(data, f, indent=2) elif self.config.output_format in [OutputFormat.PNG, OutputFormat.SVG]: # Save matplotlib figure - plt.savefig(filepath, format=self.config.output_format.value, bbox_inches='tight') + plt.savefig( + filepath, format=self.config.output_format.value, bbox_inches="tight" + ) plt.close() elif self.config.output_format == OutputFormat.DOT: # Save as DOT file for Graphviz try: from networkx.drawing.nx_agraph import write_dot + write_dot(self.graph, filepath) except ImportError: - logging.error("networkx.drawing.nx_agraph not available. Install pygraphviz for DOT format.") + logging.exception( + "networkx.drawing.nx_agraph not available. Install pygraphviz for DOT format." + ) return None - + logging.info(f"Visualization saved to {filepath}") return filepath - + def _convert_graph_to_json(self): """ Convert the networkx graph to a JSON-serializable dictionary. - + Returns: Dictionary representation of the graph """ @@ -240,19 +253,22 @@ def _convert_graph_to_json(self): "type": attrs.get("type", ""), "color": attrs.get("color", "#BBBBBB"), } - + # Add file path if available if "file_path" in attrs: node_data["file_path"] = attrs["file_path"] - + # Add other attributes for key, value in attrs.items(): if key not in ["name", "type", "color", "file_path", "original_node"]: - if isinstance(value, (str, int, float, bool, list, dict)) or value is None: + if ( + isinstance(value, str | int | float | bool | list | dict) + or value is None + ): node_data[key] = value - + nodes.append(node_data) - + edges = [] for source, target, attrs in self.graph.edges(data=True): # Create a serializable edge @@ -260,14 +276,17 @@ def _convert_graph_to_json(self): "source": source, "target": target, } - + # Add other attributes for key, value in attrs.items(): - if isinstance(value, (str, int, float, bool, list, dict)) or value is None: + if ( + isinstance(value, str | int | float | bool | list | dict) + or value is None + ): edge_data[key] = value - + edges.append(edge_data) - + return { "nodes": nodes, "edges": edges, @@ -277,18 +296,18 @@ def _convert_graph_to_json(self): "timestamp": datetime.now().isoformat(), "node_count": len(nodes), "edge_count": len(edges), - } + }, } - + def _plot_graph(self): """ Plot the graph using matplotlib. - + Returns: Matplotlib figure """ plt.figure(figsize=(12, 10)) - + # Extract node positions using specified layout algorithm if self.config.layout_algorithm == "spring": pos = nx.spring_layout(self.graph, seed=42) @@ -299,39 +318,43 @@ def _plot_graph(self): else: # Default to spring layout pos = nx.spring_layout(self.graph, seed=42) - + # Extract node colors - node_colors = [attrs.get("color", "#BBBBBB") for _, attrs in self.graph.nodes(data=True)] - + node_colors = [ + attrs.get("color", "#BBBBBB") for _, attrs in self.graph.nodes(data=True) + ] + # Extract node sizes (can be based on some metric) node_sizes = [self.config.node_size_base for _ in self.graph.nodes()] - + # Draw nodes nx.draw_networkx_nodes( - self.graph, pos, - node_color=node_colors, - node_size=node_sizes, - alpha=0.8 + self.graph, pos, node_color=node_colors, node_size=node_sizes, alpha=0.8 ) - + # Draw edges nx.draw_networkx_edges( - self.graph, pos, + self.graph, + pos, width=self.config.edge_width_base, alpha=0.6, arrows=True, - arrowsize=10 + arrowsize=10, ) - + # Draw labels nx.draw_networkx_labels( - self.graph, pos, - labels={node: attrs.get("name", "") for node, attrs in self.graph.nodes(data=True)}, + self.graph, + pos, + labels={ + node: attrs.get("name", "") + for node, attrs in self.graph.nodes(data=True) + }, font_size=8, - font_weight="bold" + font_weight="bold", ) - + plt.title(f"{self.current_visualization_type} - {self.current_entity_name}") plt.axis("off") - - return plt.gcf() \ No newline at end of file + + return plt.gcf() diff --git a/codegen-on-oss/codegen_on_oss/error_analyzer.py b/codegen-on-oss/codegen_on_oss/error_analyzer.py index 2c3355d18..a38bf9de0 100644 --- a/codegen-on-oss/codegen_on_oss/error_analyzer.py +++ b/codegen-on-oss/codegen_on_oss/error_analyzer.py @@ -7,29 +7,31 @@ dependencies, code quality, and more. """ -import os -import sys -import json -import time -import logging import argparse -import tempfile import datetime -import re +import json +import logging import math +import re +import sys +import tempfile +from typing import Any + import networkx as nx -from pathlib import Path -from typing import Dict, List, Set, Tuple, Any, Optional, Union, Callable -from collections import Counter, defaultdict -import matplotlib.pyplot as plt from rich.console import Console +from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TextColumn, + TimeElapsedColumn, +) from rich.table import Table -from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn try: - from codegen.sdk.core.codebase import Codebase from codegen.configs.models.codebase import CodebaseConfig from codegen.configs.models.secrets import SecretsConfig + from codegen.sdk.core.codebase import Codebase from codegen.shared.enums.programming_language import ProgrammingLanguage except ImportError: print("Codegen SDK not found. Please install it first.") @@ -38,8 +40,8 @@ # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) @@ -152,21 +154,27 @@ "count_lines", "calculate_maintainability_index", "get_maintainability_rank", - ] + ], } + class CodebaseAnalyzer: """ Comprehensive codebase analyzer using Codegen SDK. - + This class provides methods to analyze a codebase and extract detailed information about its structure, dependencies, code quality, and more. """ - - def __init__(self, repo_url: str = None, repo_path: str = None, language: str = None): + + def __init__( + self, + repo_url: str | None = None, + repo_path: str | None = None, + language: str | None = None, + ): """ Initialize the CodebaseAnalyzer. - + Args: repo_url: URL of the repository to analyze repo_path: Local path to the repository to analyze @@ -178,60 +186,66 @@ def __init__(self, repo_url: str = None, repo_path: str = None, language: str = self.codebase = None self.console = Console() self.results = {} - + # Initialize the codebase if repo_url: self._init_from_url(repo_url, language) elif repo_path: self._init_from_path(repo_path, language) - - def _init_from_url(self, repo_url: str, language: str = None): + + def _init_from_url(self, repo_url: str, language: str | None = None): """Initialize codebase from a repository URL.""" try: # Extract owner and repo name from URL - if repo_url.endswith('.git'): + if repo_url.endswith(".git"): repo_url = repo_url[:-4] - - parts = repo_url.rstrip('/').split('/') + + parts = repo_url.rstrip("/").split("/") repo_name = parts[-1] owner = parts[-2] repo_full_name = f"{owner}/{repo_name}" - + # Create a temporary directory for cloning tmp_dir = tempfile.mkdtemp(prefix="codebase_analyzer_") - + # Configure the codebase config = CodebaseConfig( debug=False, allow_external=True, py_resolve_syspath=True, ) - + secrets = SecretsConfig() - + # Initialize the codebase - self.console.print(f"[bold green]Initializing codebase from {repo_url}...[/bold green]") - + self.console.print( + f"[bold green]Initializing codebase from {repo_url}...[/bold green]" + ) + prog_lang = None if language: prog_lang = ProgrammingLanguage(language.upper()) - + self.codebase = Codebase.from_github( repo_full_name=repo_full_name, tmp_dir=tmp_dir, language=prog_lang, config=config, secrets=secrets, - full_history=True + full_history=True, + ) + + self.console.print( + f"[bold green]Successfully initialized codebase from {repo_url}[/bold green]" ) - - self.console.print(f"[bold green]Successfully initialized codebase from {repo_url}[/bold green]") - + except Exception as e: - self.console.print(f"[bold red]Error initializing codebase from URL: {e}[/bold red]") + self.console.print( + f"[bold red]Error initializing codebase from URL: {e}[/bold red]" + ) raise - - def _init_from_path(self, repo_path: str, language: str = None): + + def _init_from_path(self, repo_path: str, language: str | None = None): """Initialize codebase from a local repository path.""" try: # Configure the codebase @@ -240,48 +254,58 @@ def _init_from_path(self, repo_path: str, language: str = None): allow_external=True, py_resolve_syspath=True, ) - + secrets = SecretsConfig() - + # Initialize the codebase - self.console.print(f"[bold green]Initializing codebase from {repo_path}...[/bold green]") - + self.console.print( + f"[bold green]Initializing codebase from {repo_path}...[/bold green]" + ) + prog_lang = None if language: prog_lang = ProgrammingLanguage(language.upper()) - + self.codebase = Codebase( - repo_path=repo_path, - language=prog_lang, - config=config, - secrets=secrets + repo_path=repo_path, language=prog_lang, config=config, secrets=secrets ) - - self.console.print(f"[bold green]Successfully initialized codebase from {repo_path}[/bold green]") - + + self.console.print( + f"[bold green]Successfully initialized codebase from {repo_path}[/bold green]" + ) + except Exception as e: - self.console.print(f"[bold red]Error initializing codebase from path: {e}[/bold red]") + self.console.print( + f"[bold red]Error initializing codebase from path: {e}[/bold red]" + ) raise - - def analyze(self, categories: List[str] = None, output_format: str = "json", output_file: str = None): + + def analyze( + self, + categories: list[str] | None = None, + output_format: str = "json", + output_file: str | None = None, + ): """ Perform a comprehensive analysis of the codebase. - + Args: categories: List of categories to analyze. If None, all categories are analyzed. output_format: Format of the output (json, html, console) output_file: Path to the output file - + Returns: Dict containing the analysis results """ if not self.codebase: - raise ValueError("Codebase not initialized. Please initialize the codebase first.") - + raise ValueError( + "Codebase not initialized. Please initialize the codebase first." + ) + # If no categories specified, analyze all if not categories: categories = list(METRICS_CATEGORIES.keys()) - + # Initialize results dictionary self.results = { "metadata": { @@ -289,9 +313,9 @@ def analyze(self, categories: List[str] = None, output_format: str = "json", out "analysis_time": datetime.datetime.now().isoformat(), "language": str(self.codebase.ctx.programming_language), }, - "categories": {} + "categories": {}, } - + # Analyze each category with Progress( SpinnerColumn(), @@ -300,20 +324,24 @@ def analyze(self, categories: List[str] = None, output_format: str = "json", out TextColumn("[bold green]{task.completed}/{task.total}"), TimeElapsedColumn(), ) as progress: - task = progress.add_task("[bold green]Analyzing codebase...", total=len(categories)) - + task = progress.add_task( + "[bold green]Analyzing codebase...", total=len(categories) + ) + for category in categories: if category not in METRICS_CATEGORIES: - self.console.print(f"[bold yellow]Warning: Unknown category '{category}'. Skipping.[/bold yellow]") + self.console.print( + f"[bold yellow]Warning: Unknown category '{category}'. Skipping.[/bold yellow]" + ) progress.update(task, advance=1) continue - + self.console.print(f"[bold blue]Analyzing {category}...[/bold blue]") - + # Get the metrics for this category metrics = METRICS_CATEGORIES[category] category_results = {} - + # Run each metric for metric in metrics: try: @@ -322,78 +350,82 @@ def analyze(self, categories: List[str] = None, output_format: str = "json", out result = method() category_results[metric] = result else: - category_results[metric] = {"error": f"Method {metric} not implemented"} + category_results[metric] = { + "error": f"Method {metric} not implemented" + } except Exception as e: category_results[metric] = {"error": str(e)} - + # Add the results to the main results dictionary self.results["categories"][category] = category_results - + progress.update(task, advance=1) - + # Output the results if output_format == "json": if output_file: - with open(output_file, 'w') as f: + with open(output_file, "w") as f: json.dump(self.results, f, indent=2) - self.console.print(f"[bold green]Results saved to {output_file}[/bold green]") + self.console.print( + f"[bold green]Results saved to {output_file}[/bold green]" + ) else: return self.results elif output_format == "html": self._generate_html_report(output_file) elif output_format == "console": self._print_console_report() - + return self.results # # Codebase Structure Analysis Methods # - - def get_file_count(self) -> Dict[str, int]: + + def get_file_count(self) -> dict[str, int]: """Get the total number of files in the codebase.""" files = list(self.codebase.files) return { "total_files": len(files), - "source_files": len([f for f in files if not f.is_binary]) + "source_files": len([f for f in files if not f.is_binary]), } - - def get_files_by_language(self) -> Dict[str, int]: + + def get_files_by_language(self) -> dict[str, int]: """Get the distribution of files by language/extension.""" files = list(self.codebase.files) extensions = {} - + for file in files: if file.is_binary: continue - + ext = file.extension if not ext: ext = "(no extension)" - + if ext in extensions: extensions[ext] += 1 else: extensions[ext] = 1 - + return extensions - - def get_file_size_distribution(self) -> Dict[str, int]: + + def get_file_size_distribution(self) -> dict[str, int]: """Get the distribution of file sizes.""" files = list(self.codebase.files) size_ranges = { "small (< 1KB)": 0, "medium (1KB - 10KB)": 0, "large (10KB - 100KB)": 0, - "very large (> 100KB)": 0 + "very large (> 100KB)": 0, } - + for file in files: if file.is_binary: continue - + size = len(file.content) - + if size < 1024: size_ranges["small (< 1KB)"] += 1 elif size < 10240: @@ -402,124 +434,127 @@ def get_file_size_distribution(self) -> Dict[str, int]: size_ranges["large (10KB - 100KB)"] += 1 else: size_ranges["very large (> 100KB)"] += 1 - + return size_ranges - - def get_directory_structure(self) -> Dict[str, Any]: + + def get_directory_structure(self) -> dict[str, Any]: """Get the directory structure of the codebase.""" directories = {} - + for directory in self.codebase.directories: path = str(directory.path) - parent_path = str(directory.path.parent) if directory.path.parent != self.codebase.repo_path else "/" - + parent_path = ( + str(directory.path.parent) + if directory.path.parent != self.codebase.repo_path + else "/" + ) + if parent_path not in directories: directories[parent_path] = [] - + directories[parent_path].append({ "name": directory.path.name, "path": path, "files": len(directory.files), - "subdirectories": len(directory.subdirectories) + "subdirectories": len(directory.subdirectories), }) - + return directories - - def get_symbol_count(self) -> Dict[str, int]: + + def get_symbol_count(self) -> dict[str, int]: """Get the total count of symbols in the codebase.""" return { "total_symbols": len(list(self.codebase.symbols)), "classes": len(list(self.codebase.classes)), "functions": len(list(self.codebase.functions)), "global_vars": len(list(self.codebase.global_vars)), - "interfaces": len(list(self.codebase.interfaces)) + "interfaces": len(list(self.codebase.interfaces)), } - - def get_symbol_type_distribution(self) -> Dict[str, int]: + + def get_symbol_type_distribution(self) -> dict[str, int]: """Get the distribution of symbol types.""" symbols = list(self.codebase.symbols) distribution = {} - + for symbol in symbols: symbol_type = str(symbol.symbol_type) - + if symbol_type in distribution: distribution[symbol_type] += 1 else: distribution[symbol_type] = 1 - + return distribution - - def get_symbol_hierarchy(self) -> Dict[str, Any]: + + def get_symbol_hierarchy(self) -> dict[str, Any]: """Get the hierarchy of symbols in the codebase.""" classes = list(self.codebase.classes) hierarchy = {} - + for cls in classes: class_name = cls.name parent_classes = [] - + # Get parent classes if available if hasattr(cls, "parent_class_names"): parent_classes = cls.parent_class_names - + hierarchy[class_name] = { "parent_classes": parent_classes, "methods": [method.name for method in cls.methods], - "attributes": [attr.name for attr in cls.attributes] if hasattr(cls, "attributes") else [] + "attributes": [attr.name for attr in cls.attributes] + if hasattr(cls, "attributes") + else [], } - + return hierarchy - - def get_top_level_vs_nested_symbols(self) -> Dict[str, int]: + + def get_top_level_vs_nested_symbols(self) -> dict[str, int]: """Get the count of top-level vs nested symbols.""" symbols = list(self.codebase.symbols) top_level = 0 nested = 0 - + for symbol in symbols: if hasattr(symbol, "is_top_level") and symbol.is_top_level: top_level += 1 else: nested += 1 - - return { - "top_level": top_level, - "nested": nested - } - - def get_import_dependency_map(self) -> Dict[str, List[str]]: + + return {"top_level": top_level, "nested": nested} + + def get_import_dependency_map(self) -> dict[str, list[str]]: """Get a map of import dependencies.""" files = list(self.codebase.files) dependency_map = {} - + for file in files: if file.is_binary: continue - + file_path = file.file_path imports = [] - + for imp in file.imports: if hasattr(imp, "imported_symbol") and imp.imported_symbol: imported_symbol = imp.imported_symbol if hasattr(imported_symbol, "file") and imported_symbol.file: imports.append(imported_symbol.file.file_path) - + dependency_map[file_path] = imports - + return dependency_map - - def get_external_vs_internal_dependencies(self) -> Dict[str, int]: + + def get_external_vs_internal_dependencies(self) -> dict[str, int]: """Get the count of external vs internal dependencies.""" files = list(self.codebase.files) internal = 0 external = 0 - + for file in files: if file.is_binary: continue - + for imp in file.imports: if hasattr(imp, "imported_symbol") and imp.imported_symbol: imported_symbol = imp.imported_symbol @@ -529,197 +564,201 @@ def get_external_vs_internal_dependencies(self) -> Dict[str, int]: external += 1 else: external += 1 - - return { - "internal": internal, - "external": external - } - - def get_circular_imports(self) -> List[List[str]]: + + return {"internal": internal, "external": external} + + def get_circular_imports(self) -> list[list[str]]: """Detect circular imports in the codebase.""" files = list(self.codebase.files) dependency_map = {} - + # Build dependency graph for file in files: if file.is_binary: continue - + file_path = file.file_path imports = [] - + for imp in file.imports: if hasattr(imp, "imported_symbol") and imp.imported_symbol: imported_symbol = imp.imported_symbol if hasattr(imported_symbol, "file") and imported_symbol.file: imports.append(imported_symbol.file.file_path) - + dependency_map[file_path] = imports - + # Create a directed graph G = nx.DiGraph() - + # Add nodes and edges for file_path, imports in dependency_map.items(): G.add_node(file_path) for imp in imports: G.add_edge(file_path, imp) - + # Find cycles cycles = list(nx.simple_cycles(G)) - + return cycles - - def get_unused_imports(self) -> List[Dict[str, str]]: + + def get_unused_imports(self) -> list[dict[str, str]]: """Get a list of unused imports.""" files = list(self.codebase.files) unused_imports = [] - + for file in files: if file.is_binary: continue - + for imp in file.imports: if hasattr(imp, "usages") and len(imp.usages) == 0: unused_imports.append({ "file": file.file_path, - "import": imp.source + "import": imp.source, }) - + return unused_imports - - def get_module_coupling_metrics(self) -> Dict[str, float]: + + def get_module_coupling_metrics(self) -> dict[str, float]: """Calculate module coupling metrics.""" files = list(self.codebase.files) dependency_map = {} - + # Build dependency graph for file in files: if file.is_binary: continue - + file_path = file.file_path imports = [] - + for imp in file.imports: if hasattr(imp, "imported_symbol") and imp.imported_symbol: imported_symbol = imp.imported_symbol if hasattr(imported_symbol, "file") and imported_symbol.file: imports.append(imported_symbol.file.file_path) - + dependency_map[file_path] = imports - + # Calculate metrics total_files = len(dependency_map) total_dependencies = sum(len(deps) for deps in dependency_map.values()) - + if total_files == 0: return { "average_dependencies_per_file": 0, "max_dependencies": 0, - "coupling_factor": 0 + "coupling_factor": 0, } - - max_dependencies = max(len(deps) for deps in dependency_map.values()) if dependency_map else 0 - coupling_factor = total_dependencies / (total_files * (total_files - 1)) if total_files > 1 else 0 - + + max_dependencies = ( + max(len(deps) for deps in dependency_map.values()) if dependency_map else 0 + ) + coupling_factor = ( + total_dependencies / (total_files * (total_files - 1)) + if total_files > 1 + else 0 + ) + return { "average_dependencies_per_file": total_dependencies / total_files, "max_dependencies": max_dependencies, - "coupling_factor": coupling_factor + "coupling_factor": coupling_factor, } - - def get_module_cohesion_analysis(self) -> Dict[str, float]: + + def get_module_cohesion_analysis(self) -> dict[str, float]: """Analyze module cohesion.""" files = list(self.codebase.files) cohesion_metrics = {} - + for file in files: if file.is_binary: continue - + symbols = list(file.symbols) total_symbols = len(symbols) - + if total_symbols <= 1: continue - + # Count internal references internal_refs = 0 - + for symbol in symbols: if hasattr(symbol, "symbol_usages"): for usage in symbol.symbol_usages: if hasattr(usage, "file") and usage.file == file: internal_refs += 1 - + max_possible_refs = total_symbols * (total_symbols - 1) cohesion = internal_refs / max_possible_refs if max_possible_refs > 0 else 0 - + cohesion_metrics[file.file_path] = cohesion - + # Calculate average cohesion if cohesion_metrics: avg_cohesion = sum(cohesion_metrics.values()) / len(cohesion_metrics) else: avg_cohesion = 0 - - return { - "average_cohesion": avg_cohesion, - "file_cohesion": cohesion_metrics - } - - def get_package_structure(self) -> Dict[str, Any]: + + return {"average_cohesion": avg_cohesion, "file_cohesion": cohesion_metrics} + + def get_package_structure(self) -> dict[str, Any]: """Get the package structure of the codebase.""" directories = {} - + for directory in self.codebase.directories: path = str(directory.path) - parent_path = str(directory.path.parent) if directory.path.parent != self.codebase.repo_path else "/" - + parent_path = ( + str(directory.path.parent) + if directory.path.parent != self.codebase.repo_path + else "/" + ) + if parent_path not in directories: directories[parent_path] = [] - + # Check if this is a package (has __init__.py) is_package = any(f.name == "__init__.py" for f in directory.files) - + directories[parent_path].append({ "name": directory.path.name, "path": path, "is_package": is_package, "files": len(directory.files), - "subdirectories": len(directory.subdirectories) + "subdirectories": len(directory.subdirectories), }) - + return directories - - def get_module_dependency_graph(self) -> Dict[str, List[str]]: + + def get_module_dependency_graph(self) -> dict[str, list[str]]: """Get the module dependency graph.""" files = list(self.codebase.files) dependency_graph = {} - + for file in files: if file.is_binary: continue - + file_path = file.file_path imports = [] - + for imp in file.imports: if hasattr(imp, "imported_symbol") and imp.imported_symbol: imported_symbol = imp.imported_symbol if hasattr(imported_symbol, "file") and imported_symbol.file: imports.append(imported_symbol.file.file_path) - + dependency_graph[file_path] = imports - + return dependency_graph # # Symbol-Level Analysis Methods # - - def get_function_parameter_analysis(self) -> Dict[str, Any]: + + def get_function_parameter_analysis(self) -> dict[str, Any]: """Analyze function parameters.""" functions = list(self.codebase.functions) parameter_stats = { @@ -728,106 +767,116 @@ def get_function_parameter_analysis(self) -> Dict[str, Any]: "functions_with_no_parameters": 0, "functions_with_many_parameters": 0, # > 5 parameters "parameter_type_coverage": 0, - "functions_with_default_params": 0 + "functions_with_default_params": 0, } - + if not functions: return parameter_stats - + total_params = 0 functions_with_types = 0 functions_with_defaults = 0 - + for func in functions: params = func.parameters param_count = len(params) total_params += param_count - + if param_count == 0: parameter_stats["functions_with_no_parameters"] += 1 elif param_count > 5: parameter_stats["functions_with_many_parameters"] += 1 - + # Check for type annotations has_type_annotations = all(hasattr(p, "type") and p.type for p in params) if has_type_annotations: functions_with_types += 1 - + # Check for default values has_defaults = any(hasattr(p, "default") and p.default for p in params) if has_defaults: functions_with_defaults += 1 - + parameter_stats["total_parameters"] = total_params parameter_stats["avg_parameters_per_function"] = total_params / len(functions) - parameter_stats["parameter_type_coverage"] = functions_with_types / len(functions) if functions else 0 + parameter_stats["parameter_type_coverage"] = ( + functions_with_types / len(functions) if functions else 0 + ) parameter_stats["functions_with_default_params"] = functions_with_defaults - + return parameter_stats - - def get_return_type_analysis(self) -> Dict[str, Any]: + + def get_return_type_analysis(self) -> dict[str, Any]: """Analyze function return types.""" functions = list(self.codebase.functions) return_type_stats = { "functions_with_return_type": 0, "return_type_coverage": 0, - "common_return_types": {} + "common_return_types": {}, } - + if not functions: return return_type_stats - + functions_with_return_type = 0 return_types = {} - + for func in functions: if hasattr(func, "return_type") and func.return_type: functions_with_return_type += 1 - - return_type = str(func.return_type.source) if hasattr(func.return_type, "source") else str(func.return_type) - + + return_type = ( + str(func.return_type.source) + if hasattr(func.return_type, "source") + else str(func.return_type) + ) + if return_type in return_types: return_types[return_type] += 1 else: return_types[return_type] = 1 - + return_type_stats["functions_with_return_type"] = functions_with_return_type - return_type_stats["return_type_coverage"] = functions_with_return_type / len(functions) - + return_type_stats["return_type_coverage"] = functions_with_return_type / len( + functions + ) + # Get the most common return types sorted_types = sorted(return_types.items(), key=lambda x: x[1], reverse=True) - return_type_stats["common_return_types"] = dict(sorted_types[:10]) # Top 10 return types - + return_type_stats["common_return_types"] = dict( + sorted_types[:10] + ) # Top 10 return types + return return_type_stats - - def get_function_complexity_metrics(self) -> Dict[str, Any]: + + def get_function_complexity_metrics(self) -> dict[str, Any]: """Calculate function complexity metrics.""" functions = list(self.codebase.functions) complexity_metrics = { "avg_function_length": 0, "max_function_length": 0, "functions_by_complexity": { - "simple": 0, # < 10 lines - "moderate": 0, # 10-30 lines - "complex": 0, # 30-100 lines - "very_complex": 0 # > 100 lines - } + "simple": 0, # < 10 lines + "moderate": 0, # 10-30 lines + "complex": 0, # 30-100 lines + "very_complex": 0, # > 100 lines + }, } - + if not functions: return complexity_metrics - + total_length = 0 max_length = 0 - + for func in functions: # Calculate function length in lines func_source = func.source - func_lines = func_source.count('\n') + 1 - + func_lines = func_source.count("\n") + 1 + total_length += func_lines max_length = max(max_length, func_lines) - + # Categorize by complexity if func_lines < 10: complexity_metrics["functions_by_complexity"]["simple"] += 1 @@ -837,133 +886,141 @@ def get_function_complexity_metrics(self) -> Dict[str, Any]: complexity_metrics["functions_by_complexity"]["complex"] += 1 else: complexity_metrics["functions_by_complexity"]["very_complex"] += 1 - + complexity_metrics["avg_function_length"] = total_length / len(functions) complexity_metrics["max_function_length"] = max_length - + return complexity_metrics - - def get_call_site_tracking(self) -> Dict[str, Any]: + + def get_call_site_tracking(self) -> dict[str, Any]: """Track function call sites.""" functions = list(self.codebase.functions) call_site_stats = { "functions_with_no_calls": 0, "functions_with_many_calls": 0, # > 10 calls "avg_call_sites_per_function": 0, - "most_called_functions": [] + "most_called_functions": [], } - + if not functions: return call_site_stats - + function_calls = {} total_calls = 0 - + for func in functions: if hasattr(func, "call_sites"): call_count = len(func.call_sites) total_calls += call_count - + if call_count == 0: call_site_stats["functions_with_no_calls"] += 1 elif call_count > 10: call_site_stats["functions_with_many_calls"] += 1 - + function_calls[func.name] = call_count - + call_site_stats["avg_call_sites_per_function"] = total_calls / len(functions) - + # Get the most called functions - sorted_functions = sorted(function_calls.items(), key=lambda x: x[1], reverse=True) - call_site_stats["most_called_functions"] = [{"name": name, "calls": calls} for name, calls in sorted_functions[:10]] - + sorted_functions = sorted( + function_calls.items(), key=lambda x: x[1], reverse=True + ) + call_site_stats["most_called_functions"] = [ + {"name": name, "calls": calls} for name, calls in sorted_functions[:10] + ] + return call_site_stats - - def get_async_function_detection(self) -> Dict[str, Any]: + + def get_async_function_detection(self) -> dict[str, Any]: """Detect async functions.""" functions = list(self.codebase.functions) async_stats = { "total_async_functions": 0, "async_function_percentage": 0, - "async_functions": [] + "async_functions": [], } - + if not functions: return async_stats - + async_functions = [] - + for func in functions: if hasattr(func, "is_async") and func.is_async: async_functions.append({ "name": func.name, - "file": func.file.file_path if hasattr(func, "file") else "Unknown" + "file": func.file.file_path if hasattr(func, "file") else "Unknown", }) - + async_stats["total_async_functions"] = len(async_functions) async_stats["async_function_percentage"] = len(async_functions) / len(functions) async_stats["async_functions"] = async_functions - + return async_stats - - def get_function_overload_analysis(self) -> Dict[str, Any]: + + def get_function_overload_analysis(self) -> dict[str, Any]: """Analyze function overloads.""" functions = list(self.codebase.functions) overload_stats = { "total_overloaded_functions": 0, "overloaded_function_percentage": 0, - "overloaded_functions": [] + "overloaded_functions": [], } - + if not functions: return overload_stats - + overloaded_functions = [] function_names = {} - + for func in functions: name = func.name - + if name in function_names: function_names[name].append(func) else: function_names[name] = [func] - + for name, funcs in function_names.items(): if len(funcs) > 1: overloaded_functions.append({ "name": name, "overloads": len(funcs), - "file": funcs[0].file.file_path if hasattr(funcs[0], "file") else "Unknown" + "file": funcs[0].file.file_path + if hasattr(funcs[0], "file") + else "Unknown", }) - + overload_stats["total_overloaded_functions"] = len(overloaded_functions) - overload_stats["overloaded_function_percentage"] = len(overloaded_functions) / len(function_names) if function_names else 0 + overload_stats["overloaded_function_percentage"] = ( + len(overloaded_functions) / len(function_names) if function_names else 0 + ) overload_stats["overloaded_functions"] = overloaded_functions - + return overload_stats - - def get_inheritance_hierarchy(self) -> Dict[str, Any]: + + def get_inheritance_hierarchy(self) -> dict[str, Any]: """Get the inheritance hierarchy of classes.""" classes = list(self.codebase.classes) hierarchy = {} - + for cls in classes: class_name = cls.name parent_classes = [] - + # Get parent classes if available if hasattr(cls, "parent_class_names"): parent_classes = cls.parent_class_names - + hierarchy[class_name] = { "parent_classes": parent_classes, - "file": cls.file.file_path if hasattr(cls, "file") else "Unknown" + "file": cls.file.file_path if hasattr(cls, "file") else "Unknown", } - + # Build inheritance tree inheritance_tree = {} - + for class_name, info in hierarchy.items(): if not info["parent_classes"]: if class_name not in inheritance_tree: @@ -973,13 +1030,10 @@ def get_inheritance_hierarchy(self) -> Dict[str, Any]: if parent not in inheritance_tree: inheritance_tree[parent] = [] inheritance_tree[parent].append(class_name) - - return { - "class_hierarchy": hierarchy, - "inheritance_tree": inheritance_tree - } - - def get_method_analysis(self) -> Dict[str, Any]: + + return {"class_hierarchy": hierarchy, "inheritance_tree": inheritance_tree} + + def get_method_analysis(self) -> dict[str, Any]: """Analyze class methods.""" classes = list(self.codebase.classes) method_stats = { @@ -987,29 +1041,24 @@ def get_method_analysis(self) -> Dict[str, Any]: "avg_methods_per_class": 0, "classes_with_no_methods": 0, "classes_with_many_methods": 0, # > 10 methods - "method_types": { - "instance": 0, - "static": 0, - "class": 0, - "property": 0 - } + "method_types": {"instance": 0, "static": 0, "class": 0, "property": 0}, } - + if not classes: return method_stats - + total_methods = 0 - + for cls in classes: methods = cls.methods if hasattr(cls, "methods") else [] method_count = len(methods) total_methods += method_count - + if method_count == 0: method_stats["classes_with_no_methods"] += 1 elif method_count > 10: method_stats["classes_with_many_methods"] += 1 - + # Analyze method types for method in methods: if hasattr(method, "is_static") and method.is_static: @@ -1020,13 +1069,15 @@ def get_method_analysis(self) -> Dict[str, Any]: method_stats["method_types"]["property"] += 1 else: method_stats["method_types"]["instance"] += 1 - + method_stats["total_methods"] = total_methods - method_stats["avg_methods_per_class"] = total_methods / len(classes) if classes else 0 - + method_stats["avg_methods_per_class"] = ( + total_methods / len(classes) if classes else 0 + ) + return method_stats - - def get_attribute_analysis(self) -> Dict[str, Any]: + + def get_attribute_analysis(self) -> dict[str, Any]: """Analyze class attributes.""" classes = list(self.codebase.classes) attribute_stats = { @@ -1034,113 +1085,134 @@ def get_attribute_analysis(self) -> Dict[str, Any]: "avg_attributes_per_class": 0, "classes_with_no_attributes": 0, "classes_with_many_attributes": 0, # > 10 attributes - "attribute_types": {} + "attribute_types": {}, } - + if not classes: return attribute_stats - + total_attributes = 0 attribute_types = {} - + for cls in classes: attributes = cls.attributes if hasattr(cls, "attributes") else [] attr_count = len(attributes) total_attributes += attr_count - + if attr_count == 0: attribute_stats["classes_with_no_attributes"] += 1 elif attr_count > 10: attribute_stats["classes_with_many_attributes"] += 1 - + # Analyze attribute types for attr in attributes: if hasattr(attr, "type") and attr.type: - attr_type = str(attr.type.source) if hasattr(attr.type, "source") else str(attr.type) - + attr_type = ( + str(attr.type.source) + if hasattr(attr.type, "source") + else str(attr.type) + ) + if attr_type in attribute_types: attribute_types[attr_type] += 1 else: attribute_types[attr_type] = 1 - + attribute_stats["total_attributes"] = total_attributes - attribute_stats["avg_attributes_per_class"] = total_attributes / len(classes) if classes else 0 + attribute_stats["avg_attributes_per_class"] = ( + total_attributes / len(classes) if classes else 0 + ) attribute_stats["attribute_types"] = attribute_types - + return attribute_stats - - def get_constructor_analysis(self) -> Dict[str, Any]: + + def get_constructor_analysis(self) -> dict[str, Any]: """Analyze class constructors.""" classes = list(self.codebase.classes) constructor_stats = { "classes_with_constructor": 0, "constructor_percentage": 0, - "avg_constructor_params": 0 + "avg_constructor_params": 0, } - + if not classes: return constructor_stats - + classes_with_constructor = 0 total_constructor_params = 0 - + for cls in classes: constructor = None - + # Find constructor for method in cls.methods: if hasattr(method, "is_constructor") and method.is_constructor: constructor = method break - + if constructor: classes_with_constructor += 1 - param_count = len(constructor.parameters) if hasattr(constructor, "parameters") else 0 + param_count = ( + len(constructor.parameters) + if hasattr(constructor, "parameters") + else 0 + ) total_constructor_params += param_count - + constructor_stats["classes_with_constructor"] = classes_with_constructor - constructor_stats["constructor_percentage"] = classes_with_constructor / len(classes) - constructor_stats["avg_constructor_params"] = total_constructor_params / classes_with_constructor if classes_with_constructor else 0 - + constructor_stats["constructor_percentage"] = classes_with_constructor / len( + classes + ) + constructor_stats["avg_constructor_params"] = ( + total_constructor_params / classes_with_constructor + if classes_with_constructor + else 0 + ) + return constructor_stats - - def get_interface_implementation_verification(self) -> Dict[str, Any]: + + def get_interface_implementation_verification(self) -> dict[str, Any]: """Verify interface implementations.""" classes = list(self.codebase.classes) interfaces = list(self.codebase.interfaces) implementation_stats = { "total_interfaces": len(interfaces), "classes_implementing_interfaces": 0, - "interface_implementations": {} + "interface_implementations": {}, } - + if not interfaces or not classes: return implementation_stats - + # Map interfaces to implementing classes interface_implementations = {} - + for interface in interfaces: interface_name = interface.name implementing_classes = [] - + for cls in classes: - if hasattr(cls, "parent_class_names") and interface_name in cls.parent_class_names: + if ( + hasattr(cls, "parent_class_names") + and interface_name in cls.parent_class_names + ): implementing_classes.append(cls.name) - + interface_implementations[interface_name] = implementing_classes - + # Count classes implementing interfaces classes_implementing = set() for implementers in interface_implementations.values(): classes_implementing.update(implementers) - - implementation_stats["classes_implementing_interfaces"] = len(classes_implementing) + + implementation_stats["classes_implementing_interfaces"] = len( + classes_implementing + ) implementation_stats["interface_implementations"] = interface_implementations - + return implementation_stats - - def get_access_modifier_usage(self) -> Dict[str, Any]: + + def get_access_modifier_usage(self) -> dict[str, Any]: """Analyze access modifier usage.""" symbols = list(self.codebase.symbols) access_stats = { @@ -1148,9 +1220,9 @@ def get_access_modifier_usage(self) -> Dict[str, Any]: "private": 0, "protected": 0, "internal": 0, - "unknown": 0 + "unknown": 0, } - + for symbol in symbols: if hasattr(symbol, "is_private") and symbol.is_private: access_stats["private"] += 1 @@ -1162,242 +1234,254 @@ def get_access_modifier_usage(self) -> Dict[str, Any]: access_stats["public"] += 1 else: access_stats["unknown"] += 1 - + return access_stats # # Code Quality Analysis Methods # - - def get_unused_functions(self) -> List[Dict[str, str]]: + + def get_unused_functions(self) -> list[dict[str, str]]: """Get a list of unused functions.""" functions = list(self.codebase.functions) unused_functions = [] - + for func in functions: if hasattr(func, "call_sites") and len(func.call_sites) == 0: # Skip special methods like __init__, __str__, etc. if hasattr(func, "is_magic") and func.is_magic: continue - + # Skip entry points and main functions if func.name in ["main", "__main__"]: continue - + unused_functions.append({ "name": func.name, - "file": func.file.file_path if hasattr(func, "file") else "Unknown" + "file": func.file.file_path if hasattr(func, "file") else "Unknown", }) - + return unused_functions - - def get_unused_classes(self) -> List[Dict[str, str]]: + + def get_unused_classes(self) -> list[dict[str, str]]: """Get a list of unused classes.""" classes = list(self.codebase.classes) unused_classes = [] - + for cls in classes: if hasattr(cls, "symbol_usages") and len(cls.symbol_usages) == 0: unused_classes.append({ "name": cls.name, - "file": cls.file.file_path if hasattr(cls, "file") else "Unknown" + "file": cls.file.file_path if hasattr(cls, "file") else "Unknown", }) - + return unused_classes - - def get_unused_variables(self) -> List[Dict[str, str]]: + + def get_unused_variables(self) -> list[dict[str, str]]: """Get a list of unused variables.""" global_vars = list(self.codebase.global_vars) unused_vars = [] - + for var in global_vars: if hasattr(var, "symbol_usages") and len(var.symbol_usages) == 0: unused_vars.append({ "name": var.name, - "file": var.file.file_path if hasattr(var, "file") else "Unknown" + "file": var.file.file_path if hasattr(var, "file") else "Unknown", }) - + return unused_vars - - def get_unused_imports(self) -> List[Dict[str, str]]: + + def get_unused_imports(self) -> list[dict[str, str]]: """Get a list of unused imports.""" files = list(self.codebase.files) unused_imports = [] - + for file in files: if file.is_binary: continue - + for imp in file.imports: if hasattr(imp, "usages") and len(imp.usages) == 0: unused_imports.append({ "file": file.file_path, - "import": imp.source + "import": imp.source, }) - + return unused_imports - - def get_similar_function_detection(self) -> List[Dict[str, Any]]: + + def get_similar_function_detection(self) -> list[dict[str, Any]]: """Detect similar functions.""" functions = list(self.codebase.functions) similar_functions = [] - + # Group functions by name function_groups = {} - + for func in functions: name = func.name - + if name in function_groups: function_groups[name].append(func) else: function_groups[name] = [func] - + # Find similar functions for name, funcs in function_groups.items(): if len(funcs) > 1: similar_functions.append({ "name": name, "count": len(funcs), - "files": [func.file.file_path if hasattr(func, "file") else "Unknown" for func in funcs] + "files": [ + func.file.file_path if hasattr(func, "file") else "Unknown" + for func in funcs + ], }) - + return similar_functions - - def get_repeated_code_patterns(self) -> Dict[str, Any]: + + def get_repeated_code_patterns(self) -> dict[str, Any]: """Detect repeated code patterns.""" functions = list(self.codebase.functions) - + # This is a simplified implementation that looks for functions with similar structure # A more advanced implementation would use code clone detection algorithms - + # Group functions by length (in lines) functions_by_length = {} - + for func in functions: func_source = func.source - func_lines = func_source.count('\n') + 1 - + func_lines = func_source.count("\n") + 1 + if func_lines in functions_by_length: functions_by_length[func_lines].append(func) else: functions_by_length[func_lines] = [func] - + # Find potential code clones (functions with same length) potential_clones = {} - + for length, funcs in functions_by_length.items(): if len(funcs) > 1: potential_clones[length] = [func.name for func in funcs] - - return { - "potential_code_clones": potential_clones - } - - def get_refactoring_opportunities(self) -> Dict[str, Any]: + + return {"potential_code_clones": potential_clones} + + def get_refactoring_opportunities(self) -> dict[str, Any]: """Identify refactoring opportunities.""" refactoring_opportunities = { "long_functions": [], "large_classes": [], "high_coupling_files": [], - "low_cohesion_files": [] + "low_cohesion_files": [], } - + # Find long functions functions = list(self.codebase.functions) for func in functions: func_source = func.source - func_lines = func_source.count('\n') + 1 - + func_lines = func_source.count("\n") + 1 + if func_lines > 50: # Threshold for long functions refactoring_opportunities["long_functions"].append({ "name": func.name, "file": func.file.file_path if hasattr(func, "file") else "Unknown", - "lines": func_lines + "lines": func_lines, }) - + # Find large classes classes = list(self.codebase.classes) for cls in classes: methods = cls.methods if hasattr(cls, "methods") else [] attributes = cls.attributes if hasattr(cls, "attributes") else [] - + if len(methods) + len(attributes) > 20: # Threshold for large classes refactoring_opportunities["large_classes"].append({ "name": cls.name, "file": cls.file.file_path if hasattr(cls, "file") else "Unknown", "methods": len(methods), - "attributes": len(attributes) + "attributes": len(attributes), }) - + # Find high coupling files files = list(self.codebase.files) for file in files: if file.is_binary: continue - + imports = file.imports if len(imports) > 15: # Threshold for high coupling refactoring_opportunities["high_coupling_files"].append({ "file": file.file_path, - "imports": len(imports) + "imports": len(imports), }) - + # Find low cohesion files cohesion_metrics = self.get_module_cohesion_analysis() file_cohesion = cohesion_metrics.get("file_cohesion", {}) - + for file_path, cohesion in file_cohesion.items(): if cohesion < 0.3: # Threshold for low cohesion refactoring_opportunities["low_cohesion_files"].append({ "file": file_path, - "cohesion": cohesion + "cohesion": cohesion, }) - + return refactoring_opportunities - - def calculate_cyclomatic_complexity(self) -> Dict[str, Any]: + + def calculate_cyclomatic_complexity(self) -> dict[str, Any]: """Calculate cyclomatic complexity for functions.""" functions = list(self.codebase.functions) complexity_results = { "avg_complexity": 0, "max_complexity": 0, "complexity_distribution": { - "low": 0, # 1-5 - "moderate": 0, # 6-10 - "high": 0, # 11-20 - "very_high": 0 # > 20 + "low": 0, # 1-5 + "moderate": 0, # 6-10 + "high": 0, # 11-20 + "very_high": 0, # > 20 }, - "complex_functions": [] + "complex_functions": [], } - + if not functions: return complexity_results - + total_complexity = 0 max_complexity = 0 complex_functions = [] - + for func in functions: # A simple approximation of cyclomatic complexity # In a real implementation, we would parse the AST and count decision points source = func.source - + # Count decision points if_count = source.count("if ") + source.count("elif ") for_count = source.count("for ") while_count = source.count("while ") - case_count = source.count("case ") + source.count("switch ") + source.count("match ") + case_count = ( + source.count("case ") + source.count("switch ") + source.count("match ") + ) catch_count = source.count("catch ") + source.count("except ") and_count = source.count(" && ") + source.count(" and ") or_count = source.count(" || ") + source.count(" or ") - + # Calculate complexity - complexity = 1 + if_count + for_count + while_count + case_count + catch_count + and_count + or_count - + complexity = ( + 1 + + if_count + + for_count + + while_count + + case_count + + catch_count + + and_count + + or_count + ) + total_complexity += complexity max_complexity = max(max_complexity, complexity) - + # Categorize complexity if complexity <= 5: complexity_results["complexity_distribution"]["low"] += 1 @@ -1407,26 +1491,28 @@ def calculate_cyclomatic_complexity(self) -> Dict[str, Any]: complexity_results["complexity_distribution"]["high"] += 1 else: complexity_results["complexity_distribution"]["very_high"] += 1 - + # Track complex functions if complexity > 10: complex_functions.append({ "name": func.name, "file": func.file.file_path if hasattr(func, "file") else "Unknown", - "complexity": complexity + "complexity": complexity, }) - + complexity_results["avg_complexity"] = total_complexity / len(functions) complexity_results["max_complexity"] = max_complexity - complexity_results["complex_functions"] = sorted(complex_functions, key=lambda x: x["complexity"], reverse=True)[:10] # Top 10 most complex - + complexity_results["complex_functions"] = sorted( + complex_functions, key=lambda x: x["complexity"], reverse=True + )[:10] # Top 10 most complex + return complexity_results - - def cc_rank(self) -> Dict[str, str]: + + def cc_rank(self) -> dict[str, str]: """Rank the codebase based on cyclomatic complexity.""" complexity_results = self.calculate_cyclomatic_complexity() avg_complexity = complexity_results["avg_complexity"] - + if avg_complexity < 5: rank = "A" description = "Excellent: Low complexity, highly maintainable code" @@ -1435,39 +1521,75 @@ def cc_rank(self) -> Dict[str, str]: description = "Good: Moderate complexity, maintainable code" elif avg_complexity < 15: rank = "C" - description = "Fair: Moderate to high complexity, some maintenance challenges" + description = ( + "Fair: Moderate to high complexity, some maintenance challenges" + ) elif avg_complexity < 20: rank = "D" description = "Poor: High complexity, difficult to maintain" else: rank = "F" - description = "Very Poor: Very high complexity, extremely difficult to maintain" - + description = ( + "Very Poor: Very high complexity, extremely difficult to maintain" + ) + return { "rank": rank, "description": description, - "avg_complexity": avg_complexity + "avg_complexity": avg_complexity, } - - def get_operators_and_operands(self) -> Dict[str, Any]: + + def get_operators_and_operands(self) -> dict[str, Any]: """Get operators and operands for Halstead metrics.""" files = list(self.codebase.files) - + # Define common operators - operators = ["+", "-", "*", "/", "%", "=", "==", "!=", "<", ">", "<=", ">=", - "&&", "||", "!", "&", "|", "^", "~", "<<", ">>", "++", "--", - "+=", "-=", "*=", "/=", "%=", "&=", "|=", "^=", "<<=", ">>="] - + operators = [ + "+", + "-", + "*", + "/", + "%", + "=", + "==", + "!=", + "<", + ">", + "<=", + ">=", + "&&", + "||", + "!", + "&", + "|", + "^", + "~", + "<<", + ">>", + "++", + "--", + "+=", + "-=", + "*=", + "/=", + "%=", + "&=", + "|=", + "^=", + "<<=", + ">>=", + ] + # Count operators and operands operator_count = {} operand_count = {} - + for file in files: if file.is_binary: continue - + content = file.content - + # Count operators for op in operators: count = content.count(op) @@ -1476,37 +1598,62 @@ def get_operators_and_operands(self) -> Dict[str, Any]: operator_count[op] += count else: operator_count[op] = count - + # Simplified operand counting (this is a rough approximation) # In a real implementation, we would parse the AST and extract identifiers - words = re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', content) + words = re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", content) for word in words: - if word not in ["if", "else", "for", "while", "return", "break", "continue", - "class", "def", "function", "import", "from", "as", "try", - "except", "finally", "with", "in", "is", "not", "and", "or"]: + if word not in [ + "if", + "else", + "for", + "while", + "return", + "break", + "continue", + "class", + "def", + "function", + "import", + "from", + "as", + "try", + "except", + "finally", + "with", + "in", + "is", + "not", + "and", + "or", + ]: if word in operand_count: operand_count[word] += 1 else: operand_count[word] = 1 - + return { "unique_operators": len(operator_count), "total_operators": sum(operator_count.values()), "unique_operands": len(operand_count), "total_operands": sum(operand_count.values()), - "top_operators": dict(sorted(operator_count.items(), key=lambda x: x[1], reverse=True)[:10]), - "top_operands": dict(sorted(operand_count.items(), key=lambda x: x[1], reverse=True)[:10]) + "top_operators": dict( + sorted(operator_count.items(), key=lambda x: x[1], reverse=True)[:10] + ), + "top_operands": dict( + sorted(operand_count.items(), key=lambda x: x[1], reverse=True)[:10] + ), } - - def calculate_halstead_volume(self) -> Dict[str, float]: + + def calculate_halstead_volume(self) -> dict[str, float]: """Calculate Halstead volume metrics.""" operators_and_operands = self.get_operators_and_operands() - + n1 = operators_and_operands["unique_operators"] n2 = operators_and_operands["unique_operands"] N1 = operators_and_operands["total_operators"] N2 = operators_and_operands["total_operands"] - + # Calculate Halstead metrics vocabulary = n1 + n2 length = N1 + N2 @@ -1514,8 +1661,10 @@ def calculate_halstead_volume(self) -> Dict[str, float]: difficulty = (n1 / 2) * (N2 / n2) if n2 > 0 else 0 effort = volume * difficulty time = effort / 18 # Time in seconds (18 is a constant from empirical studies) - bugs = volume / 3000 # Estimated bugs (3000 is a constant from empirical studies) - + bugs = ( + volume / 3000 + ) # Estimated bugs (3000 is a constant from empirical studies) + return { "vocabulary": vocabulary, "length": length, @@ -1523,71 +1672,80 @@ def calculate_halstead_volume(self) -> Dict[str, float]: "difficulty": difficulty, "effort": effort, "time": time, # in seconds - "bugs": bugs + "bugs": bugs, } - - def count_lines(self) -> Dict[str, int]: + + def count_lines(self) -> dict[str, int]: """Count lines of code.""" files = list(self.codebase.files) - + total_lines = 0 code_lines = 0 comment_lines = 0 blank_lines = 0 - + for file in files: if file.is_binary: continue - + content = file.content - lines = content.split('\n') - + lines = content.split("\n") + total_lines += len(lines) - + for line in lines: line = line.strip() - + if not line: blank_lines += 1 - elif line.startswith('#') or line.startswith('//') or line.startswith('/*') or line.startswith('*'): + elif ( + line.startswith("#") + or line.startswith("//") + or line.startswith("/*") + or line.startswith("*") + ): comment_lines += 1 else: code_lines += 1 - + return { "total_lines": total_lines, "code_lines": code_lines, "comment_lines": comment_lines, "blank_lines": blank_lines, - "comment_ratio": comment_lines / code_lines if code_lines > 0 else 0 + "comment_ratio": comment_lines / code_lines if code_lines > 0 else 0, } - - def calculate_maintainability_index(self) -> Dict[str, float]: + + def calculate_maintainability_index(self) -> dict[str, float]: """Calculate maintainability index.""" halstead = self.calculate_halstead_volume() complexity = self.calculate_cyclomatic_complexity() lines = self.count_lines() - + # Calculate maintainability index # MI = 171 - 5.2 * ln(V) - 0.23 * CC - 16.2 * ln(LOC) volume = halstead["volume"] avg_complexity = complexity["avg_complexity"] loc = lines["code_lines"] - - mi = 171 - 5.2 * math.log(volume) - 0.23 * avg_complexity - 16.2 * math.log(loc) if volume > 0 and loc > 0 else 0 - + + mi = ( + 171 - 5.2 * math.log(volume) - 0.23 * avg_complexity - 16.2 * math.log(loc) + if volume > 0 and loc > 0 + else 0 + ) + # Normalize to 0-100 scale normalized_mi = max(0, min(100, mi * 100 / 171)) - + return { "maintainability_index": mi, - "normalized_maintainability_index": normalized_mi + "normalized_maintainability_index": normalized_mi, } - - def get_maintainability_rank(self) -> Dict[str, str]: + + def get_maintainability_rank(self) -> dict[str, str]: """Rank the codebase based on maintainability index.""" mi = self.calculate_maintainability_index()["normalized_maintainability_index"] - + if mi >= 85: rank = "A" description = "Highly maintainable" @@ -1603,68 +1761,64 @@ def get_maintainability_rank(self) -> Dict[str, str]: else: rank = "F" description = "Very difficult to maintain" - - return { - "rank": rank, - "description": description, - "maintainability_index": mi - } - - def get_cognitive_complexity(self) -> Dict[str, Any]: + + return {"rank": rank, "description": description, "maintainability_index": mi} + + def get_cognitive_complexity(self) -> dict[str, Any]: """Calculate cognitive complexity for functions.""" functions = list(self.codebase.functions) complexity_results = { "avg_complexity": 0, "max_complexity": 0, "complexity_distribution": { - "low": 0, # 0-5 - "moderate": 0, # 6-10 - "high": 0, # 11-20 - "very_high": 0 # > 20 + "low": 0, # 0-5 + "moderate": 0, # 6-10 + "high": 0, # 11-20 + "very_high": 0, # > 20 }, - "complex_functions": [] + "complex_functions": [], } - + if not functions: return complexity_results - + total_complexity = 0 max_complexity = 0 complex_functions = [] - + for func in functions: # A simple approximation of cognitive complexity # In a real implementation, we would parse the AST and analyze control flow source = func.source - + # Count decision points with nesting nesting_level = 0 cognitive_complexity = 0 - - lines = source.split('\n') + + lines = source.split("\n") for line in lines: line = line.strip() - + # Increase nesting level - if re.search(r'\b(if|for|while|switch|case|catch|try)\b', line): + if re.search(r"\b(if|for|while|switch|case|catch|try)\b", line): cognitive_complexity += 1 + nesting_level nesting_level += 1 - + # Decrease nesting level - if line.startswith('}') or line.endswith(':'): + if line.startswith("}") or line.endswith(":"): nesting_level = max(0, nesting_level - 1) - + # Add complexity for boolean operators cognitive_complexity += line.count(" && ") + line.count(" and ") cognitive_complexity += line.count(" || ") + line.count(" or ") - + # Add complexity for jumps - if re.search(r'\b(break|continue|goto|return)\b', line): + if re.search(r"\b(break|continue|goto|return)\b", line): cognitive_complexity += 1 - + total_complexity += cognitive_complexity max_complexity = max(max_complexity, cognitive_complexity) - + # Categorize complexity if cognitive_complexity <= 5: complexity_results["complexity_distribution"]["low"] += 1 @@ -1674,65 +1828,69 @@ def get_cognitive_complexity(self) -> Dict[str, Any]: complexity_results["complexity_distribution"]["high"] += 1 else: complexity_results["complexity_distribution"]["very_high"] += 1 - + # Track complex functions if cognitive_complexity > 10: complex_functions.append({ "name": func.name, "file": func.file.file_path if hasattr(func, "file") else "Unknown", - "complexity": cognitive_complexity + "complexity": cognitive_complexity, }) - + complexity_results["avg_complexity"] = total_complexity / len(functions) complexity_results["max_complexity"] = max_complexity - complexity_results["complex_functions"] = sorted(complex_functions, key=lambda x: x["complexity"], reverse=True)[:10] # Top 10 most complex - + complexity_results["complex_functions"] = sorted( + complex_functions, key=lambda x: x["complexity"], reverse=True + )[:10] # Top 10 most complex + return complexity_results - - def get_nesting_depth_analysis(self) -> Dict[str, Any]: + + def get_nesting_depth_analysis(self) -> dict[str, Any]: """Analyze nesting depth in functions.""" functions = list(self.codebase.functions) nesting_results = { "avg_max_nesting": 0, "max_nesting": 0, "nesting_distribution": { - "low": 0, # 0-2 - "moderate": 0, # 3-4 - "high": 0, # 5-6 - "very_high": 0 # > 6 + "low": 0, # 0-2 + "moderate": 0, # 3-4 + "high": 0, # 5-6 + "very_high": 0, # > 6 }, - "deeply_nested_functions": [] + "deeply_nested_functions": [], } - + if not functions: return nesting_results - + total_max_nesting = 0 max_nesting_overall = 0 deeply_nested_functions = [] - + for func in functions: source = func.source - lines = source.split('\n') - + lines = source.split("\n") + max_nesting = 0 current_nesting = 0 - + for line in lines: line = line.strip() - + # Increase nesting level - if re.search(r'\b(if|for|while|switch|case|catch|try)\b', line) and not line.startswith('}'): + if re.search( + r"\b(if|for|while|switch|case|catch|try)\b", line + ) and not line.startswith("}"): current_nesting += 1 max_nesting = max(max_nesting, current_nesting) - + # Decrease nesting level - if line.startswith('}'): + if line.startswith("}"): current_nesting = max(0, current_nesting - 1) - + total_max_nesting += max_nesting max_nesting_overall = max(max_nesting_overall, max_nesting) - + # Categorize nesting if max_nesting <= 2: nesting_results["nesting_distribution"]["low"] += 1 @@ -1742,50 +1900,52 @@ def get_nesting_depth_analysis(self) -> Dict[str, Any]: nesting_results["nesting_distribution"]["high"] += 1 else: nesting_results["nesting_distribution"]["very_high"] += 1 - + # Track deeply nested functions if max_nesting > 4: deeply_nested_functions.append({ "name": func.name, "file": func.file.file_path if hasattr(func, "file") else "Unknown", - "max_nesting": max_nesting + "max_nesting": max_nesting, }) - + nesting_results["avg_max_nesting"] = total_max_nesting / len(functions) nesting_results["max_nesting"] = max_nesting_overall - nesting_results["deeply_nested_functions"] = sorted(deeply_nested_functions, key=lambda x: x["max_nesting"], reverse=True)[:10] # Top 10 most nested - + nesting_results["deeply_nested_functions"] = sorted( + deeply_nested_functions, key=lambda x: x["max_nesting"], reverse=True + )[:10] # Top 10 most nested + return nesting_results - - def get_function_size_metrics(self) -> Dict[str, Any]: + + def get_function_size_metrics(self) -> dict[str, Any]: """Get function size metrics.""" functions = list(self.codebase.functions) size_metrics = { "avg_function_length": 0, "max_function_length": 0, "function_size_distribution": { - "small": 0, # < 10 lines - "medium": 0, # 10-30 lines - "large": 0, # 30-100 lines - "very_large": 0 # > 100 lines + "small": 0, # < 10 lines + "medium": 0, # 10-30 lines + "large": 0, # 30-100 lines + "very_large": 0, # > 100 lines }, - "largest_functions": [] + "largest_functions": [], } - + if not functions: return size_metrics - + total_length = 0 max_length = 0 largest_functions = [] - + for func in functions: func_source = func.source - func_lines = func_source.count('\n') + 1 - + func_lines = func_source.count("\n") + 1 + total_length += func_lines max_length = max(max_length, func_lines) - + # Categorize by size if func_lines < 10: size_metrics["function_size_distribution"]["small"] += 1 @@ -1795,30 +1955,32 @@ def get_function_size_metrics(self) -> Dict[str, Any]: size_metrics["function_size_distribution"]["large"] += 1 else: size_metrics["function_size_distribution"]["very_large"] += 1 - + # Track large functions if func_lines > 30: largest_functions.append({ "name": func.name, "file": func.file.file_path if hasattr(func, "file") else "Unknown", - "lines": func_lines + "lines": func_lines, }) - + size_metrics["avg_function_length"] = total_length / len(functions) size_metrics["max_function_length"] = max_length - size_metrics["largest_functions"] = sorted(largest_functions, key=lambda x: x["lines"], reverse=True)[:10] # Top 10 largest - + size_metrics["largest_functions"] = sorted( + largest_functions, key=lambda x: x["lines"], reverse=True + )[:10] # Top 10 largest + return size_metrics # # Visualization and Output Methods # - + def _generate_html_report(self, output_file: str) -> None: """Generate an HTML report of the analysis results.""" if not output_file: output_file = "codebase_analysis_report.html" - + # Simple HTML template html = f""" @@ -1847,14 +2009,14 @@ def _generate_html_report(self, output_file: str) -> None:

    Language: {self.results["metadata"]["language"]}

    """ - + # Add each category for category, metrics in self.results["categories"].items(): html += f"""

    {category.replace("_", " ").title()}

    """ - + for metric_name, metric_value in metrics.items(): html += f"""
    @@ -1862,82 +2024,96 @@ def _generate_html_report(self, output_file: str) -> None:
    {json.dumps(metric_value, indent=2)}
    """ - + html += "
    " - + html += """ """ - + with open(output_file, "w") as f: f.write(html) - - self.console.print(f"[bold green]HTML report saved to {output_file}[/bold green]") - + + self.console.print( + f"[bold green]HTML report saved to {output_file}[/bold green]" + ) + def _print_console_report(self) -> None: """Print a summary report to the console.""" - self.console.print(f"[bold blue]Codebase Analysis Report for {self.results['metadata']['repo_name']}[/bold blue]") - self.console.print(f"[bold]Analysis Time:[/bold] {self.results['metadata']['analysis_time']}") - self.console.print(f"[bold]Language:[/bold] {self.results['metadata']['language']}") - + self.console.print( + f"[bold blue]Codebase Analysis Report for {self.results['metadata']['repo_name']}[/bold blue]" + ) + self.console.print( + f"[bold]Analysis Time:[/bold] {self.results['metadata']['analysis_time']}" + ) + self.console.print( + f"[bold]Language:[/bold] {self.results['metadata']['language']}" + ) + for category, metrics in self.results["categories"].items(): - self.console.print(f"\n[bold green]{category.replace('_', ' ').title()}[/bold green]") - + self.console.print( + f"\n[bold green]{category.replace('_', ' ').title()}[/bold green]" + ) + for metric_name, metric_value in metrics.items(): - self.console.print(f"[bold]{metric_name.replace('_', ' ').title()}:[/bold]") - + self.console.print( + f"[bold]{metric_name.replace('_', ' ').title()}:[/bold]" + ) + if isinstance(metric_value, dict): table = Table(show_header=True) table.add_column("Key") table.add_column("Value") - + for k, v in metric_value.items(): if isinstance(v, dict): table.add_row(k, str(v)) else: table.add_row(str(k), str(v)) - + self.console.print(table) elif isinstance(metric_value, list): if len(metric_value) > 0 and isinstance(metric_value[0], dict): if len(metric_value) > 0: table = Table(show_header=True) - for key in metric_value[0].keys(): + for key in metric_value[0]: table.add_column(key) - + for item in metric_value[:10]: # Show only first 10 items table.add_row(*[str(v) for v in item.values()]) - + self.console.print(table) if len(metric_value) > 10: - self.console.print(f"... and {len(metric_value) - 10} more items") + self.console.print( + f"... and {len(metric_value) - 10} more items" + ) else: self.console.print(str(metric_value)) else: self.console.print(str(metric_value)) - - def get_monthly_commits(self) -> Dict[str, int]: + + def get_monthly_commits(self) -> dict[str, int]: """Get the number of commits per month.""" try: # Get commit history commits = list(self.codebase.github.repo.get_commits()) - + # Group commits by month commits_by_month = {} - + for commit in commits: date = commit.commit.author.date month_key = f"{date.year}-{date.month:02d}" - + if month_key in commits_by_month: commits_by_month[month_key] += 1 else: commits_by_month[month_key] = 1 - + # Sort by month sorted_commits = dict(sorted(commits_by_month.items())) - + return sorted_commits except Exception as e: return {"error": str(e)} @@ -1946,46 +2122,59 @@ def get_monthly_commits(self) -> Dict[str, int]: def main(): """Main entry point for the codebase analyzer.""" parser = argparse.ArgumentParser(description="Comprehensive Codebase Analyzer") - + # Repository source source_group = parser.add_mutually_exclusive_group(required=True) source_group.add_argument("--repo-url", help="URL of the repository to analyze") - source_group.add_argument("--repo-path", help="Local path to the repository to analyze") - + source_group.add_argument( + "--repo-path", help="Local path to the repository to analyze" + ) + # Analysis options - parser.add_argument("--language", help="Programming language of the codebase (auto-detected if not provided)") - parser.add_argument("--categories", nargs="+", help="Categories to analyze (default: all)") - + parser.add_argument( + "--language", + help="Programming language of the codebase (auto-detected if not provided)", + ) + parser.add_argument( + "--categories", nargs="+", help="Categories to analyze (default: all)" + ) + # Output options - parser.add_argument("--output-format", choices=["json", "html", "console"], default="console", help="Output format") + parser.add_argument( + "--output-format", + choices=["json", "html", "console"], + default="console", + help="Output format", + ) parser.add_argument("--output-file", help="Path to the output file") - + args = parser.parse_args() - + try: # Initialize the analyzer analyzer = CodebaseAnalyzer( - repo_url=args.repo_url, - repo_path=args.repo_path, - language=args.language + repo_url=args.repo_url, repo_path=args.repo_path, language=args.language ) - + # Perform the analysis - results = analyzer.analyze( + analyzer.analyze( categories=args.categories, output_format=args.output_format, - output_file=args.output_file + output_file=args.output_file, ) - + # Print success message if args.output_format == "json" and args.output_file: print(f"Analysis results saved to {args.output_file}") elif args.output_format == "html": - print(f"HTML report saved to {args.output_file or 'codebase_analysis_report.html'}") - + print( + f"HTML report saved to {args.output_file or 'codebase_analysis_report.html'}" + ) + except Exception as e: print(f"Error: {e}") import traceback + traceback.print_exc() sys.exit(1) diff --git a/codegen-on-oss/tests/unit/analyzers/test_diff_lite.py b/codegen-on-oss/tests/unit/analyzers/test_diff_lite.py new file mode 100644 index 000000000..b022d5412 --- /dev/null +++ b/codegen-on-oss/tests/unit/analyzers/test_diff_lite.py @@ -0,0 +1,129 @@ +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +from watchfiles import Change + +from codegen_on_oss.analyzers.diff_lite import ChangeType, DiffLite + + +class TestChangeType(unittest.TestCase): + def test_from_watch_change_type_added(self): + self.assertEqual( + ChangeType.from_watch_change_type(Change.added), ChangeType.Added + ) + + def test_from_watch_change_type_deleted(self): + self.assertEqual( + ChangeType.from_watch_change_type(Change.deleted), ChangeType.Removed + ) + + def test_from_watch_change_type_modified(self): + self.assertEqual( + ChangeType.from_watch_change_type(Change.modified), ChangeType.Modified + ) + + def test_from_watch_change_type_invalid(self): + # Create a mock Change that doesn't match any of the expected values + invalid_change = MagicMock() + with self.assertRaises(ValueError): + ChangeType.from_watch_change_type(invalid_change) + + def test_from_git_change_type_modified(self): + self.assertEqual(ChangeType.from_git_change_type("M"), ChangeType.Modified) + + def test_from_git_change_type_removed(self): + self.assertEqual(ChangeType.from_git_change_type("D"), ChangeType.Removed) + + def test_from_git_change_type_renamed(self): + self.assertEqual(ChangeType.from_git_change_type("R"), ChangeType.Renamed) + + def test_from_git_change_type_added(self): + self.assertEqual(ChangeType.from_git_change_type("A"), ChangeType.Added) + + def test_from_git_change_type_invalid(self): + with self.assertRaises(ValueError): + ChangeType.from_git_change_type("X") + + +class TestDiffLite(unittest.TestCase): + def test_from_watch_change(self): + path = "test/path.py" + diff = DiffLite.from_watch_change(Change.added, path) + + self.assertEqual(diff.change_type, ChangeType.Added) + self.assertEqual(diff.path, Path(path)) + self.assertIsNone(diff.rename_from) + self.assertIsNone(diff.rename_to) + self.assertIsNone(diff.old_content) + + @patch("git.Diff") + def test_from_git_diff_modified(self, mock_diff): + mock_diff.change_type = "M" + mock_diff.a_path = "test/path.py" + mock_diff.rename_from = None + mock_diff.rename_to = None + + # Mock the blob and data stream + mock_blob = MagicMock() + mock_blob.data_stream.read.return_value = b"old content" + mock_diff.a_blob = mock_blob + + diff = DiffLite.from_git_diff(mock_diff) + + self.assertEqual(diff.change_type, ChangeType.Modified) + self.assertEqual(diff.path, Path("test/path.py")) + self.assertIsNone(diff.rename_from) + self.assertIsNone(diff.rename_to) + self.assertEqual(diff.old_content, b"old content") + + @patch("git.Diff") + def test_from_git_diff_renamed(self, mock_diff): + mock_diff.change_type = "R" + mock_diff.a_path = "test/old_path.py" + mock_diff.rename_from = "test/old_path.py" + mock_diff.rename_to = "test/new_path.py" + mock_diff.a_blob = None + + diff = DiffLite.from_git_diff(mock_diff) + + self.assertEqual(diff.change_type, ChangeType.Renamed) + self.assertEqual(diff.path, Path("test/old_path.py")) + self.assertEqual(diff.rename_from, Path("test/old_path.py")) + self.assertEqual(diff.rename_to, Path("test/new_path.py")) + self.assertIsNone(diff.old_content) + + def test_from_reverse_diff_added_to_removed(self): + original = DiffLite(change_type=ChangeType.Added, path=Path("test/path.py")) + + reversed_diff = DiffLite.from_reverse_diff(original) + + self.assertEqual(reversed_diff.change_type, ChangeType.Removed) + self.assertEqual(reversed_diff.path, Path("test/path.py")) + + def test_from_reverse_diff_removed_to_added(self): + original = DiffLite(change_type=ChangeType.Removed, path=Path("test/path.py")) + + reversed_diff = DiffLite.from_reverse_diff(original) + + self.assertEqual(reversed_diff.change_type, ChangeType.Added) + self.assertEqual(reversed_diff.path, Path("test/path.py")) + + def test_from_reverse_diff_renamed(self): + original = DiffLite( + change_type=ChangeType.Renamed, + path=Path("test/old_path.py"), + rename_from=Path("test/old_path.py"), + rename_to=Path("test/new_path.py"), + ) + + reversed_diff = DiffLite.from_reverse_diff(original) + + self.assertEqual(reversed_diff.change_type, ChangeType.Renamed) + self.assertEqual(reversed_diff.path, Path("test/old_path.py")) + self.assertEqual(reversed_diff.rename_from, Path("test/new_path.py")) + self.assertEqual(reversed_diff.rename_to, Path("test/old_path.py")) + + +if __name__ == "__main__": + unittest.main() diff --git a/docs/mint.json b/docs/mint.json index 737c098ee..e83d7a630 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -1,396 +1,394 @@ { - "$schema": "https://mintlify.com/schema.json", - "name": "Codegen", - "logo": { - "dark": "https://cdn.prod.website-files.com/67070304751b9b01bf6a161c/679bcf45a3e32761c42b324b_Codegen_Logomark_Dark.svg", - "light": "https://cdn.prod.website-files.com/67070304751b9b01bf6a161c/679bcf45bf55446746125835_Codegen_Logomark_Light.svg" - }, - "modeToggle": { - "default": "dark" - }, - "metadata": { - "og:site_name": "Codegen", - "og:title": "Codegen - Manipulate Code at Scale", - "og:description": "A scriptable interface to a powerful, multi-lingual language server built on top of Tree-sitter.", - "og:url": "https://docs.codegen.com", - "og:locale": "en_US", - "og:logo": "https://i.imgur.com/f4OVOqI.png", - "article:publisher": "Codegen, Inc.", - "twitter:site": "@codegen" - }, - "favicon": "/favicon.svg", - "colors": { - "primary": "#a277ff", - "light": "#a277ff", - "dark": "#a277ff", - "anchors": { - "from": "#61ffca", - "to": "#61ffca" - } - }, - "theme": "prism", - "background": { - "style": "gradient" - }, - "analytics": { - "posthog": { - "apiKey": "phc_GLxaINoQJnuyCyxDmTciQqzdKBYFVDkY7bRBO4bDdso" - } - }, - "feedback": { - "thumbsRating": true - }, - "topbarCtaButton": { - "name": "GitHub", - "url": "https://github.com/codegen-sh/codegen-sdk" - }, - "tabs": [ - { - "name": "API Reference", - "url": "/api-reference" - }, - { - "name": "CLI", - "url": "/cli" - }, - { - "name": "Blog", - "url": "/blog" - }, - { - "name": "Changelog", - "url": "/changelog" - }, - { - "name": "codegen", - "url": "/gen" - } - ], - "navigation": [ - { - "group": "Introduction", - "pages": [ - "introduction/overview", - "introduction/getting-started", - "introduction/installation", - "introduction/ide-usage", - "introduction/work-with-ai", - "introduction/how-it-works", - "introduction/advanced-settings", - "introduction/guiding-principles", - "introduction/community", - "introduction/about", - "introduction/faq" - ] - }, - { - "group": "Tutorials", - "pages": [ - "tutorials/at-a-glance", - "tutorials/build-code-agent", - "tutorials/slack-bot", - "tutorials/github-review-bot", - "tutorials/deep-code-research", - "tutorials/codebase-analytics-dashboard", - "tutorials/training-data", - "tutorials/codebase-visualization", - "tutorials/migrating-apis", - "tutorials/organize-your-codebase", - "tutorials/promise-to-async-await", - "tutorials/modularity", - "tutorials/manage-feature-flags", - "tutorials/deleting-dead-code", - "tutorials/increase-type-coverage", - "tutorials/managing-typescript-exports", - "tutorials/converting-default-exports", - "tutorials/creating-documentation", - "tutorials/react-modernization", - "tutorials/unittest-to-pytest", - "tutorials/sqlalchemy-1.6-to-2.0", - "tutorials/fixing-import-loops-in-pytorch", - "tutorials/python2-to-python3", - "tutorials/flask-to-fastapi", - "tutorials/build-mcp", - "tutorials/neo4j-graph", - "tutorials/attributions" - ] - }, - { - "group": "Building with Codegen", - "pages": [ - "building-with-codegen/at-a-glance", - "building-with-codegen/parsing-codebases", - "building-with-codegen/reusable-codemods", - "building-with-codegen/dot-codegen", - "building-with-codegen/function-decorator", - "building-with-codegen/language-support", - "building-with-codegen/commit-and-reset", - "building-with-codegen/git-operations", - "building-with-codegen/files-and-directories", - "building-with-codegen/the-editable-api", - "building-with-codegen/symbol-api", - "building-with-codegen/class-api", - "building-with-codegen/imports", - "building-with-codegen/exports", - "building-with-codegen/inheritable-behaviors", - "building-with-codegen/statements-and-code-blocks", - "building-with-codegen/dependencies-and-usages", - "building-with-codegen/function-calls-and-callsites", - "building-with-codegen/variable-assignments", - "building-with-codegen/local-variables", - "building-with-codegen/comments-and-docstrings", - "building-with-codegen/external-modules", - "building-with-codegen/type-annotations", - "building-with-codegen/moving-symbols", - "building-with-codegen/collections", - "building-with-codegen/traversing-the-call-graph", - "building-with-codegen/react-and-jsx", - "building-with-codegen/codebase-visualization", - "building-with-codegen/flagging-symbols", - "building-with-codegen/calling-out-to-llms", - "building-with-codegen/semantic-code-search", - "building-with-codegen/reducing-conditions" - ] - }, - { - "group": "CLI", - "pages": [ - "cli/about", - "cli/init", - "cli/notebook", - "cli/create", - "cli/run", - "cli/reset", - "cli/expert" - ] - }, - { - "group": "Changelog", - "pages": [ - "changelog/changelog" - ] - }, - { - "group": "Blog", - "pages": [ - "blog/posts", - "blog/devin", - "blog/act-via-code", - "blog/promise-to-async-await-twilio", - "blog/fixing-import-loops" - ] - }, - { - "group": "codegen", - "pages": [ - "gen/introduction", - "gen/capabilities", - "gen/integrations", - "gen/faq" - ] - }, - { - "group": "API Reference", - "pages": [ - "api-reference/index", - { - "group": "Core", - "icon": "code", - "pages": [ - "api-reference/core/Argument", - "api-reference/core/Assignment", - "api-reference/core/AssignmentStatement", - "api-reference/core/Attribute", - "api-reference/core/AwaitExpression", - "api-reference/core/BinaryExpression", - "api-reference/core/BlockStatement", - "api-reference/core/Boolean", - "api-reference/core/Callable", - "api-reference/core/CatchStatement", - "api-reference/core/ChainedAttribute", - "api-reference/core/Class", - "api-reference/core/CodeBlock", - "api-reference/core/CodeOwner", - "api-reference/core/Codebase", - "api-reference/core/Comment", - "api-reference/core/CommentGroup", - "api-reference/core/ComparisonExpression", - "api-reference/core/Decorator", - "api-reference/core/Dict", - "api-reference/core/Directory", - "api-reference/core/Editable", - "api-reference/core/Export", - "api-reference/core/ExportStatement", - "api-reference/core/Exportable", - "api-reference/core/Expression", - "api-reference/core/ExpressionGroup", - "api-reference/core/ExpressionStatement", - "api-reference/core/ExternalModule", - "api-reference/core/File", - "api-reference/core/FlagKwargs", - "api-reference/core/ForLoopStatement", - "api-reference/core/Function", - "api-reference/core/FunctionCall", - "api-reference/core/GenericType", - "api-reference/core/HasBlock", - "api-reference/core/HasName", - "api-reference/core/HasValue", - "api-reference/core/IfBlockStatement", - "api-reference/core/Import", - "api-reference/core/ImportStatement", - "api-reference/core/ImportType", - "api-reference/core/Importable", - "api-reference/core/Interface", - "api-reference/core/List", - "api-reference/core/MessageType", - "api-reference/core/MultiExpression", - "api-reference/core/MultiLineCollection", - "api-reference/core/Name", - "api-reference/core/NamedType", - "api-reference/core/NoneType", - "api-reference/core/Number", - "api-reference/core/Pair", - "api-reference/core/Parameter", - "api-reference/core/ParenthesizedExpression", - "api-reference/core/Placeholder", - "api-reference/core/PlaceholderType", - "api-reference/core/RaiseStatement", - "api-reference/core/ReturnStatement", - "api-reference/core/SourceFile", - "api-reference/core/Span", - "api-reference/core/Statement", - "api-reference/core/StatementType", - "api-reference/core/String", - "api-reference/core/StubPlaceholder", - "api-reference/core/SubscriptExpression", - "api-reference/core/SwitchCase", - "api-reference/core/SwitchStatement", - "api-reference/core/Symbol", - "api-reference/core/SymbolGroup", - "api-reference/core/SymbolStatement", - "api-reference/core/TernaryExpression", - "api-reference/core/TryCatchStatement", - "api-reference/core/Tuple", - "api-reference/core/TupleType", - "api-reference/core/Type", - "api-reference/core/TypeAlias", - "api-reference/core/TypePlaceholder", - "api-reference/core/Typeable", - "api-reference/core/UnaryExpression", - "api-reference/core/UnionType", - "api-reference/core/Unpack", - "api-reference/core/Unwrappable", - "api-reference/core/Usable", - "api-reference/core/Usage", - "api-reference/core/UsageKind", - "api-reference/core/UsageType", - "api-reference/core/Value", - "api-reference/core/WhileStatement", - "api-reference/core/WithStatement" - ] - }, - { - "group": "Python", - "icon": "python", - "pages": [ - "api-reference/python/PyAssignment", - "api-reference/python/PyAssignmentStatement", - "api-reference/python/PyAttribute", - "api-reference/python/PyBlockStatement", - "api-reference/python/PyBreakStatement", - "api-reference/python/PyCatchStatement", - "api-reference/python/PyChainedAttribute", - "api-reference/python/PyClass", - "api-reference/python/PyCodeBlock", - "api-reference/python/PyComment", - "api-reference/python/PyCommentGroup", - "api-reference/python/PyCommentType", - "api-reference/python/PyConditionalExpression", - "api-reference/python/PyDecorator", - "api-reference/python/PyFile", - "api-reference/python/PyForLoopStatement", - "api-reference/python/PyFunction", - "api-reference/python/PyGenericType", - "api-reference/python/PyHasBlock", - "api-reference/python/PyIfBlockStatement", - "api-reference/python/PyImport", - "api-reference/python/PyImportStatement", - "api-reference/python/PyMatchCase", - "api-reference/python/PyMatchStatement", - "api-reference/python/PyNamedType", - "api-reference/python/PyParameter", - "api-reference/python/PyPassStatement", - "api-reference/python/PyReturnTypePlaceholder", - "api-reference/python/PyString", - "api-reference/python/PySymbol", - "api-reference/python/PyTryCatchStatement", - "api-reference/python/PyUnionType", - "api-reference/python/PyWhileStatement" - ] - }, - { - "group": "Typescript", - "icon": "js", - "pages": [ - "api-reference/typescript/JSXElement", - "api-reference/typescript/JSXExpression", - "api-reference/typescript/JSXProp", - "api-reference/typescript/TSArrayType", - "api-reference/typescript/TSAssignment", - "api-reference/typescript/TSAssignmentStatement", - "api-reference/typescript/TSAttribute", - "api-reference/typescript/TSBlockStatement", - "api-reference/typescript/TSCatchStatement", - "api-reference/typescript/TSChainedAttribute", - "api-reference/typescript/TSClass", - "api-reference/typescript/TSCodeBlock", - "api-reference/typescript/TSComment", - "api-reference/typescript/TSCommentGroup", - "api-reference/typescript/TSCommentType", - "api-reference/typescript/TSConditionalType", - "api-reference/typescript/TSConfig", - "api-reference/typescript/TSDecorator", - "api-reference/typescript/TSDict", - "api-reference/typescript/TSEnum", - "api-reference/typescript/TSExport", - "api-reference/typescript/TSExpressionType", - "api-reference/typescript/TSFile", - "api-reference/typescript/TSForLoopStatement", - "api-reference/typescript/TSFunction", - "api-reference/typescript/TSFunctionType", - "api-reference/typescript/TSGenericType", - "api-reference/typescript/TSHasBlock", - "api-reference/typescript/TSIfBlockStatement", - "api-reference/typescript/TSImport", - "api-reference/typescript/TSImportStatement", - "api-reference/typescript/TSInterface", - "api-reference/typescript/TSLabeledStatement", - "api-reference/typescript/TSLookupType", - "api-reference/typescript/TSNamedType", - "api-reference/typescript/TSNamespace", - "api-reference/typescript/TSObjectType", - "api-reference/typescript/TSPair", - "api-reference/typescript/TSParameter", - "api-reference/typescript/TSQueryType", - "api-reference/typescript/TSReadonlyType", - "api-reference/typescript/TSReturnTypePlaceholder", - "api-reference/typescript/TSString", - "api-reference/typescript/TSSwitchCase", - "api-reference/typescript/TSSwitchStatement", - "api-reference/typescript/TSSymbol", - "api-reference/typescript/TSTernaryExpression", - "api-reference/typescript/TSTryCatchStatement", - "api-reference/typescript/TSTypeAlias", - "api-reference/typescript/TSUndefinedType", - "api-reference/typescript/TSUnionType", - "api-reference/typescript/TSWhileStatement" - ] - } - ] - } - ], - "footerSocials": { - "x": "https://x.com/codegen", - "linkedin": "https://linkedin.com/company/codegen-dot-com" - } -} \ No newline at end of file + "$schema": "https://mintlify.com/schema.json", + "name": "Codegen", + "logo": { + "dark": "https://cdn.prod.website-files.com/67070304751b9b01bf6a161c/679bcf45a3e32761c42b324b_Codegen_Logomark_Dark.svg", + "light": "https://cdn.prod.website-files.com/67070304751b9b01bf6a161c/679bcf45bf55446746125835_Codegen_Logomark_Light.svg" + }, + "modeToggle": { + "default": "dark" + }, + "metadata": { + "og:site_name": "Codegen", + "og:title": "Codegen - Manipulate Code at Scale", + "og:description": "A scriptable interface to a powerful, multi-lingual language server built on top of Tree-sitter.", + "og:url": "https://docs.codegen.com", + "og:locale": "en_US", + "og:logo": "https://i.imgur.com/f4OVOqI.png", + "article:publisher": "Codegen, Inc.", + "twitter:site": "@codegen" + }, + "favicon": "/favicon.svg", + "colors": { + "primary": "#a277ff", + "light": "#a277ff", + "dark": "#a277ff", + "anchors": { + "from": "#61ffca", + "to": "#61ffca" + } + }, + "theme": "prism", + "background": { + "style": "gradient" + }, + "analytics": { + "posthog": { + "apiKey": "phc_GLxaINoQJnuyCyxDmTciQqzdKBYFVDkY7bRBO4bDdso" + } + }, + "feedback": { + "thumbsRating": true + }, + "topbarCtaButton": { + "name": "GitHub", + "url": "https://github.com/codegen-sh/codegen-sdk" + }, + "tabs": [ + { + "name": "API Reference", + "url": "/api-reference" + }, + { + "name": "CLI", + "url": "/cli" + }, + { + "name": "Blog", + "url": "/blog" + }, + { + "name": "Changelog", + "url": "/changelog" + }, + { + "name": "codegen", + "url": "/gen" + } + ], + "navigation": [ + { + "group": "Introduction", + "pages": [ + "introduction/overview", + "introduction/getting-started", + "introduction/installation", + "introduction/ide-usage", + "introduction/work-with-ai", + "introduction/how-it-works", + "introduction/advanced-settings", + "introduction/guiding-principles", + "introduction/community", + "introduction/about", + "introduction/faq" + ] + }, + { + "group": "Tutorials", + "pages": [ + "tutorials/at-a-glance", + "tutorials/build-code-agent", + "tutorials/slack-bot", + "tutorials/github-review-bot", + "tutorials/deep-code-research", + "tutorials/codebase-analytics-dashboard", + "tutorials/training-data", + "tutorials/codebase-visualization", + "tutorials/migrating-apis", + "tutorials/organize-your-codebase", + "tutorials/promise-to-async-await", + "tutorials/modularity", + "tutorials/manage-feature-flags", + "tutorials/deleting-dead-code", + "tutorials/increase-type-coverage", + "tutorials/managing-typescript-exports", + "tutorials/converting-default-exports", + "tutorials/creating-documentation", + "tutorials/react-modernization", + "tutorials/unittest-to-pytest", + "tutorials/sqlalchemy-1.6-to-2.0", + "tutorials/fixing-import-loops-in-pytorch", + "tutorials/python2-to-python3", + "tutorials/flask-to-fastapi", + "tutorials/build-mcp", + "tutorials/neo4j-graph", + "tutorials/attributions" + ] + }, + { + "group": "Building with Codegen", + "pages": [ + "building-with-codegen/at-a-glance", + "building-with-codegen/parsing-codebases", + "building-with-codegen/reusable-codemods", + "building-with-codegen/dot-codegen", + "building-with-codegen/function-decorator", + "building-with-codegen/language-support", + "building-with-codegen/commit-and-reset", + "building-with-codegen/git-operations", + "building-with-codegen/files-and-directories", + "building-with-codegen/the-editable-api", + "building-with-codegen/symbol-api", + "building-with-codegen/class-api", + "building-with-codegen/imports", + "building-with-codegen/exports", + "building-with-codegen/inheritable-behaviors", + "building-with-codegen/statements-and-code-blocks", + "building-with-codegen/dependencies-and-usages", + "building-with-codegen/function-calls-and-callsites", + "building-with-codegen/variable-assignments", + "building-with-codegen/local-variables", + "building-with-codegen/comments-and-docstrings", + "building-with-codegen/external-modules", + "building-with-codegen/type-annotations", + "building-with-codegen/moving-symbols", + "building-with-codegen/collections", + "building-with-codegen/traversing-the-call-graph", + "building-with-codegen/react-and-jsx", + "building-with-codegen/codebase-visualization", + "building-with-codegen/flagging-symbols", + "building-with-codegen/calling-out-to-llms", + "building-with-codegen/semantic-code-search", + "building-with-codegen/reducing-conditions" + ] + }, + { + "group": "CLI", + "pages": [ + "cli/about", + "cli/init", + "cli/notebook", + "cli/create", + "cli/run", + "cli/reset", + "cli/expert" + ] + }, + { + "group": "Changelog", + "pages": ["changelog/changelog"] + }, + { + "group": "Blog", + "pages": [ + "blog/posts", + "blog/devin", + "blog/act-via-code", + "blog/promise-to-async-await-twilio", + "blog/fixing-import-loops" + ] + }, + { + "group": "codegen", + "pages": [ + "gen/introduction", + "gen/capabilities", + "gen/integrations", + "gen/faq" + ] + }, + { + "group": "API Reference", + "pages": [ + "api-reference/index", + { + "group": "Core", + "icon": "code", + "pages": [ + "api-reference/core/Argument", + "api-reference/core/Assignment", + "api-reference/core/AssignmentStatement", + "api-reference/core/Attribute", + "api-reference/core/AwaitExpression", + "api-reference/core/BinaryExpression", + "api-reference/core/BlockStatement", + "api-reference/core/Boolean", + "api-reference/core/Callable", + "api-reference/core/CatchStatement", + "api-reference/core/ChainedAttribute", + "api-reference/core/Class", + "api-reference/core/CodeBlock", + "api-reference/core/CodeOwner", + "api-reference/core/Codebase", + "api-reference/core/Comment", + "api-reference/core/CommentGroup", + "api-reference/core/ComparisonExpression", + "api-reference/core/Decorator", + "api-reference/core/Dict", + "api-reference/core/Directory", + "api-reference/core/Editable", + "api-reference/core/Export", + "api-reference/core/ExportStatement", + "api-reference/core/Exportable", + "api-reference/core/Expression", + "api-reference/core/ExpressionGroup", + "api-reference/core/ExpressionStatement", + "api-reference/core/ExternalModule", + "api-reference/core/File", + "api-reference/core/FlagKwargs", + "api-reference/core/ForLoopStatement", + "api-reference/core/Function", + "api-reference/core/FunctionCall", + "api-reference/core/GenericType", + "api-reference/core/HasBlock", + "api-reference/core/HasName", + "api-reference/core/HasValue", + "api-reference/core/IfBlockStatement", + "api-reference/core/Import", + "api-reference/core/ImportStatement", + "api-reference/core/ImportType", + "api-reference/core/Importable", + "api-reference/core/Interface", + "api-reference/core/List", + "api-reference/core/MessageType", + "api-reference/core/MultiExpression", + "api-reference/core/MultiLineCollection", + "api-reference/core/Name", + "api-reference/core/NamedType", + "api-reference/core/NoneType", + "api-reference/core/Number", + "api-reference/core/Pair", + "api-reference/core/Parameter", + "api-reference/core/ParenthesizedExpression", + "api-reference/core/Placeholder", + "api-reference/core/PlaceholderType", + "api-reference/core/RaiseStatement", + "api-reference/core/ReturnStatement", + "api-reference/core/SourceFile", + "api-reference/core/Span", + "api-reference/core/Statement", + "api-reference/core/StatementType", + "api-reference/core/String", + "api-reference/core/StubPlaceholder", + "api-reference/core/SubscriptExpression", + "api-reference/core/SwitchCase", + "api-reference/core/SwitchStatement", + "api-reference/core/Symbol", + "api-reference/core/SymbolGroup", + "api-reference/core/SymbolStatement", + "api-reference/core/TernaryExpression", + "api-reference/core/TryCatchStatement", + "api-reference/core/Tuple", + "api-reference/core/TupleType", + "api-reference/core/Type", + "api-reference/core/TypeAlias", + "api-reference/core/TypePlaceholder", + "api-reference/core/Typeable", + "api-reference/core/UnaryExpression", + "api-reference/core/UnionType", + "api-reference/core/Unpack", + "api-reference/core/Unwrappable", + "api-reference/core/Usable", + "api-reference/core/Usage", + "api-reference/core/UsageKind", + "api-reference/core/UsageType", + "api-reference/core/Value", + "api-reference/core/WhileStatement", + "api-reference/core/WithStatement" + ] + }, + { + "group": "Python", + "icon": "python", + "pages": [ + "api-reference/python/PyAssignment", + "api-reference/python/PyAssignmentStatement", + "api-reference/python/PyAttribute", + "api-reference/python/PyBlockStatement", + "api-reference/python/PyBreakStatement", + "api-reference/python/PyCatchStatement", + "api-reference/python/PyChainedAttribute", + "api-reference/python/PyClass", + "api-reference/python/PyCodeBlock", + "api-reference/python/PyComment", + "api-reference/python/PyCommentGroup", + "api-reference/python/PyCommentType", + "api-reference/python/PyConditionalExpression", + "api-reference/python/PyDecorator", + "api-reference/python/PyFile", + "api-reference/python/PyForLoopStatement", + "api-reference/python/PyFunction", + "api-reference/python/PyGenericType", + "api-reference/python/PyHasBlock", + "api-reference/python/PyIfBlockStatement", + "api-reference/python/PyImport", + "api-reference/python/PyImportStatement", + "api-reference/python/PyMatchCase", + "api-reference/python/PyMatchStatement", + "api-reference/python/PyNamedType", + "api-reference/python/PyParameter", + "api-reference/python/PyPassStatement", + "api-reference/python/PyReturnTypePlaceholder", + "api-reference/python/PyString", + "api-reference/python/PySymbol", + "api-reference/python/PyTryCatchStatement", + "api-reference/python/PyUnionType", + "api-reference/python/PyWhileStatement" + ] + }, + { + "group": "Typescript", + "icon": "js", + "pages": [ + "api-reference/typescript/JSXElement", + "api-reference/typescript/JSXExpression", + "api-reference/typescript/JSXProp", + "api-reference/typescript/TSArrayType", + "api-reference/typescript/TSAssignment", + "api-reference/typescript/TSAssignmentStatement", + "api-reference/typescript/TSAttribute", + "api-reference/typescript/TSBlockStatement", + "api-reference/typescript/TSCatchStatement", + "api-reference/typescript/TSChainedAttribute", + "api-reference/typescript/TSClass", + "api-reference/typescript/TSCodeBlock", + "api-reference/typescript/TSComment", + "api-reference/typescript/TSCommentGroup", + "api-reference/typescript/TSCommentType", + "api-reference/typescript/TSConditionalType", + "api-reference/typescript/TSConfig", + "api-reference/typescript/TSDecorator", + "api-reference/typescript/TSDict", + "api-reference/typescript/TSEnum", + "api-reference/typescript/TSExport", + "api-reference/typescript/TSExpressionType", + "api-reference/typescript/TSFile", + "api-reference/typescript/TSForLoopStatement", + "api-reference/typescript/TSFunction", + "api-reference/typescript/TSFunctionType", + "api-reference/typescript/TSGenericType", + "api-reference/typescript/TSHasBlock", + "api-reference/typescript/TSIfBlockStatement", + "api-reference/typescript/TSImport", + "api-reference/typescript/TSImportStatement", + "api-reference/typescript/TSInterface", + "api-reference/typescript/TSLabeledStatement", + "api-reference/typescript/TSLookupType", + "api-reference/typescript/TSNamedType", + "api-reference/typescript/TSNamespace", + "api-reference/typescript/TSObjectType", + "api-reference/typescript/TSPair", + "api-reference/typescript/TSParameter", + "api-reference/typescript/TSQueryType", + "api-reference/typescript/TSReadonlyType", + "api-reference/typescript/TSReturnTypePlaceholder", + "api-reference/typescript/TSString", + "api-reference/typescript/TSSwitchCase", + "api-reference/typescript/TSSwitchStatement", + "api-reference/typescript/TSSymbol", + "api-reference/typescript/TSTernaryExpression", + "api-reference/typescript/TSTryCatchStatement", + "api-reference/typescript/TSTypeAlias", + "api-reference/typescript/TSUndefinedType", + "api-reference/typescript/TSUnionType", + "api-reference/typescript/TSWhileStatement" + ] + } + ] + } + ], + "footerSocials": { + "x": "https://x.com/codegen", + "linkedin": "https://linkedin.com/company/codegen-dot-com" + } +} diff --git a/organize_codebase.py b/organize_codebase.py index d12d4f660..8dd4dd3fc 100644 --- a/organize_codebase.py +++ b/organize_codebase.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -""" -Codebase Organizer Script +"""Codebase Organizer Script This script helps organize a codebase by analyzing file contents and moving related files into appropriate directories based on their functionality. @@ -9,163 +8,150 @@ import os import re import shutil -from pathlib import Path -from typing import Dict, List, Set, Tuple # Define categories and their related patterns CATEGORIES = { - "analyzers": [ - r"analyzer", r"analysis", r"analyze" - ], - "code_quality": [ - r"code_quality", r"quality", r"lint" - ], - "context": [ - r"context", r"codebase_context" - ], - "dependencies": [ - r"dependenc", r"import" - ], - "issues": [ - r"issue", r"error" - ], - "visualization": [ - r"visual", r"display", r"render" - ], + "analyzers": [r"analyzer", r"analysis", r"analyze"], + "code_quality": [r"code_quality", r"quality", r"lint"], + "context": [r"context", r"codebase_context"], + "dependencies": [r"dependenc", r"import"], + "issues": [r"issue", r"error"], + "visualization": [r"visual", r"display", r"render"], } + def read_file_content(file_path: str) -> str: """Read the content of a file.""" try: - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, encoding="utf-8") as f: return f.read() except Exception as e: print(f"Error reading {file_path}: {e}") return "" -def categorize_file(file_path: str, categories: Dict[str, List[str]]) -> List[str]: + +def categorize_file(file_path: str, categories: dict[str, list[str]]) -> list[str]: """Categorize a file based on its content and name.""" file_categories = [] content = read_file_content(file_path) filename = os.path.basename(file_path) - + # Check filename and content against category patterns for category, patterns in categories.items(): for pattern in patterns: if re.search(pattern, filename, re.IGNORECASE) or re.search(pattern, content, re.IGNORECASE): file_categories.append(category) break - + return file_categories -def analyze_imports(file_path: str) -> Set[str]: + +def analyze_imports(file_path: str) -> set[str]: """Analyze imports in a Python file.""" imports = set() content = read_file_content(file_path) - + # Find import statements - import_patterns = [ - r'import\s+([a-zA-Z0-9_\.]+)', - r'from\s+([a-zA-Z0-9_\.]+)\s+import' - ] - + import_patterns = [r"import\s+([a-zA-Z0-9_\.]+)", r"from\s+([a-zA-Z0-9_\.]+)\s+import"] + for pattern in import_patterns: for match in re.finditer(pattern, content): imports.add(match.group(1)) - + return imports -def build_dependency_graph(files: List[str]) -> Dict[str, Set[str]]: + +def build_dependency_graph(files: list[str]) -> dict[str, set[str]]: """Build a dependency graph for the files.""" graph = {} module_to_file = {} - + # Map module names to files for file_path in files: - if not file_path.endswith('.py'): + if not file_path.endswith(".py"): continue - + module_name = os.path.splitext(os.path.basename(file_path))[0] module_to_file[module_name] = file_path - + # Build the graph for file_path in files: - if not file_path.endswith('.py'): + if not file_path.endswith(".py"): continue - + imports = analyze_imports(file_path) graph[file_path] = set() - + for imp in imports: # Check if this is a local import - parts = imp.split('.') + parts = imp.split(".") if parts[0] in module_to_file: graph[file_path].add(module_to_file[parts[0]]) - + return graph -def find_related_files(graph: Dict[str, Set[str]], file_path: str) -> Set[str]: + +def find_related_files(graph: dict[str, set[str]], file_path: str) -> set[str]: """Find files related to the given file based on the dependency graph.""" related = set() - + # Files that this file imports if file_path in graph: related.update(graph[file_path]) - + # Files that import this file for other_file, deps in graph.items(): if file_path in deps: related.add(other_file) - + return related -def organize_files(directory: str, dry_run: bool = True) -> Dict[str, List[str]]: - """ - Organize files in the directory into categories. - + +def organize_files(directory: str, dry_run: bool = True) -> dict[str, list[str]]: + """Organize files in the directory into categories. + Args: directory: The directory containing the files to organize dry_run: If True, only print the planned changes without making them - + Returns: A dictionary mapping categories to lists of files """ # Get all Python files - py_files = [os.path.join(directory, f) for f in os.listdir(directory) - if f.endswith('.py') and os.path.isfile(os.path.join(directory, f))] - + py_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(".py") and os.path.isfile(os.path.join(directory, f))] + # Build dependency graph graph = build_dependency_graph(py_files) - + # Categorize files categorized_files = {} for category in CATEGORIES: categorized_files[category] = [] - + # Special case for README and init files categorized_files["root"] = [] - + for file_path in py_files: filename = os.path.basename(file_path) - + # Keep some files in the root directory - if filename in ['__init__.py', 'README.md']: + if filename in ["__init__.py", "README.md"]: categorized_files["root"].append(file_path) continue - + # Categorize the file categories = categorize_file(file_path, CATEGORIES) - + if not categories: # If no category found, use related files to determine category related = find_related_files(graph, file_path) for related_file in related: related_categories = categorize_file(related_file, CATEGORIES) categories.extend(related_categories) - + # Remove duplicates categories = list(set(categories)) - + if not categories: # If still no category, put in a default category based on filename if "analyzer" in filename: @@ -183,52 +169,53 @@ def organize_files(directory: str, dry_run: bool = True) -> Dict[str, List[str]] else: # Default to analyzers if nothing else matches categories = ["analyzers"] - + # Use the first category (most relevant) primary_category = categories[0] categorized_files[primary_category].append(file_path) - + # Print and execute the organization plan for category, files in categorized_files.items(): if not files: continue - + print(f"\nCategory: {category}") for file_path in files: print(f" - {os.path.basename(file_path)}") - + if not dry_run and category != "root": # Create the category directory if it doesn't exist category_dir = os.path.join(directory, category) os.makedirs(category_dir, exist_ok=True) - + # Move files to the category directory for file_path in files: if category != "root": dest_path = os.path.join(category_dir, os.path.basename(file_path)) shutil.move(file_path, dest_path) print(f" Moved to {dest_path}") - + return categorized_files + def main(): """Main function to organize the codebase.""" import argparse - - parser = argparse.ArgumentParser(description='Organize a codebase by categorizing files.') - parser.add_argument('directory', help='The directory containing the files to organize') - parser.add_argument('--execute', action='store_true', help='Execute the organization plan (default is dry run)') - + + parser = argparse.ArgumentParser(description="Organize a codebase by categorizing files.") + parser.add_argument("directory", help="The directory containing the files to organize") + parser.add_argument("--execute", action="store_true", help="Execute the organization plan (default is dry run)") + args = parser.parse_args() - + print(f"Analyzing files in {args.directory}...") organize_files(args.directory, dry_run=not args.execute) - + if not args.execute: print("\nThis was a dry run. Use --execute to actually move the files.") else: print("\nFiles have been organized.") + if __name__ == "__main__": main() - diff --git a/organize_specific_codebase.py b/organize_specific_codebase.py index cfe8f534d..22d6b6a6d 100644 --- a/organize_specific_codebase.py +++ b/organize_specific_codebase.py @@ -1,100 +1,69 @@ #!/usr/bin/env python3 -""" -Specific Codebase Organizer +"""Specific Codebase Organizer This script organizes the specific codebase structure shown in the screenshot, with 5 folders and 21 Python files in the root directory. """ import os -import re import shutil -from pathlib import Path -from typing import Dict, List, Set # Define the organization structure based on the files in the screenshot ORGANIZATION_PLAN = { - "analyzers": [ - "analyzer.py", - "analyzer_manager.py", - "base_analyzer.py", - "code_quality_analyzer.py", - "codebase_analyzer.py", - "dependency_analyzer.py", - "error_analyzer.py", - "unified_analyzer.py" - ], - "code_quality": [ - "code_quality.py" - ], - "context": [ - "codebase_context.py", - "context_codebase.py", - "current_code_codebase.py" - ], - "issues": [ - "issue_analyzer.py", - "issue_types.py", - "issues.py" - ], - "dependencies": [ - "dependencies.py" - ], + "analyzers": ["analyzer.py", "analyzer_manager.py", "base_analyzer.py", "code_quality_analyzer.py", "codebase_analyzer.py", "dependency_analyzer.py", "error_analyzer.py", "unified_analyzer.py"], + "code_quality": ["code_quality.py"], + "context": ["codebase_context.py", "context_codebase.py", "current_code_codebase.py"], + "issues": ["issue_analyzer.py", "issue_types.py", "issues.py"], + "dependencies": ["dependencies.py"], # Files to keep in root - "root": [ - "__init__.py", - "api.py", - "README.md" - ] + "root": ["__init__.py", "api.py", "README.md"], } + def organize_specific_codebase(directory: str, dry_run: bool = True) -> None: - """ - Organize the specific codebase structure. - + """Organize the specific codebase structure. + Args: directory: The directory containing the files to organize dry_run: If True, only print the planned changes without making them """ print(f"Organizing codebase in {directory}...") - + # Create directories if they don't exist (unless dry run) if not dry_run: for category in ORGANIZATION_PLAN: if category != "root": os.makedirs(os.path.join(directory, category), exist_ok=True) - + # Process each file according to the plan for category, files in ORGANIZATION_PLAN.items(): print(f"\nCategory: {category}") - + for filename in files: source_path = os.path.join(directory, filename) - + # Skip if file doesn't exist if not os.path.exists(source_path): print(f" - {filename} (not found, skipping)") continue - + print(f" - {filename}") - + # Move the file if not a dry run and not in root category if not dry_run and category != "root": dest_path = os.path.join(directory, category, filename) shutil.move(source_path, dest_path) print(f" Moved to {dest_path}") - + # Handle any remaining Python files not explicitly categorized all_planned_files = [f for files in ORGANIZATION_PLAN.values() for f in files] - remaining_files = [f for f in os.listdir(directory) - if f.endswith('.py') and os.path.isfile(os.path.join(directory, f)) - and f not in all_planned_files] - + remaining_files = [f for f in os.listdir(directory) if f.endswith(".py") and os.path.isfile(os.path.join(directory, f)) and f not in all_planned_files] + if remaining_files: print("\nRemaining Python files (not categorized):") for filename in remaining_files: print(f" - {filename}") - + # Try to categorize based on filename if "analyzer" in filename.lower(): category = "analyzers" @@ -111,9 +80,9 @@ def organize_specific_codebase(directory: str, dry_run: bool = True) -> None: else: # Default to analyzers category = "analyzers" - + print(f" Suggested category: {category}") - + # Move the file if not a dry run if not dry_run: os.makedirs(os.path.join(directory, category), exist_ok=True) @@ -121,36 +90,37 @@ def organize_specific_codebase(directory: str, dry_run: bool = True) -> None: shutil.move(os.path.join(directory, filename), dest_path) print(f" Moved to {dest_path}") + def main(): """Main function to organize the specific codebase.""" import argparse - - parser = argparse.ArgumentParser(description='Organize the specific codebase structure.') - parser.add_argument('directory', help='The directory containing the files to organize') - parser.add_argument('--execute', action='store_true', help='Execute the organization plan (default is dry run)') - + + parser = argparse.ArgumentParser(description="Organize the specific codebase structure.") + parser.add_argument("directory", help="The directory containing the files to organize") + parser.add_argument("--execute", action="store_true", help="Execute the organization plan (default is dry run)") + args = parser.parse_args() - + organize_specific_codebase(args.directory, dry_run=not args.execute) - + if not args.execute: print("\nThis was a dry run. Use --execute to actually move the files.") else: print("\nFiles have been organized according to the plan.") - + print("\nAfter organizing, you may need to update imports in your code.") print("You can use the Codegen SDK to automatically update imports:") print(""" # Example code to update imports after moving files from codegen.sdk import Codebase - + # Initialize the codebase codebase = Codebase("path/to/your/codebase") - + # Commit the changes to ensure the codebase is up-to-date codebase.commit() """) + if __name__ == "__main__": main() - diff --git a/organize_with_codegen_sdk.py b/organize_with_codegen_sdk.py index 263947c1b..066366c6e 100644 --- a/organize_with_codegen_sdk.py +++ b/organize_with_codegen_sdk.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -""" -Codebase Organizer using Codegen SDK +"""Codebase Organizer using Codegen SDK This script uses the Codegen SDK to programmatically organize a codebase by moving symbols between files and updating imports automatically. @@ -8,7 +7,6 @@ import os import sys -from typing import Dict, List, Set, Optional try: from codegen.sdk import Codebase @@ -19,118 +17,87 @@ # Define the organization structure based on the files in the screenshot ORGANIZATION_PLAN = { - "analyzers": [ - "analyzer.py", - "analyzer_manager.py", - "base_analyzer.py", - "code_quality_analyzer.py", - "codebase_analyzer.py", - "dependency_analyzer.py", - "error_analyzer.py", - "unified_analyzer.py" - ], - "code_quality": [ - "code_quality.py" - ], - "context": [ - "codebase_context.py", - "context_codebase.py", - "current_code_codebase.py" - ], - "issues": [ - "issue_analyzer.py", - "issue_types.py", - "issues.py" - ], - "dependencies": [ - "dependencies.py" - ], + "analyzers": ["analyzer.py", "analyzer_manager.py", "base_analyzer.py", "code_quality_analyzer.py", "codebase_analyzer.py", "dependency_analyzer.py", "error_analyzer.py", "unified_analyzer.py"], + "code_quality": ["code_quality.py"], + "context": ["codebase_context.py", "context_codebase.py", "current_code_codebase.py"], + "issues": ["issue_analyzer.py", "issue_types.py", "issues.py"], + "dependencies": ["dependencies.py"], # Files to keep in root - "root": [ - "__init__.py", - "api.py", - "README.md" - ] + "root": ["__init__.py", "api.py", "README.md"], } + def organize_with_codegen_sdk(directory: str, dry_run: bool = True) -> None: - """ - Organize the codebase using Codegen SDK. - + """Organize the codebase using Codegen SDK. + Args: directory: The directory containing the files to organize dry_run: If True, only print the planned changes without making them """ print(f"Organizing codebase in {directory} using Codegen SDK...") - + # Initialize the codebase codebase = Codebase(directory) - + # Create directories if they don't exist (unless dry run) if not dry_run: for category in ORGANIZATION_PLAN: if category != "root": os.makedirs(os.path.join(directory, category), exist_ok=True) - + # Process each file according to the plan for category, files in ORGANIZATION_PLAN.items(): if category == "root": continue # Skip files that should stay in root - + print(f"\nCategory: {category}") - + for filename in files: source_path = os.path.join(directory, filename) - + # Skip if file doesn't exist if not os.path.exists(source_path): print(f" - {filename} (not found, skipping)") continue - + print(f" - {filename}") - + # Move the file if not a dry run if not dry_run: try: # Get the source file source_file = codebase.get_file(filename) - + # Create the destination file path dest_path = os.path.join(category, filename) - + # Create the destination file if it doesn't exist if not os.path.exists(os.path.join(directory, dest_path)): dest_file = codebase.create_file(dest_path) else: dest_file = codebase.get_file(dest_path) - + # Move all symbols from source to destination for symbol in source_file.symbols: print(f" Moving symbol: {symbol.name}") - symbol.move_to_file( - dest_file, - include_dependencies=True, - strategy="update_all_imports" - ) - + symbol.move_to_file(dest_file, include_dependencies=True, strategy="update_all_imports") + # Commit changes to ensure the codebase is up-to-date codebase.commit() - + print(f" Moved to {dest_path} with imports updated") except Exception as e: print(f" Error moving {filename}: {e}") - + # Handle any remaining Python files not explicitly categorized all_planned_files = [f for files in ORGANIZATION_PLAN.values() for f in files] - remaining_files = [f for f in os.listdir(directory) - if f.endswith('.py') and os.path.isfile(os.path.join(directory, f)) - and f not in all_planned_files] - + remaining_files = [f for f in os.listdir(directory) if f.endswith(".py") and os.path.isfile(os.path.join(directory, f)) and f not in all_planned_files] + if remaining_files: print("\nRemaining Python files (not categorized):") for filename in remaining_files: print(f" - {filename}") - + # Try to categorize based on filename if "analyzer" in filename.lower(): category = "analyzers" @@ -147,57 +114,54 @@ def organize_with_codegen_sdk(directory: str, dry_run: bool = True) -> None: else: # Default to analyzers category = "analyzers" - + print(f" Suggested category: {category}") - + # Move the file if not a dry run if not dry_run: try: # Get the source file source_file = codebase.get_file(filename) - + # Create the destination file path dest_path = os.path.join(category, filename) - + # Create the destination file if it doesn't exist if not os.path.exists(os.path.join(directory, dest_path)): dest_file = codebase.create_file(dest_path) else: dest_file = codebase.get_file(dest_path) - + # Move all symbols from source to destination for symbol in source_file.symbols: print(f" Moving symbol: {symbol.name}") - symbol.move_to_file( - dest_file, - include_dependencies=True, - strategy="update_all_imports" - ) - + symbol.move_to_file(dest_file, include_dependencies=True, strategy="update_all_imports") + # Commit changes to ensure the codebase is up-to-date codebase.commit() - + print(f" Moved to {dest_path} with imports updated") except Exception as e: print(f" Error moving {filename}: {e}") + def main(): """Main function to organize the codebase using Codegen SDK.""" import argparse - - parser = argparse.ArgumentParser(description='Organize the codebase using Codegen SDK.') - parser.add_argument('directory', help='The directory containing the files to organize') - parser.add_argument('--execute', action='store_true', help='Execute the organization plan (default is dry run)') - + + parser = argparse.ArgumentParser(description="Organize the codebase using Codegen SDK.") + parser.add_argument("directory", help="The directory containing the files to organize") + parser.add_argument("--execute", action="store_true", help="Execute the organization plan (default is dry run)") + args = parser.parse_args() - + organize_with_codegen_sdk(args.directory, dry_run=not args.execute) - + if not args.execute: print("\nThis was a dry run. Use --execute to actually move the files.") else: print("\nFiles have been organized according to the plan.") + if __name__ == "__main__": main() - diff --git a/requirements.txt b/requirements.txt index 6a778d879..131fd3da4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,3 @@ codegen-sdk>=0.1.0 networkx>=2.6.3 matplotlib>=3.5.1 rich>=12.0.0 - diff --git a/src/codegen/agents/agent.py b/src/codegen/agents/agent.py index 3f50bb87f..24772b625 100644 --- a/src/codegen/agents/agent.py +++ b/src/codegen/agents/agent.py @@ -1,5 +1,5 @@ import os -from typing import Any, Optional +from typing import Any from codegen.agents.client.openapi_client.api.agents_api import AgentsApi from codegen.agents.client.openapi_client.api_client import ApiClient @@ -44,7 +44,7 @@ def refresh(self) -> None: class Agent: """API client for interacting with Codegen AI agents.""" - def __init__(self, token: str, org_id: Optional[int] = None, base_url: Optional[str] = CODEGEN_BASE_API_URL): + def __init__(self, token: str, org_id: int | None = None, base_url: str | None = CODEGEN_BASE_API_URL): """Initialize a new Agent client. Args: @@ -81,7 +81,7 @@ def run(self, prompt: str) -> AgentTask: self.current_job = job return job - def get_status(self) -> Optional[dict[str, Any]]: + def get_status(self) -> dict[str, Any] | None: """Get the status of the current job. Returns: diff --git a/src/codegen/agents/chat_agent.py b/src/codegen/agents/chat_agent.py index 24ecada26..08c36a74f 100644 --- a/src/codegen/agents/chat_agent.py +++ b/src/codegen/agents/chat_agent.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from uuid import uuid4 from langchain.tools import BaseTool @@ -13,7 +13,7 @@ class ChatAgent: """Agent for interacting with a codebase.""" - def __init__(self, codebase: "Codebase", model_provider: str = "anthropic", model_name: str = "claude-3-5-sonnet-latest", memory: bool = True, tools: Optional[list[BaseTool]] = None, **kwargs): + def __init__(self, codebase: "Codebase", model_provider: str = "anthropic", model_name: str = "claude-3-5-sonnet-latest", memory: bool = True, tools: list[BaseTool] | None = None, **kwargs): """Initialize a CodeAgent. Args: @@ -31,7 +31,7 @@ def __init__(self, codebase: "Codebase", model_provider: str = "anthropic", mode self.codebase = codebase self.agent = create_chat_agent(self.codebase, model_provider=model_provider, model_name=model_name, memory=memory, additional_tools=tools, **kwargs) - def run(self, prompt: str, thread_id: Optional[str] = None) -> str: + def run(self, prompt: str, thread_id: str | None = None) -> str: """Run the agent with a prompt. Args: @@ -59,7 +59,7 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str: return s["final_answer"] - def chat(self, prompt: str, thread_id: Optional[str] = None) -> tuple[str, str]: + def chat(self, prompt: str, thread_id: str | None = None) -> tuple[str, str]: """Chat with the agent, maintaining conversation history. Args: diff --git a/src/codegen/agents/client/openapi_client/api/agents_api.py b/src/codegen/agents/client/openapi_client/api/agents_api.py index 73c570738..7b8e8f47e 100644 --- a/src/codegen/agents/client/openapi_client/api/agents_api.py +++ b/src/codegen/agents/client/openapi_client/api/agents_api.py @@ -8,7 +8,7 @@ Do not edit the class manually. """ -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any from pydantic import Field, StrictFloat, StrictInt, StrictStr, validate_call @@ -36,11 +36,11 @@ def create_agent_run_v1_organizations_org_id_agent_run_post( self, org_id: StrictInt, create_agent_run_input: CreateAgentRunInput, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> AgentRunResponse: """Create Agent Run @@ -84,7 +84,7 @@ def create_agent_run_v1_organizations_org_id_agent_run_post( _host_index=_host_index, ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -100,11 +100,11 @@ def create_agent_run_v1_organizations_org_id_agent_run_post_with_http_info( self, org_id: StrictInt, create_agent_run_input: CreateAgentRunInput, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> ApiResponse[AgentRunResponse]: """Create Agent Run @@ -148,7 +148,7 @@ def create_agent_run_v1_organizations_org_id_agent_run_post_with_http_info( _host_index=_host_index, ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -164,11 +164,11 @@ def create_agent_run_v1_organizations_org_id_agent_run_post_without_preload_cont self, org_id: StrictInt, create_agent_run_input: CreateAgentRunInput, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: """Create Agent Run @@ -212,7 +212,7 @@ def create_agent_run_v1_organizations_org_id_agent_run_post_without_preload_cont _host_index=_host_index, ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -235,10 +235,10 @@ def _create_agent_run_v1_organizations_org_id_agent_run_post_serialize( _path_params: dict[str, str] = {} _query_params: list[tuple[str, str]] = [] - _header_params: dict[str, Optional[str]] = _headers or {} + _header_params: dict[str, str | None] = _headers or {} _form_params: list[tuple[str, str]] = [] - _files: dict[str, Union[str, bytes, list[str], list[bytes], list[tuple[str, bytes]]]] = {} - _body_params: Optional[bytes] = None + _files: dict[str, str | bytes | list[str] | list[bytes] | list[tuple[str, bytes]]] = {} + _body_params: bytes | None = None # process the path parameters if org_id is not None: @@ -287,11 +287,11 @@ def create_agent_run_v1_organizations_org_id_agent_run_post_0( self, org_id: StrictInt, create_agent_run_input: CreateAgentRunInput, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> AgentRunResponse: """Create Agent Run @@ -335,7 +335,7 @@ def create_agent_run_v1_organizations_org_id_agent_run_post_0( _host_index=_host_index, ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -351,11 +351,11 @@ def create_agent_run_v1_organizations_org_id_agent_run_post_0_with_http_info( self, org_id: StrictInt, create_agent_run_input: CreateAgentRunInput, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> ApiResponse[AgentRunResponse]: """Create Agent Run @@ -399,7 +399,7 @@ def create_agent_run_v1_organizations_org_id_agent_run_post_0_with_http_info( _host_index=_host_index, ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -415,11 +415,11 @@ def create_agent_run_v1_organizations_org_id_agent_run_post_0_without_preload_co self, org_id: StrictInt, create_agent_run_input: CreateAgentRunInput, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: """Create Agent Run @@ -463,7 +463,7 @@ def create_agent_run_v1_organizations_org_id_agent_run_post_0_without_preload_co _host_index=_host_index, ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -486,10 +486,10 @@ def _create_agent_run_v1_organizations_org_id_agent_run_post_0_serialize( _path_params: dict[str, str] = {} _query_params: list[tuple[str, str]] = [] - _header_params: dict[str, Optional[str]] = _headers or {} + _header_params: dict[str, str | None] = _headers or {} _form_params: list[tuple[str, str]] = [] - _files: dict[str, Union[str, bytes, list[str], list[bytes], list[tuple[str, bytes]]]] = {} - _body_params: Optional[bytes] = None + _files: dict[str, str | bytes | list[str] | list[bytes] | list[tuple[str, bytes]]] = {} + _body_params: bytes | None = None # process the path parameters if org_id is not None: @@ -538,11 +538,11 @@ def create_agent_run_v1_organizations_org_id_agent_run_post_1( self, org_id: StrictInt, create_agent_run_input: CreateAgentRunInput, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> AgentRunResponse: """Create Agent Run @@ -586,7 +586,7 @@ def create_agent_run_v1_organizations_org_id_agent_run_post_1( _host_index=_host_index, ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -602,11 +602,11 @@ def create_agent_run_v1_organizations_org_id_agent_run_post_1_with_http_info( self, org_id: StrictInt, create_agent_run_input: CreateAgentRunInput, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> ApiResponse[AgentRunResponse]: """Create Agent Run @@ -650,7 +650,7 @@ def create_agent_run_v1_organizations_org_id_agent_run_post_1_with_http_info( _host_index=_host_index, ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -666,11 +666,11 @@ def create_agent_run_v1_organizations_org_id_agent_run_post_1_without_preload_co self, org_id: StrictInt, create_agent_run_input: CreateAgentRunInput, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: """Create Agent Run @@ -714,7 +714,7 @@ def create_agent_run_v1_organizations_org_id_agent_run_post_1_without_preload_co _host_index=_host_index, ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -737,10 +737,10 @@ def _create_agent_run_v1_organizations_org_id_agent_run_post_1_serialize( _path_params: dict[str, str] = {} _query_params: list[tuple[str, str]] = [] - _header_params: dict[str, Optional[str]] = _headers or {} + _header_params: dict[str, str | None] = _headers or {} _form_params: list[tuple[str, str]] = [] - _files: dict[str, Union[str, bytes, list[str], list[bytes], list[tuple[str, bytes]]]] = {} - _body_params: Optional[bytes] = None + _files: dict[str, str | bytes | list[str] | list[bytes] | list[tuple[str, bytes]]] = {} + _body_params: bytes | None = None # process the path parameters if org_id is not None: @@ -789,11 +789,11 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get( self, agent_run_id: StrictInt, org_id: StrictInt, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> AgentRunResponse: """Get Agent Run @@ -831,7 +831,7 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get( agent_run_id=agent_run_id, org_id=org_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -847,11 +847,11 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_with_http_i self, agent_run_id: StrictInt, org_id: StrictInt, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> ApiResponse[AgentRunResponse]: """Get Agent Run @@ -889,7 +889,7 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_with_http_i agent_run_id=agent_run_id, org_id=org_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -905,11 +905,11 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_without_pre self, agent_run_id: StrictInt, org_id: StrictInt, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: """Get Agent Run @@ -947,7 +947,7 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_without_pre agent_run_id=agent_run_id, org_id=org_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -970,10 +970,10 @@ def _get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_serialize( _path_params: dict[str, str] = {} _query_params: list[tuple[str, str]] = [] - _header_params: dict[str, Optional[str]] = _headers or {} + _header_params: dict[str, str | None] = _headers or {} _form_params: list[tuple[str, str]] = [] - _files: dict[str, Union[str, bytes, list[str], list[bytes], list[tuple[str, bytes]]]] = {} - _body_params: Optional[bytes] = None + _files: dict[str, str | bytes | list[str] | list[bytes] | list[tuple[str, bytes]]] = {} + _body_params: bytes | None = None # process the path parameters if agent_run_id is not None: @@ -1014,11 +1014,11 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_0( self, agent_run_id: StrictInt, org_id: StrictInt, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> AgentRunResponse: """Get Agent Run @@ -1056,7 +1056,7 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_0( agent_run_id=agent_run_id, org_id=org_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -1072,11 +1072,11 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_0_with_http self, agent_run_id: StrictInt, org_id: StrictInt, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> ApiResponse[AgentRunResponse]: """Get Agent Run @@ -1114,7 +1114,7 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_0_with_http agent_run_id=agent_run_id, org_id=org_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -1130,11 +1130,11 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_0_without_p self, agent_run_id: StrictInt, org_id: StrictInt, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: """Get Agent Run @@ -1172,7 +1172,7 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_0_without_p agent_run_id=agent_run_id, org_id=org_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -1195,10 +1195,10 @@ def _get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_0_serializ _path_params: dict[str, str] = {} _query_params: list[tuple[str, str]] = [] - _header_params: dict[str, Optional[str]] = _headers or {} + _header_params: dict[str, str | None] = _headers or {} _form_params: list[tuple[str, str]] = [] - _files: dict[str, Union[str, bytes, list[str], list[bytes], list[tuple[str, bytes]]]] = {} - _body_params: Optional[bytes] = None + _files: dict[str, str | bytes | list[str] | list[bytes] | list[tuple[str, bytes]]] = {} + _body_params: bytes | None = None # process the path parameters if agent_run_id is not None: @@ -1239,11 +1239,11 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_1( self, agent_run_id: StrictInt, org_id: StrictInt, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> AgentRunResponse: """Get Agent Run @@ -1281,7 +1281,7 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_1( agent_run_id=agent_run_id, org_id=org_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -1297,11 +1297,11 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_1_with_http self, agent_run_id: StrictInt, org_id: StrictInt, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> ApiResponse[AgentRunResponse]: """Get Agent Run @@ -1339,7 +1339,7 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_1_with_http agent_run_id=agent_run_id, org_id=org_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -1355,11 +1355,11 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_1_without_p self, agent_run_id: StrictInt, org_id: StrictInt, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: """Get Agent Run @@ -1397,7 +1397,7 @@ def get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_1_without_p agent_run_id=agent_run_id, org_id=org_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "AgentRunResponse", "422": "HTTPValidationError", } @@ -1420,10 +1420,10 @@ def _get_agent_run_v1_organizations_org_id_agent_run_agent_run_id_get_1_serializ _path_params: dict[str, str] = {} _query_params: list[tuple[str, str]] = [] - _header_params: dict[str, Optional[str]] = _headers or {} + _header_params: dict[str, str | None] = _headers or {} _form_params: list[tuple[str, str]] = [] - _files: dict[str, Union[str, bytes, list[str], list[bytes], list[tuple[str, bytes]]]] = {} - _body_params: Optional[bytes] = None + _files: dict[str, str | bytes | list[str] | list[bytes] | list[tuple[str, bytes]]] = {} + _body_params: bytes | None = None # process the path parameters if agent_run_id is not None: diff --git a/src/codegen/agents/client/openapi_client/api/organizations_api.py b/src/codegen/agents/client/openapi_client/api/organizations_api.py index 091c830a1..237141ef0 100644 --- a/src/codegen/agents/client/openapi_client/api/organizations_api.py +++ b/src/codegen/agents/client/openapi_client/api/organizations_api.py @@ -8,7 +8,7 @@ Do not edit the class manually. """ -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any from pydantic import Field, StrictFloat, StrictInt, StrictStr, validate_call @@ -33,13 +33,13 @@ def __init__(self, api_client=None) -> None: @validate_call def get_organizations_v1_organizations_get( self, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> PageOrganizationResponse: """Get Organizations @@ -77,7 +77,7 @@ def get_organizations_v1_organizations_get( skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageOrganizationResponse", "422": "HTTPValidationError", } @@ -91,13 +91,13 @@ def get_organizations_v1_organizations_get( @validate_call def get_organizations_v1_organizations_get_with_http_info( self, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> ApiResponse[PageOrganizationResponse]: """Get Organizations @@ -135,7 +135,7 @@ def get_organizations_v1_organizations_get_with_http_info( skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageOrganizationResponse", "422": "HTTPValidationError", } @@ -149,13 +149,13 @@ def get_organizations_v1_organizations_get_with_http_info( @validate_call def get_organizations_v1_organizations_get_without_preload_content( self, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: """Get Organizations @@ -193,7 +193,7 @@ def get_organizations_v1_organizations_get_without_preload_content( skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageOrganizationResponse", "422": "HTTPValidationError", } @@ -216,10 +216,10 @@ def _get_organizations_v1_organizations_get_serialize( _path_params: dict[str, str] = {} _query_params: list[tuple[str, str]] = [] - _header_params: dict[str, Optional[str]] = _headers or {} + _header_params: dict[str, str | None] = _headers or {} _form_params: list[tuple[str, str]] = [] - _files: dict[str, Union[str, bytes, list[str], list[bytes], list[tuple[str, bytes]]]] = {} - _body_params: Optional[bytes] = None + _files: dict[str, str | bytes | list[str] | list[bytes] | list[tuple[str, bytes]]] = {} + _body_params: bytes | None = None # process the path parameters # process the query parameters @@ -260,13 +260,13 @@ def _get_organizations_v1_organizations_get_serialize( @validate_call def get_organizations_v1_organizations_get_0( self, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> PageOrganizationResponse: """Get Organizations @@ -304,7 +304,7 @@ def get_organizations_v1_organizations_get_0( skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageOrganizationResponse", "422": "HTTPValidationError", } @@ -318,13 +318,13 @@ def get_organizations_v1_organizations_get_0( @validate_call def get_organizations_v1_organizations_get_0_with_http_info( self, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> ApiResponse[PageOrganizationResponse]: """Get Organizations @@ -362,7 +362,7 @@ def get_organizations_v1_organizations_get_0_with_http_info( skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageOrganizationResponse", "422": "HTTPValidationError", } @@ -376,13 +376,13 @@ def get_organizations_v1_organizations_get_0_with_http_info( @validate_call def get_organizations_v1_organizations_get_0_without_preload_content( self, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: """Get Organizations @@ -420,7 +420,7 @@ def get_organizations_v1_organizations_get_0_without_preload_content( skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageOrganizationResponse", "422": "HTTPValidationError", } @@ -443,10 +443,10 @@ def _get_organizations_v1_organizations_get_0_serialize( _path_params: dict[str, str] = {} _query_params: list[tuple[str, str]] = [] - _header_params: dict[str, Optional[str]] = _headers or {} + _header_params: dict[str, str | None] = _headers or {} _form_params: list[tuple[str, str]] = [] - _files: dict[str, Union[str, bytes, list[str], list[bytes], list[tuple[str, bytes]]]] = {} - _body_params: Optional[bytes] = None + _files: dict[str, str | bytes | list[str] | list[bytes] | list[tuple[str, bytes]]] = {} + _body_params: bytes | None = None # process the path parameters # process the query parameters @@ -487,13 +487,13 @@ def _get_organizations_v1_organizations_get_0_serialize( @validate_call def get_organizations_v1_organizations_get_1( self, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> PageOrganizationResponse: """Get Organizations @@ -531,7 +531,7 @@ def get_organizations_v1_organizations_get_1( skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageOrganizationResponse", "422": "HTTPValidationError", } @@ -545,13 +545,13 @@ def get_organizations_v1_organizations_get_1( @validate_call def get_organizations_v1_organizations_get_1_with_http_info( self, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> ApiResponse[PageOrganizationResponse]: """Get Organizations @@ -589,7 +589,7 @@ def get_organizations_v1_organizations_get_1_with_http_info( skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageOrganizationResponse", "422": "HTTPValidationError", } @@ -603,13 +603,13 @@ def get_organizations_v1_organizations_get_1_with_http_info( @validate_call def get_organizations_v1_organizations_get_1_without_preload_content( self, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: """Get Organizations @@ -647,7 +647,7 @@ def get_organizations_v1_organizations_get_1_without_preload_content( skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageOrganizationResponse", "422": "HTTPValidationError", } @@ -670,10 +670,10 @@ def _get_organizations_v1_organizations_get_1_serialize( _path_params: dict[str, str] = {} _query_params: list[tuple[str, str]] = [] - _header_params: dict[str, Optional[str]] = _headers or {} + _header_params: dict[str, str | None] = _headers or {} _form_params: list[tuple[str, str]] = [] - _files: dict[str, Union[str, bytes, list[str], list[bytes], list[tuple[str, bytes]]]] = {} - _body_params: Optional[bytes] = None + _files: dict[str, str | bytes | list[str] | list[bytes] | list[tuple[str, bytes]]] = {} + _body_params: bytes | None = None # process the path parameters # process the query parameters diff --git a/src/codegen/agents/client/openapi_client/api/users_api.py b/src/codegen/agents/client/openapi_client/api/users_api.py index 2787cd5a4..fba76ac08 100644 --- a/src/codegen/agents/client/openapi_client/api/users_api.py +++ b/src/codegen/agents/client/openapi_client/api/users_api.py @@ -8,7 +8,7 @@ Do not edit the class manually. """ -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any from pydantic import Field, StrictFloat, StrictInt, StrictStr, validate_call @@ -36,11 +36,11 @@ def get_user_v1_organizations_org_id_users_user_id_get( self, org_id: StrictStr, user_id: StrictStr, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> UserResponse: """Get User @@ -78,7 +78,7 @@ def get_user_v1_organizations_org_id_users_user_id_get( org_id=org_id, user_id=user_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "UserResponse", "422": "HTTPValidationError", } @@ -94,11 +94,11 @@ def get_user_v1_organizations_org_id_users_user_id_get_with_http_info( self, org_id: StrictStr, user_id: StrictStr, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> ApiResponse[UserResponse]: """Get User @@ -136,7 +136,7 @@ def get_user_v1_organizations_org_id_users_user_id_get_with_http_info( org_id=org_id, user_id=user_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "UserResponse", "422": "HTTPValidationError", } @@ -152,11 +152,11 @@ def get_user_v1_organizations_org_id_users_user_id_get_without_preload_content( self, org_id: StrictStr, user_id: StrictStr, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: """Get User @@ -194,7 +194,7 @@ def get_user_v1_organizations_org_id_users_user_id_get_without_preload_content( org_id=org_id, user_id=user_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "UserResponse", "422": "HTTPValidationError", } @@ -217,10 +217,10 @@ def _get_user_v1_organizations_org_id_users_user_id_get_serialize( _path_params: dict[str, str] = {} _query_params: list[tuple[str, str]] = [] - _header_params: dict[str, Optional[str]] = _headers or {} + _header_params: dict[str, str | None] = _headers or {} _form_params: list[tuple[str, str]] = [] - _files: dict[str, Union[str, bytes, list[str], list[bytes], list[tuple[str, bytes]]]] = {} - _body_params: Optional[bytes] = None + _files: dict[str, str | bytes | list[str] | list[bytes] | list[tuple[str, bytes]]] = {} + _body_params: bytes | None = None # process the path parameters if org_id is not None: @@ -261,11 +261,11 @@ def get_user_v1_organizations_org_id_users_user_id_get_0( self, org_id: StrictStr, user_id: StrictStr, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> UserResponse: """Get User @@ -303,7 +303,7 @@ def get_user_v1_organizations_org_id_users_user_id_get_0( org_id=org_id, user_id=user_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "UserResponse", "422": "HTTPValidationError", } @@ -319,11 +319,11 @@ def get_user_v1_organizations_org_id_users_user_id_get_0_with_http_info( self, org_id: StrictStr, user_id: StrictStr, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> ApiResponse[UserResponse]: """Get User @@ -361,7 +361,7 @@ def get_user_v1_organizations_org_id_users_user_id_get_0_with_http_info( org_id=org_id, user_id=user_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "UserResponse", "422": "HTTPValidationError", } @@ -377,11 +377,11 @@ def get_user_v1_organizations_org_id_users_user_id_get_0_without_preload_content self, org_id: StrictStr, user_id: StrictStr, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: """Get User @@ -419,7 +419,7 @@ def get_user_v1_organizations_org_id_users_user_id_get_0_without_preload_content org_id=org_id, user_id=user_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "UserResponse", "422": "HTTPValidationError", } @@ -442,10 +442,10 @@ def _get_user_v1_organizations_org_id_users_user_id_get_0_serialize( _path_params: dict[str, str] = {} _query_params: list[tuple[str, str]] = [] - _header_params: dict[str, Optional[str]] = _headers or {} + _header_params: dict[str, str | None] = _headers or {} _form_params: list[tuple[str, str]] = [] - _files: dict[str, Union[str, bytes, list[str], list[bytes], list[tuple[str, bytes]]]] = {} - _body_params: Optional[bytes] = None + _files: dict[str, str | bytes | list[str] | list[bytes] | list[tuple[str, bytes]]] = {} + _body_params: bytes | None = None # process the path parameters if org_id is not None: @@ -486,11 +486,11 @@ def get_user_v1_organizations_org_id_users_user_id_get_1( self, org_id: StrictStr, user_id: StrictStr, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> UserResponse: """Get User @@ -528,7 +528,7 @@ def get_user_v1_organizations_org_id_users_user_id_get_1( org_id=org_id, user_id=user_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "UserResponse", "422": "HTTPValidationError", } @@ -544,11 +544,11 @@ def get_user_v1_organizations_org_id_users_user_id_get_1_with_http_info( self, org_id: StrictStr, user_id: StrictStr, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> ApiResponse[UserResponse]: """Get User @@ -586,7 +586,7 @@ def get_user_v1_organizations_org_id_users_user_id_get_1_with_http_info( org_id=org_id, user_id=user_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "UserResponse", "422": "HTTPValidationError", } @@ -602,11 +602,11 @@ def get_user_v1_organizations_org_id_users_user_id_get_1_without_preload_content self, org_id: StrictStr, user_id: StrictStr, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: """Get User @@ -644,7 +644,7 @@ def get_user_v1_organizations_org_id_users_user_id_get_1_without_preload_content org_id=org_id, user_id=user_id, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "UserResponse", "422": "HTTPValidationError", } @@ -667,10 +667,10 @@ def _get_user_v1_organizations_org_id_users_user_id_get_1_serialize( _path_params: dict[str, str] = {} _query_params: list[tuple[str, str]] = [] - _header_params: dict[str, Optional[str]] = _headers or {} + _header_params: dict[str, str | None] = _headers or {} _form_params: list[tuple[str, str]] = [] - _files: dict[str, Union[str, bytes, list[str], list[bytes], list[tuple[str, bytes]]]] = {} - _body_params: Optional[bytes] = None + _files: dict[str, str | bytes | list[str] | list[bytes] | list[tuple[str, bytes]]] = {} + _body_params: bytes | None = None # process the path parameters if org_id is not None: @@ -710,13 +710,13 @@ def _get_user_v1_organizations_org_id_users_user_id_get_1_serialize( def get_users_v1_organizations_org_id_users_get( self, org_id: StrictStr, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> PageUserResponse: """Get Users @@ -756,7 +756,7 @@ def get_users_v1_organizations_org_id_users_get( org_id=org_id, skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageUserResponse", "422": "HTTPValidationError", } @@ -771,13 +771,13 @@ def get_users_v1_organizations_org_id_users_get( def get_users_v1_organizations_org_id_users_get_with_http_info( self, org_id: StrictStr, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> ApiResponse[PageUserResponse]: """Get Users @@ -817,7 +817,7 @@ def get_users_v1_organizations_org_id_users_get_with_http_info( org_id=org_id, skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageUserResponse", "422": "HTTPValidationError", } @@ -832,13 +832,13 @@ def get_users_v1_organizations_org_id_users_get_with_http_info( def get_users_v1_organizations_org_id_users_get_without_preload_content( self, org_id: StrictStr, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: """Get Users @@ -878,7 +878,7 @@ def get_users_v1_organizations_org_id_users_get_without_preload_content( org_id=org_id, skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageUserResponse", "422": "HTTPValidationError", } @@ -902,10 +902,10 @@ def _get_users_v1_organizations_org_id_users_get_serialize( _path_params: dict[str, str] = {} _query_params: list[tuple[str, str]] = [] - _header_params: dict[str, Optional[str]] = _headers or {} + _header_params: dict[str, str | None] = _headers or {} _form_params: list[tuple[str, str]] = [] - _files: dict[str, Union[str, bytes, list[str], list[bytes], list[tuple[str, bytes]]]] = {} - _body_params: Optional[bytes] = None + _files: dict[str, str | bytes | list[str] | list[bytes] | list[tuple[str, bytes]]] = {} + _body_params: bytes | None = None # process the path parameters if org_id is not None: @@ -949,13 +949,13 @@ def _get_users_v1_organizations_org_id_users_get_serialize( def get_users_v1_organizations_org_id_users_get_0( self, org_id: StrictStr, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> PageUserResponse: """Get Users @@ -995,7 +995,7 @@ def get_users_v1_organizations_org_id_users_get_0( org_id=org_id, skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageUserResponse", "422": "HTTPValidationError", } @@ -1010,13 +1010,13 @@ def get_users_v1_organizations_org_id_users_get_0( def get_users_v1_organizations_org_id_users_get_0_with_http_info( self, org_id: StrictStr, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> ApiResponse[PageUserResponse]: """Get Users @@ -1056,7 +1056,7 @@ def get_users_v1_organizations_org_id_users_get_0_with_http_info( org_id=org_id, skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageUserResponse", "422": "HTTPValidationError", } @@ -1071,13 +1071,13 @@ def get_users_v1_organizations_org_id_users_get_0_with_http_info( def get_users_v1_organizations_org_id_users_get_0_without_preload_content( self, org_id: StrictStr, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: """Get Users @@ -1117,7 +1117,7 @@ def get_users_v1_organizations_org_id_users_get_0_without_preload_content( org_id=org_id, skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageUserResponse", "422": "HTTPValidationError", } @@ -1141,10 +1141,10 @@ def _get_users_v1_organizations_org_id_users_get_0_serialize( _path_params: dict[str, str] = {} _query_params: list[tuple[str, str]] = [] - _header_params: dict[str, Optional[str]] = _headers or {} + _header_params: dict[str, str | None] = _headers or {} _form_params: list[tuple[str, str]] = [] - _files: dict[str, Union[str, bytes, list[str], list[bytes], list[tuple[str, bytes]]]] = {} - _body_params: Optional[bytes] = None + _files: dict[str, str | bytes | list[str] | list[bytes] | list[tuple[str, bytes]]] = {} + _body_params: bytes | None = None # process the path parameters if org_id is not None: @@ -1188,13 +1188,13 @@ def _get_users_v1_organizations_org_id_users_get_0_serialize( def get_users_v1_organizations_org_id_users_get_1( self, org_id: StrictStr, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> PageUserResponse: """Get Users @@ -1234,7 +1234,7 @@ def get_users_v1_organizations_org_id_users_get_1( org_id=org_id, skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageUserResponse", "422": "HTTPValidationError", } @@ -1249,13 +1249,13 @@ def get_users_v1_organizations_org_id_users_get_1( def get_users_v1_organizations_org_id_users_get_1_with_http_info( self, org_id: StrictStr, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> ApiResponse[PageUserResponse]: """Get Users @@ -1295,7 +1295,7 @@ def get_users_v1_organizations_org_id_users_get_1_with_http_info( org_id=org_id, skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageUserResponse", "422": "HTTPValidationError", } @@ -1310,13 +1310,13 @@ def get_users_v1_organizations_org_id_users_get_1_with_http_info( def get_users_v1_organizations_org_id_users_get_1_without_preload_content( self, org_id: StrictStr, - skip: Optional[Annotated[int, Field(strict=True, ge=0)]] = None, - limit: Optional[Annotated[int, Field(le=100, strict=True, ge=1)]] = None, - authorization: Optional[Any] = None, - _request_timeout: Union[None, Annotated[StrictFloat, Field(gt=0)], tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]]] = None, - _request_auth: Optional[dict[StrictStr, Any]] = None, - _content_type: Optional[StrictStr] = None, - _headers: Optional[dict[StrictStr, Any]] = None, + skip: Annotated[int, Field(strict=True, ge=0)] | None = None, + limit: Annotated[int, Field(le=100, strict=True, ge=1)] | None = None, + authorization: Any | None = None, + _request_timeout: None | Annotated[StrictFloat, Field(gt=0)] | tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]] = None, + _request_auth: dict[StrictStr, Any] | None = None, + _content_type: StrictStr | None = None, + _headers: dict[StrictStr, Any] | None = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: """Get Users @@ -1356,7 +1356,7 @@ def get_users_v1_organizations_org_id_users_get_1_without_preload_content( org_id=org_id, skip=skip, limit=limit, authorization=authorization, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index ) - _response_types_map: dict[str, Optional[str]] = { + _response_types_map: dict[str, str | None] = { "200": "PageUserResponse", "422": "HTTPValidationError", } @@ -1380,10 +1380,10 @@ def _get_users_v1_organizations_org_id_users_get_1_serialize( _path_params: dict[str, str] = {} _query_params: list[tuple[str, str]] = [] - _header_params: dict[str, Optional[str]] = _headers or {} + _header_params: dict[str, str | None] = _headers or {} _form_params: list[tuple[str, str]] = [] - _files: dict[str, Union[str, bytes, list[str], list[bytes], list[tuple[str, bytes]]]] = {} - _body_params: Optional[bytes] = None + _files: dict[str, str | bytes | list[str] | list[bytes] | list[tuple[str, bytes]]] = {} + _body_params: bytes | None = None # process the path parameters if org_id is not None: diff --git a/src/codegen/agents/client/openapi_client/api_client.py b/src/codegen/agents/client/openapi_client/api_client.py index 02e3fdb6c..68dcad83a 100644 --- a/src/codegen/agents/client/openapi_client/api_client.py +++ b/src/codegen/agents/client/openapi_client/api_client.py @@ -15,13 +15,12 @@ import os import re import tempfile + +# Import for date parsing +from datetime import datetime as dt from enum import Enum -from typing import Optional, Union from urllib.parse import quote -from dateutil.parser import parse -from pydantic import SecretStr - import codegen.agents.client.openapi_client as openapi_client from codegen.agents.client.openapi_client import rest from codegen.agents.client.openapi_client.api_response import ApiResponse @@ -29,7 +28,7 @@ from codegen.agents.client.openapi_client.configuration import Configuration from codegen.agents.client.openapi_client.exceptions import ApiException, ApiValueError -RequestSerialized = tuple[str, str, dict[str, str], Optional[str], list[str]] +RequestSerialized = tuple[str, str, dict[str, str], str | None, list[str]] class ApiClient: @@ -42,8 +41,8 @@ class ApiClient: :param configuration: .Configuration object for this client :param header_name: a header to pass when making calls to the API. - :param header_value: a header value to pass when making calls to - the API. + :param header_value: a header value to pass when making calls + to the API. :param cookie: a cookie to include in the header when making calls to the API """ @@ -227,7 +226,7 @@ def call_api(self, method, url, header_params=None, body=None, post_params=None, return response_data - def response_deserialize(self, response_data: rest.RESTResponse, response_types_map: Optional[dict[str, ApiResponseT]] = None) -> ApiResponse[ApiResponseT]: + def response_deserialize(self, response_data: rest.RESTResponse, response_types_map: dict[str, ApiResponseT] | None = None) -> ApiResponse[ApiResponseT]: """Deserializes response into an object. :param response_data: RESTResponse object to be deserialized. :param response_types_map: dict of response types. @@ -295,7 +294,7 @@ def sanitize_for_serialization(self, obj): return [self.sanitize_for_serialization(sub_obj) for sub_obj in obj] elif isinstance(obj, tuple): return tuple(self.sanitize_for_serialization(sub_obj) for sub_obj in obj) - elif isinstance(obj, (datetime.datetime, datetime.date)): + elif isinstance(obj, datetime.datetime | datetime.date): return obj.isoformat() elif isinstance(obj, decimal.Decimal): return str(obj) @@ -315,7 +314,7 @@ def sanitize_for_serialization(self, obj): return {key: self.sanitize_for_serialization(val) for key, val in obj_dict.items()} - def deserialize(self, response_text: str, response_type: str, content_type: Optional[str]): + def deserialize(self, response_text: str, response_type: str, content_type: str | None): """Deserializes response into an object. :param response: RESTResponse object to be deserialized. @@ -430,7 +429,7 @@ def parameters_to_url_query(self, params, collection_formats): for k, v in params.items() if isinstance(params, dict) else params: if isinstance(v, bool): v = str(v).lower() - if isinstance(v, (int, float)): + if isinstance(v, int | float): v = str(v) if isinstance(v, dict): v = json.dumps(v) @@ -456,7 +455,7 @@ def parameters_to_url_query(self, params, collection_formats): def files_parameters( self, - files: dict[str, Union[str, bytes, list[str], list[bytes], tuple[str, bytes]]], + files: dict[str, str | bytes | list[str] | list[bytes] | tuple[str, bytes]], ): """Builds form parameters. @@ -485,7 +484,7 @@ def files_parameters( params.append(tuple([k, tuple([filename, filedata, mimetype])])) return params - def select_header_accept(self, accepts: list[str]) -> Optional[str]: + def select_header_accept(self, accepts: list[str]) -> str | None: """Returns `Accept` based on an array of accepts provided. :param accepts: List of headers. @@ -618,7 +617,8 @@ def __deserialize_date(self, string): :return: date. """ try: - return parse(string).date() + # Use datetime's own parsing instead of dateutil + return dt.fromisoformat(string.replace("Z", "+00:00")).date() except ImportError: return string except ValueError: @@ -633,7 +633,8 @@ def __deserialize_datetime(self, string): :return: datetime. """ try: - return parse(string) + # Use datetime's own parsing instead of dateutil + return dt.fromisoformat(string.replace("Z", "+00:00")) except ImportError: return string except ValueError: diff --git a/src/codegen/agents/client/openapi_client/configuration.py b/src/codegen/agents/client/openapi_client/configuration.py index 8c1b3f2e6..39fc424fe 100644 --- a/src/codegen/agents/client/openapi_client/configuration.py +++ b/src/codegen/agents/client/openapi_client/configuration.py @@ -14,10 +14,9 @@ import multiprocessing import sys from logging import FileHandler -from typing import Any, ClassVar, Literal, Optional, TypedDict, Union +from typing import Any, ClassVar, Literal, NotRequired, Self, TypedDict import urllib3 -from typing_extensions import NotRequired, Self JSON_SCHEMA_VALIDATION_KEYWORDS = {"multipleOf", "maximum", "exclusiveMaximum", "minimum", "exclusiveMinimum", "maxLength", "minLength", "pattern", "maxItems", "minItems"} @@ -51,7 +50,7 @@ "type": Literal["api_key"], "in": str, "key": str, - "value": Optional[str], + "value": str | None, }, ) @@ -62,7 +61,7 @@ "type": Literal["basic"], "in": Literal["header"], "key": Literal["Authorization"], - "value": Optional[str], + "value": str | None, }, ) @@ -152,26 +151,26 @@ class Configuration: """ - _default: ClassVar[Optional[Self]] = None + _default: ClassVar[Self | None] = None def __init__( self, - host: Optional[str] = None, - api_key: Optional[dict[str, str]] = None, - api_key_prefix: Optional[dict[str, str]] = None, - username: Optional[str] = None, - password: Optional[str] = None, - access_token: Optional[str] = None, - server_index: Optional[int] = None, - server_variables: Optional[ServerVariablesT] = None, - server_operation_index: Optional[dict[int, int]] = None, - server_operation_variables: Optional[dict[int, ServerVariablesT]] = None, + host: str | None = None, + api_key: dict[str, str] | None = None, + api_key_prefix: dict[str, str] | None = None, + username: str | None = None, + password: str | None = None, + access_token: str | None = None, + server_index: int | None = None, + server_variables: ServerVariablesT | None = None, + server_operation_index: dict[int, int] | None = None, + server_operation_variables: dict[int, ServerVariablesT] | None = None, ignore_operation_servers: bool = False, - ssl_ca_cert: Optional[str] = None, - retries: Optional[int] = None, - ca_cert_data: Optional[Union[str, bytes]] = None, + ssl_ca_cert: str | None = None, + retries: int | None = None, + ca_cert_data: str | bytes | None = None, *, - debug: Optional[bool] = None, + debug: bool | None = None, ) -> None: """Constructor""" self._base_path = "http://localhost" if host is None else host @@ -225,7 +224,7 @@ def __init__( self.logger_stream_handler = None """Log stream handler """ - self.logger_file_handler: Optional[FileHandler] = None + self.logger_file_handler: FileHandler | None = None """Log file handler """ self.logger_file = None @@ -272,7 +271,7 @@ def __init__( cpu_count * 5 is used as default value to increase performance. """ - self.proxy: Optional[str] = None + self.proxy: str | None = None """Proxy URL """ self.proxy_headers = None @@ -317,7 +316,7 @@ def __setattr__(self, name: str, value: Any) -> None: object.__setattr__(self, name, value) @classmethod - def set_default(cls, default: Optional[Self]) -> None: + def set_default(cls, default: Self | None) -> None: """Set default instance of configuration. It stores default configuration, which can be @@ -352,7 +351,7 @@ def get_default(cls) -> Self: return cls._default @property - def logger_file(self) -> Optional[str]: + def logger_file(self) -> str | None: """The logger file. If the logger_file is None, then add stream handler and remove file @@ -364,7 +363,7 @@ def logger_file(self) -> Optional[str]: return self.__logger_file @logger_file.setter - def logger_file(self, value: Optional[str]) -> None: + def logger_file(self, value: str | None) -> None: """The logger file. If the logger_file is None, then add stream handler and remove file @@ -436,7 +435,7 @@ def logger_format(self, value: str) -> None: self.__logger_format = value self.logger_formatter = logging.Formatter(self.__logger_format) - def get_api_key_with_prefix(self, identifier: str, alias: Optional[str] = None) -> Optional[str]: + def get_api_key_with_prefix(self, identifier: str, alias: str | None = None) -> str | None: """Gets API key (with prefix if set). :param identifier: The identifier of apiKey. @@ -455,7 +454,7 @@ def get_api_key_with_prefix(self, identifier: str, alias: Optional[str] = None) return None - def get_basic_auth_token(self) -> Optional[str]: + def get_basic_auth_token(self) -> str | None: """Gets HTTP basic authentication header (string). :return: The token for basic HTTP authentication. @@ -497,9 +496,9 @@ def get_host_settings(self) -> list[HostSetting]: def get_host_from_settings( self, - index: Optional[int], - variables: Optional[ServerVariablesT] = None, - servers: Optional[list[HostSetting]] = None, + index: int | None, + variables: ServerVariablesT | None = None, + servers: list[HostSetting] | None = None, ) -> str: """Gets host URL based on the index and variables :param index: array index of the host settings diff --git a/src/codegen/agents/client/openapi_client/exceptions.py b/src/codegen/agents/client/openapi_client/exceptions.py index 46e8604fa..cdca2dce3 100644 --- a/src/codegen/agents/client/openapi_client/exceptions.py +++ b/src/codegen/agents/client/openapi_client/exceptions.py @@ -8,9 +8,7 @@ Do not edit the class manually. """ -from typing import Any, Optional - -from typing_extensions import Self +from typing import Any, Self class OpenApiException(Exception): @@ -102,8 +100,8 @@ def __init__( reason=None, http_resp=None, *, - body: Optional[str] = None, - data: Optional[Any] = None, + body: str | None = None, + data: Any | None = None, ) -> None: self.status = status self.reason = reason @@ -128,8 +126,8 @@ def from_response( cls, *, http_resp, - body: Optional[str], - data: Optional[Any], + body: str | None, + data: Any | None, ) -> Self: if http_resp.status == 400: raise BadRequestException(http_resp=http_resp, body=body, data=data) diff --git a/src/codegen/agents/client/openapi_client/models/agent_run_response.py b/src/codegen/agents/client/openapi_client/models/agent_run_response.py index 10190e84f..ea92b6f72 100644 --- a/src/codegen/agents/client/openapi_client/models/agent_run_response.py +++ b/src/codegen/agents/client/openapi_client/models/agent_run_response.py @@ -13,10 +13,9 @@ import json import pprint import re # noqa: F401 -from typing import Any, ClassVar +from typing import Any, ClassVar, Self from pydantic import BaseModel, ConfigDict, StrictInt, StrictStr -from typing_extensions import Self class AgentRunResponse(BaseModel): diff --git a/src/codegen/agents/client/openapi_client/models/create_agent_run_input.py b/src/codegen/agents/client/openapi_client/models/create_agent_run_input.py index 74469d48f..29c7efbbf 100644 --- a/src/codegen/agents/client/openapi_client/models/create_agent_run_input.py +++ b/src/codegen/agents/client/openapi_client/models/create_agent_run_input.py @@ -13,10 +13,9 @@ import json import pprint import re # noqa: F401 -from typing import Any, ClassVar +from typing import Any, ClassVar, Self from pydantic import BaseModel, ConfigDict, StrictStr -from typing_extensions import Self class CreateAgentRunInput(BaseModel): diff --git a/src/codegen/agents/client/openapi_client/models/http_validation_error.py b/src/codegen/agents/client/openapi_client/models/http_validation_error.py index 1f9125486..ea51b64b6 100644 --- a/src/codegen/agents/client/openapi_client/models/http_validation_error.py +++ b/src/codegen/agents/client/openapi_client/models/http_validation_error.py @@ -13,10 +13,9 @@ import json import pprint import re # noqa: F401 -from typing import Any, ClassVar +from typing import Any, ClassVar, Self from pydantic import BaseModel, ConfigDict -from typing_extensions import Self from codegen.agents.client.openapi_client.models.validation_error import ValidationError diff --git a/src/codegen/agents/client/openapi_client/models/organization_response.py b/src/codegen/agents/client/openapi_client/models/organization_response.py index 5912da5f1..c33f6070e 100644 --- a/src/codegen/agents/client/openapi_client/models/organization_response.py +++ b/src/codegen/agents/client/openapi_client/models/organization_response.py @@ -13,10 +13,9 @@ import json import pprint import re # noqa: F401 -from typing import Any, ClassVar +from typing import Any, ClassVar, Self from pydantic import BaseModel, ConfigDict, StrictInt, StrictStr -from typing_extensions import Self from codegen.agents.client.openapi_client.models.organization_settings import OrganizationSettings diff --git a/src/codegen/agents/client/openapi_client/models/organization_settings.py b/src/codegen/agents/client/openapi_client/models/organization_settings.py index d7538126d..7c7ca8eba 100644 --- a/src/codegen/agents/client/openapi_client/models/organization_settings.py +++ b/src/codegen/agents/client/openapi_client/models/organization_settings.py @@ -13,10 +13,9 @@ import json import pprint import re # noqa: F401 -from typing import Any, ClassVar +from typing import Any, ClassVar, Self from pydantic import BaseModel, ConfigDict, StrictBool -from typing_extensions import Self class OrganizationSettings(BaseModel): diff --git a/src/codegen/agents/client/openapi_client/models/page_organization_response.py b/src/codegen/agents/client/openapi_client/models/page_organization_response.py index 1390ac802..7e4ff3ec4 100644 --- a/src/codegen/agents/client/openapi_client/models/page_organization_response.py +++ b/src/codegen/agents/client/openapi_client/models/page_organization_response.py @@ -13,10 +13,9 @@ import json import pprint import re # noqa: F401 -from typing import Any, ClassVar +from typing import Any, ClassVar, Self from pydantic import BaseModel, ConfigDict, StrictInt -from typing_extensions import Self from codegen.agents.client.openapi_client.models.organization_response import OrganizationResponse diff --git a/src/codegen/agents/client/openapi_client/models/page_user_response.py b/src/codegen/agents/client/openapi_client/models/page_user_response.py index b79c8b7e3..e1b5c08a3 100644 --- a/src/codegen/agents/client/openapi_client/models/page_user_response.py +++ b/src/codegen/agents/client/openapi_client/models/page_user_response.py @@ -13,10 +13,9 @@ import json import pprint import re # noqa: F401 -from typing import Any, ClassVar +from typing import Any, ClassVar, Self from pydantic import BaseModel, ConfigDict, StrictInt -from typing_extensions import Self from codegen.agents.client.openapi_client.models.user_response import UserResponse diff --git a/src/codegen/agents/client/openapi_client/models/user_response.py b/src/codegen/agents/client/openapi_client/models/user_response.py index 71aa4cdae..50ae56824 100644 --- a/src/codegen/agents/client/openapi_client/models/user_response.py +++ b/src/codegen/agents/client/openapi_client/models/user_response.py @@ -13,10 +13,9 @@ import json import pprint import re # noqa: F401 -from typing import Any, ClassVar +from typing import Any, ClassVar, Self from pydantic import BaseModel, ConfigDict, StrictInt, StrictStr -from typing_extensions import Self class UserResponse(BaseModel): diff --git a/src/codegen/agents/client/openapi_client/models/validation_error.py b/src/codegen/agents/client/openapi_client/models/validation_error.py index fab1db703..ae7c23f11 100644 --- a/src/codegen/agents/client/openapi_client/models/validation_error.py +++ b/src/codegen/agents/client/openapi_client/models/validation_error.py @@ -13,10 +13,9 @@ import json import pprint import re # noqa: F401 -from typing import Any, ClassVar +from typing import Any, ClassVar, Self from pydantic import BaseModel, ConfigDict, StrictStr -from typing_extensions import Self from codegen.agents.client.openapi_client.models.validation_error_loc_inner import ValidationErrorLocInner diff --git a/src/codegen/agents/client/openapi_client/models/validation_error_loc_inner.py b/src/codegen/agents/client/openapi_client/models/validation_error_loc_inner.py index 04a44443e..59f0e20d7 100644 --- a/src/codegen/agents/client/openapi_client/models/validation_error_loc_inner.py +++ b/src/codegen/agents/client/openapi_client/models/validation_error_loc_inner.py @@ -13,10 +13,9 @@ import json import pprint import re # noqa: F401 -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Self from pydantic import BaseModel, StrictInt, StrictStr, ValidationError, field_validator -from typing_extensions import Self VALIDATIONERRORLOCINNER_ANY_OF_SCHEMAS = ["int", "str"] diff --git a/src/codegen/agents/client/openapi_client/rest.py b/src/codegen/agents/client/openapi_client/rest.py index 445c0144c..004ec00b6 100644 --- a/src/codegen/agents/client/openapi_client/rest.py +++ b/src/codegen/agents/client/openapi_client/rest.py @@ -131,7 +131,7 @@ def request(self, method, url, headers=None, body=None, post_params=None, _reque timeout = None if _request_timeout: - if isinstance(_request_timeout, (int, float)): + if isinstance(_request_timeout, int | float): timeout = urllib3.Timeout(total=_request_timeout) elif isinstance(_request_timeout, tuple) and len(_request_timeout) == 2: timeout = urllib3.Timeout(connect=_request_timeout[0], read=_request_timeout[1]) diff --git a/src/codegen/agents/code_agent.py b/src/codegen/agents/code_agent.py index 693c0cd44..e1da384f1 100644 --- a/src/codegen/agents/code_agent.py +++ b/src/codegen/agents/code_agent.py @@ -1,5 +1,5 @@ import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from uuid import uuid4 from langchain.tools import BaseTool @@ -32,7 +32,7 @@ class CodeAgent: run_id: str | None = None instance_id: str | None = None difficulty: int | None = None - logger: Optional[ExternalLogger] = None + logger: ExternalLogger | None = None def __init__( self, @@ -40,12 +40,12 @@ def __init__( model_provider: str = "anthropic", model_name: str = "claude-3-7-sonnet-latest", memory: bool = True, - tools: Optional[list[BaseTool]] = None, - tags: Optional[list[str]] = [], - metadata: Optional[dict] = {}, - agent_config: Optional[AgentConfig] = None, - thread_id: Optional[str] = None, - logger: Optional[ExternalLogger] = None, + tools: list[BaseTool] | None = None, + tags: list[str] | None = [], + metadata: dict | None = {}, + agent_config: AgentConfig | None = None, + thread_id: str | None = None, + logger: ExternalLogger | None = None, **kwargs, ): """Initialize a CodeAgent. @@ -106,7 +106,7 @@ def __init__( **metadata, } - def run(self, prompt: str, image_urls: Optional[list[str]] = None) -> str: + def run(self, prompt: str, image_urls: list[str] | None = None) -> str: """Run the agent with a prompt and optional images. Args: diff --git a/src/codegen/agents/data.py b/src/codegen/agents/data.py index fab2283da..3b81d202b 100644 --- a/src/codegen/agents/data.py +++ b/src/codegen/agents/data.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from datetime import UTC, datetime -from typing import Literal, Optional, Union +from typing import Literal # Base dataclass for all message types @@ -31,9 +31,9 @@ class SystemMessageData(BaseMessage): class ToolCall: """Represents a tool call within an assistant message.""" - name: Optional[str] = None - arguments: Optional[str] = None - id: Optional[str] = None + name: str | None = None + arguments: str | None = None + id: str | None = None @dataclass @@ -49,10 +49,10 @@ class ToolMessageData(BaseMessage): """Represents a tool response message.""" type: Literal["tool"] = field(default="tool") - tool_name: Optional[str] = None - tool_response: Optional[str] = None - tool_id: Optional[str] = None - status: Optional[str] = None + tool_name: str | None = None + tool_response: str | None = None + tool_id: str | None = None + status: str | None = None @dataclass @@ -69,4 +69,4 @@ class UnknownMessage(BaseMessage): type: Literal["unknown"] = field(default="unknown") -type AgentRunMessage = Union[UserMessage, SystemMessageData, AssistantMessage, ToolMessageData, FunctionMessageData, UnknownMessage] +type AgentRunMessage = UserMessage | SystemMessageData | AssistantMessage | ToolMessageData | FunctionMessageData | UnknownMessage diff --git a/src/codegen/agents/tracer.py b/src/codegen/agents/tracer.py index ef711b9e9..4563da17d 100644 --- a/src/codegen/agents/tracer.py +++ b/src/codegen/agents/tracer.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from langchain.schema import AIMessage, HumanMessage from langchain.schema import FunctionMessage as LCFunctionMessage @@ -11,7 +11,7 @@ class MessageStreamTracer: - def __init__(self, logger: Optional[ExternalLogger] = None): + def __init__(self, logger: ExternalLogger | None = None): self.traces = [] self.logger = logger @@ -34,7 +34,7 @@ def process_stream(self, message_stream: Generator) -> Generator: # Pass through the chunk to maintain the original stream behavior yield chunk - def extract_structured_data(self, chunk: dict[str, Any]) -> Optional[BaseMessage]: + def extract_structured_data(self, chunk: dict[str, Any]) -> BaseMessage | None: """Extract structured data from a message chunk. Returns None if the chunk doesn't contain useful information. Returns a BaseMessage subclass instance based on the message type. diff --git a/src/codegen/cli/commands/serve/main.py b/src/codegen/cli/commands/serve/main.py index 4c2dbc12e..0a074075e 100644 --- a/src/codegen/cli/commands/serve/main.py +++ b/src/codegen/cli/commands/serve/main.py @@ -4,7 +4,6 @@ import subprocess import sys from pathlib import Path -from typing import Optional import rich import rich_click as click @@ -89,7 +88,7 @@ def create_app_module(file_path: Path) -> str: return f"{module_name}:app" -def start_ngrok(port: int) -> Optional[str]: +def start_ngrok(port: int) -> str | None: """Start ngrok and return the public URL""" try: import requests diff --git a/src/codegen/extensions/attribution/git_history.py b/src/codegen/extensions/attribution/git_history.py index 39dfcc740..714dde359 100644 --- a/src/codegen/extensions/attribution/git_history.py +++ b/src/codegen/extensions/attribution/git_history.py @@ -1,7 +1,6 @@ import time from collections import defaultdict, deque from datetime import datetime -from typing import Optional import pygit2 from intervaltree import IntervalTree @@ -16,7 +15,7 @@ class GitAttributionTracker: """Tracks attribution information for code symbols based on git history.""" - def __init__(self, codebase: Codebase, ai_authors: Optional[list[str]] = None): + def __init__(self, codebase: Codebase, ai_authors: list[str] | None = None): """Initialize the attribution tracker. Args: @@ -43,7 +42,7 @@ def __init__(self, codebase: Codebase, ai_authors: Optional[list[str]] = None): self._commits: deque[Commit] - def build_history(self, max_commits: Optional[int] = None) -> None: + def build_history(self, max_commits: int | None = None) -> None: """Build the git history for the codebase. Args: @@ -325,7 +324,7 @@ def get_symbol_history(self, symbol: Symbol) -> list[dict]: symbol_id = f"{symbol.filepath}:{symbol.name}" return self._symbol_history.get(symbol_id, []) - def get_symbol_last_editor(self, symbol: Symbol) -> Optional[str]: + def get_symbol_last_editor(self, symbol: Symbol) -> str | None: """Get the last person who edited a symbol. Args: diff --git a/src/codegen/extensions/attribution/main.py b/src/codegen/extensions/attribution/main.py index a282fda89..c197a7a3b 100644 --- a/src/codegen/extensions/attribution/main.py +++ b/src/codegen/extensions/attribution/main.py @@ -1,10 +1,8 @@ -from typing import Optional - from codegen.extensions.attribution.git_history import GitAttributionTracker from codegen.sdk.core.codebase import Codebase -def analyze_ai_impact(codebase: Codebase, ai_authors: Optional[list[str]] = None, max_commits: Optional[int] = None) -> dict: +def analyze_ai_impact(codebase: Codebase, ai_authors: list[str] | None = None, max_commits: int | None = None) -> dict: """Analyze the impact of AI on a codebase. Args: @@ -57,7 +55,7 @@ def analyze_ai_impact(codebase: Codebase, ai_authors: Optional[list[str]] = None } -def add_attribution_to_symbols(codebase: Codebase, ai_authors: Optional[list[str]] = None) -> None: +def add_attribution_to_symbols(codebase: Codebase, ai_authors: list[str] | None = None) -> None: """Add attribution information to symbols in the codebase. This adds the following attributes to each symbol: diff --git a/src/codegen/extensions/events/codegen_app.py b/src/codegen/extensions/events/codegen_app.py index 024730168..4d3e5b415 100644 --- a/src/codegen/extensions/events/codegen_app.py +++ b/src/codegen/extensions/events/codegen_app.py @@ -1,5 +1,5 @@ import os -from typing import Any, Optional +from typing import Any from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse @@ -23,7 +23,7 @@ class CodegenApp: linear: Linear slack: Slack - def __init__(self, name: str, repo: Optional[str] = None, tmp_dir: str = "/tmp/codegen", commit: str | None = "latest"): + def __init__(self, name: str, repo: str | None = None, tmp_dir: str = "/tmp/codegen", commit: str | None = "latest"): self.name = name self.tmp_dir = tmp_dir diff --git a/src/codegen/extensions/events/github.py b/src/codegen/extensions/events/github.py index d17b16aef..5f1733ff3 100644 --- a/src/codegen/extensions/events/github.py +++ b/src/codegen/extensions/events/github.py @@ -1,6 +1,7 @@ import logging import os -from typing import Any, Callable, TypeVar +from collections.abc import Callable +from typing import Any, TypeVar from fastapi import Request from github import Github diff --git a/src/codegen/extensions/events/github_types.py b/src/codegen/extensions/events/github_types.py index fd3f62536..18cce2c82 100644 --- a/src/codegen/extensions/events/github_types.py +++ b/src/codegen/extensions/events/github_types.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional class GitHubRepository: @@ -37,11 +36,11 @@ class GitHubInstallation: events: list[str] created_at: datetime updated_at: datetime - single_file_name: Optional[str] + single_file_name: str | None has_multiple_single_files: bool single_file_paths: list[str] - suspended_by: Optional[str] - suspended_at: Optional[datetime] + suspended_by: str | None + suspended_at: datetime | None class GitHubUser: @@ -58,5 +57,5 @@ class GitHubInstallationEvent: action: str installation: GitHubInstallation repositories: list[GitHubRepository] - requester: Optional[dict] + requester: dict | None sender: GitHubUser diff --git a/src/codegen/extensions/events/linear.py b/src/codegen/extensions/events/linear.py index 4fe5b2e91..813556f76 100644 --- a/src/codegen/extensions/events/linear.py +++ b/src/codegen/extensions/events/linear.py @@ -1,5 +1,6 @@ import logging -from typing import Any, Callable, TypeVar +from collections.abc import Callable +from typing import Any, TypeVar from pydantic import BaseModel diff --git a/src/codegen/extensions/github/types/pull_request.py b/src/codegen/extensions/github/types/pull_request.py index c4b58eed6..739c69776 100644 --- a/src/codegen/extensions/github/types/pull_request.py +++ b/src/codegen/extensions/github/types/pull_request.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel @@ -47,30 +47,30 @@ class PullRequest(BaseModel): locked: bool title: str user: GitHubUser - body: Optional[str] + body: str | None created_at: str updated_at: str - closed_at: Optional[str] - merged_at: Optional[str] - merge_commit_sha: Optional[str] - assignee: Optional[GitHubUser] + closed_at: str | None + merged_at: str | None + merge_commit_sha: str | None + assignee: GitHubUser | None assignees: list[GitHubUser] requested_reviewers: list[GitHubUser] requested_teams: list[dict] labels: list[Label] - milestone: Optional[dict] + milestone: dict | None draft: bool head: PullRequestRef base: PullRequestRef _links: PullRequestLinks author_association: str - auto_merge: Optional[dict] - active_lock_reason: Optional[str] + auto_merge: dict | None + active_lock_reason: str | None merged: bool - mergeable: Optional[bool] - rebaseable: Optional[bool] + mergeable: bool | None + rebaseable: bool | None mergeable_state: str - merged_by: Optional[GitHubUser] + merged_by: GitHubUser | None comments: int review_comments: int maintainer_can_modify: bool diff --git a/src/codegen/extensions/github/types/push.py b/src/codegen/extensions/github/types/push.py index 10f44f5e7..2cf615b6a 100644 --- a/src/codegen/extensions/github/types/push.py +++ b/src/codegen/extensions/github/types/push.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel from .base import GitHubRepository, GitHubUser @@ -23,7 +21,7 @@ class PushEvent(BaseModel): created: bool deleted: bool forced: bool - base_ref: Optional[str] + base_ref: str | None compare: str commits: list[GitHubCommit] head_commit: GitHubCommit diff --git a/src/codegen/extensions/graph/create_graph.py b/src/codegen/extensions/graph/create_graph.py index 442b2dcd6..9718f01eb 100644 --- a/src/codegen/extensions/graph/create_graph.py +++ b/src/codegen/extensions/graph/create_graph.py @@ -1,5 +1,3 @@ -from typing import Optional - from codegen.extensions.graph.utils import Node, NodeLabel, Relation, RelationLabel, SimpleGraph from codegen.sdk.code_generation.doc_utils.utils import safe_get_class from codegen.sdk.core.class_definition import Class @@ -16,7 +14,7 @@ def create_codebase_graph(codebase): # Track existing nodes by name to prevent duplicates node_registry = {} # name -> node_id mapping - def get_or_create_node(name: str, label: NodeLabel, parent_name: Optional[str] = None, properties: dict | None = None): + def get_or_create_node(name: str, label: NodeLabel, parent_name: str | None = None, properties: dict | None = None): """Get existing node or create new one if it doesn't exist.""" full_name = f"{parent_name}.{name}" if parent_name and parent_name != "Class" else name if full_name in node_registry: diff --git a/src/codegen/extensions/graph/neo4j_exporter.py b/src/codegen/extensions/graph/neo4j_exporter.py index 72a499636..48241c0d6 100644 --- a/src/codegen/extensions/graph/neo4j_exporter.py +++ b/src/codegen/extensions/graph/neo4j_exporter.py @@ -26,7 +26,7 @@ def export_graph(self, graph: SimpleGraph): with self.driver.session() as session: # Create nodes for node in graph.nodes.values(): - properties = {"name": node.name, "full_name": node.full_name, **{k: str(v) if isinstance(v, (dict, list)) else v for k, v in node.properties.items()}} + properties = {"name": node.name, "full_name": node.full_name, **{k: str(v) if isinstance(v, dict | list) else v for k, v in node.properties.items()}} query = f"CREATE (n:{node.label} {{{', '.join(f'{k}: ${k}' for k in properties.keys())}}})" session.run(query, properties) @@ -36,7 +36,7 @@ def export_graph(self, graph: SimpleGraph): source_node = graph.nodes[relation.source_id] target_node = graph.nodes[relation.target_id] - properties = {**{k: str(v) if isinstance(v, (dict, list)) else v for k, v in relation.properties.items()}} + properties = {**{k: str(v) if isinstance(v, dict | list) else v for k, v in relation.properties.items()}} query = ( f"MATCH (source:{source_node.label} {{full_name: $source_name}}), " diff --git a/src/codegen/extensions/index/code_index.py b/src/codegen/extensions/index/code_index.py index 4cf8a5de3..bd422b6a8 100644 --- a/src/codegen/extensions/index/code_index.py +++ b/src/codegen/extensions/index/code_index.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Optional, TypeVar +from typing import TypeVar import numpy as np @@ -34,9 +34,9 @@ def __init__(self, codebase: Codebase): codebase: The codebase to index """ self.codebase = codebase - self.E: Optional[np.ndarray] = None - self.items: Optional[np.ndarray] = None - self.commit_hash: Optional[str] = None + self.E: np.ndarray | None = None + self.items: np.ndarray | None = None + self.commit_hash: str | None = None @property @abstractmethod @@ -151,7 +151,7 @@ def update(self) -> None: # Update commit hash self.commit_hash = self._get_current_commit() - def save(self, save_path: Optional[str] = None) -> None: + def save(self, save_path: str | None = None) -> None: """Save the index to disk.""" if self.E is None or self.items is None: msg = "No embeddings to save. Call create() first." @@ -162,7 +162,7 @@ def save(self, save_path: Optional[str] = None) -> None: self._save_index(save_path) - def load(self, load_path: Optional[str] = None) -> None: + def load(self, load_path: str | None = None) -> None: """Load the index from disk.""" load_path = Path(load_path) if load_path else self._get_default_save_path() diff --git a/src/codegen/extensions/index/file_index.py b/src/codegen/extensions/index/file_index.py index a76e62d5e..dee164999 100644 --- a/src/codegen/extensions/index/file_index.py +++ b/src/codegen/extensions/index/file_index.py @@ -2,7 +2,6 @@ import pickle from pathlib import Path -from typing import Optional import modal import numpy as np @@ -87,7 +86,7 @@ def delete_modal_dict(self) -> bool: logger.exception(f"Failed to delete Modal Dict: {e}") return False - def modal_dict_exists(self, commit_hash: Optional[str] = None) -> bool: + def modal_dict_exists(self, commit_hash: str | None = None) -> bool: """Check if a Modal Dict exists for a specific commit. Args: diff --git a/src/codegen/extensions/langchain/graph.py b/src/codegen/extensions/langchain/graph.py index 436cc7d30..bb8bc9c74 100644 --- a/src/codegen/extensions/langchain/graph.py +++ b/src/codegen/extensions/langchain/graph.py @@ -1,7 +1,7 @@ """Demo implementation of an agent with Codegen tools.""" import uuid -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal import anthropic import openai @@ -27,7 +27,7 @@ from codegen.extensions.langchain.utils.utils import get_max_model_input_tokens -def manage_messages(existing: list[AnyMessage], updates: Union[list[AnyMessage], dict]) -> list[AnyMessage]: +def manage_messages(existing: list[AnyMessage], updates: list[AnyMessage] | dict) -> list[AnyMessage]: """Custom reducer for managing message history with summarization. Args: @@ -229,7 +229,7 @@ def should_continue(self, state: GraphState) -> Literal["tools", "summarize_conv return END # =================================== COMPILE GRAPH ==================================== - def create(self, checkpointer: Optional[MemorySaver] = None, debug: bool = False) -> CompiledGraph: + def create(self, checkpointer: MemorySaver | None = None, debug: bool = False) -> CompiledGraph: """Create and compile the graph.""" builder = StateGraph(GraphState) @@ -507,9 +507,9 @@ def create_react_agent( model: "LLM", tools: list[BaseTool], system_message: SystemMessage, - checkpointer: Optional[MemorySaver] = None, + checkpointer: MemorySaver | None = None, debug: bool = False, - config: Optional[dict[str, Any]] = None, + config: dict[str, Any] | None = None, ) -> CompiledGraph: """Create a reactive agent graph.""" graph = AgentGraph(model, tools, system_message, config=config) diff --git a/src/codegen/extensions/langchain/llm.py b/src/codegen/extensions/langchain/llm.py index dadcf6314..716fa9ead 100644 --- a/src/codegen/extensions/langchain/llm.py +++ b/src/codegen/extensions/langchain/llm.py @@ -2,7 +2,7 @@ import os from collections.abc import Sequence -from typing import Any, Optional +from typing import Any from langchain_anthropic import ChatAnthropic from langchain_core.callbacks import CallbackManagerForLLMRun @@ -26,11 +26,11 @@ class LLM(BaseChatModel): temperature: float = Field(default=0, description="Temperature parameter for the model.", ge=0, le=1) - top_p: Optional[float] = Field(default=None, description="Top-p sampling parameter.", ge=0, le=1) + top_p: float | None = Field(default=None, description="Top-p sampling parameter.", ge=0, le=1) - top_k: Optional[int] = Field(default=None, description="Top-k sampling parameter.", ge=1) + top_k: int | None = Field(default=None, description="Top-k sampling parameter.", ge=1) - max_tokens: Optional[int] = Field(default=None, description="Maximum number of tokens to generate.", ge=1) + max_tokens: int | None = Field(default=None, description="Maximum number of tokens to generate.", ge=1) def __init__(self, model_provider: str = "anthropic", model_name: str = "claude-3-5-sonnet-latest", **kwargs: Any) -> None: """Initialize the LLM. @@ -110,8 +110,8 @@ def _get_model(self) -> BaseChatModel: def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: """Generate chat completion using the underlying model. diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index 0749384a4..9d156ee5f 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -1,7 +1,7 @@ """Langchain tools for workspace operations.""" from collections.abc import Callable -from typing import Annotated, ClassVar, Literal, Optional +from typing import Annotated, ClassVar, Literal from langchain_core.messages import ToolMessage from langchain_core.stores import InMemoryBaseStore @@ -56,10 +56,10 @@ class ViewFileInput(BaseModel): """Input for viewing a file.""" filepath: str = Field(..., description="Path to the file relative to workspace root") - start_line: Optional[int] = Field(None, description="Starting line number to view (1-indexed, inclusive)") - end_line: Optional[int] = Field(None, description="Ending line number to view (1-indexed, inclusive)") - max_lines: Optional[int] = Field(None, description="Maximum number of lines to view at once, defaults to 500") - line_numbers: Optional[bool] = Field(True, description="If True, add line numbers to the content (1-indexed)") + start_line: int | None = Field(None, description="Starting line number to view (1-indexed, inclusive)") + end_line: int | None = Field(None, description="Ending line number to view (1-indexed, inclusive)") + max_lines: int | None = Field(None, description="Maximum number of lines to view at once, defaults to 500") + line_numbers: bool | None = Field(True, description="If True, add line numbers to the content (1-indexed)") tool_call_id: Annotated[str, InjectedToolCallId] @@ -80,10 +80,10 @@ def _run( self, tool_call_id: str, filepath: str, - start_line: Optional[int] = None, - end_line: Optional[int] = None, - max_lines: Optional[int] = None, - line_numbers: Optional[bool] = True, + start_line: int | None = None, + end_line: int | None = None, + max_lines: int | None = None, + line_numbers: bool | None = True, ) -> ToolMessage: result = view_file( self.codebase, @@ -146,7 +146,7 @@ class RipGrepTool(BaseTool): def __init__(self, codebase: Codebase) -> None: super().__init__(codebase=codebase) - def _run(self, tool_call_id: str, query: str, file_extensions: Optional[list[str]] = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> ToolMessage: + def _run(self, tool_call_id: str, query: str, file_extensions: list[str] | None = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> ToolMessage: result = search(self.codebase, query, file_extensions=file_extensions, page=page, files_per_page=files_per_page, use_regex=use_regex) return result.render(tool_call_id) diff --git a/src/codegen/extensions/langchain/utils/custom_tool_node.py b/src/codegen/extensions/langchain/utils/custom_tool_node.py index bdbe4ab0e..b3cac7211 100644 --- a/src/codegen/extensions/langchain/utils/custom_tool_node.py +++ b/src/codegen/extensions/langchain/utils/custom_tool_node.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from langchain_core.messages import ( AIMessage, @@ -15,12 +15,8 @@ class CustomToolNode(ToolNode): def _parse_input( self, - input: Union[ - list[AnyMessage], - dict[str, Any], - BaseModel, - ], - store: Optional[InMemoryBaseStore], + input: list[AnyMessage] | dict[str, Any] | BaseModel, + store: InMemoryBaseStore | None, ) -> tuple[list[ToolCall], Literal["list", "dict", "tool_calls"]]: """Parse the input and check for truncated tool calls.""" messages = input.get("messages", []) diff --git a/src/codegen/extensions/langchain/utils/get_langsmith_url.py b/src/codegen/extensions/langchain/utils/get_langsmith_url.py index fb4fab0e7..645a8021d 100644 --- a/src/codegen/extensions/langchain/utils/get_langsmith_url.py +++ b/src/codegen/extensions/langchain/utils/get_langsmith_url.py @@ -1,10 +1,9 @@ import datetime -from typing import Optional from langsmith import Client -def get_langsmith_url(client: Client, run_id: str, project_name: Optional[str] = None) -> str: +def get_langsmith_url(client: Client, run_id: str, project_name: str | None = None) -> str: """Get the URL for a run in LangSmith. Args: @@ -35,7 +34,7 @@ def get_langsmith_url(client: Client, run_id: str, project_name: Optional[str] = return f"{host_url}/o/{tenant_id}/r/{run_id}?poll=true" -def find_and_print_langsmith_run_url(client: Client, project_name: Optional[str] = None) -> Optional[str]: +def find_and_print_langsmith_run_url(client: Client, project_name: str | None = None) -> str | None: """Find the most recent LangSmith run and print its URL. Args: diff --git a/src/codegen/extensions/linear/linear_client.py b/src/codegen/extensions/linear/linear_client.py index 0c3803153..4cca2c5e1 100644 --- a/src/codegen/extensions/linear/linear_client.py +++ b/src/codegen/extensions/linear/linear_client.py @@ -1,5 +1,4 @@ import os -from typing import Optional import requests from requests.adapters import HTTPAdapter @@ -15,7 +14,7 @@ class LinearClient: api_headers: dict api_endpoint = "https://api.linear.app/graphql" - def __init__(self, access_token: Optional[str] = None, team_id: Optional[str] = None, max_retries: int = 3, backoff_factor: float = 0.5): + def __init__(self, access_token: str | None = None, team_id: str | None = None, max_retries: int = 3, backoff_factor: float = 0.5): if not access_token: access_token = os.getenv("LINEAR_ACCESS_TOKEN") if not access_token: diff --git a/src/codegen/extensions/lsp/definition.py b/src/codegen/extensions/lsp/definition.py index acecc7256..ef8deff5c 100644 --- a/src/codegen/extensions/lsp/definition.py +++ b/src/codegen/extensions/lsp/definition.py @@ -28,7 +28,7 @@ def go_to_definition(node: Editable | None, uri: str, position: Position) -> Edi if resolved is None: logger.warning(f"No resolved value found for {node.name} at {uri}:{position}") return None - if isinstance(resolved, (HasName,)): + if isinstance(resolved, HasName): resolved = resolved.get_name() if isinstance(resolved.parent, Assignment) and resolved.parent.value == resolved: resolved = resolved.parent.get_name() diff --git a/src/codegen/extensions/lsp/execute.py b/src/codegen/extensions/lsp/execute.py index 5e34121d1..e72fafddc 100644 --- a/src/codegen/extensions/lsp/execute.py +++ b/src/codegen/extensions/lsp/execute.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Any, Callable +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from lsprotocol import types from lsprotocol.types import Position, Range diff --git a/src/codegen/extensions/lsp/server.py b/src/codegen/extensions/lsp/server.py index 4d24cc7f2..7ba5d1f74 100644 --- a/src/codegen/extensions/lsp/server.py +++ b/src/codegen/extensions/lsp/server.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from lsprotocol import types from lsprotocol.types import Position, Range @@ -21,9 +21,9 @@ class CodegenLanguageServer(LanguageServer): - codebase: Optional[Codebase] - io: Optional[LSPIO] - progress_manager: Optional[LSPProgress] + codebase: Codebase | None + io: LSPIO | None + progress_manager: LSPProgress | None actions: dict[str, CodeAction] def __init__(self, *args: Any, **kwargs: Any) -> None: diff --git a/src/codegen/extensions/mcp/codebase_tools.py b/src/codegen/extensions/mcp/codebase_tools.py index 52a25b1d6..bb3423e73 100644 --- a/src/codegen/extensions/mcp/codebase_tools.py +++ b/src/codegen/extensions/mcp/codebase_tools.py @@ -1,5 +1,5 @@ import json -from typing import Annotated, Optional +from typing import Annotated from mcp.server.fastmcp import FastMCP @@ -18,12 +18,12 @@ @mcp.tool(name="reveal_symbol", description="Reveal the dependencies and usages of a symbol up to N degrees") def reveal_symbol_tool( symbol_name: Annotated[str, "Name of the symbol to inspect"], - target_file: Annotated[Optional[str], "The file path of the file containing the symbol to inspect"], + target_file: Annotated[str | None, "The file path of the file containing the symbol to inspect"], codebase_dir: Annotated[str, "The root directory of your codebase"], codebase_language: Annotated[ProgrammingLanguage, "The language the codebase is written in"], - max_depth: Annotated[Optional[int], "depth up to which symbol information is retrieved"], - collect_dependencies: Annotated[Optional[bool], "includes dependencies of symbol"], - collect_usages: Annotated[Optional[bool], "includes usages of symbol"], + max_depth: Annotated[int | None, "depth up to which symbol information is retrieved"], + collect_dependencies: Annotated[bool | None, "includes dependencies of symbol"], + collect_usages: Annotated[bool | None, "includes usages of symbol"], ): codebase = Codebase(repo_path=codebase_dir, language=codebase_language) result = reveal_symbol( @@ -42,8 +42,8 @@ def search_codebase_tool( query: Annotated[str, "The search query to find in the codebase. When ripgrep is available, this will be passed as a ripgrep pattern. For regex searches, set use_regex=True."], codebase_dir: Annotated[str, "The root directory of your codebase"], codebase_language: Annotated[ProgrammingLanguage, "The language the codebase is written in"], - target_directories: Annotated[Optional[list[str]], "list of directories to search within"] = None, - file_extensions: Annotated[Optional[list[str]], "list of file extensions to search (e.g. ['.py', '.ts'])"] = None, + target_directories: Annotated[list[str] | None, "list of directories to search within"] = None, + file_extensions: Annotated[list[str] | None, "list of file extensions to search (e.g. ['.py', '.ts'])"] = None, page: Annotated[int, "page number to return (1-based)"] = 1, files_per_page: Annotated[int, "number of files to return per page"] = 10, use_regex: Annotated[bool, "use regex for the search query"] = False, diff --git a/src/codegen/extensions/swebench/utils.py b/src/codegen/extensions/swebench/utils.py index c5054b2d0..f29fdbcc8 100644 --- a/src/codegen/extensions/swebench/utils.py +++ b/src/codegen/extensions/swebench/utils.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from pathlib import Path from pprint import pprint -from typing import Literal, Optional +from typing import Literal from datasets import load_dataset @@ -21,13 +21,13 @@ class SweBenchExample: patch: str test_patch: str problem_statement: str - hints_text: Optional[str] + hints_text: str | None created_at: str version: str fail_to_pass: str - pass_to_pass: Optional[str] - environment_setup_commit: Optional[str] - difficulty: Optional[int] + pass_to_pass: str | None + environment_setup_commit: str | None + difficulty: int | None def load_predictions(paths): diff --git a/src/codegen/extensions/tools/bash.py b/src/codegen/extensions/tools/bash.py index dd9da037d..e99953186 100644 --- a/src/codegen/extensions/tools/bash.py +++ b/src/codegen/extensions/tools/bash.py @@ -3,7 +3,7 @@ import re import shlex import subprocess -from typing import ClassVar, Optional +from typing import ClassVar from pydantic import Field @@ -29,18 +29,18 @@ class RunBashCommandObservation(Observation): """Response from running a bash command.""" - stdout: Optional[str] = Field( + stdout: str | None = Field( default=None, description="Standard output from the command", ) - stderr: Optional[str] = Field( + stderr: str | None = Field( default=None, description="Standard error from the command", ) command: str = Field( description="The command that was executed", ) - pid: Optional[int] = Field( + pid: int | None = Field( default=None, description="Process ID for background commands", ) diff --git a/src/codegen/extensions/tools/create_file.py b/src/codegen/extensions/tools/create_file.py index cc22d3ede..fbb64d374 100644 --- a/src/codegen/extensions/tools/create_file.py +++ b/src/codegen/extensions/tools/create_file.py @@ -1,6 +1,6 @@ """Tool for creating new files.""" -from typing import ClassVar, Optional +from typing import ClassVar from pydantic import Field @@ -23,7 +23,7 @@ class CreateFileObservation(Observation): str_template: ClassVar[str] = "Created file {filepath}" -def create_file(codebase: Codebase, filepath: str, content: str, max_tokens: Optional[int] = None) -> CreateFileObservation: +def create_file(codebase: Codebase, filepath: str, content: str, max_tokens: int | None = None) -> CreateFileObservation: """Create a new file. Args: diff --git a/src/codegen/extensions/tools/edit_file.py b/src/codegen/extensions/tools/edit_file.py index b4831b968..c58fb322d 100644 --- a/src/codegen/extensions/tools/edit_file.py +++ b/src/codegen/extensions/tools/edit_file.py @@ -1,6 +1,6 @@ """Tool for editing file contents.""" -from typing import TYPE_CHECKING, ClassVar, Optional +from typing import TYPE_CHECKING, ClassVar from langchain_core.messages import ToolMessage from pydantic import Field @@ -20,7 +20,7 @@ class EditFileObservation(Observation): filepath: str = Field( description="Path to the edited file", ) - diff: Optional[str] = Field( + diff: str | None = Field( default=None, description="Unified diff showing the changes made", ) diff --git a/src/codegen/extensions/tools/github/create_pr_review_comment.py b/src/codegen/extensions/tools/github/create_pr_review_comment.py index 30324bed3..03eda6b84 100644 --- a/src/codegen/extensions/tools/github/create_pr_review_comment.py +++ b/src/codegen/extensions/tools/github/create_pr_review_comment.py @@ -1,6 +1,6 @@ """Tool for creating PR review comments.""" -from typing import ClassVar, Optional +from typing import ClassVar from pydantic import Field @@ -37,7 +37,7 @@ def create_pr_review_comment( commit_sha: str, path: str, line: int, - start_line: Optional[int] = None, + start_line: int | None = None, ) -> PRReviewCommentObservation: """Create an inline review comment on a specific line in a pull request. diff --git a/src/codegen/extensions/tools/link_annotation.py b/src/codegen/extensions/tools/link_annotation.py index 543f0fc3a..119c13e73 100644 --- a/src/codegen/extensions/tools/link_annotation.py +++ b/src/codegen/extensions/tools/link_annotation.py @@ -1,8 +1,8 @@ """Tool for viewing PR contents and modified symbols.""" import re +from collections.abc import Callable from enum import StrEnum -from typing import Callable from codegen.sdk.core.codebase import Codebase diff --git a/src/codegen/extensions/tools/observation.py b/src/codegen/extensions/tools/observation.py index 6cde37317..ad6201d06 100644 --- a/src/codegen/extensions/tools/observation.py +++ b/src/codegen/extensions/tools/observation.py @@ -1,7 +1,7 @@ """Base class for tool observations/responses.""" import json -from typing import Any, ClassVar, Optional +from typing import Any, ClassVar from langchain_core.messages import ToolMessage from pydantic import BaseModel, Field @@ -22,7 +22,7 @@ class Observation(BaseModel): default="success", description="Status of the operation - 'success' or 'error'", ) - error: Optional[str] = Field( + error: str | None = Field( default=None, description="Error message if status is 'error'", ) @@ -61,7 +61,7 @@ def render_as_string(self, max_tokens: int = 8000) -> str: return rendered[:max_tokens] + "\n\n...truncated...\n\n" return rendered - def render(self, tool_call_id: Optional[str] = None) -> ToolMessage | str: + def render(self, tool_call_id: str | None = None) -> ToolMessage | str: """Render the observation as a ToolMessage or string. Args: diff --git a/src/codegen/extensions/tools/reflection.py b/src/codegen/extensions/tools/reflection.py index 6e5aad3d6..49c159b12 100644 --- a/src/codegen/extensions/tools/reflection.py +++ b/src/codegen/extensions/tools/reflection.py @@ -1,6 +1,6 @@ """Tool for agent self-reflection and planning.""" -from typing import ClassVar, Optional +from typing import ClassVar from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.output_parsers import StrOutputParser @@ -27,8 +27,8 @@ class ReflectionObservation(Observation): context_summary: str = Field(description="Summary of the current context") findings: str = Field(description="Key information and insights gathered") - challenges: Optional[str] = Field(None, description="Current obstacles or questions") - focus: Optional[str] = Field(None, description="Specific aspect focused on") + challenges: str | None = Field(None, description="Current obstacles or questions") + focus: str | None = Field(None, description="Specific aspect focused on") sections: list[ReflectionSection] = Field(description="Structured reflection sections") str_template: ClassVar[str] = "Reflection on: {focus}" @@ -128,8 +128,8 @@ def perform_reflection( context_summary: str, findings_so_far: str, current_challenges: str = "", - reflection_focus: Optional[str] = None, - codebase: Optional[Codebase] = None, + reflection_focus: str | None = None, + codebase: Codebase | None = None, ) -> ReflectionObservation: """Perform agent reflection to organize thoughts and plan next steps. diff --git a/src/codegen/extensions/tools/replacement_edit.py b/src/codegen/extensions/tools/replacement_edit.py index aa5cd98be..74fe335a2 100644 --- a/src/codegen/extensions/tools/replacement_edit.py +++ b/src/codegen/extensions/tools/replacement_edit.py @@ -2,7 +2,7 @@ import difflib import re -from typing import ClassVar, Optional +from typing import ClassVar from pydantic import Field @@ -18,23 +18,23 @@ class ReplacementEditObservation(Observation): filepath: str = Field( description="Path to the edited file", ) - diff: Optional[str] = Field( + diff: str | None = Field( default=None, description="Unified diff showing the changes made", ) - new_content: Optional[str] = Field( + new_content: str | None = Field( default=None, description="New content with line numbers", ) - message: Optional[str] = Field( + message: str | None = Field( default=None, description="Message describing the result", ) - error: Optional[str] = Field( + error: str | None = Field( default=None, description="Error message if an error occurred", ) - error_pattern: Optional[str] = Field( + error_pattern: str | None = Field( default=None, description="Regex pattern that failed to compile", ) @@ -101,7 +101,7 @@ def replacement_edit( replacement: str, start: int = 1, end: int = -1, - count: Optional[int] = None, + count: int | None = None, flags: re.RegexFlag = re.MULTILINE, ) -> ReplacementEditObservation: """Replace text in a file using regex pattern matching. diff --git a/src/codegen/extensions/tools/reveal_symbol.py b/src/codegen/extensions/tools/reveal_symbol.py index c91b0a111..3ddd4e045 100644 --- a/src/codegen/extensions/tools/reveal_symbol.py +++ b/src/codegen/extensions/tools/reveal_symbol.py @@ -1,6 +1,6 @@ """Tool for revealing symbol dependencies and usages.""" -from typing import Any, ClassVar, Optional +from typing import Any, ClassVar import tiktoken from pydantic import Field @@ -18,7 +18,7 @@ class SymbolInfo(Observation): """Information about a symbol.""" name: str = Field(description="Name of the symbol") - filepath: Optional[str] = Field(description="Path to the file containing the symbol") + filepath: str | None = Field(description="Path to the file containing the symbol") source: str = Field(description="Source code of the symbol") str_template: ClassVar[str] = "{name} in {filepath}" @@ -27,11 +27,11 @@ class SymbolInfo(Observation): class RevealSymbolObservation(Observation): """Response from revealing symbol dependencies and usages.""" - dependencies: Optional[list[SymbolInfo]] = Field( + dependencies: list[SymbolInfo] | None = Field( default=None, description="List of symbols this symbol depends on", ) - usages: Optional[list[SymbolInfo]] = Field( + usages: list[SymbolInfo] | None = Field( default=None, description="List of symbols that use this symbol", ) @@ -39,7 +39,7 @@ class RevealSymbolObservation(Observation): default=False, description="Whether results were truncated due to token limit", ) - valid_filepaths: Optional[list[str]] = Field( + valid_filepaths: list[str] | None = Field( default=None, description="List of valid filepaths when symbol is ambiguous", ) @@ -115,7 +115,7 @@ def truncate_source(source: str, max_tokens: int) -> str: return "".join(result) -def get_symbol_info(symbol: Symbol, max_tokens: Optional[int] = None) -> SymbolInfo: +def get_symbol_info(symbol: Symbol, max_tokens: int | None = None) -> SymbolInfo: """Get relevant information about a symbol. Args: @@ -137,7 +137,7 @@ def get_symbol_info(symbol: Symbol, max_tokens: Optional[int] = None) -> SymbolI ) -def hop_through_imports(symbol: Symbol, seen_imports: Optional[set[str]] = None) -> Symbol: +def hop_through_imports(symbol: Symbol, seen_imports: set[str] | None = None) -> Symbol: """Follow import chain to find the root symbol, stopping at ExternalModule.""" if seen_imports is None: seen_imports = set() @@ -162,8 +162,8 @@ def hop_through_imports(symbol: Symbol, seen_imports: Optional[set[str]] = None) def get_extended_context( symbol: Symbol, degree: int, - max_tokens: Optional[int] = None, - seen_symbols: Optional[set[Symbol]] = None, + max_tokens: int | None = None, + seen_symbols: set[Symbol] | None = None, current_degree: int = 0, total_tokens: int = 0, collect_dependencies: bool = True, @@ -255,11 +255,11 @@ def under_token_limit() -> bool: def reveal_symbol( codebase: Codebase, symbol_name: str, - filepath: Optional[str] = None, - max_depth: Optional[int] = 1, - max_tokens: Optional[int] = None, - collect_dependencies: Optional[bool] = True, - collect_usages: Optional[bool] = True, + filepath: str | None = None, + max_depth: int | None = 1, + max_tokens: int | None = None, + collect_dependencies: bool | None = True, + collect_usages: bool | None = True, ) -> RevealSymbolObservation: """Reveal the dependencies and usages of a symbol up to N degrees. diff --git a/src/codegen/extensions/tools/search_files_by_name.py b/src/codegen/extensions/tools/search_files_by_name.py index b44f6da85..d28df5ba9 100644 --- a/src/codegen/extensions/tools/search_files_by_name.py +++ b/src/codegen/extensions/tools/search_files_by_name.py @@ -1,7 +1,7 @@ import math import shutil import subprocess -from typing import ClassVar, Optional +from typing import ClassVar from pydantic import Field @@ -88,13 +88,12 @@ def search_files_by_name( if files_per_page == math.inf: files_per_page = total_files total_pages = 1 - else: + else: total_pages = (total_files + files_per_page - 1) // files_per_page if total_files > 0 else 1 - - + # Ensure page is within valid range page = min(page, total_pages) - + # Get paginated results start_idx = (page - 1) * files_per_page end_idx = start_idx + files_per_page diff --git a/src/codegen/extensions/tools/semantic_edit.py b/src/codegen/extensions/tools/semantic_edit.py index 97ba927c5..fb3fff339 100644 --- a/src/codegen/extensions/tools/semantic_edit.py +++ b/src/codegen/extensions/tools/semantic_edit.py @@ -2,7 +2,7 @@ import difflib import re -from typing import TYPE_CHECKING, ClassVar, Optional +from typing import TYPE_CHECKING, ClassVar from langchain_core.messages import ToolMessage from langchain_core.output_parsers import StrOutputParser @@ -26,15 +26,15 @@ class SemanticEditObservation(Observation): filepath: str = Field( description="Path to the edited file", ) - diff: Optional[str] = Field( + diff: str | None = Field( default=None, description="Unified diff of changes made to the file", ) - new_content: Optional[str] = Field( + new_content: str | None = Field( default=None, description="New content of the file with line numbers after edits", ) - line_count: Optional[int] = Field( + line_count: int | None = Field( default=None, description="Total number of lines in the edited file", ) diff --git a/src/codegen/extensions/tools/semantic_search.py b/src/codegen/extensions/tools/semantic_search.py index 93cf05212..d2d3b3b00 100644 --- a/src/codegen/extensions/tools/semantic_search.py +++ b/src/codegen/extensions/tools/semantic_search.py @@ -1,6 +1,6 @@ """Semantic search over codebase files.""" -from typing import ClassVar, Optional +from typing import ClassVar from pydantic import Field @@ -51,7 +51,7 @@ def semantic_search( query: str, k: int = 5, preview_length: int = 200, - index_path: Optional[str] = None, + index_path: str | None = None, ) -> SemanticSearchObservation: """Search the codebase using semantic similarity. diff --git a/src/codegen/extensions/tools/tool_output_types.py b/src/codegen/extensions/tools/tool_output_types.py index 1678e0c7e..4da6a5a43 100644 --- a/src/codegen/extensions/tools/tool_output_types.py +++ b/src/codegen/extensions/tools/tool_output_types.py @@ -1,6 +1,6 @@ """Type definitions for tool outputs.""" -from typing import Optional, TypedDict +from typing import TypedDict class EditFileArtifacts(TypedDict, total=False): @@ -10,8 +10,8 @@ class EditFileArtifacts(TypedDict, total=False): """ filepath: str # Path to the edited file - diff: Optional[str] # Diff of changes made to the file - error: Optional[str] # Error message (only present on error) + diff: str | None # Diff of changes made to the file + error: str | None # Error message (only present on error) class ViewFileArtifacts(TypedDict, total=False): @@ -22,14 +22,14 @@ class ViewFileArtifacts(TypedDict, total=False): """ filepath: str # Path to the viewed file - start_line: Optional[int] # Starting line number viewed - end_line: Optional[int] # Ending line number viewed - content: Optional[str] # Content of the file - total_lines: Optional[int] # Total number of lines in file - has_more: Optional[bool] # Whether there are more lines to view - max_lines_per_page: Optional[int] # Maximum lines that can be viewed at once - file_size: Optional[int] # Size of file in bytes - error: Optional[str] # Error message (only present on error) + start_line: int | None # Starting line number viewed + end_line: int | None # Ending line number viewed + content: str | None # Content of the file + total_lines: int | None # Total number of lines in file + has_more: bool | None # Whether there are more lines to view + max_lines_per_page: int | None # Maximum lines that can be viewed at once + file_size: int | None # Size of file in bytes + error: str | None # Error message (only present on error) class ListDirectoryArtifacts(TypedDict, total=False): @@ -41,14 +41,14 @@ class ListDirectoryArtifacts(TypedDict, total=False): dirpath: str # Full path to the directory name: str # Name of the directory - files: Optional[list[str]] # List of files in this directory - file_paths: Optional[list[str]] # Full paths to files in this directory - subdirs: Optional[list[str]] # List of subdirectory names - subdir_paths: Optional[list[str]] # Full paths to subdirectories - is_leaf: Optional[bool] # Whether this is a leaf node (at max depth) - depth: Optional[int] # Current depth in the tree - max_depth: Optional[int] # Maximum depth allowed - error: Optional[str] # Error message (only present on error) + files: list[str] | None # List of files in this directory + file_paths: list[str] | None # Full paths to files in this directory + subdirs: list[str] | None # List of subdirectory names + subdir_paths: list[str] | None # Full paths to subdirectories + is_leaf: bool | None # Whether this is a leaf node (at max depth) + depth: int | None # Current depth in the tree + max_depth: int | None # Maximum depth allowed + error: str | None # Error message (only present on error) class SearchMatch(TypedDict, total=False): @@ -74,7 +74,7 @@ class SearchArtifacts(TypedDict, total=False): files_per_page: int # Number of files shown per page matches: list[SearchMatch] # List of matches with file paths and line numbers file_paths: list[str] # List of files containing matches - error: Optional[str] # Error message (only present on error) + error: str | None # Error message (only present on error) class SemanticEditArtifacts(TypedDict, total=False): @@ -85,10 +85,10 @@ class SemanticEditArtifacts(TypedDict, total=False): """ filepath: str # Path to the edited file - diff: Optional[str] # Unified diff of changes made to the file - new_content: Optional[str] # New content of the file after edits - line_count: Optional[int] # Total number of lines in the edited file - error: Optional[str] # Error message (only present on error) + diff: str | None # Unified diff of changes made to the file + new_content: str | None # New content of the file after edits + line_count: int | None # Total number of lines in the edited file + error: str | None # Error message (only present on error) class RelaceEditArtifacts(TypedDict, total=False): @@ -99,7 +99,7 @@ class RelaceEditArtifacts(TypedDict, total=False): """ filepath: str # Path to the edited file - diff: Optional[str] # Unified diff of changes made to the file - new_content: Optional[str] # New content of the file after edits - line_count: Optional[int] # Total number of lines in the edited file - error: Optional[str] # Error message (only present on error) + diff: str | None # Unified diff of changes made to the file + new_content: str | None # New content of the file after edits + line_count: int | None # Total number of lines in the edited file + error: str | None # Error message (only present on error) diff --git a/src/codegen/extensions/tools/view_file.py b/src/codegen/extensions/tools/view_file.py index fbfcd8b02..7909f10de 100644 --- a/src/codegen/extensions/tools/view_file.py +++ b/src/codegen/extensions/tools/view_file.py @@ -1,6 +1,6 @@ """Tool for viewing file contents and metadata.""" -from typing import TYPE_CHECKING, ClassVar, Optional +from typing import TYPE_CHECKING, ClassVar from langchain_core.messages import ToolMessage from pydantic import Field @@ -25,23 +25,23 @@ class ViewFileObservation(Observation): raw_content: str = Field( description="Raw content of the file", ) - line_count: Optional[int] = Field( + line_count: int | None = Field( default=None, description="Number of lines in the file", ) - start_line: Optional[int] = Field( + start_line: int | None = Field( default=None, description="Starting line number of the content (1-indexed)", ) - end_line: Optional[int] = Field( + end_line: int | None = Field( default=None, description="Ending line number of the content (1-indexed)", ) - has_more: Optional[bool] = Field( + has_more: bool | None = Field( default=None, description="Whether there are more lines after end_line", ) - max_lines_per_page: Optional[int] = Field( + max_lines_per_page: int | None = Field( default=None, description="Maximum number of lines that can be viewed at once", ) @@ -109,8 +109,8 @@ def view_file( codebase: Codebase, filepath: str, line_numbers: bool = True, - start_line: Optional[int] = None, - end_line: Optional[int] = None, + start_line: int | None = None, + end_line: int | None = None, max_lines: int = 500, ) -> ViewFileObservation: """View the contents and metadata of a file. diff --git a/src/codegen/runner/sandbox/middlewares.py b/src/codegen/runner/sandbox/middlewares.py index 8edea49b9..9b2752916 100644 --- a/src/codegen/runner/sandbox/middlewares.py +++ b/src/codegen/runner/sandbox/middlewares.py @@ -1,6 +1,7 @@ import traceback +from collections.abc import Callable from http import HTTPStatus -from typing import Callable, TypeVar +from typing import TypeVar from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request diff --git a/src/codegen/sdk/codebase/multigraph.py b/src/codegen/sdk/codebase/multigraph.py index 2a76fec70..912b1db5b 100644 --- a/src/codegen/sdk/codebase/multigraph.py +++ b/src/codegen/sdk/codebase/multigraph.py @@ -2,11 +2,8 @@ from dataclasses import dataclass, field from typing import Generic, TypeVar -from codegen.sdk import TYPE_CHECKING from codegen.sdk.core.detached_symbols.function_call import FunctionCall - -if TYPE_CHECKING: - from codegen.sdk.core.function import Function +from codegen.sdk.core.function import Function TFunction = TypeVar("TFunction", bound=Function) diff --git a/src/codegen/sdk/core/codeowner.py b/src/codegen/sdk/core/codeowner.py index 8db24cc67..90416cbe2 100644 --- a/src/codegen/sdk/core/codeowner.py +++ b/src/codegen/sdk/core/codeowner.py @@ -1,5 +1,5 @@ -from collections.abc import Iterable, Iterator -from typing import Callable, Generic, Literal +from collections.abc import Callable, Iterable, Iterator +from typing import Generic, Literal from codeowners import CodeOwners as CodeOwnersParser diff --git a/src/codegen/sdk/core/utils/cache_utils.py b/src/codegen/sdk/core/utils/cache_utils.py index 60f7c4dbf..723e8ee8d 100644 --- a/src/codegen/sdk/core/utils/cache_utils.py +++ b/src/codegen/sdk/core/utils/cache_utils.py @@ -1,6 +1,6 @@ import functools -from collections.abc import Iterator -from typing import Callable, Generic, ParamSpec, TypeVar +from collections.abc import Callable, Iterator +from typing import Generic, ParamSpec, TypeVar from codegen.sdk.extensions.utils import lru_cache diff --git a/src/codegen/sdk/system-prompt.txt b/src/codegen/sdk/system-prompt.txt index f72a67db1..698d8d762 100644 --- a/src/codegen/sdk/system-prompt.txt +++ b/src/codegen/sdk/system-prompt.txt @@ -657,7 +657,7 @@ Codegen creates a custom Python environment in `.codegen/.venv`. Configure your ```bash .codegen/.venv/bin/python ``` - + Alternatively, create a `.vscode/settings.json`: ```json { @@ -679,7 +679,7 @@ Codegen creates a custom Python environment in `.codegen/.venv`. Configure your .codegen/.venv/bin/python ``` - + @@ -1630,8 +1630,8 @@ iconType: "solid" - Yes - [by design](/introduction/guiding-principles#python-first-composability). - + Yes - [by design](/introduction/guiding-principles#python-first-composability). + Codegen works like any other python package. It works alongside your IDE, version control system, and other development tools. - Currently, the codebase object can only parse source code files of one language at a time. This means that if you want to work with both Python and TypeScript files, you will need to create two separate codebase objects. + Currently, the codebase object can only parse source code files of one language at a time. This means that if you want to work with both Python and TypeScript files, you will need to create two separate codebase objects. ## Accessing Code @@ -3407,7 +3407,7 @@ for module, imports in module_imports.items(): Always check if imports resolve to external modules before modification to avoid breaking third-party package imports. - + ## Import Statements vs Imports @@ -3609,7 +3609,7 @@ for exp in file.exports: # Get original and current symbols current = exp.exported_symbol original = exp.resolved_symbol - + print(f"Re-exporting {original.name} from {exp.from_file.filepath}") print(f"Through: {' -> '.join(e.file.filepath for e in exp.export_chain)}") ``` @@ -3659,7 +3659,7 @@ for from_file, exports in file_exports.items(): When managing exports, consider the impact on your module's public API. Not all symbols that can be exported should be exported. - + --- title: "Inheritable Behaviors" @@ -4149,9 +4149,9 @@ If `A` depends on `B`, then `B` is used by `A`. This relationship is tracked in flowchart LR B(BaseClass) - - - + + + A(MyClass) B ---| used by |A A ---|depends on |B @@ -4320,7 +4320,7 @@ class A: def method_a(self): pass class B(A): - def method_b(self): + def method_b(self): self.method_a() class C(B): @@ -5210,7 +5210,7 @@ for attr in class_def.attributes: # Each attribute has an assignment property attr_type = attr.assignment.type # -> TypeAnnotation print(f"{attr.name}: {attr_type.source}") # e.g. "x: int" - + # Set attribute type attr.assignment.set_type("int") @@ -5227,7 +5227,7 @@ Union types ([UnionType](/api-reference/core/UnionType)) can be manipulated as c ```python # Get union type -union_type = function.return_type # -> A | B +union_type = function.return_type # -> A | B print(union_type.symbols) # ["A", "B"] # Add/remove options @@ -6078,13 +6078,13 @@ Here's an example of using flags during code analysis: ```python def analyze_codebase(codebase): - for function in codebase.functions: + for function in codebase.functions: # Check documentation if not function.docstring: function.flag( message="Missing docstring", ) - + # Check error handling if function.is_async and not function.has_try_catch: function.flag( @@ -6794,7 +6794,7 @@ Explore our tutorials to learn how to use Codegen for various code transformatio > Update API calls, handle breaking changes, and manage bulk updates across your codebase. - Convert Flask applications to FastAPI, updating routes and dependencies. - Migrate Python 2 code to Python 3, updating syntax and modernizing APIs. @@ -6827,9 +6827,9 @@ Explore our tutorials to learn how to use Codegen for various code transformatio > Restructure files, enforce naming conventions, and improve project layout. - Split large files, extract shared logic, and manage dependencies. @@ -6927,7 +6927,7 @@ The agent has access to powerful code viewing and manipulation tools powered by - `CreateFileTool`: Create new files - `DeleteFileTool`: Delete files - `RenameFileTool`: Rename files -- `EditFileTool`: Edit files +- `EditFileTool`: Edit files @@ -7434,7 +7434,7 @@ Be explicit about the changes, produce a short summary, and point out possible i Focus on facts and technical details, using code snippets where helpful. """ result = agent.run(prompt) - + # Clean up the temporary comment comment.delete() ``` @@ -7615,21 +7615,21 @@ def research(repo_name: Optional[str] = None, query: Optional[str] = None): """Start a code research session.""" # Initialize codebase codebase = initialize_codebase(repo_name) - + # Create and run the agent agent = create_research_agent(codebase) - + # Main research loop while True: if not query: query = Prompt.ask("[bold cyan]Research query[/bold cyan]") - + result = agent.invoke( {"input": query}, config={"configurable": {"thread_id": 1}} ) console.print(Markdown(result["messages"][-1].content)) - + query = None # Clear for next iteration ``` @@ -7677,7 +7677,7 @@ class CustomAnalysisTool(BaseTool): """Custom tool for specialized code analysis.""" name = "custom_analysis" description = "Performs specialized code analysis" - + def _run(self, query: str) -> str: # Custom analysis logic return results @@ -7817,7 +7817,7 @@ def calculate_maintainability_index( ## Line Metrics -Line metrics provide insights into the size, complexity, and maintainability of a codebase. These measurements help determine the scale of a project, identify areas that may need refactoring, and track the growth of the codebase over time. +Line metrics provide insights into the size, complexity, and maintainability of a codebase. These measurements help determine the scale of a project, identify areas that may need refactoring, and track the growth of the codebase over time. ### Lines of Code Lines of Code refers to the total number of lines in the source code, including blank lines and comments. This is accomplished with a simple count of all lines in the source file. @@ -8114,7 +8114,7 @@ from codegen import Codebase # Initialize codebase codebase = Codebase("path/to/posthog/") -# Create a directed graph for representing call relationships +# Create a directed graph for representing call relationships G = nx.DiGraph() # Configuration flags @@ -8136,7 +8136,7 @@ We'll create a function that will recursively traverse the call trace of a funct ```python def create_downstream_call_trace(src_func: Function, depth: int = 0): """Creates call graph by recursively traversing function calls - + Args: src_func (Function): Starting function for call graph depth (int): Current recursion depth @@ -8144,7 +8144,7 @@ def create_downstream_call_trace(src_func: Function, depth: int = 0): # Prevent infinite recursion if MAX_DEPTH <= depth: return - + # External modules are not functions if isinstance(src_func, ExternalModule): return @@ -8154,12 +8154,12 @@ def create_downstream_call_trace(src_func: Function, depth: int = 0): # Skip self-recursive calls if call.name == src_func.name: continue - + # Get called function definition func = call.function_definition if not func: continue - + # Apply configured filters if isinstance(func, ExternalModule) and IGNORE_EXTERNAL_MODULE_CALLS: continue @@ -8173,7 +8173,7 @@ def create_downstream_call_trace(src_func: Function, depth: int = 0): func_name = f"{func.parent_class.name}.{func.name}" if func.is_method else func.name # Add node and edge with metadata - G.add_node(func, name=func_name, + G.add_node(func, name=func_name, color=COLOR_PALETTE.get(func.__class__.__name__)) G.add_edge(src_func, func, **generate_edge_meta(call)) @@ -8188,10 +8188,10 @@ We can enrich our edges with metadata about the function calls: ```python def generate_edge_meta(call: FunctionCall) -> dict: """Generate metadata for call graph edges - + Args: call (FunctionCall): Function call information - + Returns: dict: Edge metadata including name and location """ @@ -8210,8 +8210,8 @@ Finally, we can visualize our call graph starting from a specific function: target_class = codebase.get_class('SharingConfigurationViewSet') target_method = target_class.get_method('patch') -# Add root node -G.add_node(target_method, +# Add root node +G.add_node(target_method, name=f"{target_class.name}.{target_method.name}", color=COLOR_PALETTE["StartFunction"]) @@ -8261,7 +8261,7 @@ The core function for building our dependency graph: ```python def create_dependencies_visualization(symbol: Symbol, depth: int = 0): """Creates visualization of symbol dependencies - + Args: symbol (Symbol): Starting symbol to analyze depth (int): Current recursion depth @@ -8269,11 +8269,11 @@ def create_dependencies_visualization(symbol: Symbol, depth: int = 0): # Prevent excessive recursion if depth >= MAX_DEPTH: return - + # Process each dependency for dep in symbol.dependencies: dep_symbol = None - + # Handle different dependency types if isinstance(dep, Symbol): # Direct symbol reference @@ -8284,13 +8284,13 @@ def create_dependencies_visualization(symbol: Symbol, depth: int = 0): if dep_symbol: # Add node with appropriate styling - G.add_node(dep_symbol, - color=COLOR_PALETTE.get(dep_symbol.__class__.__name__, + G.add_node(dep_symbol, + color=COLOR_PALETTE.get(dep_symbol.__class__.__name__, "#f694ff")) - + # Add dependency relationship G.add_edge(symbol, dep_symbol) - + # Recurse unless it's a class (avoid complexity) if not isinstance(dep_symbol, PyClass): create_dependencies_visualization(dep_symbol, depth + 1) @@ -8302,7 +8302,7 @@ Finally, we can visualize our dependency graph starting from a specific symbol: # Get target symbol target_func = codebase.get_function("get_query_runner") -# Add root node +# Add root node G.add_node(target_func, color=COLOR_PALETTE["StartFunction"]) # Generate dependency graph @@ -8345,16 +8345,16 @@ HTTP_METHODS = ["get", "put", "patch", "post", "head", "delete"] def generate_edge_meta(usage: Usage) -> dict: """Generate metadata for graph edges - + Args: usage (Usage): Usage relationship information - + Returns: dict: Edge metadata including name and location """ return { "name": usage.match.source, - "file_path": usage.match.filepath, + "file_path": usage.match.filepath, "start_point": usage.match.start_point, "end_point": usage.match.end_point, "symbol_name": usage.match.__class__.__name__ @@ -8362,10 +8362,10 @@ def generate_edge_meta(usage: Usage) -> dict: def is_http_method(symbol: PySymbol) -> bool: """Check if a symbol is an HTTP endpoint method - + Args: symbol (PySymbol): Symbol to check - + Returns: bool: True if symbol is an HTTP method """ @@ -8379,7 +8379,7 @@ The main function for creating our blast radius visualization: ```python def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0): """Create visualization of symbol usage relationships - + Args: symbol (PySymbol): Starting symbol to analyze depth (int): Current recursion depth @@ -8387,11 +8387,11 @@ def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0): # Prevent excessive recursion if depth >= MAX_DEPTH: return - + # Process each usage of the symbol for usage in symbol.usages: usage_symbol = usage.usage_symbol - + # Determine node color based on type if is_http_method(usage_symbol): color = COLOR_PALETTE.get("HTTP_METHOD") @@ -8401,7 +8401,7 @@ def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0): # Add node and edge to graph G.add_node(usage_symbol, color=color) G.add_edge(symbol, usage_symbol, **generate_edge_meta(usage)) - + # Recursively process usage symbol create_blast_radius_visualization(usage_symbol, depth + 1) ``` @@ -8552,7 +8552,7 @@ for call in old_api.call_sites: f"data={call.get_arg_by_parameter_name('input').value}", f"timeout={call.get_arg_by_parameter_name('wait').value}" ] - + # Replace the old call with the new API call.replace(f"new_process_data({', '.join(args)})") ``` @@ -8566,10 +8566,10 @@ When updating chained method calls, like database queries or builder patterns: for execute_call in codebase.function_calls: if execute_call.name != "execute": continue - + # Get the full chain chain = execute_call.call_chain - + # Example: Add .timeout() before .execute() if "timeout" not in {call.name for call in chain}: execute_call.insert_before("timeout(30)") @@ -8588,45 +8588,45 @@ Here's a comprehensive example: ```python def migrate_api_v1_to_v2(codebase): old_api = codebase.get_function("create_user_v1") - + # Document all existing call patterns call_patterns = {} for call in old_api.call_sites: args = [arg.source for arg in call.args] pattern = ", ".join(args) call_patterns[pattern] = call_patterns.get(pattern, 0) + 1 - + print("Found call patterns:") for pattern, count in call_patterns.items(): print(f" {pattern}: {count} occurrences") - + # Create new API version new_api = old_api.copy() new_api.rename("create_user_v2") - + # Update parameter types new_api.get_parameter("email").type = "EmailStr" new_api.get_parameter("role").type = "UserRole" - + # Add new required parameters new_api.add_parameter("tenant_id: UUID") - + # Update all call sites for call in old_api.call_sites: # Get current arguments email_arg = call.get_arg_by_parameter_name("email") role_arg = call.get_arg_by_parameter_name("role") - + # Build new argument list with type conversions new_args = [ f"email=EmailStr({email_arg.value})", f"role=UserRole({role_arg.value})", "tenant_id=get_current_tenant_id()" ] - + # Replace old call with new version call.replace(f"create_user_v2({', '.join(new_args)})") - + # Add deprecation notice to old version old_api.add_decorator('@deprecated("Use create_user_v2 instead")') @@ -8648,10 +8648,10 @@ migrate_api_v1_to_v2(codebase) ```python # First update parameter names param.rename("new_name") - + # Then update types param.type = "new_type" - + # Finally update call sites for call in api.call_sites: # ... update calls @@ -8661,7 +8661,7 @@ migrate_api_v1_to_v2(codebase) ```python # Add new parameter with default api.add_parameter("new_param: str = None") - + # Later make it required api.get_parameter("new_param").remove_default() ``` @@ -8676,7 +8676,7 @@ migrate_api_v1_to_v2(codebase) Remember to test thoroughly after making bulk changes to APIs. While Codegen ensures syntactic correctness, you'll want to verify the semantic correctness of the changes. - + --- title: "Organizing Your Codebase" @@ -9240,16 +9240,16 @@ from collections import defaultdict # Create a graph of file dependencies def create_dependency_graph(): G = nx.DiGraph() - + for file in codebase.files: # Add node for this file G.add_node(file.filepath) - + # Add edges for each import for imp in file.imports: if imp.from_file: # Skip external imports G.add_edge(file.filepath, imp.from_file.filepath) - + return G # Create and analyze the graph @@ -9278,18 +9278,18 @@ def break_circular_dependency(cycle): # Get the first two files in the cycle file1 = codebase.get_file(cycle[0]) file2 = codebase.get_file(cycle[1]) - + # Create a shared module for common code shared_dir = "shared" if not codebase.has_directory(shared_dir): codebase.create_directory(shared_dir) - + # Find symbols used by both files shared_symbols = [] for symbol in file1.symbols: if any(usage.file == file2 for usage in symbol.usages): shared_symbols.append(symbol) - + # Move shared symbols to a new file if shared_symbols: shared_file = codebase.create_file(f"{shared_dir}/shared_types.py") @@ -9311,7 +9311,7 @@ def organize_file_imports(file): std_lib_imports = [] third_party_imports = [] local_imports = [] - + for imp in file.imports: if imp.is_standard_library: std_lib_imports.append(imp) @@ -9319,26 +9319,26 @@ def organize_file_imports(file): third_party_imports.append(imp) else: local_imports.append(imp) - + # Sort each group for group in [std_lib_imports, third_party_imports, local_imports]: group.sort(key=lambda x: x.module_name) - + # Remove all existing imports for imp in file.imports: imp.remove() - + # Add imports back in organized groups if std_lib_imports: for imp in std_lib_imports: file.add_import(imp.source) file.insert_after_imports("") # Add newline - + if third_party_imports: for imp in third_party_imports: file.add_import(imp.source) file.insert_after_imports("") # Add newline - + if local_imports: for imp in local_imports: file.add_import(imp.source) @@ -9357,22 +9357,22 @@ from collections import defaultdict def analyze_module_coupling(): coupling_scores = defaultdict(int) - + for file in codebase.files: # Count unique files imported from imported_files = {imp.from_file for imp in file.imports if imp.from_file} coupling_scores[file.filepath] = len(imported_files) - + # Count files that import this file - importing_files = {usage.file for symbol in file.symbols + importing_files = {usage.file for symbol in file.symbols for usage in symbol.usages if usage.file != file} coupling_scores[file.filepath] += len(importing_files) - + # Sort by coupling score - sorted_files = sorted(coupling_scores.items(), - key=lambda x: x[1], + sorted_files = sorted(coupling_scores.items(), + key=lambda x: x[1], reverse=True) - + print("\nšŸ” Module Coupling Analysis:") print("\nMost coupled files:") for filepath, score in sorted_files[:5]: @@ -9390,9 +9390,9 @@ def extract_shared_code(file, min_usages=3): # Find symbols used by multiple files for symbol in file.symbols: # Get unique files using this symbol - using_files = {usage.file for usage in symbol.usages + using_files = {usage.file for usage in symbol.usages if usage.file != file} - + if len(using_files) >= min_usages: # Create appropriate shared module module_name = determine_shared_module(symbol) @@ -9400,7 +9400,7 @@ def extract_shared_code(file, min_usages=3): shared_file = codebase.create_file(f"shared/{module_name}.py") else: shared_file = codebase.get_file(f"shared/{module_name}.py") - + # Move symbol to shared module symbol.move_to_file(shared_file, strategy="update_all_imports") @@ -9454,7 +9454,7 @@ if feature_flag_class: # Initialize usage count for all attributes for attr in feature_flag_class.attributes: feature_flag_usage[attr.name] = 0 - + # Get all usages of the FeatureFlag class for usage in feature_flag_class.usages: usage_source = usage.usage_symbol.source if hasattr(usage, 'usage_symbol') else str(usage) @@ -10199,7 +10199,7 @@ Let's break down how this works: if export.is_reexport() and export.is_default_export(): print(f" šŸ”„ Converting default export '{export.name}'") ``` - + The code identifies default exports by checking: 1. If it's a re-export (`is_reexport()`) 2. If it's a default export (`is_default_export()`) @@ -10307,7 +10307,7 @@ for file in codebase.files: print(f"✨ Fixed exports in {target_file.filepath}") -``` +``` --- title: "Creating Documentation" @@ -10396,11 +10396,11 @@ for directory in codebase.directories: # Skip test, sql and alembic directories if any(x in directory.path.lower() for x in ['test', 'sql', 'alembic']): continue - + # Get undecorated functions funcs = [f for f in directory.functions if not f.is_decorated] total = len(funcs) - + # Only analyze dirs with >10 functions if total > 10: documented = sum(1 for f in funcs if f.docstring) @@ -10415,12 +10415,12 @@ for directory in codebase.directories: if dir_stats: lowest_dir = min(dir_stats.items(), key=lambda x: x[1]['coverage']) path, stats = lowest_dir - + print(f"šŸ“‰ Lowest coverage directory: '{path}'") print(f" • Total functions: {stats['total']}") print(f" • Documented: {stats['documented']}") print(f" • Coverage: {stats['coverage']:.1f}%") - + # Print all directory stats for comparison print("\nšŸ“Š All directory coverage rates:") for path, stats in sorted(dir_stats.items(), key=lambda x: x[1]['coverage']): @@ -11208,7 +11208,7 @@ iconType: "solid" -Import loops occur when two or more Python modules depend on each other, creating a circular dependency. While some import cycles can be harmless, others can lead to runtime errors and make code harder to maintain. +Import loops occur when two or more Python modules depend on each other, creating a circular dependency. While some import cycles can be harmless, others can lead to runtime errors and make code harder to maintain. In this tutorial, we'll explore how to identify and fix problematic import cycles using Codegen. @@ -12105,7 +12105,7 @@ Match (s: Func )-[r: CALLS]-> (e:Func) RETURN s, e LIMIT 10 ```cypher Match path = (:(Method|Func)) -[:CALLS*5..10]-> (:(Method|Func)) -Return path +Return path LIMIT 20 ``` @@ -12123,8 +12123,8 @@ iconType: "solid" # AI Impact Analysis -This tutorial shows how to use Codegen's attribution extension to analyze the impact of AI on your -codebase. You'll learn how to identify which parts of your code were written by AI tools like +This tutorial shows how to use Codegen's attribution extension to analyze the impact of AI on your +codebase. You'll learn how to identify which parts of your code were written by AI tools like GitHub Copilot, Devin, or other AI assistants. Note: the code is flexible - you can track CI pipeline bots, or any other contributor you want. @@ -12203,7 +12203,7 @@ for symbol in codebase.symbols: ### Customizing AI Author Detection -By default, the analysis looks for common AI bot names in commit authors. +By default, the analysis looks for common AI bot names in commit authors. You can customize this by providing your own list of AI authors: ```python @@ -12244,34 +12244,34 @@ from codegen.shared.enums.programming_language import ProgrammingLanguage def analyze_contributors(codebase): """Analyze contributors to the codebase and their impact.""" print("\nšŸ” Contributor Analysis:") - + # Define which authors are considered AI ai_authors = ['devin[bot]', 'codegen[bot]', 'github-actions[bot]', 'dependabot[bot]'] - + # Add attribution information to all symbols print("Adding attribution information to symbols...") add_attribution_to_symbols(codebase, ai_authors) - + # Collect statistics about contributors contributor_stats = Counter() ai_contributor_stats = Counter() - + print("Analyzing symbol attributions...") for symbol in codebase.symbols: if hasattr(symbol, 'last_editor') and symbol.last_editor: contributor_stats[symbol.last_editor] += 1 - + # Track if this is an AI contributor if any(ai in symbol.last_editor for ai in ai_authors): ai_contributor_stats[symbol.last_editor] += 1 - + # Print top contributors overall print("\nšŸ‘„ Top Contributors by Symbols Authored:") for contributor, count in contributor_stats.most_common(10): is_ai = any(ai in contributor for ai in ai_authors) ai_indicator = "šŸ¤–" if is_ai else "šŸ‘¤" print(f" {ai_indicator} {contributor}: {count} symbols") - + # Print top AI contributors if any if ai_contributor_stats: print("\nšŸ¤– Top AI Contributors:") @@ -12283,24 +12283,23 @@ if os.path.exists(".git"): repo_path = os.getcwd() repo_config = RepoConfig.from_repo_path(repo_path) repo_operator = RepoOperator(repo_config=repo_config) - + project = ProjectConfig.from_repo_operator( repo_operator=repo_operator, programming_language=ProgrammingLanguage.PYTHON ) codebase = Codebase(projects=[project]) - + # Run the contributor analysis analyze_contributors(codebase) ``` ## Conclusion -The attribution extension provides valuable insights into how AI tools are being used in your +The attribution extension provides valuable insights into how AI tools are being used in your development process. By understanding which parts of your codebase are authored by AI, you can: - Track the adoption of AI coding assistants in your team - Identify areas where AI is most effective - Ensure appropriate review of AI-generated code - Measure the impact of AI on developer productivity - diff --git a/src/codegen/sdk/types.py b/src/codegen/sdk/types.py index 7f070aa0d..99943886b 100644 --- a/src/codegen/sdk/types.py +++ b/src/codegen/sdk/types.py @@ -1,3 +1 @@ -from typing import TypeAlias - -JSON: TypeAlias = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None +type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None diff --git a/src/codegen/shared/compilation/exception_utils.py b/src/codegen/shared/compilation/exception_utils.py index 3f1a40400..b19cc6084 100644 --- a/src/codegen/shared/compilation/exception_utils.py +++ b/src/codegen/shared/compilation/exception_utils.py @@ -45,10 +45,7 @@ def get_local_frame(exc_type: type[BaseException], exc_value: BaseException, exc def get_local_frame_context(frame: FrameType): local_vars = {k: v for k, v in frame.f_locals.items() if not k.startswith("__")} - if "print" in local_vars: - del local_vars["print"] - if "codebase" in local_vars: - del local_vars["codebase"] - if "pr_options" in local_vars: - del local_vars["pr_options"] + local_vars.pop("print", None) + local_vars.pop("codebase", None) + local_vars.pop("pr_options", None) return local_vars diff --git a/tests/unit/codegen/agents/test_api_client.py b/tests/unit/codegen/agents/test_api_client.py index ff75e1261..8a3c28d2c 100644 --- a/tests/unit/codegen/agents/test_api_client.py +++ b/tests/unit/codegen/agents/test_api_client.py @@ -89,7 +89,7 @@ def test_sanitize_for_serialization_tuple(self, api_client): def test_sanitize_for_serialization_datetime(self, api_client): """Test sanitization of datetime values""" - dt = datetime.datetime(2022, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) + dt = datetime.datetime(2022, 1, 1, 12, 0, 0, tzinfo=datetime.UTC) assert api_client.sanitize_for_serialization(dt) == "2022-01-01T12:00:00+00:00" date = datetime.date(2022, 1, 1) diff --git a/tests/unit/codegen/extensions/lsp/test_document_symbols.py b/tests/unit/codegen/extensions/lsp/test_document_symbols.py index 7671ae652..14f516f84 100644 --- a/tests/unit/codegen/extensions/lsp/test_document_symbols.py +++ b/tests/unit/codegen/extensions/lsp/test_document_symbols.py @@ -1,5 +1,4 @@ -from collections.abc import Sequence -from typing import cast +from typing import TYPE_CHECKING, cast import pytest from lsprotocol.types import ( @@ -14,6 +13,9 @@ from codegen.sdk.core.codebase import Codebase +if TYPE_CHECKING: + from collections.abc import Sequence + @pytest.mark.parametrize( "original, expected_symbols", @@ -227,7 +229,7 @@ async def test_document_symbols( result = await client.text_document_document_symbol_async(params=DocumentSymbolParams(text_document=TextDocumentIdentifier(uri=f"file://{codebase.repo_path}/test.py"))) assert result is not None - symbols = cast(Sequence[DocumentSymbol], result) + symbols = cast("Sequence[DocumentSymbol]", result) assert len(symbols) == len(expected_symbols) for actual, expected in zip(symbols, expected_symbols): assert actual.name == expected.name diff --git a/tests/unit/codegen/extensions/lsp/test_workspace_sync.py b/tests/unit/codegen/extensions/lsp/test_workspace_sync.py index c3cd38b64..2b7a839c0 100644 --- a/tests/unit/codegen/extensions/lsp/test_workspace_sync.py +++ b/tests/unit/codegen/extensions/lsp/test_workspace_sync.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import pytest from lsprotocol.types import (