diff --git a/codegen-on-oss/codegen_on_oss/analyzers/analyzer.py b/codegen-on-oss/codegen_on_oss/analyzers/analyzer.py index 4337bba5b..3471380d8 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/analyzer.py +++ b/codegen-on-oss/codegen_on_oss/analyzers/analyzer.py @@ -7,35 +7,37 @@ It serves as the primary API entry point for the analyzer backend. """ -import os -import sys import json import logging +import sys from datetime import datetime -from pathlib import Path -from typing import Dict, List, Set, Tuple, Any, Optional, Union, Type, Callable -from enum import Enum +from typing import Any try: - from codegen.sdk.core.codebase import Codebase from codegen.configs.models.codebase import CodebaseConfig from codegen.configs.models.secrets import SecretsConfig - from codegen.sdk.codebase.config import ProjectConfig - from codegen.git.schemas.repo_config import RepoConfig from codegen.git.repo_operator.repo_operator import RepoOperator + from codegen.git.schemas.repo_config import RepoConfig + from codegen.sdk.codebase.config import ProjectConfig + from codegen.sdk.core.codebase import Codebase from codegen.shared.enums.programming_language import ProgrammingLanguage except ImportError: print("Codegen SDK not found. Please install it first.") sys.exit(1) # Import internal modules - these will be replaced with actual imports once implemented -from codegen_on_oss.analyzers.issues import Issue, IssueSeverity, AnalysisType, IssueCategory +from codegen_on_oss.analyzers.issues import ( + AnalysisType, + Issue, + IssueCategory, + IssueSeverity, +) # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) @@ -56,108 +58,115 @@ ".vscode", ] + class AnalyzerRegistry: """Registry of analyzer plugins.""" - + _instance = None - + def __new__(cls): if cls._instance is None: - cls._instance = super(AnalyzerRegistry, cls).__new__(cls) - cls._instance._analyzers = {} + cls._instance = super().__new__(cls) + cls._analyzers = {} return cls._instance - - def register(self, analysis_type: AnalysisType, analyzer_class: Type['AnalyzerPlugin']): + + def register( + self, analysis_type: AnalysisType, analyzer_class: type["AnalyzerPlugin"] + ): """Register an analyzer plugin.""" self._analyzers[analysis_type] = analyzer_class - - def get_analyzer(self, analysis_type: AnalysisType) -> Optional[Type['AnalyzerPlugin']]: + + def get_analyzer( + self, analysis_type: AnalysisType + ) -> type["AnalyzerPlugin"] | None: """Get the analyzer plugin for a specific analysis type.""" return self._analyzers.get(analysis_type) - - def list_analyzers(self) -> Dict[AnalysisType, Type['AnalyzerPlugin']]: + + def list_analyzers(self) -> dict[AnalysisType, type["AnalyzerPlugin"]]: """Get all registered analyzers.""" return self._analyzers.copy() + class AnalyzerPlugin: """Base class for analyzer plugins.""" - - def __init__(self, manager: 'AnalyzerManager'): + + def __init__(self, manager: "AnalyzerManager"): """Initialize the analyzer plugin.""" self.manager = manager self.issues = [] - - def analyze(self) -> Dict[str, Any]: + + def analyze(self) -> dict[str, Any]: """Perform analysis using this plugin.""" raise NotImplementedError("Analyzer plugins must implement analyze()") - + def add_issue(self, issue: Issue): """Add an issue to the list.""" self.manager.add_issue(issue) self.issues.append(issue) + class CodeQualityPlugin(AnalyzerPlugin): """Plugin for code quality analysis.""" - - def analyze(self) -> Dict[str, Any]: + + def analyze(self) -> dict[str, Any]: """Perform code quality analysis.""" # This is a simplified placeholder - would import and use code_quality.py result = { "dead_code": self._find_dead_code(), "complexity": self._analyze_complexity(), "maintainability": self._analyze_maintainability(), - "style_issues": self._analyze_style_issues() + "style_issues": self._analyze_style_issues(), } return result - - def _find_dead_code(self) -> Dict[str, Any]: + + def _find_dead_code(self) -> dict[str, Any]: """Find unused code in the codebase.""" # This is a placeholder return {"unused_functions": [], "unused_classes": [], "unused_variables": []} - - def _analyze_complexity(self) -> Dict[str, Any]: + + def _analyze_complexity(self) -> dict[str, Any]: """Analyze code complexity.""" # This is a placeholder return {"complex_functions": [], "average_complexity": 0} - - def _analyze_maintainability(self) -> Dict[str, Any]: + + def _analyze_maintainability(self) -> dict[str, Any]: """Analyze code maintainability.""" # This is a placeholder return {"maintainability_index": {}} - - def _analyze_style_issues(self) -> Dict[str, Any]: + + def _analyze_style_issues(self) -> dict[str, Any]: """Analyze code style issues.""" # This is a placeholder return {"style_violations": []} + class DependencyPlugin(AnalyzerPlugin): """Plugin for dependency analysis.""" - def analyze(self) -> Dict[str, Any]: + def analyze(self) -> dict[str, Any]: """Perform dependency analysis using the DependencyAnalyzer.""" - from codegen_on_oss.analyzers.dependencies import DependencyAnalyzer from codegen_on_oss.analyzers.codebase_context import CodebaseContext + from codegen_on_oss.analyzers.dependencies import DependencyAnalyzer # Create context if needed - context = getattr(self.manager, 'base_context', None) - if not context and hasattr(self.manager, 'base_codebase'): + context = getattr(self.manager, "base_context", None) + if not context and hasattr(self.manager, "base_codebase"): try: context = CodebaseContext( codebase=self.manager.base_codebase, base_path=self.manager.repo_path, pr_branch=None, - base_branch=self.manager.base_branch + base_branch=self.manager.base_branch, ) # Save context for future use self.manager.base_context = context - except Exception as e: - logger.error(f"Error initializing context: {e}") + except Exception: + logger.exception("Error initializing context") # Initialize and run the dependency analyzer if context: dependency_analyzer = DependencyAnalyzer( - codebase=self.manager.base_codebase, - context=context + codebase=self.manager.base_codebase, context=context ) # Run analysis @@ -173,43 +182,44 @@ def analyze(self) -> Dict[str, Any]: result = { "import_dependencies": self._analyze_imports(), "circular_dependencies": self._find_circular_dependencies(), - "module_coupling": self._analyze_module_coupling() + "module_coupling": self._analyze_module_coupling(), } return result - def _analyze_imports(self) -> Dict[str, Any]: + def _analyze_imports(self) -> dict[str, Any]: """Fallback import analysis if context initialization failed.""" return {"module_dependencies": [], "external_dependencies": []} - def _find_circular_dependencies(self) -> Dict[str, Any]: + def _find_circular_dependencies(self) -> dict[str, Any]: """Fallback circular dependencies analysis if context initialization failed.""" return {"circular_imports": []} - def _analyze_module_coupling(self) -> Dict[str, Any]: + def _analyze_module_coupling(self) -> dict[str, Any]: """Fallback module coupling analysis if context initialization failed.""" return {"high_coupling_modules": []} + class AnalyzerManager: """ Unified manager for codebase analysis. - + This class serves as the main entry point for all analysis operations, coordinating different analyzer plugins and managing results. """ - + def __init__( self, - repo_url: Optional[str] = None, - repo_path: Optional[str] = None, + repo_url: str | None = None, + repo_path: str | None = None, base_branch: str = "main", - pr_number: Optional[int] = None, - language: Optional[str] = None, - file_ignore_list: Optional[List[str]] = None, - config: Optional[Dict[str, Any]] = None + pr_number: int | None = None, + language: str | None = None, + file_ignore_list: list[str] | None = None, + config: dict[str, Any] | None = None, ): """ Initialize the analyzer manager. - + Args: repo_url: URL of the repository to analyze repo_path: Local path to the repository to analyze @@ -224,88 +234,89 @@ def __init__( self.base_branch = base_branch self.pr_number = pr_number self.language = language - + # Use custom ignore list or default global list self.file_ignore_list = file_ignore_list or GLOBAL_FILE_IGNORE_LIST - + # Configuration options self.config = config or {} - + # Codebase and context objects self.base_codebase = None self.pr_codebase = None - + # Analysis results self.issues = [] self.results = {} - + # PR comparison data self.pr_diff = None self.commit_shas = None self.modified_symbols = None self.pr_branch = None - + # Initialize codebase(s) based on provided parameters if repo_url: self._init_from_url(repo_url, language) elif repo_path: self._init_from_path(repo_path, language) - + # If PR number is provided, initialize PR-specific data if self.pr_number is not None and self.base_codebase is not None: self._init_pr_data(self.pr_number) - + # Register default analyzers self._register_default_analyzers() - - def _init_from_url(self, repo_url: str, language: Optional[str] = None): + + def _init_from_url(self, repo_url: str, language: str | None = None): """Initialize codebase from a repository URL.""" try: # Extract repository information - if repo_url.endswith('.git'): + if repo_url.endswith(".git"): repo_url = repo_url[:-4] - - parts = repo_url.rstrip('/').split('/') + + parts = repo_url.rstrip("/").split("/") repo_name = parts[-1] owner = parts[-2] repo_full_name = f"{owner}/{repo_name}" - + # Create temporary directory for cloning import tempfile + tmp_dir = tempfile.mkdtemp(prefix="analyzer_") - + # Set up configuration config = CodebaseConfig( debug=False, allow_external=True, py_resolve_syspath=True, ) - + secrets = SecretsConfig() - + # Determine programming language prog_lang = None if language: prog_lang = ProgrammingLanguage(language.upper()) - + # Initialize the codebase logger.info(f"Initializing codebase from {repo_url}") - + self.base_codebase = Codebase.from_github( repo_full_name=repo_full_name, tmp_dir=tmp_dir, language=prog_lang, config=config, - secrets=secrets + secrets=secrets, ) - + logger.info(f"Successfully initialized codebase from {repo_url}") - - except Exception as e: - logger.error(f"Error initializing codebase from URL: {e}") + + except Exception: + logger.exception("Error initializing codebase from URL") raise - - def _init_from_path(self, repo_path: str, language: Optional[str] = None): + + def _init_from_path(self, repo_path: str, language: str | None = None): """Initialize codebase from a local repository path.""" try: # Set up configuration @@ -314,163 +325,163 @@ def _init_from_path(self, repo_path: str, language: Optional[str] = None): allow_external=True, py_resolve_syspath=True, ) - + secrets = SecretsConfig() - + # Initialize the codebase logger.info(f"Initializing codebase from {repo_path}") - + # Determine programming language prog_lang = None if language: prog_lang = ProgrammingLanguage(language.upper()) - + # Set up repository configuration repo_config = RepoConfig.from_repo_path(repo_path) repo_config.respect_gitignore = False repo_operator = RepoOperator(repo_config=repo_config, bot_commit=False) - + # Create project configuration project_config = ProjectConfig( repo_operator=repo_operator, - programming_language=prog_lang if prog_lang else None + programming_language=prog_lang if prog_lang else None, ) - + # Initialize codebase self.base_codebase = Codebase( - projects=[project_config], - config=config, - secrets=secrets + projects=[project_config], config=config, secrets=secrets ) - + logger.info(f"Successfully initialized codebase from {repo_path}") - - except Exception as e: - logger.error(f"Error initializing codebase from path: {e}") + + except Exception: + logger.exception("Error initializing codebase from path") raise - + def _init_pr_data(self, pr_number: int): """Initialize PR-specific data.""" try: logger.info(f"Fetching PR #{pr_number} data") result = self.base_codebase.get_modified_symbols_in_pr(pr_number) - + # Unpack the result tuple if len(result) >= 3: self.pr_diff, self.commit_shas, self.modified_symbols = result[:3] if len(result) >= 4: self.pr_branch = result[3] - + logger.info(f"Found {len(self.modified_symbols)} modified symbols in PR") - + # Initialize PR codebase self._init_pr_codebase() - - except Exception as e: - logger.error(f"Error initializing PR data: {e}") + + except Exception: + logger.exception("Error initializing PR data") raise - + def _init_pr_codebase(self): """Initialize PR codebase by checking out the PR branch.""" if not self.base_codebase or not self.pr_number: logger.error("Base codebase or PR number not initialized") return - + try: # Get PR data if not already fetched if not self.pr_branch: self._init_pr_data(self.pr_number) - + if not self.pr_branch: logger.error("Failed to get PR branch") return - + # Clone the base codebase self.pr_codebase = self.base_codebase - + # Checkout PR branch logger.info(f"Checking out PR branch: {self.pr_branch}") self.pr_codebase.checkout(self.pr_branch) - + logger.info("Successfully initialized PR codebase") - - except Exception as e: - logger.error(f"Error initializing PR codebase: {e}") + + except Exception: + logger.exception("Error initializing PR codebase") raise - + def _register_default_analyzers(self): """Register default analyzers.""" registry = AnalyzerRegistry() registry.register(AnalysisType.CODE_QUALITY, CodeQualityPlugin) registry.register(AnalysisType.DEPENDENCY, DependencyPlugin) - + def add_issue(self, issue: Issue): """Add an issue to the list.""" # Check if issue should be skipped if self._should_skip_issue(issue): return - + self.issues.append(issue) - + def _should_skip_issue(self, issue: Issue) -> bool: """Check if an issue should be skipped.""" # Skip issues in ignored files file_path = issue.file - + # Check against ignore list for pattern in self.file_ignore_list: if pattern in file_path: return True - - # Check if the file is a test file - if "test" in file_path.lower() or "tests" in file_path.lower(): - # Skip low-severity issues in test files - if issue.severity in [IssueSeverity.INFO, IssueSeverity.WARNING]: - return True - - return False - - def get_issues(self, severity: Optional[IssueSeverity] = None, category: Optional[IssueCategory] = None) -> List[Issue]: + + # Check if the file is a test file and skip low-severity issues in test files + return bool( + ("test" in file_path.lower() or "tests" in file_path.lower()) + and issue.severity in [IssueSeverity.INFO, IssueSeverity.WARNING] + ) + + def get_issues( + self, + severity: IssueSeverity | None = None, + category: IssueCategory | None = None, + ) -> list[Issue]: """ Get all issues matching the specified criteria. - + Args: severity: Optional severity level to filter by category: Optional category to filter by - + Returns: List of matching issues """ filtered_issues = self.issues - + if severity: filtered_issues = [i for i in filtered_issues if i.severity == severity] - + if category: filtered_issues = [i for i in filtered_issues if i.category == category] - + return filtered_issues - + def analyze( - self, - analysis_types: Optional[List[Union[AnalysisType, str]]] = None, - output_file: Optional[str] = None, - output_format: str = "json" - ) -> Dict[str, Any]: + self, + analysis_types: list[AnalysisType | str] | None = None, + output_file: str | None = None, + output_format: str = "json", + ) -> dict[str, Any]: """ Perform analysis on the codebase. - + Args: analysis_types: List of analysis types to perform output_file: Path to save results to output_format: Format of the output file - + Returns: Dictionary containing analysis results """ if not self.base_codebase: - raise ValueError("Codebase not initialized") - + raise ValueError("Base codebase is missing") + # Convert string analysis types to enums if analysis_types: analysis_types = [ @@ -480,78 +491,97 @@ def analyze( else: # Default to code quality and dependency analysis analysis_types = [AnalysisType.CODE_QUALITY, AnalysisType.DEPENDENCY] - + # Initialize results self.results = { "metadata": { "analysis_time": datetime.now().isoformat(), "analysis_types": [t.value for t in analysis_types], - "repo_name": getattr(self.base_codebase.ctx, 'repo_name', None), - "language": str(getattr(self.base_codebase.ctx, 'programming_language', None)), + "repo_name": getattr(self.base_codebase.ctx, "repo_name", None), + "language": str( + getattr(self.base_codebase.ctx, "programming_language", None) + ), }, "summary": {}, - "results": {} + "results": {}, } - + # Reset issues self.issues = [] - + # Run each analyzer registry = AnalyzerRegistry() - + for analysis_type in analysis_types: analyzer_class = registry.get_analyzer(analysis_type) - + if analyzer_class: logger.info(f"Running {analysis_type.value} analysis") analyzer = analyzer_class(self) analysis_result = analyzer.analyze() - + # Add results to unified results self.results["results"][analysis_type.value] = analysis_result else: logger.warning(f"No analyzer found for {analysis_type.value}") - + # Add issues to results self.results["issues"] = [issue.to_dict() for issue in self.issues] - + # Add issue statistics self.results["issue_stats"] = { "total": len(self.issues), "by_severity": { - "critical": sum(1 for issue in self.issues if issue.severity == IssueSeverity.CRITICAL), - "error": sum(1 for issue in self.issues if issue.severity == IssueSeverity.ERROR), - "warning": sum(1 for issue in self.issues if issue.severity == IssueSeverity.WARNING), - "info": sum(1 for issue in self.issues if issue.severity == IssueSeverity.INFO), - } + "critical": sum( + 1 + for issue in self.issues + if issue.severity == IssueSeverity.CRITICAL + ), + "error": sum( + 1 for issue in self.issues if issue.severity == IssueSeverity.ERROR + ), + "warning": sum( + 1 + for issue in self.issues + if issue.severity == IssueSeverity.WARNING + ), + "info": sum( + 1 for issue in self.issues if issue.severity == IssueSeverity.INFO + ), + }, } - + # Save results if output file is specified if output_file: self.save_results(output_file, output_format) - + return self.results - - def save_results(self, output_file: str, format: str = "json"): + + def save_results(self, output_file: str, output_format: str = "json"): """ Save analysis results to a file. - + Args: output_file: Path to the output file - format: Output format (json, html) + output_format: Format of the output file (json or yaml) """ - if format == "json": - with open(output_file, 'w') as f: + if output_format == "json": + with open(output_file, "w") as f: json.dump(self.results, f, indent=2) - elif format == "html": - self._generate_html_report(output_file) + elif output_format == "yaml": + import yaml + + with open(output_file, "w") as f: + yaml.dump(self.results, f, default_flow_style=False) + elif output_format == "console": + print(json.dumps(self.results, indent=2)) else: # Default to JSON - with open(output_file, 'w') as f: + with open(output_file, "w") as f: json.dump(self.results, f, indent=2) - + logger.info(f"Results saved to {output_file}") - + def _generate_html_report(self, output_file: str): """Generate an HTML report of the analysis results.""" html_content = f""" @@ -578,85 +608,93 @@ def _generate_html_report(self, output_file: str):

Codebase Analysis Report

Summary

-

Repository: {self.results['metadata'].get('repo_name', 'Unknown')}

-

Language: {self.results['metadata'].get('language', 'Unknown')}

-

Analysis Time: {self.results['metadata'].get('analysis_time', 'Unknown')}

-

Analysis Types: {', '.join(self.results['metadata'].get('analysis_types', []))}

+

Repository: {self.results["metadata"].get("repo_name", "Unknown")}

+

Language: {self.results["metadata"].get("language", "Unknown")}

+

Analysis Time: {self.results["metadata"].get("analysis_time", "Unknown")}

+

Analysis Types: {", ".join(self.results["metadata"].get("analysis_types", []))}

Total Issues: {len(self.issues)}

""" - + # Add issues section html_content += """

Issues

""" - + # Add issues by severity for severity in ["critical", "error", "warning", "info"]: - severity_issues = [issue for issue in self.issues if issue.severity.value == severity] - + severity_issues = [ + issue for issue in self.issues if issue.severity.value == severity + ] + if severity_issues: html_content += f"""

{severity.upper()} Issues ({len(severity_issues)})

""" - + for issue in severity_issues: - location = f"{issue.file}:{issue.line}" if issue.line else issue.file - category = f"[{issue.category.value}]" if hasattr(issue, 'category') and issue.category else "" - + location = ( + f"{issue.file}:{issue.line}" if issue.line else issue.file + ) + category = ( + f"[{issue.category.value}]" + if hasattr(issue, "category") and issue.category + else "" + ) + html_content += f"""

{location} {category} {issue.message}

-

{issue.suggestion if hasattr(issue, 'suggestion') else ""}

+

{issue.suggestion if hasattr(issue, "suggestion") else ""}

""" - + html_content += """
""" - + # Add detailed analysis sections html_content += """

Detailed Analysis

""" - - for analysis_type, results in self.results.get('results', {}).items(): + + for analysis_type, results in self.results.get("results", {}).items(): html_content += f"""

{analysis_type}

{json.dumps(results, indent=2)}
""" - + html_content += """
""" - - with open(output_file, 'w') as f: + + with open(output_file, "w") as f: f.write(html_content) - + def generate_report(self, report_type: str = "summary") -> str: """ Generate a report from the analysis results. - + Args: report_type: Type of report to generate (summary, detailed, issues) - + Returns: Report as a string """ if not self.results: raise ValueError("No analysis results available") - + if report_type == "summary": return self._generate_summary_report() elif report_type == "detailed": @@ -665,220 +703,281 @@ def generate_report(self, report_type: str = "summary") -> str: return self._generate_issues_report() else: raise ValueError(f"Unknown report type: {report_type}") - + def _generate_summary_report(self) -> str: """Generate a summary report.""" report = "===== Codebase Analysis Summary Report =====\n\n" - + # Add metadata - report += f"Repository: {self.results['metadata'].get('repo_name', 'Unknown')}\n" + report += ( + f"Repository: {self.results['metadata'].get('repo_name', 'Unknown')}\n" + ) report += f"Language: {self.results['metadata'].get('language', 'Unknown')}\n" report += f"Analysis Time: {self.results['metadata'].get('analysis_time', 'Unknown')}\n" report += f"Analysis Types: {', '.join(self.results['metadata'].get('analysis_types', []))}\n\n" - + # Add issue statistics report += f"Total Issues: {len(self.issues)}\n" report += f"Critical: {self.results['issue_stats']['by_severity'].get('critical', 0)}\n" - report += f"Errors: {self.results['issue_stats']['by_severity'].get('error', 0)}\n" + report += ( + f"Errors: {self.results['issue_stats']['by_severity'].get('error', 0)}\n" + ) report += f"Warnings: {self.results['issue_stats']['by_severity'].get('warning', 0)}\n" - report += f"Info: {self.results['issue_stats']['by_severity'].get('info', 0)}\n\n" - + report += ( + f"Info: {self.results['issue_stats']['by_severity'].get('info', 0)}\n\n" + ) + # Add analysis summaries - for analysis_type, results in self.results.get('results', {}).items(): + for analysis_type, results in self.results.get("results", {}).items(): report += f"===== {analysis_type.upper()} Analysis =====\n" - + if analysis_type == "code_quality": if "dead_code" in results: dead_code = results["dead_code"] report += f"Dead Code: {len(dead_code.get('unused_functions', []))} unused functions, " - report += f"{len(dead_code.get('unused_classes', []))} unused classes\n" - + report += ( + f"{len(dead_code.get('unused_classes', []))} unused classes\n" + ) + if "complexity" in results: complexity = results["complexity"] report += f"Complexity: {len(complexity.get('complex_functions', []))} complex functions\n" - + elif analysis_type == "dependency": if "circular_dependencies" in results: circular = results["circular_dependencies"] report += f"Circular Dependencies: {len(circular.get('circular_imports', []))}\n" - + if "module_coupling" in results: coupling = results["module_coupling"] report += f"High Coupling Modules: {len(coupling.get('high_coupling_modules', []))}\n" - + report += "\n" - + return report - + def _generate_detailed_report(self) -> str: """Generate a detailed report.""" report = "===== Codebase Analysis Detailed Report =====\n\n" - + # Add metadata - report += f"Repository: {self.results['metadata'].get('repo_name', 'Unknown')}\n" + report += ( + f"Repository: {self.results['metadata'].get('repo_name', 'Unknown')}\n" + ) report += f"Language: {self.results['metadata'].get('language', 'Unknown')}\n" report += f"Analysis Time: {self.results['metadata'].get('analysis_time', 'Unknown')}\n" report += f"Analysis Types: {', '.join(self.results['metadata'].get('analysis_types', []))}\n\n" - + # Add detailed issue report report += "===== Issues =====\n\n" - + for severity in ["critical", "error", "warning", "info"]: - severity_issues = [issue for issue in self.issues if issue.severity.value == severity] - + severity_issues = [ + issue for issue in self.issues if issue.severity.value == severity + ] + if severity_issues: report += f"{severity.upper()} Issues ({len(severity_issues)}):\n" - + for issue in severity_issues: - location = f"{issue.file}:{issue.line}" if issue.line else issue.file - category = f"[{issue.category.value}]" if hasattr(issue, 'category') and issue.category else "" - + location = ( + f"{issue.file}:{issue.line}" if issue.line else issue.file + ) + category = ( + f"[{issue.category.value}]" + if hasattr(issue, "category") and issue.category + else "" + ) + report += f"- {location} {category} {issue.message}\n" - if hasattr(issue, 'suggestion') and issue.suggestion: + if hasattr(issue, "suggestion") and issue.suggestion: report += f" Suggestion: {issue.suggestion}\n" - + report += "\n" - + # Add detailed analysis - for analysis_type, results in self.results.get('results', {}).items(): + for analysis_type, results in self.results.get("results", {}).items(): report += f"===== {analysis_type.upper()} Analysis =====\n\n" - + # Format based on analysis type if analysis_type == "code_quality": # Dead code details if "dead_code" in results: dead_code = results["dead_code"] report += "Dead Code:\n" - - if dead_code.get('unused_functions'): + + if dead_code.get("unused_functions"): report += " Unused Functions:\n" - for func in dead_code.get('unused_functions', [])[:10]: # Limit to 10 + for func in dead_code.get("unused_functions", [])[ + :10 + ]: # Limit to 10 report += f" - {func.get('name')} ({func.get('file')})\n" - - if len(dead_code.get('unused_functions', [])) > 10: + + if len(dead_code.get("unused_functions", [])) > 10: report += f" ... and {len(dead_code.get('unused_functions', [])) - 10} more\n" - - if dead_code.get('unused_classes'): + + if dead_code.get("unused_classes"): report += " Unused Classes:\n" - for cls in dead_code.get('unused_classes', [])[:10]: # Limit to 10 + for cls in dead_code.get("unused_classes", [])[ + :10 + ]: # Limit to 10 report += f" - {cls.get('name')} ({cls.get('file')})\n" - - if len(dead_code.get('unused_classes', [])) > 10: + + if len(dead_code.get("unused_classes", [])) > 10: report += f" ... and {len(dead_code.get('unused_classes', [])) - 10} more\n" - + report += "\n" - + # Complexity details if "complexity" in results: complexity = results["complexity"] report += "Code Complexity:\n" - - if complexity.get('complex_functions'): + + if complexity.get("complex_functions"): report += " Complex Functions:\n" - for func in complexity.get('complex_functions', [])[:10]: # Limit to 10 + for func in complexity.get("complex_functions", [])[ + :10 + ]: # Limit to 10 report += f" - {func.get('name')} (Complexity: {func.get('complexity')}, {func.get('file')})\n" - - if len(complexity.get('complex_functions', [])) > 10: + + if len(complexity.get("complex_functions", [])) > 10: report += f" ... and {len(complexity.get('complex_functions', [])) - 10} more\n" - + report += "\n" - + elif analysis_type == "dependency": # Circular dependencies if "circular_dependencies" in results: circular = results["circular_dependencies"] report += "Circular Dependencies:\n" - - if circular.get('circular_imports'): - for i, cycle in enumerate(circular.get('circular_imports', [])[:5]): # Limit to 5 - report += f" Cycle {i+1} (Length: {cycle.get('length')}):\n" - for j, file_path in enumerate(cycle.get('files', [])): - report += f" {j+1}. {file_path}\n" - - if len(circular.get('circular_imports', [])) > 5: + + if circular.get("circular_imports"): + for i, cycle in enumerate( + circular.get("circular_imports", [])[:5] + ): # Limit to 5 + report += ( + f" Cycle {i + 1} (Length: {cycle.get('length')}):\n" + ) + for j, file_path in enumerate(cycle.get("files", [])): + report += f" {j + 1}. {file_path}\n" + + if len(circular.get("circular_imports", [])) > 5: report += f" ... and {len(circular.get('circular_imports', [])) - 5} more cycles\n" - + report += "\n" - + # Module coupling if "module_coupling" in results: coupling = results["module_coupling"] report += "Module Coupling:\n" - - if coupling.get('high_coupling_modules'): + + if coupling.get("high_coupling_modules"): report += " High Coupling Modules:\n" - for module in coupling.get('high_coupling_modules', [])[:10]: # Limit to 10 + for module in coupling.get("high_coupling_modules", [])[ + :10 + ]: # Limit to 10 report += f" - {module.get('module')} (Ratio: {module.get('coupling_ratio'):.2f})\n" - - if len(coupling.get('high_coupling_modules', [])) > 10: + + if len(coupling.get("high_coupling_modules", [])) > 10: report += f" ... and {len(coupling.get('high_coupling_modules', [])) - 10} more\n" - + report += "\n" - + return report - + def _generate_issues_report(self) -> str: """Generate an issues-focused report.""" report = "===== Codebase Analysis Issues Report =====\n\n" - + # Add issue statistics report += f"Total Issues: {len(self.issues)}\n" report += f"Critical: {self.results['issue_stats']['by_severity'].get('critical', 0)}\n" - report += f"Errors: {self.results['issue_stats']['by_severity'].get('error', 0)}\n" + report += ( + f"Errors: {self.results['issue_stats']['by_severity'].get('error', 0)}\n" + ) report += f"Warnings: {self.results['issue_stats']['by_severity'].get('warning', 0)}\n" - report += f"Info: {self.results['issue_stats']['by_severity'].get('info', 0)}\n\n" - + report += ( + f"Info: {self.results['issue_stats']['by_severity'].get('info', 0)}\n\n" + ) + # Add issues by severity for severity in ["critical", "error", "warning", "info"]: - severity_issues = [issue for issue in self.issues if issue.severity.value == severity] - + severity_issues = [ + issue for issue in self.issues if issue.severity.value == severity + ] + if severity_issues: report += f"{severity.upper()} Issues ({len(severity_issues)}):\n" - + for issue in severity_issues: - location = f"{issue.file}:{issue.line}" if issue.line else issue.file - category = f"[{issue.category.value}]" if hasattr(issue, 'category') and issue.category else "" - + location = ( + f"{issue.file}:{issue.line}" if issue.line else issue.file + ) + category = ( + f"[{issue.category.value}]" + if hasattr(issue, "category") and issue.category + else "" + ) + report += f"- {location} {category} {issue.message}\n" - if hasattr(issue, 'suggestion') and issue.suggestion: + if hasattr(issue, "suggestion") and issue.suggestion: report += f" Suggestion: {issue.suggestion}\n" - + report += "\n" - + return report + def main(): """Command-line entry point.""" import argparse - + parser = argparse.ArgumentParser(description="Unified Codebase Analyzer") - + # Repository source options source_group = parser.add_mutually_exclusive_group(required=True) source_group.add_argument("--repo-url", help="URL of the repository to analyze") - source_group.add_argument("--repo-path", help="Local path to the repository to analyze") - + source_group.add_argument( + "--repo-path", help="Local path to the repository to analyze" + ) + # Analysis options - parser.add_argument("--analysis-types", nargs="+", choices=[at.value for at in AnalysisType], - default=["code_quality", "dependency"], - help="Types of analysis to perform") - parser.add_argument("--language", choices=["python", "typescript"], - help="Programming language (auto-detected if not provided)") - parser.add_argument("--base-branch", default="main", - help="Base branch for PR comparison (default: main)") - parser.add_argument("--pr-number", type=int, - help="PR number to analyze") - + parser.add_argument( + "--analysis-types", + nargs="+", + choices=[at.value for at in AnalysisType], + default=["code_quality", "dependency"], + help="Types of analysis to perform", + ) + parser.add_argument( + "--language", + choices=["python", "typescript"], + help="Programming language (auto-detected if not provided)", + ) + parser.add_argument( + "--base-branch", + default="main", + help="Base branch for PR comparison (default: main)", + ) + parser.add_argument("--pr-number", type=int, help="PR number to analyze") + # Output options - parser.add_argument("--output-file", - help="Path to the output file") - parser.add_argument("--output-format", choices=["json", "html", "console"], default="json", - help="Output format") - parser.add_argument("--report-type", choices=["summary", "detailed", "issues"], default="summary", - help="Type of report to generate (default: summary)") - + parser.add_argument("--output-file", help="Path to the output file") + parser.add_argument( + "--output-format", + choices=["json", "yaml", "console"], + default="json", + help="Output format", + ) + parser.add_argument( + "--report-type", + choices=["summary", "detailed", "issues"], + default="summary", + help="Type of report to generate (default: summary)", + ) + args = parser.parse_args() - + try: # Initialize the analyzer manager manager = AnalyzerManager( @@ -886,26 +985,28 @@ def main(): repo_path=args.repo_path, language=args.language, base_branch=args.base_branch, - pr_number=args.pr_number + pr_number=args.pr_number, ) - + # Run the analysis manager.analyze( analysis_types=args.analysis_types, output_file=args.output_file, - output_format=args.output_format + output_format=args.output_format, ) - + # Generate and print report if format is console if args.output_format == "console": report = manager.generate_report(args.report_type) print(report) - + except Exception as e: - logger.error(f"Error: {e}") + logger.exception(f"Error: {e}") import traceback + traceback.print_exc() sys.exit(1) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/codegen-on-oss/codegen_on_oss/analyzers/doc_utils/__init__.py b/codegen-on-oss/codegen_on_oss/analyzers/doc_utils/__init__.py new file mode 100644 index 000000000..2c86996c0 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/doc_utils/__init__.py @@ -0,0 +1,2 @@ +"""Documentation utilities for the analyzers module.""" + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/doc_utils/schemas.py b/codegen-on-oss/codegen_on_oss/analyzers/doc_utils/schemas.py new file mode 100644 index 000000000..a60dae9a1 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/doc_utils/schemas.py @@ -0,0 +1,53 @@ +"""Schema definitions for documentation generation.""" + +from typing import Literal + +from pydantic import BaseModel +from pydantic.fields import Field + + +class ParameterDoc(BaseModel): + """Documentation for a function or method parameter.""" + + name: str = Field(..., description="The name of the parameter") + description: str = Field(..., description="The description of the parameter") + type: str = Field(..., description="The type of the parameter") + default: str = Field(default="", description="The default value of the parameter") + + +class MethodDoc(BaseModel): + """Documentation for a method or property.""" + + name: str = Field(..., description="The name of the method") + description: str | None = Field(..., description="The description of the method") + parameters: list[ParameterDoc] = Field(..., description="The parameters of the method") + return_type: list[str] | None = Field(default=None, description="The return types of the method") + return_description: str | None = Field(default=None, description="The return description of the method") + method_type: Literal["method", "property", "attribute"] = Field(..., description="The type of the method") + code: str = Field(..., description="The signature of the method or attribute") + path: str = Field(..., description="The path of the method that indicates its parent class //") + raises: list[dict] | None = Field(..., description="The raises of the method") + metainfo: dict = Field(..., description="Information about the method's true parent class and path") + version: str = Field(..., description="The commit hash of the git commit that generated the docs") + github_url: str = Field(..., description="The github url of the method") + + +class ClassDoc(BaseModel): + """Documentation for a class.""" + + title: str = Field(..., description="The title of the class") + description: str = Field(..., description="The description of the class") + content: str = Field(..., description="The content of the class") + path: str = Field(..., description="The path of the class") + inherits_from: list[str] = Field(..., description="The classes that the class inherits from") + version: str = Field(..., description="The commit hash of the git commit that generated the docs") + methods: list[MethodDoc] = Field(default=[], description="The methods of the class") + attributes: list[MethodDoc] = Field(default=[], description="The attributes of the class") + github_url: str = Field(..., description="The github url of the class") + + +class GSDocs(BaseModel): + """Collection of class documentation.""" + + classes: list[ClassDoc] = Field(..., description="The classes to document") + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/doc_utils/utils.py b/codegen-on-oss/codegen_on_oss/analyzers/doc_utils/utils.py new file mode 100644 index 000000000..43722e057 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/doc_utils/utils.py @@ -0,0 +1,81 @@ +"""Utility functions for documentation generation.""" + +import re +import textwrap +from typing import Optional + + +def sanitize_docstring_for_markdown(docstring: Optional[str]) -> str: + """Sanitize the docstring for MDX. + + Args: + docstring: The docstring to sanitize. + + Returns: + The sanitized docstring. + """ + if docstring is None: + return "" + docstring_lines = docstring.splitlines() + if len(docstring_lines) > 1: + docstring_lines[1:] = [textwrap.dedent(line) for line in docstring_lines[1:]] + docstring = "\n".join(docstring_lines) + if docstring.startswith('"""'): + docstring = docstring[3:] + if docstring.endswith('"""'): + docstring = docstring[:-3] + return docstring + + +def sanitize_mdx_mintlify_description(content: str) -> str: + """Mintlify description field needs to have string escaped, which content doesn't need. + + Args: + content: The content to sanitize. + + Returns: + The sanitized content. + """ + content = sanitize_docstring_for_markdown(content) + # make sure all `< />` components are properly escaped with a `` inline-block + # if the string already has the single-quote then this is a no-op + content = re.sub(r"(?]+>)(?!`)", r"`\1`", content) + + # escape double quote characters + if re.search(r'\\"', content): + return content # No-op if already escaped + return re.sub(r'(")', r"\\\1", content) + + +def sanitize_html_for_mdx(html_string: str) -> str: + """Sanitize HTML string for MDX by escaping double quotes in attribute values. + + Args: + html_string: The input HTML string to sanitize + + Returns: + The sanitized HTML string with escaped quotes + """ + # Replace double quotes with " but only in HTML attributes + return re.sub(r'"', """, html_string) + + +def extract_class_description(docstring: str) -> str: + """Extract the class description from a docstring, excluding the attributes section. + + Args: + docstring: The class docstring to parse + + Returns: + The class description with whitespace normalized + """ + if not docstring: + return "" + + # Split by "Attributes:" and take only the first part + parts = docstring.split("Attributes:") + description = parts[0] + + # Normalize whitespace + lines = [line.strip() for line in description.strip().splitlines()] + return " ".join(filter(None, lines)) diff --git a/codegen-on-oss/codegen_on_oss/analyzers/mdx_docs_generation.py b/codegen-on-oss/codegen_on_oss/analyzers/mdx_docs_generation.py new file mode 100644 index 000000000..81645488f --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/mdx_docs_generation.py @@ -0,0 +1,333 @@ +"""MDX documentation generation utilities for code analysis. + +This module provides functionality for generating MDX documentation from code analysis, +including rendering MDX pages for classes, documenting methods and attributes, +formatting parameters and return types, and sanitizing HTML and MDX content. +""" + +import re +from typing import Optional, List, Union + +from codegen_on_oss.analyzers.doc_utils.schemas import ClassDoc, MethodDoc, ParameterDoc +from codegen_on_oss.analyzers.doc_utils.utils import sanitize_html_for_mdx, sanitize_mdx_mintlify_description + + +def render_mdx_page_for_class(cls_doc: ClassDoc) -> str: + """Renders the MDX for a single class. + + Args: + cls_doc: The class documentation object. + + Returns: + The MDX content for the class. + """ + return f"""{render_mdx_page_title(cls_doc)} +{render_mdx_inheritence_section(cls_doc)} +{render_mdx_attributes_section(cls_doc)} +{render_mdx_methods_section(cls_doc)} +""" + + +def render_mdx_page_title(cls_doc: ClassDoc, icon: Optional[str] = None) -> str: + """Renders the MDX for the page title. + + Args: + cls_doc: The class documentation object. + icon: Optional icon to display. + + Returns: + The MDX content for the page title. + """ + page_desc = cls_doc.description if hasattr(cls_doc, "description") else "" + + return f"""--- +title: "{cls_doc.title}" +sidebarTitle: "{cls_doc.title}" +icon: "{icon if icon else ""}" +description: "{sanitize_mdx_mintlify_description(page_desc)}" +--- +import {{Parameter}} from '/snippets/Parameter.mdx'; +import {{ParameterWrapper}} from '/snippets/ParameterWrapper.mdx'; +import {{Return}} from '/snippets/Return.mdx'; +import {{HorizontalDivider}} from '/snippets/HorizontalDivider.mdx'; +import {{GithubLinkNote}} from '/snippets/GithubLinkNote.mdx'; +import {{Attribute}} from '/snippets/Attribute.mdx'; + + +""" + + +def render_mdx_inheritence_section(cls_doc: ClassDoc) -> str: + """Renders the MDX for the inheritance section. + + Args: + cls_doc: The class documentation object. + + Returns: + The MDX content for the inheritance section. + """ + # Filter on parents who we have docs for + parents = cls_doc.inherits_from + if not parents: + return "" + parents_string = ", ".join([parse_link(parent) for parent in parents]) + return f"""### Inherits from +{parents_string} +""" + + +def render_mdx_attributes_section(cls_doc: ClassDoc) -> str: + """Renders the MDX for the attributes section. + + Args: + cls_doc: The class documentation object. + + Returns: + The MDX content for the attributes section. + """ + sorted_attributes = sorted(cls_doc.attributes + [method for method in cls_doc.methods if method.method_type == "property"], key=lambda x: x.name) + if len(sorted_attributes) <= 0: + return "" + attributes_mdx_string = "\n".join([render_mdx_for_attribute(attribute) for attribute in sorted_attributes]) + + return f"""## Attributes + +{attributes_mdx_string} +""" + + +def render_mdx_methods_section(cls_doc: ClassDoc) -> str: + """Renders the MDX for the methods section. + + Args: + cls_doc: The class documentation object. + + Returns: + The MDX content for the methods section. + """ + sorted_methods = sorted(cls_doc.methods, key=lambda x: x.name) + if len(sorted_methods) <= 0: + return "" + methods_mdx_string = "\n".join([render_mdx_for_method(method) for method in sorted_methods if method.method_type == "method"]) + + return f"""## Methods + +{methods_mdx_string} +""" + + +def render_mdx_for_attribute(attribute: MethodDoc) -> str: + """Renders the MDX for a single attribute. + + Args: + attribute: The attribute documentation object. + + Returns: + The MDX content for the attribute. + """ + attribute_docstring = sanitize_mdx_mintlify_description(attribute.description or "") + if attribute.return_type and len(attribute.return_type) > 0: + return_type = f"{resolve_type_string(attribute.return_type[0])}" + else: + return_type = "" + if not attribute_docstring: + attribute_docstring = "\n" + return f"""### {attribute.name} + +"} }} description="{attribute_docstring}" /> +""" + + +######################################################################################################################## +# METHODS +######################################################################################################################## + + +def format_parameter_for_mdx(parameter: ParameterDoc) -> str: + """Format a parameter for MDX documentation. + + Args: + parameter: The parameter documentation object. + + Returns: + The MDX content for the parameter. + """ + type_string = resolve_type_string(parameter.type) + return f""" + +""".strip() + + +def format_parameters_for_mdx(parameters: List[ParameterDoc]) -> str: + """Format a list of parameters for MDX documentation. + + Args: + parameters: The list of parameter documentation objects. + + Returns: + The MDX content for the parameters. + """ + return "\n".join([format_parameter_for_mdx(parameter) for parameter in parameters]) + + +def format_return_for_mdx(return_type: List[str], return_description: Optional[str] = None) -> str: + """Format a return type for MDX documentation. + + Args: + return_type: The return type. + return_description: The return description. + + Returns: + The MDX content for the return type. + """ + description = sanitize_html_for_mdx(return_description) if return_description else "" + return_type_str = resolve_type_string(return_type[0]) if return_type and len(return_type) > 0 else "" + + return f""" + +""" + + +def render_mdx_for_method(method: MethodDoc) -> str: + """Renders the MDX for a single method. + + Args: + method: The method documentation object. + + Returns: + The MDX content for the method. + """ + description = sanitize_mdx_mintlify_description(method.description or "") + # =====[ RENDER ]===== + mdx_string = f"""### {method.name} +{description} + +""" + if method.parameters: + mdx_string += f""" + +{format_parameters_for_mdx(method.parameters)} + +""" + if method.return_type: + mdx_string += f""" +{format_return_for_mdx(method.return_type, method.return_description)} +""" + + return mdx_string + + +def get_mdx_route_for_class(cls_doc: ClassDoc) -> str: + """Get the expected MDX route for a class. + + Split by /core, /python, and /typescript + + Args: + cls_doc: The class documentation object. + + Returns: + The MDX route for the class. + """ + lower_class_name = cls_doc.title.lower() + if lower_class_name.startswith("py"): + return f"codebase-sdk/python/{cls_doc.title}" + elif lower_class_name.startswith(("ts", "jsx")): + return f"codebase-sdk/typescript/{cls_doc.title}" + else: + return f"codebase-sdk/core/{cls_doc.title}" + + +def format_type_string(type_string: str) -> str: + """Format a type string for MDX documentation. + + Args: + type_string: The type string to format. + + Returns: + The formatted type string. + """ + type_parts = type_string.split("|") + return " | ".join([type_str.strip() for type_str in type_parts]) + + +def resolve_type_string(type_string: str) -> str: + """Resolve a type string for MDX documentation. + + Args: + type_string: The type string to resolve. + + Returns: + The resolved type string. + """ + if "<" in type_string: + return f"<>{parse_link(type_string, href=True)}" + else: + return f'{format_type_string(type_string)}' + + +def format_builtin_type_string(type_string: str) -> str: + """Format a builtin type string for MDX documentation. + + Args: + type_string: The type string to format. + + Returns: + The formatted type string. + """ + if "|" in type_string: + type_strings = type_string.split("|") + return " | ".join([type_str.strip() for type_str in type_strings]) + return type_string + + +def span_type_string_by_pipe(type_string: str) -> str: + """Span a type string by pipe for MDX documentation. + + Args: + type_string: The type string to span. + + Returns: + The spanned type string. + """ + if "|" in type_string: + type_strings = type_string.split("|") + return " | ".join([f"{type_str.strip()}" for type_str in type_strings]) + return type_string + + +def parse_link(type_string: str, href: bool = False) -> str: + """Parse a link for MDX documentation. + + Args: + type_string: The type string to parse. + href: Whether to use href format. + + Returns: + The parsed link. + """ + # Match components with angle brackets, handling nested structures + parts = [p for p in re.split(r"(<[^>]+>)", type_string) if p] + + result = [] + for part in parts: + if part.startswith("<") and part.endswith(">"): + # Extract the path from between angle brackets + path = part[1:-1] + symbol = path.split("/")[-1] + + # Create a Link object + link = f'{symbol}' if href else f"[{symbol}](/{path})" + result.append(link) + else: + part = format_builtin_type_string(part) + if href: + result.append(f"{part.strip()}") + else: + result.append(part.strip()) + + return " ".join(result) diff --git a/codegen-on-oss/codegen_on_oss/analyzers/tests/__init__.py b/codegen-on-oss/codegen_on_oss/analyzers/tests/__init__.py new file mode 100644 index 000000000..a90a84b09 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/tests/__init__.py @@ -0,0 +1,2 @@ +"""Test package for the analyzers module.""" + diff --git a/codegen-on-oss/codegen_on_oss/analyzers/tests/test_mdx_docs_generation.py b/codegen-on-oss/codegen_on_oss/analyzers/tests/test_mdx_docs_generation.py new file mode 100644 index 000000000..8e98f459e --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/tests/test_mdx_docs_generation.py @@ -0,0 +1,181 @@ +"""Tests for the mdx_docs_generation module.""" + +import unittest + +from codegen_on_oss.analyzers.doc_utils.schemas import ClassDoc, MethodDoc, ParameterDoc +from codegen_on_oss.analyzers.mdx_docs_generation import ( + format_parameter_for_mdx, + format_parameters_for_mdx, + format_return_for_mdx, + get_mdx_route_for_class, + render_mdx_for_attribute, + render_mdx_for_method, + render_mdx_inheritence_section, + render_mdx_page_for_class, + render_mdx_page_title, + resolve_type_string, +) + + +class TestMdxDocsGeneration(unittest.TestCase): + """Test cases for the mdx_docs_generation module.""" + + def setUp(self): + """Set up test fixtures.""" + self.parameter_doc = ParameterDoc( + name="test_param", + description="A test parameter", + type="str", + default="'default'" + ) + + self.method_doc = MethodDoc( + name="test_method", + description="A test method", + parameters=[self.parameter_doc], + return_type=["bool"], + return_description="Returns a boolean", + method_type="method", + code="def test_method(test_param: str = 'default') -> bool:", + path="python/TestClass/test_method", + raises=[], + metainfo={}, + version="abc123", + github_url="https://github.com/example/repo/blob/main/test.py" + ) + + self.attribute_doc = MethodDoc( + name="test_attribute", + description="A test attribute", + parameters=[], + return_type=["str"], + return_description=None, + method_type="attribute", + code="test_attribute: str", + path="python/TestClass/test_attribute", + raises=[], + metainfo={}, + version="abc123", + github_url="https://github.com/example/repo/blob/main/test.py" + ) + + self.class_doc = ClassDoc( + title="TestClass", + description="A test class", + content="class TestClass:\n \"\"\"A test class\"\"\"\n pass", + path="python/TestClass", + inherits_from=["BaseClass"], + version="abc123", + methods=[self.method_doc], + attributes=[self.attribute_doc], + github_url="https://github.com/example/repo/blob/main/test.py" + ) + + def test_render_mdx_page_title(self): + """Test rendering MDX page title.""" + result = render_mdx_page_title(self.class_doc) + self.assertIn('title: "TestClass"', result) + self.assertIn('description: "A test class"', result) + + def test_render_mdx_inheritence_section(self): + """Test rendering MDX inheritance section.""" + result = render_mdx_inheritence_section(self.class_doc) + self.assertIn("### Inherits from", result) + self.assertIn("BaseClass", result) + + def test_render_mdx_for_attribute(self): + """Test rendering MDX for an attribute.""" + result = render_mdx_for_attribute(self.attribute_doc) + self.assertIn('### test_attribute', result) + self.assertIn('test_method', result) + self.assertIn('', result) + self.assertIn('str', simple_result) + + # Test complex type with link + complex_result = resolve_type_string("") + self.assertIn("<>", complex_result) + self.assertIn("[Symbol](/api-reference/core/Symbol)", complex_result) + + def test_render_mdx_page_for_class(self): + """Test rendering a complete MDX page for a class.""" + result = render_mdx_page_for_class(self.class_doc) + # Check that all sections are included + self.assertIn('title: "TestClass"', result) + self.assertIn("### Inherits from", result) + self.assertIn("## Attributes", result) + self.assertIn("## Methods", result) + self.assertIn('test_attribute', result) + self.assertIn('test_method', result) + + +if __name__ == "__main__": + unittest.main() +