diff --git a/codegen-on-oss/codegen_on_oss/analyzers/README.md b/codegen-on-oss/codegen_on_oss/analyzers/README.md index e268fbd32..7863c4401 100644 --- a/codegen-on-oss/codegen_on_oss/analyzers/README.md +++ b/codegen-on-oss/codegen_on_oss/analyzers/README.md @@ -1,248 +1,73 @@ -# CodeGen Analyzer +# Analyzers Package -The CodeGen Analyzer module provides comprehensive static analysis capabilities for codebases, focusing on code quality, dependencies, structure, and visualization. It serves as a backend API that can be used by frontend applications to analyze repositories. +This package provides tools for analyzing and modifying code during analysis. -## Architecture +## Transaction Manager -The analyzer system is built with a modular plugin-based architecture: +The `transaction_manager.py` module provides a transaction manager for handling code modifications during analysis. It's responsible for queuing, sorting, and committing transactions in a controlled manner. -``` -analyzers/ -├── api.py # Main API endpoints for frontend integration -├── analyzer.py # Plugin-based analyzer system -├── issues.py # Issue tracking and management -├── code_quality.py # Code quality analysis -├── dependencies.py # Dependency analysis -├── models/ -│ └── analysis_result.py # Data models for analysis results -├── context/ # Code context management -├── visualization/ # Visualization support -└── resolution/ # Issue resolution tools -``` - -## Core Components - -### 1. API Interface (`api.py`) - -The main entry point for frontend applications. Provides REST-like endpoints for: -- Codebase analysis -- PR analysis -- Dependency visualization -- Issue reporting -- Code quality assessment - -### 2. Analyzer System (`analyzer.py`) - -Plugin-based system that coordinates different types of analysis: -- Code quality analysis (complexity, maintainability) -- Dependency analysis (imports, cycles, coupling) -- PR impact analysis -- Type checking and error detection - -### 3. Issue Tracking (`issues.py`) - -Comprehensive issue model with: -- Severity levels (critical, error, warning, info) -- Categories (dead code, complexity, dependency, etc.) -- Location information and suggestions -- Filtering and grouping capabilities - -### 4. Dependency Analysis (`dependencies.py`) - -Analysis of codebase dependencies: -- Import dependencies between modules -- Circular dependency detection -- Module coupling analysis -- External dependencies tracking -- Call graphs and class hierarchies - -### 5. Code Quality Analysis (`code_quality.py`) - -Analysis of code quality aspects: -- Dead code detection (unused functions, variables) -- Complexity metrics (cyclomatic, cognitive) -- Parameter checking (types, usage) -- Style issues and maintainability - -## Using the API - -### Setup - -```python -from codegen_on_oss.analyzers.api import CodegenAnalyzerAPI - -# Create API instance with repository -api = CodegenAnalyzerAPI(repo_path="/path/to/repo") -# OR -api = CodegenAnalyzerAPI(repo_url="https://github.com/owner/repo") -``` - -### Analyzing a Codebase - -```python -# Run comprehensive analysis -results = api.analyze_codebase() - -# Run specific analysis types -results = api.analyze_codebase(analysis_types=["code_quality", "dependency"]) +### Key Features -# Force refresh of cached analysis -results = api.analyze_codebase(force_refresh=True) -``` - -### Analyzing a PR - -```python -# Analyze a specific PR -pr_results = api.analyze_pr(pr_number=123) - -# Get PR impact visualization -impact_viz = api.get_pr_impact(pr_number=123, format="json") -``` +- **Transaction Queuing**: Queue up code modifications to be applied later +- **Transaction Sorting**: Sort transactions by priority and position +- **Conflict Resolution**: Detect and resolve conflicts between transactions +- **Transaction Limits**: Set limits on the number of transactions and execution time +- **Bulk Commits**: Commit multiple transactions at once +- **Undo Support**: Revert transactions if needed -### Getting Issues +### Usage Example ```python -# Get all issues -all_issues = api.get_issues() +from codegen_on_oss.analyzers.transaction_manager import TransactionManager +from codegen_on_oss.analyzers.transactions import EditTransaction -# Get issues by severity -critical_issues = api.get_issues(severity="critical") -error_issues = api.get_issues(severity="error") +# Create a transaction manager +manager = TransactionManager() -# Get issues by category -dependency_issues = api.get_issues(category="dependency_cycle") -``` +# Set limits +manager.set_max_transactions(100) # Limit to 100 transactions +manager.reset_stopwatch(5) # Limit to 5 seconds -### Getting Visualizations +# Create a transaction +transaction = EditTransaction(start_byte=10, end_byte=20, file=file_obj, new_content="new code") -```python -# Get module dependency graph -module_deps = api.get_module_dependencies(format="json") - -# Get function call graph -call_graph = api.get_function_call_graph( - function_name="main", - depth=3, - format="json" -) - -# Export visualization to file -api.export_visualization(call_graph, format="html", filename="call_graph.html") -``` +# Add the transaction to the queue +manager.add_transaction(transaction) -### Common Analysis Patterns +# Commit all transactions +files_to_commit = manager.to_commit() +diffs = manager.commit(files_to_commit) -```python -# Find dead code -api.analyze_codebase(analysis_types=["code_quality"]) -dead_code = api.get_issues(category="dead_code") +# Or apply a single transaction immediately +manager.apply(transaction) -# Find circular dependencies -api.analyze_codebase(analysis_types=["dependency"]) -circular_deps = api.get_circular_dependencies() +# Or apply all transactions at once +diffs = manager.apply_all() -# Find parameter issues -api.analyze_codebase(analysis_types=["code_quality"]) -param_issues = api.get_parameter_issues() +# Revert all transactions +manager.revert_all() ``` -## REST API Endpoints +### Transaction Types -The analyzer can be exposed as REST API endpoints for integration with frontend applications: +The following transaction types are supported: -### Codebase Analysis +- **EditTransaction**: Replace content in a file +- **InsertTransaction**: Insert content at a specific position +- **RemoveTransaction**: Remove content from a file +- **FileAddTransaction**: Add a new file +- **FileRenameTransaction**: Rename a file +- **FileRemoveTransaction**: Remove a file -``` -POST /api/analyze/codebase -{ - "repo_path": "/path/to/repo", - "analysis_types": ["code_quality", "dependency"] -} -``` +### Error Handling -### PR Analysis +The transaction manager can raise the following exceptions: -``` -POST /api/analyze/pr -{ - "repo_path": "/path/to/repo", - "pr_number": 123 -} -``` - -### Visualization +- **MaxTransactionsExceeded**: Raised when the number of transactions exceeds the limit +- **MaxPreviewTimeExceeded**: Raised when the execution time exceeds the limit +- **TransactionError**: Raised when there's a conflict between transactions -``` -POST /api/visualize -{ - "repo_path": "/path/to/repo", - "viz_type": "module_dependencies", - "params": { - "layout": "hierarchical", - "format": "json" - } -} -``` +### Integration with Analyzers -### Issues +The transaction manager is designed to be used with the analyzers package to provide a consistent way to modify code during analysis. It can be integrated with other components of the analyzers package to provide a complete code analysis and modification solution. -``` -GET /api/issues?severity=error&category=dependency_cycle -``` - -## Implementation Example - -For a web application exposing these endpoints with Flask: - -```python -from flask import Flask, request, jsonify -from codegen_on_oss.analyzers.api import ( - api_analyze_codebase, - api_analyze_pr, - api_get_visualization, - api_get_static_errors -) - -app = Flask(__name__) - -@app.route("/api/analyze/codebase", methods=["POST"]) -def analyze_codebase(): - data = request.json - result = api_analyze_codebase( - repo_path=data.get("repo_path"), - analysis_types=data.get("analysis_types") - ) - return jsonify(result) - -@app.route("/api/analyze/pr", methods=["POST"]) -def analyze_pr(): - data = request.json - result = api_analyze_pr( - repo_path=data.get("repo_path"), - pr_number=data.get("pr_number") - ) - return jsonify(result) - -@app.route("/api/visualize", methods=["POST"]) -def visualize(): - data = request.json - result = api_get_visualization( - repo_path=data.get("repo_path"), - viz_type=data.get("viz_type"), - params=data.get("params", {}) - ) - return jsonify(result) - -@app.route("/api/issues", methods=["GET"]) -def get_issues(): - repo_path = request.args.get("repo_path") - severity = request.args.get("severity") - category = request.args.get("category") - - api = create_api(repo_path=repo_path) - return jsonify(api.get_issues(severity=severity, category=category)) - -if __name__ == "__main__": - app.run(debug=True) -``` \ No newline at end of file diff --git a/codegen-on-oss/codegen_on_oss/analyzers/transaction_manager.py b/codegen-on-oss/codegen_on_oss/analyzers/transaction_manager.py new file mode 100644 index 000000000..7efd254bd --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/transaction_manager.py @@ -0,0 +1,589 @@ +#!/usr/bin/env python3 +""" +Transaction Manager Module for Analyzers + +This module provides a transaction manager for handling code modifications during analysis. +It's responsible for queuing, sorting, and committing transactions in a controlled manner. +""" + +import logging +import math +import time +from collections.abc import Callable +from pathlib import Path +from typing import Any + +from codegen_on_oss.analyzers.transactions import ( + ChangeType, + DiffLite, + EditTransaction, + FileAddTransaction, + FileRemoveTransaction, + FileRenameTransaction, + RemoveTransaction, + Transaction, + TransactionPriority, +) + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], +) +logger = logging.getLogger(__name__) + + +class MaxTransactionsExceeded(Exception): + """Raised when the number of transactions exceeds the max_transactions limit.""" + + def __init__(self, message: str, threshold: int | None = None): + super().__init__(message) + self.threshold = threshold + + +class MaxPreviewTimeExceeded(Exception): + """Raised when more than the allotted time has passed for previewing transactions.""" + + def __init__(self, message: str, threshold: int | None = None): + super().__init__(message) + self.threshold = threshold + + +class TransactionError(Exception): + """Exception raised for transaction-related errors.""" + + pass + + +class TransactionManager: + """Responsible for handling `Transaction` objects - basically an atomic modification of a codebase. + + This is used to queue up transactions and then commit them in bulk. + """ + + def __init__(self) -> None: + """Initialize the transaction manager.""" + self.queued_transactions: dict[Path, list[Transaction]] = {} + self.pending_undos: set[Callable[[], None]] = set() + self._commiting: bool = False + self.max_transactions: int | None = None # None = no limit + self.stopwatch_start: float | None = None + self.stopwatch_max_seconds: int | None = None # None = no limit + self.session: dict[str, Any] = {} # Session data for tracking state + + def sort_transactions(self) -> None: + """Sort transactions by priority and position.""" + for _file_path, file_transactions in self.queued_transactions.items(): + file_transactions.sort(key=Transaction._to_sort_key) + + def clear_transactions(self) -> None: + """Clear all transactions and reset limits. + + Should be called between analysis runs to remove any potential extraneous transactions. + """ + if len(self.queued_transactions) > 0: + logger.warning("Not all transactions have been committed") + self.queued_transactions.clear() + for undo in self.pending_undos: + undo() + self.pending_undos.clear() + self.set_max_transactions(None) + self.reset_stopwatch() + + def _format_transactions(self, transactions: list[Transaction]) -> str: + """Format transactions for display.""" + return "\\n".join([ + ">" * 100 + f"\\n[ID: {t.transaction_id}]: {t.diff_str()}" + "<" * 100 + for t in transactions + ]) + + def get_transactions_str(self) -> str: + """Returns a human-readable string representation of the transactions.""" + return "\\n\\n\\n".join([ + f"{file_path}:\\n{self._format_transactions(transactions)}" + for file_path, transactions in self.queued_transactions.items() + ]) + + #################################################################################################################### + # Transaction Limits + #################################################################################################################### + + def get_num_transactions(self) -> int: + """Returns total number of transactions created to date.""" + return sum([ + len(transactions) for transactions in self.queued_transactions.values() + ]) + + def set_max_transactions(self, max_transactions: int | None = None) -> None: + """Set the maximum number of transactions allowed.""" + self.max_transactions = max_transactions + + def max_transactions_exceeded(self) -> bool: + """Util method to check if the max transactions limit has been exceeded.""" + if self.max_transactions is None: + return False + return self.get_num_transactions() >= self.max_transactions + + #################################################################################################################### + # Stopwatch + #################################################################################################################### + + def reset_stopwatch(self, max_seconds: int | None = None) -> None: + """Reset the stopwatch with an optional time limit.""" + self.stopwatch_start = time.time() + self.stopwatch_max_seconds = max_seconds + + def is_time_exceeded(self) -> bool: + """Check if the stopwatch time limit has been exceeded.""" + if self.stopwatch_max_seconds is None or self.stopwatch_start is None: + return False + else: + num_seconds = time.time() - self.stopwatch_start + return num_seconds > self.stopwatch_max_seconds + + #################################################################################################################### + # Transaction Creation + #################################################################################################################### + + def add_file_add_transaction(self, filepath: Path) -> None: + """Add a transaction to create a new file.""" + t = FileAddTransaction(filepath) + self.add_transaction(t) + + def add_file_rename_transaction(self, file: Any, new_filepath: str) -> None: + """Add a transaction to rename a file.""" + t = FileRenameTransaction(file, new_filepath) + self.add_transaction(t) + + def add_file_remove_transaction(self, file: Any) -> None: + """Add a transaction to remove a file.""" + t = FileRemoveTransaction(file) + self.add_transaction(t) + + def add_transaction( + self, + transaction: Transaction, + dedupe: bool = True, + solve_conflicts: bool = True, + ) -> bool: + """Add a transaction to the queue. + + Args: + transaction: The transaction to add + dedupe: Whether to check for duplicate transactions + solve_conflicts: Whether to resolve conflicts with existing transactions + + Returns: + True if the transaction was added, False otherwise + """ + # Get the list of transactions for the file + file_path = transaction.file_path + if file_path not in self.queued_transactions: + self.queued_transactions[file_path] = [] + file_queue = self.queued_transactions[file_path] + + # Dedupe transactions + if dedupe and transaction in file_queue: + logger.debug(f"Transaction already exists in queue: {transaction}") + return False + + # Solve conflicts + if new_transaction := self._resolve_conflicts( + transaction, file_queue, solve_conflicts=solve_conflicts + ): + file_queue.append(new_transaction) + + self.check_limits() + return True + + def add(self, transaction: Transaction) -> bool: + """Alias for add_transaction.""" + return self.add_transaction(transaction) + + def check_limits(self) -> None: + """Check if any limits have been exceeded.""" + self.check_max_transactions() + self.check_max_preview_time() + + def check_max_transactions(self) -> None: + """Check if the maximum number of transactions has been exceeded.""" + if self.max_transactions_exceeded(): + logger.info( + f"Max transactions reached: {self.max_transactions}. Stopping analysis." + ) + msg = f"Max transactions reached: {self.max_transactions}" + raise MaxTransactionsExceeded(msg, threshold=self.max_transactions) + + def check_max_preview_time(self) -> None: + """Check if the maximum preview time has been exceeded.""" + if self.is_time_exceeded(): + logger.info( + f"Max preview time exceeded: {self.stopwatch_max_seconds}. Stopping analysis." + ) + msg = f"Max preview time exceeded: {self.stopwatch_max_seconds}" + raise MaxPreviewTimeExceeded(msg, threshold=self.stopwatch_max_seconds) + + #################################################################################################################### + # Commit + #################################################################################################################### + + def to_commit(self, files: set[Path] | None = None) -> set[Path]: + """Get paths of files to commit. + + Args: + files: Optional set of files to filter by + + Returns: + Set of file paths to commit + """ + if files is None: + return set(self.queued_transactions.keys()) + return files.intersection(self.queued_transactions) + + def commit(self, files: set[Path]) -> list[DiffLite]: + """Execute transactions in bulk for each file, in reverse order of start_byte. + + Args: + files: Set of file paths to commit + + Returns: + List of diffs that were committed + """ + if self._commiting: + logger.warning("Skipping commit, already committing") + return [] + + self._commiting = True + try: + diffs: list[DiffLite] = [] + if not self.queued_transactions or len(self.queued_transactions) == 0: + return diffs + + self.sort_transactions() + + # Log information about the commit + if len(files) > 3: + num_transactions = sum([ + len(self.queued_transactions[file_path]) for file_path in files + ]) + logger.info( + f"Committing {num_transactions} transactions for {len(files)} files" + ) + else: + for file in files: + logger.info( + f"Committing {len(self.queued_transactions[file])} transactions for {file}" + ) + + # Execute transactions for each file + for file_path in files: + file_transactions = self.queued_transactions.pop(file_path, []) + modified = False + for transaction in file_transactions: + # Add diff IF the file is a source file + diff = transaction.get_diff() + if diff.change_type == ChangeType.Modified: + if not modified: + modified = True + diffs.append(diff) + else: + diffs.append(diff) + transaction.execute() + + return diffs + finally: + self._commiting = False + + def apply(self, transaction: Transaction) -> None: + """Apply a single transaction immediately. + + Args: + transaction: The transaction to apply + """ + self.add_transaction(transaction) + self.commit({transaction.file_path}) + + def apply_all(self) -> list[DiffLite]: + """Apply all queued transactions. + + Returns: + List of diffs that were committed + """ + files = self.to_commit() + return self.commit(files) + + def revert_all(self) -> None: + """Revert all pending transactions.""" + self.queued_transactions.clear() + for undo in self.pending_undos: + undo() + self.pending_undos.clear() + + #################################################################################################################### + # Conflict Resolution + #################################################################################################################### + + def _resolve_conflicts( + self, + transaction: Transaction, + file_queue: list[Transaction], + solve_conflicts: bool = True, + ) -> Transaction | None: + """Resolve conflicts between the new transaction and existing transactions. + + Args: + transaction: The new transaction + file_queue: List of existing transactions for the file + solve_conflicts: Whether to attempt to resolve conflicts + + Returns: + The transaction to add, or None if it should be discarded + """ + # Extract the conflict resolution logic to reduce complexity + try: + conflicts = self._get_conflicts(transaction) + if solve_conflicts and conflicts: + return self._handle_conflicts(transaction, file_queue, conflicts) + else: + # Add to priority queue and rebuild the queue + return transaction + except TransactionError: + logger.exception("Transaction conflict detected") + self._log_conflict_error(transaction, self._get_conflicts(transaction)) + raise + + def _handle_conflicts( + self, + transaction: Transaction, + file_queue: list[Transaction], + conflicts: list[Transaction], + ) -> Transaction | None: + """Handle conflicts between transactions. + + Args: + transaction: The new transaction + file_queue: List of existing transactions for the file + conflicts: List of conflicting transactions + + Returns: + The transaction to add, or None if it should be discarded + """ + # Check if the current transaction completely overlaps with any existing transaction + completely_overlapping = self._get_overlapping_conflicts(transaction) + if completely_overlapping is not None: + # If it does, check the overlapping transaction's type + # If the overlapping transaction is a remove, remove the current transaction + if isinstance(completely_overlapping, RemoveTransaction): + return None + # If the overlapping transaction is an edit, try to break it down + elif isinstance(completely_overlapping, EditTransaction): + if self._break_down_transaction(completely_overlapping, file_queue): + return transaction + + raise TransactionError() + else: + # If current transaction is deleted, remove all conflicting transactions + if isinstance(transaction, RemoveTransaction): + for t in conflicts: + file_queue.remove(t) + # If current transaction is edit, try to break it down + elif isinstance(transaction, EditTransaction): + if self._break_down_transaction(transaction, file_queue): + return None + raise TransactionError() + + return transaction + + def _break_down_transaction( + self, to_break: EditTransaction, file_queue: list[Transaction] + ) -> bool: + """Break down an edit transaction into smaller transactions. + + Args: + to_break: The transaction to break down + file_queue: List of existing transactions for the file + + Returns: + True if the transaction was broken down, False otherwise + """ + new_transactions = to_break.break_down() + if not new_transactions: + return False + + try: + insert_idx = file_queue.index(to_break) + file_queue.pop(insert_idx) + except ValueError: + insert_idx = len(file_queue) + + for new_transaction in new_transactions: + broken_down = self._resolve_conflicts( + new_transaction, file_queue, solve_conflicts=True + ) + if broken_down: + file_queue.insert(insert_idx, broken_down) + + return True + + def _log_conflict_error( + self, transaction: Transaction, conflicts: list[Transaction] + ) -> None: + """Log a conflict error. + + Args: + transaction: The transaction that caused the conflict + conflicts: List of conflicting transactions + """ + msg = ( + f"Potential conflict detected in file {transaction.file_path}!\n" + "Attempted to perform code modification:\n" + "\n" + f"{self._format_transactions([transaction])}\n" + "\n" + "That potentially conflicts with the following other modifications:\n" + "\n" + f"{self._format_transactions(conflicts)}\n" + "\n" + "Aborting!\n" + "\n" + f"[Conflict Detected] Potential Modification Conflict in File {transaction.file_path}!" + ) + raise TransactionError(msg) + + def get_transactions_at_range( + self, + file_path: Path, + start_byte: int, + end_byte: int, + transaction_order: TransactionPriority | None = None, + *, + combined: bool = False, + ) -> list[Transaction]: + """Returns list of queued transactions that matches the given filtering criteria. + + Args: + file_path: Path to the file + start_byte: Start byte position + end_byte: End byte position + transaction_order: Optional filter by transaction order + combined: Return a list of transactions which collectively apply to the given range + + Returns: + List of matching transactions + """ + matching_transactions: list[Transaction] = [] + if file_path not in self.queued_transactions: + return matching_transactions + + for t in self.queued_transactions[file_path]: + if t.start_byte == start_byte: + if t.end_byte == end_byte and ( + transaction_order is None + or t.transaction_order == transaction_order + ): + matching_transactions.append(t) + elif combined and t.start_byte != t.end_byte: + other = self.get_transactions_at_range( + t.file_path, + t.end_byte, + end_byte, + transaction_order, + combined=combined, + ) + if other: + return [t, *other] + + return matching_transactions + + def get_transaction_containing_range( + self, + file_path: Path, + start_byte: int, + end_byte: int, + transaction_order: TransactionPriority | None = None, + ) -> Transaction | None: + """Returns the nearest transaction that includes the range specified given the filtering criteria. + + Args: + file_path: Path to the file + start_byte: Start byte position + end_byte: End byte position + transaction_order: Optional filter by transaction order + + Returns: + The transaction containing the range, or None if not found + """ + if file_path not in self.queued_transactions: + return None + + smallest_difference = math.inf + best_fit_transaction = None + for t in self.queued_transactions[file_path]: + if ( + t.start_byte <= start_byte + and t.end_byte >= end_byte + and ( + transaction_order is None + or t.transaction_order == transaction_order + ) + ): + smallest_difference = min( + smallest_difference, + abs(t.start_byte - start_byte) + abs(t.end_byte - end_byte), + ) + if smallest_difference == 0: + return t + best_fit_transaction = t + return best_fit_transaction + + def _get_conflicts(self, transaction: Transaction) -> list[Transaction]: + """Returns all transactions that overlap with the given transaction. + + Args: + transaction: The transaction to check for conflicts + + Returns: + List of conflicting transactions + """ + conflicts: list[Transaction] = [] + if transaction.file_path not in self.queued_transactions: + return conflicts + + for t in self.queued_transactions[transaction.file_path]: + # Skip if the transaction is the same + if t == transaction: + continue + + # Check if the transaction overlaps with the given transaction + if ( + (t.start_byte <= transaction.start_byte < t.end_byte) + or (t.start_byte < transaction.end_byte <= t.end_byte) + or (transaction.start_byte <= t.start_byte < transaction.end_byte) + or (transaction.start_byte < t.end_byte <= transaction.end_byte) + ): + conflicts.append(t) + + return conflicts + + def _get_overlapping_conflicts( + self, transaction: Transaction + ) -> Transaction | None: + """Returns the transaction that completely overlaps with the given transaction. + + Args: + transaction: The transaction to check for overlaps + + Returns: + The overlapping transaction, or None if not found + """ + if transaction.file_path not in self.queued_transactions: + return None + + for t in self.queued_transactions[transaction.file_path]: + if ( + transaction.start_byte >= t.start_byte + and transaction.end_byte <= t.end_byte + ): + return t + return None diff --git a/codegen-on-oss/codegen_on_oss/analyzers/transactions.py b/codegen-on-oss/codegen_on_oss/analyzers/transactions.py new file mode 100644 index 000000000..b3ead5446 --- /dev/null +++ b/codegen-on-oss/codegen_on_oss/analyzers/transactions.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python3 +""" +Transactions Module for Analyzers + +This module defines transaction classes for code modifications during analysis. +It provides a structured way to represent and execute code changes. +""" + +from collections.abc import Callable +from difflib import unified_diff +from enum import IntEnum +from functools import cached_property +from pathlib import Path +from typing import Protocol, runtime_checkable, Optional, Union, Any, TYPE_CHECKING + +# Define change types for diffs +class ChangeType(IntEnum): + """Types of changes that can be made to files.""" + Modified = 1 + Removed = 2 + Renamed = 3 + Added = 4 + +# Simple diff class for tracking changes +class DiffLite: + """Simple diff for tracking code changes.""" + + def __init__( + self, + change_type: ChangeType, + path: Path, + rename_from: Optional[Path] = None, + rename_to: Optional[Path] = None, + old_content: Optional[bytes] = None + ): + self.change_type = change_type + self.path = path + self.rename_from = rename_from + self.rename_to = rename_to + self.old_content = old_content + +class TransactionPriority(IntEnum): + """Priority levels for different types of transactions.""" + Remove = 0 # Remove always has highest priority + Edit = 1 # Edit comes next + Insert = 2 # Insert is always the last of the edit operations + # File operations happen last, since they will mess up all other transactions + FileAdd = 10 + FileRename = 11 + FileRemove = 12 + +@runtime_checkable +class ContentFunc(Protocol): + """A function executed to generate a content block dynamically.""" + def __call__(self) -> str: ... + +class Transaction: + """Base class for all transactions. + + A transaction represents an atomic modification to a file in the codebase. + """ + start_byte: int + end_byte: int + file_path: Path + priority: Union[int, tuple] + transaction_order: TransactionPriority + transaction_counter: int = 0 + + def __init__( + self, + start_byte: int, + end_byte: int, + file_path: Path, + priority: Union[int, tuple] = 0, + new_content: Optional[Union[str, Callable[[], str]]] = None, + ) -> None: + self.start_byte = start_byte + assert self.start_byte >= 0 + self.end_byte = end_byte + self.file_path = file_path + self.priority = priority + self._new_content = new_content + self.transaction_id = Transaction.transaction_counter + + Transaction.transaction_counter += 1 + + def __repr__(self) -> str: + return f"" + + def __hash__(self): + return hash((self.start_byte, self.end_byte, self.file_path, self.priority, self.new_content)) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + + # Check for everything EXCEPT transaction_id + return ( + self.start_byte == other.start_byte + and self.end_byte == other.end_byte + and self.file_path == other.file_path + and self.priority == other.priority + and self._new_content == other._new_content + ) + + @property + def length(self): + """Length of the transaction in bytes.""" + return self.end_byte - self.start_byte + + def execute(self): + """Execute the transaction to modify the file.""" + msg = "Transaction.execute() must be implemented by subclasses" + raise NotImplementedError(msg) + + def get_diff(self) -> DiffLite: + """Gets the diff produced by this transaction.""" + msg = "Transaction.get_diff() must be implemented by subclasses" + raise NotImplementedError(msg) + + def diff_str(self): + """Human-readable string representation of the change.""" + msg = "Transaction.diff_str() must be implemented by subclasses" + raise NotImplementedError(msg) + + def _to_sort_key(transaction: "Transaction"): + """Key function for sorting transactions.""" + # Sort by: + # 1. Descending start_byte + # 2. Ascending transaction type + # 3. Ascending priority + # 4. Descending time of transaction + priority = (transaction.priority,) if isinstance(transaction.priority, int) else transaction.priority + + return -transaction.start_byte, transaction.transaction_order.value, priority, -transaction.transaction_id + + @cached_property + def new_content(self) -> Optional[str]: + """Get the new content, evaluating the content function if necessary.""" + return self._new_content() if isinstance(self._new_content, ContentFunc) else self._new_content + + @staticmethod + def create_new_file(filepath: Union[str, Path], content: str) -> "FileAddTransaction": + """Create a transaction to add a new file.""" + return FileAddTransaction(Path(filepath)) + + @staticmethod + def delete_file(filepath: Union[str, Path]) -> "FileRemoveTransaction": + """Create a transaction to delete a file.""" + # In a real implementation, this would need a File object + # For now, we'll create a placeholder implementation + from pathlib import Path + class FilePlaceholder: + def __init__(self, path): + self.path = Path(path) + + return FileRemoveTransaction(FilePlaceholder(filepath)) + +class RemoveTransaction(Transaction): + """Transaction to remove content from a file.""" + transaction_order = TransactionPriority.Remove + + exec_func: Optional[Callable[[], None]] = None + + def __init__(self, start_byte: int, end_byte: int, file: Any, priority: int = 0, exec_func: Optional[Callable[[], None]] = None) -> None: + super().__init__(start_byte, end_byte, file.path, priority=priority) + self.file = file + self.exec_func = exec_func + + def _generate_new_content_bytes(self) -> bytes: + """Generate the new content bytes after removal.""" + content_bytes = self.file.content_bytes + new_content_bytes = content_bytes[: self.start_byte] + content_bytes[self.end_byte :] + return new_content_bytes + + def execute(self) -> None: + """Removes the content between start_byte and end_byte.""" + self.file.write_bytes(self._generate_new_content_bytes()) + if self.exec_func: + self.exec_func() + + def get_diff(self) -> DiffLite: + """Gets the diff produced by this transaction.""" + return DiffLite(ChangeType.Modified, self.file_path, old_content=self.file.content_bytes) + + def diff_str(self) -> str: + """Human-readable string representation of the change.""" + diff = "".join(unified_diff(self.file.content.splitlines(True), self._generate_new_content_bytes().decode("utf-8").splitlines(True))) + return f"Remove {self.length} bytes at bytes ({self.start_byte}, {self.end_byte})\n{diff}" + +class InsertTransaction(Transaction): + """Transaction to insert content into a file.""" + transaction_order = TransactionPriority.Insert + + exec_func: Optional[Callable[[], None]] = None + + def __init__( + self, + insert_byte: int, + file: Any, + new_content: Union[str, Callable[[], str]], + *, + priority: Union[int, tuple] = 0, + exec_func: Optional[Callable[[], None]] = None, + ) -> None: + super().__init__(insert_byte, insert_byte, file.path, priority=priority, new_content=new_content) + self.insert_byte = insert_byte + self.file = file + self.exec_func = exec_func + + def _generate_new_content_bytes(self) -> bytes: + """Generate the new content bytes after insertion.""" + if self.new_content is None: + raise ValueError("Cannot generate content bytes: new_content is None") + new_bytes = bytes(self.new_content, encoding="utf-8") + content_bytes = self.file.content_bytes + head = content_bytes[: self.insert_byte] + tail = content_bytes[self.insert_byte :] + new_content_bytes = head + new_bytes + tail + return new_content_bytes + + def execute(self) -> None: + """Inserts new_src at the specified byte_index.""" + self.file.write_bytes(self._generate_new_content_bytes()) + if self.exec_func: + self.exec_func() + + def get_diff(self) -> DiffLite: + """Gets the diff produced by this transaction.""" + return DiffLite(ChangeType.Modified, self.file_path, old_content=self.file.content_bytes) + + def diff_str(self) -> str: + """Human-readable string representation of the change.""" + diff = "".join(unified_diff(self.file.content.splitlines(True), self._generate_new_content_bytes().decode("utf-8").splitlines(True))) + content_length = len(self.new_content) if self.new_content is not None else 0 + return f"Insert {content_length} bytes at bytes ({self.start_byte}, {self.end_byte})\n{diff}" + +class EditTransaction(Transaction): + """Transaction to edit content in a file.""" + transaction_order = TransactionPriority.Edit + new_content: str + + def __init__( + self, + start_byte: int, + end_byte: int, + file: Any, + new_content: str, + priority: int = 0, + ) -> None: + super().__init__(start_byte, end_byte, file.path, priority=priority, new_content=new_content) + self.file = file + + def _generate_new_content_bytes(self) -> bytes: + """Generate the new content bytes after editing.""" + new_bytes = bytes(self.new_content, "utf-8") + content_bytes = self.file.content_bytes + new_content_bytes = content_bytes[: self.start_byte] + new_bytes + content_bytes[self.end_byte :] + return new_content_bytes + + def execute(self) -> None: + """Edits the entirety of this node's source to new_src.""" + self.file.write_bytes(self._generate_new_content_bytes()) + + def get_diff(self) -> DiffLite: + """Gets the diff produced by this transaction.""" + return DiffLite(ChangeType.Modified, self.file_path, old_content=self.file.content_bytes) + + def diff_str(self) -> str: + """Human-readable string representation of the change.""" + diff = "".join(unified_diff(self.file.content.splitlines(True), self._generate_new_content_bytes().decode("utf-8").splitlines(True))) + return f"Edit {self.length} bytes at bytes ({self.start_byte}, {self.end_byte}), src: ({self.new_content[:50]})\n{diff}" + + def break_down(self) -> Optional[list[InsertTransaction]]: + """Break down an edit transaction into insert transactions.""" + old = self.file.content_bytes[self.start_byte : self.end_byte] + new = bytes(self.new_content, "utf-8") + if old and old in new: + prefix, suffix = new.split(old, maxsplit=1) + ret = [] + if suffix: + ret.append(InsertTransaction(self.end_byte, self.file, suffix.decode("utf-8"), priority=self.priority)) + if prefix: + ret.append(InsertTransaction(self.start_byte, self.file, prefix.decode("utf-8"), priority=self.priority)) + return ret + return None + +class FileAddTransaction(Transaction): + """Transaction to add a new file.""" + transaction_order = TransactionPriority.FileAdd + + def __init__( + self, + file_path: Path, + priority: int = 0, + ) -> None: + super().__init__(0, 0, file_path, priority=priority) + + def execute(self) -> None: + """Adds a new file.""" + pass # execute is a no-op as the file is immediately added + + def get_diff(self) -> DiffLite: + """Gets the diff produced by this transaction.""" + return DiffLite(ChangeType.Added, self.file_path) + + def diff_str(self) -> str: + """Human-readable string representation of the change.""" + return f"Add file at {self.file_path}" + +class FileRenameTransaction(Transaction): + """Transaction to rename a file.""" + transaction_order = TransactionPriority.FileRename + + def __init__( + self, + file: Any, + new_file_path: str, + priority: int = 0, + ) -> None: + super().__init__(0, 0, file.path, priority=priority, new_content=new_file_path) + self.new_file_path = file.ctx.to_absolute(new_file_path) if hasattr(file, 'ctx') else Path(new_file_path) + self.file = file + + def execute(self) -> None: + """Renames the file.""" + if hasattr(self.file, 'ctx') and hasattr(self.file.ctx, 'io'): + self.file.ctx.io.save_files({self.file.path}) + self.file_path.rename(self.new_file_path) + + def get_diff(self) -> DiffLite: + """Gets the diff produced by this transaction.""" + return DiffLite(ChangeType.Renamed, self.file_path, self.file_path, self.new_file_path, + old_content=self.file.content_bytes if hasattr(self.file, 'content_bytes') else None) + + def diff_str(self) -> str: + """Human-readable string representation of the change.""" + return f"Rename file from {self.file_path} to {self.new_file_path}" + +class FileRemoveTransaction(Transaction): + """Transaction to remove a file.""" + transaction_order = TransactionPriority.FileRemove + + def __init__( + self, + file: Any, + priority: int = 0, + ) -> None: + super().__init__(0, 0, file.path, priority=priority) + self.file = file + + def execute(self) -> None: + """Removes the file.""" + if hasattr(self.file, 'ctx') and hasattr(self.file.ctx, 'io'): + self.file.ctx.io.delete_file(self.file.path) + else: + # Fallback for when ctx.io is not available + import os + if os.path.exists(self.file_path): + os.remove(self.file_path) + + def get_diff(self) -> DiffLite: + """Gets the diff produced by this transaction.""" + return DiffLite(ChangeType.Removed, self.file_path, + old_content=self.file.content_bytes if hasattr(self.file, 'content_bytes') else None) + + def diff_str(self) -> str: + """Human-readable string representation of the change.""" + return f"Remove file at {self.file_path}" diff --git a/codegen-on-oss/tests/analyzers/test_transaction_manager.py b/codegen-on-oss/tests/analyzers/test_transaction_manager.py new file mode 100644 index 000000000..0e9d5e4f0 --- /dev/null +++ b/codegen-on-oss/tests/analyzers/test_transaction_manager.py @@ -0,0 +1,467 @@ +#!/usr/bin/env python3 +""" +Tests for the Transaction Manager module in the analyzers package. +""" + +import os +import tempfile +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +from codegen_on_oss.analyzers.transaction_manager import ( + TransactionManager, + MaxTransactionsExceeded, + MaxPreviewTimeExceeded, + TransactionError, +) +from codegen_on_oss.analyzers.transactions import ( + Transaction, + EditTransaction, + InsertTransaction, + RemoveTransaction, + FileAddTransaction, + FileRemoveTransaction, + FileRenameTransaction, + TransactionPriority, + ChangeType, + DiffLite, +) + +class TestTransactionManager(unittest.TestCase): + """Test cases for the TransactionManager class.""" + + def setUp(self): + """Set up test fixtures.""" + self.manager = TransactionManager() + + # Create a temporary file for testing + self.temp_dir = tempfile.TemporaryDirectory() + self.test_file_path = Path(os.path.join(self.temp_dir.name, "test_file.txt")) + with open(self.test_file_path, "w") as f: + f.write("This is a test file content.") + + # Create a mock file object + self.mock_file = MagicMock() + self.mock_file.path = self.test_file_path + self.mock_file.content = "This is a test file content." + self.mock_file.content_bytes = b"This is a test file content." + self.mock_file.write_bytes = MagicMock() + + def tearDown(self): + """Clean up test fixtures.""" + self.temp_dir.cleanup() + + def test_init(self): + """Test initialization of TransactionManager.""" + self.assertEqual(self.manager.queued_transactions, {}) + self.assertEqual(self.manager.pending_undos, set()) + self.assertFalse(self.manager._commiting) + self.assertIsNone(self.manager.max_transactions) + self.assertIsNone(self.manager.stopwatch_max_seconds) + + def test_add_transaction(self): + """Test adding a transaction to the manager.""" + transaction = EditTransaction(0, 5, self.mock_file, "New") + result = self.manager.add_transaction(transaction) + + self.assertTrue(result) + self.assertIn(self.test_file_path, self.manager.queued_transactions) + self.assertEqual(len(self.manager.queued_transactions[self.test_file_path]), 1) + self.assertEqual(self.manager.queued_transactions[self.test_file_path][0], transaction) + + def test_add_duplicate_transaction(self): + """Test adding a duplicate transaction.""" + transaction = EditTransaction(0, 5, self.mock_file, "New") + self.manager.add_transaction(transaction) + result = self.manager.add_transaction(transaction) + + self.assertFalse(result) + self.assertEqual(len(self.manager.queued_transactions[self.test_file_path]), 1) + + def test_sort_transactions(self): + """Test sorting transactions.""" + # Add transactions in reverse order + t1 = EditTransaction(10, 15, self.mock_file, "Edit1") + t2 = InsertTransaction(5, self.mock_file, "Insert") + t3 = RemoveTransaction(0, 5, self.mock_file) + + self.manager.add_transaction(t1) + self.manager.add_transaction(t2) + self.manager.add_transaction(t3) + + self.manager.sort_transactions() + + # Check that they're sorted by start_byte (descending) and transaction_order + sorted_transactions = self.manager.queued_transactions[self.test_file_path] + self.assertEqual(sorted_transactions[0], t1) # EditTransaction at byte 10 + self.assertEqual(sorted_transactions[1], t2) # InsertTransaction at byte 5 + self.assertEqual(sorted_transactions[2], t3) # RemoveTransaction at byte 0 + + def test_clear_transactions(self): + """Test clearing transactions.""" + transaction = EditTransaction(0, 5, self.mock_file, "New") + self.manager.add_transaction(transaction) + + # Add a mock undo function + mock_undo = MagicMock() + self.manager.pending_undos.add(mock_undo) + + self.manager.clear_transactions() + + self.assertEqual(self.manager.queued_transactions, {}) + self.assertEqual(self.manager.pending_undos, set()) + mock_undo.assert_called_once() + + def test_get_num_transactions(self): + """Test getting the number of transactions.""" + self.assertEqual(self.manager.get_num_transactions(), 0) + + t1 = EditTransaction(0, 5, self.mock_file, "Edit1") + t2 = InsertTransaction(5, self.mock_file, "Insert") + + self.manager.add_transaction(t1) + self.manager.add_transaction(t2) + + self.assertEqual(self.manager.get_num_transactions(), 2) + + def test_set_max_transactions(self): + """Test setting the maximum number of transactions.""" + self.assertIsNone(self.manager.max_transactions) + + self.manager.set_max_transactions(10) + self.assertEqual(self.manager.max_transactions, 10) + + self.manager.set_max_transactions(None) + self.assertIsNone(self.manager.max_transactions) + + def test_max_transactions_exceeded(self): + """Test checking if max transactions is exceeded.""" + self.assertFalse(self.manager.max_transactions_exceeded()) + + self.manager.set_max_transactions(2) + self.assertFalse(self.manager.max_transactions_exceeded()) + + t1 = EditTransaction(0, 5, self.mock_file, "Edit1") + t2 = InsertTransaction(5, self.mock_file, "Insert") + + self.manager.add_transaction(t1) + self.manager.add_transaction(t2) + + self.assertTrue(self.manager.max_transactions_exceeded()) + + @patch('time.time') + def test_reset_stopwatch(self, mock_time): + """Test resetting the stopwatch.""" + mock_time.return_value = 100 + + self.manager.reset_stopwatch(5) + + self.assertEqual(self.manager.stopwatch_start, 100) + self.assertEqual(self.manager.stopwatch_max_seconds, 5) + + @patch('time.time') + def test_is_time_exceeded(self, mock_time): + """Test checking if time is exceeded.""" + # Set up stopwatch + mock_time.return_value = 100 + self.manager.reset_stopwatch(5) + + # Time not exceeded + mock_time.return_value = 104 + self.assertFalse(self.manager.is_time_exceeded()) + + # Time exceeded + mock_time.return_value = 106 + self.assertTrue(self.manager.is_time_exceeded()) + + # No time limit + self.manager.reset_stopwatch(None) + mock_time.return_value = 200 + self.assertFalse(self.manager.is_time_exceeded()) + + def test_add_file_transactions(self): + """Test adding file-related transactions.""" + # Test add file transaction + self.manager.add_file_add_transaction(self.test_file_path) + self.assertIn(self.test_file_path, self.manager.queued_transactions) + self.assertEqual(len(self.manager.queued_transactions[self.test_file_path]), 1) + self.assertIsInstance(self.manager.queued_transactions[self.test_file_path][0], FileAddTransaction) + + # Clear transactions + self.manager.clear_transactions() + + # Test rename file transaction + self.manager.add_file_rename_transaction(self.mock_file, "new_name.txt") + self.assertIn(self.test_file_path, self.manager.queued_transactions) + self.assertEqual(len(self.manager.queued_transactions[self.test_file_path]), 1) + self.assertIsInstance(self.manager.queued_transactions[self.test_file_path][0], FileRenameTransaction) + + # Clear transactions + self.manager.clear_transactions() + + # Test remove file transaction + self.manager.add_file_remove_transaction(self.mock_file) + self.assertIn(self.test_file_path, self.manager.queued_transactions) + self.assertEqual(len(self.manager.queued_transactions[self.test_file_path]), 1) + self.assertIsInstance(self.manager.queued_transactions[self.test_file_path][0], FileRemoveTransaction) + + def test_check_limits(self): + """Test checking transaction limits.""" + # Test max transactions + self.manager.set_max_transactions(1) + t1 = EditTransaction(0, 5, self.mock_file, "Edit1") + self.manager.add_transaction(t1) + + with self.assertRaises(MaxTransactionsExceeded): + t2 = InsertTransaction(5, self.mock_file, "Insert") + self.manager.add_transaction(t2) + + # Reset limits + self.manager.clear_transactions() + self.manager.set_max_transactions(None) + + # Test max preview time + with patch('time.time') as mock_time: + mock_time.return_value = 100 + self.manager.reset_stopwatch(5) + + # Add a transaction (time not exceeded) + mock_time.return_value = 104 + t1 = EditTransaction(0, 5, self.mock_file, "Edit1") + self.manager.add_transaction(t1) + + # Add another transaction (time exceeded) + mock_time.return_value = 106 + t2 = InsertTransaction(5, self.mock_file, "Insert") + + with self.assertRaises(MaxPreviewTimeExceeded): + self.manager.add_transaction(t2) + + def test_to_commit(self): + """Test getting files to commit.""" + # Add transactions for two files + t1 = EditTransaction(0, 5, self.mock_file, "Edit1") + self.manager.add_transaction(t1) + + # Create another mock file + mock_file2 = MagicMock() + mock_file2.path = Path(os.path.join(self.temp_dir.name, "test_file2.txt")) + mock_file2.content = "Another test file." + mock_file2.content_bytes = b"Another test file." + + t2 = EditTransaction(0, 5, mock_file2, "Edit2") + self.manager.add_transaction(t2) + + # Get all files to commit + files_to_commit = self.manager.to_commit() + self.assertEqual(len(files_to_commit), 2) + self.assertIn(self.test_file_path, files_to_commit) + self.assertIn(mock_file2.path, files_to_commit) + + # Get specific files to commit + specific_files = {self.test_file_path} + files_to_commit = self.manager.to_commit(specific_files) + self.assertEqual(len(files_to_commit), 1) + self.assertIn(self.test_file_path, files_to_commit) + self.assertNotIn(mock_file2.path, files_to_commit) + + def test_commit(self): + """Test committing transactions.""" + # Add a transaction + t1 = EditTransaction(0, 5, self.mock_file, "New") + self.manager.add_transaction(t1) + + # Commit the transaction + diffs = self.manager.commit({self.test_file_path}) + + # Check that the transaction was executed + self.mock_file.write_bytes.assert_called_once() + + # Check that the transaction was removed from the queue + self.assertNotIn(self.test_file_path, self.manager.queued_transactions) + + # Check that a diff was returned + self.assertEqual(len(diffs), 1) + self.assertIsInstance(diffs[0], DiffLite) + self.assertEqual(diffs[0].change_type, ChangeType.Modified) + self.assertEqual(diffs[0].path, self.test_file_path) + + def test_apply(self): + """Test applying a single transaction.""" + t1 = EditTransaction(0, 5, self.mock_file, "New") + self.manager.apply(t1) + + # Check that the transaction was executed + self.mock_file.write_bytes.assert_called_once() + + # Check that the transaction was removed from the queue + self.assertNotIn(self.test_file_path, self.manager.queued_transactions) + + def test_apply_all(self): + """Test applying all transactions.""" + # Add transactions for two files + t1 = EditTransaction(0, 5, self.mock_file, "Edit1") + self.manager.add_transaction(t1) + + # Create another mock file + mock_file2 = MagicMock() + mock_file2.path = Path(os.path.join(self.temp_dir.name, "test_file2.txt")) + mock_file2.content = "Another test file." + mock_file2.content_bytes = b"Another test file." + + t2 = EditTransaction(0, 5, mock_file2, "Edit2") + self.manager.add_transaction(t2) + + # Apply all transactions + diffs = self.manager.apply_all() + + # Check that both transactions were executed + self.mock_file.write_bytes.assert_called_once() + mock_file2.write_bytes.assert_called_once() + + # Check that both transactions were removed from the queue + self.assertEqual(self.manager.queued_transactions, {}) + + # Check that diffs were returned + self.assertEqual(len(diffs), 2) + + def test_revert_all(self): + """Test reverting all transactions.""" + # Add a transaction + t1 = EditTransaction(0, 5, self.mock_file, "New") + self.manager.add_transaction(t1) + + # Add a mock undo function + mock_undo = MagicMock() + self.manager.pending_undos.add(mock_undo) + + # Revert all transactions + self.manager.revert_all() + + # Check that the transaction was removed from the queue + self.assertEqual(self.manager.queued_transactions, {}) + + # Check that the undo function was called + mock_undo.assert_called_once() + + def test_get_transactions_at_range(self): + """Test getting transactions at a specific range.""" + # Add transactions + t1 = EditTransaction(0, 5, self.mock_file, "Edit1") + t2 = EditTransaction(5, 10, self.mock_file, "Edit2") + t3 = EditTransaction(10, 15, self.mock_file, "Edit3") + + self.manager.add_transaction(t1) + self.manager.add_transaction(t2) + self.manager.add_transaction(t3) + + # Get transactions at a specific range + transactions = self.manager.get_transactions_at_range(self.test_file_path, 0, 5) + self.assertEqual(len(transactions), 1) + self.assertEqual(transactions[0], t1) + + # Get transactions with a specific transaction order + transactions = self.manager.get_transactions_at_range(self.test_file_path, 0, 5, TransactionPriority.Edit) + self.assertEqual(len(transactions), 1) + self.assertEqual(transactions[0], t1) + + # Get transactions with a different transaction order (should return empty list) + transactions = self.manager.get_transactions_at_range(self.test_file_path, 0, 5, TransactionPriority.Remove) + self.assertEqual(len(transactions), 0) + + def test_get_transaction_containing_range(self): + """Test getting a transaction containing a specific range.""" + # Add a transaction + t1 = EditTransaction(0, 10, self.mock_file, "Edit1") + self.manager.add_transaction(t1) + + # Get transaction containing a range + transaction = self.manager.get_transaction_containing_range(self.test_file_path, 2, 8) + self.assertEqual(transaction, t1) + + # Get transaction with a specific transaction order + transaction = self.manager.get_transaction_containing_range(self.test_file_path, 2, 8, TransactionPriority.Edit) + self.assertEqual(transaction, t1) + + # Get transaction with a different transaction order (should return None) + transaction = self.manager.get_transaction_containing_range(self.test_file_path, 2, 8, TransactionPriority.Remove) + self.assertIsNone(transaction) + + def test_get_conflicts(self): + """Test getting conflicting transactions.""" + # Add a transaction + t1 = EditTransaction(0, 10, self.mock_file, "Edit1") + self.manager.add_transaction(t1) + + # Create a conflicting transaction + t2 = EditTransaction(5, 15, self.mock_file, "Edit2") + + # Get conflicts + conflicts = self.manager._get_conflicts(t2) + self.assertEqual(len(conflicts), 1) + self.assertEqual(conflicts[0], t1) + + # Create a non-conflicting transaction + t3 = EditTransaction(15, 20, self.mock_file, "Edit3") + + # Get conflicts (should be empty) + conflicts = self.manager._get_conflicts(t3) + self.assertEqual(len(conflicts), 0) + + def test_get_overlapping_conflicts(self): + """Test getting completely overlapping transactions.""" + # Add a transaction + t1 = EditTransaction(0, 20, self.mock_file, "Edit1") + self.manager.add_transaction(t1) + + # Create a completely overlapped transaction + t2 = EditTransaction(5, 15, self.mock_file, "Edit2") + + # Get overlapping conflict + conflict = self.manager._get_overlapping_conflicts(t2) + self.assertEqual(conflict, t1) + + # Create a partially overlapping transaction + t3 = EditTransaction(15, 25, self.mock_file, "Edit3") + + # Get overlapping conflict (should be None) + conflict = self.manager._get_overlapping_conflicts(t3) + self.assertIsNone(conflict) + + def test_resolve_conflicts_with_remove(self): + """Test resolving conflicts with a remove transaction.""" + # Add an edit transaction + t1 = EditTransaction(0, 10, self.mock_file, "Edit1") + self.manager.add_transaction(t1) + + # Create a conflicting remove transaction + t2 = RemoveTransaction(0, 10, self.mock_file) + + # Resolve conflicts + result = self.manager._resolve_conflicts(t2, self.manager.queued_transactions[self.test_file_path]) + + # Check that the remove transaction was returned + self.assertEqual(result, t2) + + # Check that the edit transaction was removed from the queue + self.assertEqual(len(self.manager.queued_transactions[self.test_file_path]), 0) + + def test_resolve_conflicts_with_edit(self): + """Test resolving conflicts with an edit transaction.""" + # Add a remove transaction + t1 = RemoveTransaction(0, 10, self.mock_file) + self.manager.add_transaction(t1) + + # Create a conflicting edit transaction + t2 = EditTransaction(0, 10, self.mock_file, "Edit1") + + # Resolve conflicts + result = self.manager._resolve_conflicts(t2, self.manager.queued_transactions[self.test_file_path]) + + # Check that None was returned (edit transaction was discarded) + self.assertIsNone(result) + +if __name__ == '__main__': + unittest.main() +