diff --git a/.env.example b/.env.example index adb32324..d1ce0818 100644 --- a/.env.example +++ b/.env.example @@ -88,3 +88,16 @@ AUTH_SECRET=CHANGE-ME-IN-PRODUCTION # Enable authentication requirement (default: false for migration) # Set to true in production to enforce authentication # AUTH_REQUIRED=false + +# ============================================================================ +# GitHub Integration (Optional - for PR creation) +# ============================================================================ + +# GitHub Personal Access Token with repo scope +# Get yours at: https://github.com/settings/tokens +# Required for: creating PRs, merging PRs, GitHub integration +# GITHUB_TOKEN=ghp_... + +# Target repository in format "owner/repo" +# Example: frankbria/codeframe +# GITHUB_REPO=owner/repo diff --git a/codeframe/core/config.py b/codeframe/core/config.py index 42e2214d..c0bf0ff6 100644 --- a/codeframe/core/config.py +++ b/codeframe/core/config.py @@ -115,6 +115,10 @@ class GlobalConfig(BaseSettings): default_provider: str = "claude" default_model: str = "claude-sonnet-4" + # GitHub Integration (Sprint 11 - PR Management) + github_token: Optional[str] = Field(None, alias="GITHUB_TOKEN") + github_repo: Optional[str] = Field(None, alias="GITHUB_REPO") # Format: "owner/repo" + model_config = SettingsConfigDict( env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="ignore" ) diff --git a/codeframe/git/github_integration.py b/codeframe/git/github_integration.py new file mode 100644 index 00000000..e7f817de --- /dev/null +++ b/codeframe/git/github_integration.py @@ -0,0 +1,329 @@ +"""GitHub API Integration for CodeFRAME. + +Handles GitHub API operations for Pull Request management. +Part of Sprint 11 - GitHub PR Integration. +""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Dict, List, Optional +import logging + +import httpx + +logger = logging.getLogger(__name__) + + +class GitHubAPIError(Exception): + """Exception raised when GitHub API returns an error.""" + + def __init__( + self, + status_code: int, + message: str, + details: Optional[Dict[str, Any]] = None, + ): + self.status_code = status_code + self.message = message + self.details = details + super().__init__(f"GitHub API Error ({status_code}): {message}") + + +@dataclass +class PRDetails: + """Pull Request details from GitHub API.""" + + number: int + url: str + state: str + title: str + body: Optional[str] + created_at: datetime + merged_at: Optional[datetime] + head_branch: str + base_branch: str + + +@dataclass +class MergeResult: + """Result of a PR merge operation.""" + + sha: Optional[str] + merged: bool + message: str + + +class GitHubIntegration: + """GitHub API client for PR operations. + + Provides methods for creating, listing, merging, and closing + pull requests via the GitHub REST API. + """ + + BASE_URL = "https://api.github.com" + + def __init__(self, token: str, repo: str): + """Initialize GitHub integration. + + Args: + token: GitHub Personal Access Token with repo scope + repo: Repository in format "owner/repo" + + Raises: + ValueError: If repo format is invalid + """ + parts = repo.split("/", 1) + if len(parts) != 2 or not parts[0].strip() or not parts[1].strip(): + raise ValueError( + f"Invalid repo format: '{repo}'. Expected 'owner/repo'" + ) + + self.token = token + self.repo = repo + self.owner, self.repo_name = parts[0].strip(), parts[1].strip() + + self._client = httpx.AsyncClient( + headers={ + "Authorization": f"Bearer {token}", + "Accept": "application/vnd.github.v3+json", + "X-GitHub-Api-Version": "2022-11-28", + }, + timeout=30.0, + ) + + async def _make_request( + self, + method: str, + endpoint: str, + json_data: Optional[Dict[str, Any]] = None, + ) -> Any: + """Make an authenticated request to GitHub API. + + Args: + method: HTTP method (GET, POST, PATCH, PUT, DELETE) + endpoint: API endpoint path + json_data: Optional JSON body data + + Returns: + Parsed JSON response + + Raises: + GitHubAPIError: If API returns an error status + """ + url = f"{self.BASE_URL}{endpoint}" + + try: + response = await self._client.request( + method=method, + url=url, + json=json_data, + ) + + if response.status_code >= 400: + try: + error_data = response.json() + message = error_data.get("message", response.text) + details = error_data.get("errors") + except Exception: + message = response.text + details = None + + raise GitHubAPIError( + status_code=response.status_code, + message=message, + details={"errors": details} if details else None, + ) + + # Handle empty responses (204 No Content) + if response.status_code == 204: + return None + + return response.json() + + except httpx.TimeoutException as e: + logger.error(f"GitHub API timeout: {e}") + raise GitHubAPIError( + status_code=408, + message="Request timed out", + ) + except httpx.RequestError as e: + logger.error(f"GitHub API request error: {e}") + raise GitHubAPIError( + status_code=500, + message=f"Request failed: {str(e)}", + ) + + def _parse_pr_response(self, data: Dict[str, Any]) -> PRDetails: + """Parse GitHub PR response into PRDetails object. + + Args: + data: Raw GitHub API response + + Returns: + Parsed PRDetails object + """ + created_at = datetime.fromisoformat( + data["created_at"].replace("Z", "+00:00") + ) + merged_at = None + if data.get("merged_at"): + merged_at = datetime.fromisoformat( + data["merged_at"].replace("Z", "+00:00") + ) + + return PRDetails( + number=data["number"], + url=data["html_url"], + state=data["state"], + title=data["title"], + body=data.get("body"), + created_at=created_at, + merged_at=merged_at, + head_branch=data["head"]["ref"], + base_branch=data["base"]["ref"], + ) + + async def create_pull_request( + self, + branch: str, + title: str, + body: str, + base: str = "main", + ) -> PRDetails: + """Create a new pull request. + + Args: + branch: Head branch with changes + title: PR title + body: PR description + base: Base branch to merge into (default: main) + + Returns: + PRDetails with the created PR info + + Raises: + GitHubAPIError: If PR creation fails + """ + endpoint = f"/repos/{self.owner}/{self.repo_name}/pulls" + + data = await self._make_request( + method="POST", + endpoint=endpoint, + json_data={ + "title": title, + "body": body, + "head": branch, + "base": base, + }, + ) + + logger.info(f"Created PR #{data['number']}: {title}") + return self._parse_pr_response(data) + + async def get_pull_request(self, pr_number: int) -> PRDetails: + """Get pull request details. + + Args: + pr_number: PR number + + Returns: + PRDetails with the PR info + + Raises: + GitHubAPIError: If PR not found or API error + """ + endpoint = f"/repos/{self.owner}/{self.repo_name}/pulls/{pr_number}" + + data = await self._make_request( + method="GET", + endpoint=endpoint, + ) + + return self._parse_pr_response(data) + + async def list_pull_requests( + self, + state: str = "open", + ) -> List[PRDetails]: + """List pull requests for the repository. + + Args: + state: Filter by state (open, closed, all) + + Returns: + List of PRDetails + + Raises: + GitHubAPIError: If API error occurs + """ + endpoint = f"/repos/{self.owner}/{self.repo_name}/pulls" + + data = await self._make_request( + method="GET", + endpoint=f"{endpoint}?state={state}", + ) + + return [self._parse_pr_response(pr) for pr in data] + + async def merge_pull_request( + self, + pr_number: int, + method: str = "squash", + ) -> MergeResult: + """Merge a pull request. + + Args: + pr_number: PR number to merge + method: Merge method (merge, squash, rebase) + + Returns: + MergeResult with merge outcome + + Raises: + GitHubAPIError: If merge fails + """ + endpoint = f"/repos/{self.owner}/{self.repo_name}/pulls/{pr_number}/merge" + + data = await self._make_request( + method="PUT", + endpoint=endpoint, + json_data={ + "merge_method": method, + }, + ) + + logger.info(f"Merged PR #{pr_number} with method '{method}'") + return MergeResult( + sha=data.get("sha"), + merged=data.get("merged", False), + message=data.get("message", ""), + ) + + async def close_pull_request(self, pr_number: int) -> bool: + """Close a pull request without merging. + + Args: + pr_number: PR number to close + + Returns: + True if successfully closed + + Raises: + GitHubAPIError: If close fails + """ + endpoint = f"/repos/{self.owner}/{self.repo_name}/pulls/{pr_number}" + + data = await self._make_request( + method="PATCH", + endpoint=endpoint, + json_data={ + "state": "closed", + }, + ) + + logger.info(f"Closed PR #{pr_number}") + return data.get("state") == "closed" + + async def close(self) -> None: + """Close the HTTP client.""" + await self._client.aclose() diff --git a/codeframe/persistence/database.py b/codeframe/persistence/database.py index ce74d8cb..1152d17b 100644 --- a/codeframe/persistence/database.py +++ b/codeframe/persistence/database.py @@ -34,6 +34,7 @@ CorrectionRepository, ActivityRepository, AuditRepository, + PRRepository, ) if TYPE_CHECKING: @@ -106,6 +107,7 @@ def __init__(self, db_path: Path | str): self.correction_attempts: Optional[CorrectionRepository] = None self.activities: Optional[ActivityRepository] = None self.audit_logs: Optional[AuditRepository] = None + self.pull_requests: Optional[PRRepository] = None def initialize(self) -> None: """Initialize database schema and repositories.""" @@ -151,6 +153,7 @@ def _initialize_repositories(self) -> None: self.correction_attempts = CorrectionRepository(sync_conn=self.conn, async_conn=self._async_conn, database=self, sync_lock=self._sync_lock) self.activities = ActivityRepository(sync_conn=self.conn, async_conn=self._async_conn, database=self, sync_lock=self._sync_lock) self.audit_logs = AuditRepository(sync_conn=self.conn, async_conn=self._async_conn, database=self, sync_lock=self._sync_lock) + self.pull_requests = PRRepository(sync_conn=self.conn, async_conn=self._async_conn, database=self, sync_lock=self._sync_lock) # Backward compatibility properties (maintain old *_repository naming) @property @@ -211,7 +214,8 @@ def _update_repository_async_connections(self) -> None: for repo in [self.projects, self.issues, self.tasks, self.agents, self.blockers, self.memories, self.context_items, self.checkpoints, self.git_branches, self.test_results, self.lint_results, self.code_reviews, self.quality_gates, - self.token_usage, self.correction_attempts, self.activities, self.audit_logs]: + self.token_usage, self.correction_attempts, self.activities, self.audit_logs, + self.pull_requests]: if repo: repo._async_conn = self._async_conn diff --git a/codeframe/persistence/repositories/__init__.py b/codeframe/persistence/repositories/__init__.py index 32110dcc..ab56b0ab 100644 --- a/codeframe/persistence/repositories/__init__.py +++ b/codeframe/persistence/repositories/__init__.py @@ -22,6 +22,7 @@ from codeframe.persistence.repositories.correction_repository import CorrectionRepository from codeframe.persistence.repositories.activity_repository import ActivityRepository from codeframe.persistence.repositories.audit_repository import AuditRepository +from codeframe.persistence.repositories.pr_repository import PRRepository __all__ = [ "BaseRepository", @@ -42,4 +43,5 @@ "CorrectionRepository", "ActivityRepository", "AuditRepository", + "PRRepository", ] diff --git a/codeframe/persistence/repositories/pr_repository.py b/codeframe/persistence/repositories/pr_repository.py new file mode 100644 index 00000000..ada0606b --- /dev/null +++ b/codeframe/persistence/repositories/pr_repository.py @@ -0,0 +1,255 @@ +"""Repository for Pull Request operations. + +Handles database operations for GitHub Pull Request tracking. +Part of Sprint 11 - GitHub PR Integration. +""" + +from datetime import datetime, UTC +from typing import Any, Dict, List, Optional +import logging + +from codeframe.persistence.repositories.base import BaseRepository + +logger = logging.getLogger(__name__) + + +class PRRepository(BaseRepository): + """Repository for pull request database operations.""" + + def create_pr( + self, + project_id: int, + issue_id: Optional[int], + branch_name: str, + title: str, + body: str, + base_branch: str, + head_branch: str, + status: str = "open", + ) -> int: + """Create a new pull request record. + + Args: + project_id: Project ID this PR belongs to + issue_id: Optional associated issue ID + branch_name: Git branch name + title: PR title + body: PR description + base_branch: Target branch (e.g., "main") + head_branch: Source branch with changes + status: Initial status (default: "open") + + Returns: + PR ID + + Raises: + sqlite3.IntegrityError: If project_id doesn't exist + """ + cursor = self._execute( + """ + INSERT INTO pull_requests ( + project_id, issue_id, branch_name, title, body, + base_branch, head_branch, status + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (project_id, issue_id, branch_name, title, body, base_branch, head_branch, status), + ) + self._commit() + return cursor.lastrowid + + def get_pr(self, pr_id: int) -> Optional[Dict[str, Any]]: + """Get a pull request by its ID. + + Args: + pr_id: Pull request ID + + Returns: + PR dictionary or None if not found + """ + row = self._fetchone( + "SELECT * FROM pull_requests WHERE id = ?", + (pr_id,), + ) + return self._row_to_dict(row) if row else None + + def get_pr_by_number(self, project_id: int, pr_number: int) -> Optional[Dict[str, Any]]: + """Get a pull request by its GitHub PR number. + + Args: + project_id: Project ID + pr_number: GitHub PR number + + Returns: + PR dictionary or None if not found + """ + row = self._fetchone( + """ + SELECT * FROM pull_requests + WHERE project_id = ? AND pr_number = ? + """, + (project_id, pr_number), + ) + return self._row_to_dict(row) if row else None + + def list_prs( + self, project_id: int, status: Optional[str] = None + ) -> List[Dict[str, Any]]: + """List pull requests for a project. + + Args: + project_id: Project ID + status: Optional filter by status (open, merged, closed, draft) + + Returns: + List of PR dictionaries + """ + if status: + rows = self._fetchall( + """ + SELECT * FROM pull_requests + WHERE project_id = ? AND status = ? + ORDER BY created_at DESC + """, + (project_id, status), + ) + else: + rows = self._fetchall( + """ + SELECT * FROM pull_requests + WHERE project_id = ? + ORDER BY created_at DESC + """, + (project_id,), + ) + + return [self._row_to_dict(row) for row in rows] + + def update_pr_github_data( + self, + pr_id: int, + pr_number: int, + pr_url: str, + github_created_at: datetime, + ) -> None: + """Update PR with data from GitHub API response. + + Args: + pr_id: Local PR ID + pr_number: GitHub PR number + pr_url: GitHub PR URL + github_created_at: When PR was created on GitHub + + Raises: + ValueError: If pr_id does not exist + """ + # Ensure datetime is UTC-aware for consistent storage + if github_created_at.tzinfo is None: + github_created_at = github_created_at.replace(tzinfo=UTC) + + cursor = self._execute( + """ + UPDATE pull_requests + SET pr_number = ?, pr_url = ?, github_created_at = ? + WHERE id = ? + """, + (pr_number, pr_url, github_created_at.isoformat(), pr_id), + ) + if cursor.rowcount == 0: + raise ValueError(f"PR id {pr_id} not found") + self._commit() + + def update_pr_status( + self, + pr_id: int, + status: str, + merge_commit_sha: Optional[str] = None, + merged_at: Optional[datetime] = None, + ) -> None: + """Update pull request status. + + Args: + pr_id: PR ID + status: New status (open, merged, closed, draft) + merge_commit_sha: Merge commit SHA (for merged PRs) + merged_at: When PR was merged (auto-set if not provided) + + Raises: + ValueError: If pr_id does not exist + """ + now = datetime.now(UTC) + + if status == "merged": + # Use provided merged_at or current time + if merged_at is None: + merged_at = now + # Ensure datetime is UTC-aware + elif merged_at.tzinfo is None: + merged_at = merged_at.replace(tzinfo=UTC) + + cursor = self._execute( + """ + UPDATE pull_requests + SET status = ?, merge_commit_sha = ?, merged_at = ? + WHERE id = ? + """, + (status, merge_commit_sha, merged_at.isoformat(), pr_id), + ) + elif status == "closed": + cursor = self._execute( + """ + UPDATE pull_requests + SET status = ?, closed_at = ? + WHERE id = ? + """, + (status, now.isoformat(), pr_id), + ) + else: + cursor = self._execute( + """ + UPDATE pull_requests + SET status = ? + WHERE id = ? + """, + (status, pr_id), + ) + + if cursor.rowcount == 0: + raise ValueError(f"PR id {pr_id} not found") + self._commit() + + def get_pr_for_branch( + self, project_id: int, branch_name: str + ) -> Optional[Dict[str, Any]]: + """Find a PR by branch name. + + Args: + project_id: Project ID + branch_name: Git branch name + + Returns: + PR dictionary or None if not found + """ + row = self._fetchone( + """ + SELECT * FROM pull_requests + WHERE project_id = ? AND branch_name = ? + ORDER BY created_at DESC + LIMIT 1 + """, + (project_id, branch_name), + ) + return self._row_to_dict(row) if row else None + + def delete_pr(self, pr_id: int) -> int: + """Delete a pull request record. + + Args: + pr_id: PR ID + + Returns: + Number of rows deleted + """ + cursor = self._execute("DELETE FROM pull_requests WHERE id = ?", (pr_id,)) + self._commit() + return cursor.rowcount diff --git a/codeframe/persistence/schema_manager.py b/codeframe/persistence/schema_manager.py index 32018862..0bbf0671 100644 --- a/codeframe/persistence/schema_manager.py +++ b/codeframe/persistence/schema_manager.py @@ -631,6 +631,31 @@ def _create_checkpoint_git_tables(self, cursor: sqlite3.Cursor) -> None: """ ) + # Pull requests table (Sprint 11 - GitHub PR integration) + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS pull_requests ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project_id INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE, + issue_id INTEGER REFERENCES issues(id) ON DELETE SET NULL, + branch_name TEXT NOT NULL, + pr_number INTEGER, + pr_url TEXT, + title TEXT NOT NULL, + body TEXT, + base_branch TEXT DEFAULT 'main', + head_branch TEXT NOT NULL, + status TEXT CHECK(status IN ('draft', 'open', 'merged', 'closed')) DEFAULT 'open', + merge_commit_sha TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + merged_at TIMESTAMP, + closed_at TIMESTAMP, + github_created_at TIMESTAMP, + github_updated_at TIMESTAMP + ) + """ + ) + def _create_metrics_audit_tables(self, cursor: sqlite3.Cursor) -> None: """Create metrics, token usage, and audit log tables.""" # Token usage table @@ -777,6 +802,17 @@ def _create_indexes(self, cursor: sqlite3.Cursor) -> None: "CREATE INDEX IF NOT EXISTS idx_checkpoints_project ON checkpoints(project_id, created_at DESC)" ) + # Pull requests indexes (Sprint 11 - GitHub PR integration) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_pull_requests_project ON pull_requests(project_id, status)" + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_pull_requests_issue ON pull_requests(issue_id)" + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_pull_requests_branch ON pull_requests(project_id, branch_name)" + ) + # Audit logs indexes cursor.execute( "CREATE INDEX IF NOT EXISTS idx_audit_logs_user_id ON audit_logs(user_id, timestamp DESC)" diff --git a/codeframe/ui/routers/prs.py b/codeframe/ui/routers/prs.py new file mode 100644 index 00000000..d6345d83 --- /dev/null +++ b/codeframe/ui/routers/prs.py @@ -0,0 +1,422 @@ +"""Pull Request management router. + +This module provides API endpoints for: +- Creating pull requests via GitHub API +- Listing and getting PR details +- Merging and closing PRs + +Part of Sprint 11 - GitHub PR Integration. +""" + +import logging +from typing import Literal, Optional + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field + +from codeframe.core.config import GlobalConfig, load_environment +from codeframe.git.github_integration import GitHubIntegration, GitHubAPIError +from codeframe.persistence.database import Database +from codeframe.ui.dependencies import get_db +from codeframe.ui.shared import manager +from codeframe.ui.websocket_broadcasts import ( + broadcast_pr_created, + broadcast_pr_merged, + broadcast_pr_closed, +) +from codeframe.auth import get_current_user, User + + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/projects/{project_id}/prs", tags=["pull-requests"]) + + +def get_global_config() -> GlobalConfig: + """Get global configuration with GitHub settings.""" + load_environment() + return GlobalConfig() + + +# ============================================================================ +# Request/Response Models +# ============================================================================ + + +class CreatePRRequest(BaseModel): + """Request to create a pull request.""" + + branch: str = Field(..., description="Head branch with changes") + title: str = Field(..., description="PR title") + body: str = Field("", description="PR description") + base: str = Field("main", description="Base branch to merge into") + + +class CreatePRResponse(BaseModel): + """Response after creating a PR.""" + + pr_id: int + pr_number: int + pr_url: str + status: str + + +class MergePRRequest(BaseModel): + """Request to merge a pull request.""" + + method: Literal["squash", "merge", "rebase"] = Field( + "squash", description="Merge method (squash, merge, rebase)" + ) + + +class MergePRResponse(BaseModel): + """Response after merging a PR.""" + + merged: bool + merge_commit_sha: Optional[str] + + +class ClosePRResponse(BaseModel): + """Response after closing a PR.""" + + closed: bool + + +class PRListResponse(BaseModel): + """Response containing list of PRs.""" + + prs: list + total: int + + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def validate_github_config(config: GlobalConfig) -> tuple[str, str]: + """Validate that GitHub is properly configured. + + Returns: + Tuple of (github_token, github_repo) + + Raises: + HTTPException: If GitHub config is missing + """ + if not config.github_token or not config.github_repo: + raise HTTPException( + status_code=400, + detail="GitHub integration not configured. Set GITHUB_TOKEN and GITHUB_REPO environment variables.", + ) + return config.github_token, config.github_repo + + +async def validate_project_access( + project_id: int, + db: Database, + user: User, +) -> dict: + """Validate project exists and user has access. + + Returns: + Project dict + + Raises: + HTTPException: If project not found or access denied + """ + project = db.get_project(project_id) + if not project: + raise HTTPException(status_code=404, detail=f"Project {project_id} not found") + + if not db.user_has_project_access(user.id, project_id): + raise HTTPException(status_code=403, detail="Access denied") + + return project + + +# ============================================================================ +# Endpoints +# ============================================================================ + + +@router.post("", status_code=201, response_model=CreatePRResponse) +async def create_pull_request( + project_id: int, + request: CreatePRRequest, + db: Database = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """Create a new pull request via GitHub API. + + Args: + project_id: Project ID + request: PR creation request with branch, title, body, base + + Returns: + Created PR details + + Raises: + HTTPException: + - 400: GitHub not configured + - 403: Access denied + - 404: Project not found + - 422: GitHub API error + """ + # Validate access + await validate_project_access(project_id, db, current_user) + + # Get GitHub config + config = get_global_config() + github_token, github_repo = validate_github_config(config) + + # Create PR via GitHub API + gh: Optional[GitHubIntegration] = None + try: + gh = GitHubIntegration(token=github_token, repo=github_repo) + pr_details = await gh.create_pull_request( + branch=request.branch, + title=request.title, + body=request.body, + base=request.base, + ) + + # Store in database + pr_id = db.pull_requests.create_pr( + project_id=project_id, + issue_id=None, # Can be linked later + branch_name=request.branch, + title=request.title, + body=request.body, + base_branch=request.base, + head_branch=request.branch, + ) + + # Update with GitHub data + db.pull_requests.update_pr_github_data( + pr_id=pr_id, + pr_number=pr_details.number, + pr_url=pr_details.url, + github_created_at=pr_details.created_at, + ) + + # Broadcast PR created event + await broadcast_pr_created( + manager=manager, + project_id=project_id, + pr_id=pr_id, + pr_number=pr_details.number, + pr_url=pr_details.url, + title=request.title, + branch_name=request.branch, + ) + + logger.info(f"Created PR #{pr_details.number} for project {project_id}") + + return CreatePRResponse( + pr_id=pr_id, + pr_number=pr_details.number, + pr_url=pr_details.url, + status="open", + ) + + except GitHubAPIError as e: + logger.error(f"GitHub API error creating PR: {e}") + raise HTTPException(status_code=422, detail=str(e)) + + finally: + if gh is not None: + await gh.close() + + +@router.get("", response_model=PRListResponse) +async def list_pull_requests( + project_id: int, + status: Optional[str] = None, + db: Database = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """List pull requests for a project. + + Args: + project_id: Project ID + status: Optional filter by status (open, merged, closed, draft) + + Returns: + List of PRs with total count + """ + await validate_project_access(project_id, db, current_user) + + prs = db.pull_requests.list_prs(project_id, status=status) + + return PRListResponse(prs=prs, total=len(prs)) + + +@router.get("/{pr_number}") +async def get_pull_request( + project_id: int, + pr_number: int, + db: Database = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """Get pull request details by PR number. + + Args: + project_id: Project ID + pr_number: GitHub PR number + + Returns: + PR details + + Raises: + HTTPException: 404 if PR not found + """ + await validate_project_access(project_id, db, current_user) + + pr = db.pull_requests.get_pr_by_number(project_id, pr_number) + if not pr: + raise HTTPException(status_code=404, detail=f"PR #{pr_number} not found") + + return pr + + +@router.post("/{pr_number}/merge", response_model=MergePRResponse) +async def merge_pull_request( + project_id: int, + pr_number: int, + request: MergePRRequest, + db: Database = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """Merge a pull request via GitHub API. + + Args: + project_id: Project ID + pr_number: GitHub PR number to merge + request: Merge request with method (squash, merge, rebase) + + Returns: + Merge result with SHA + + Raises: + HTTPException: + - 404: PR not found + - 422: GitHub API error (not mergeable, conflicts, etc.) + """ + await validate_project_access(project_id, db, current_user) + + # Verify PR exists in our database + pr = db.pull_requests.get_pr_by_number(project_id, pr_number) + if not pr: + raise HTTPException(status_code=404, detail=f"PR #{pr_number} not found") + + # Get GitHub config + config = get_global_config() + github_token, github_repo = validate_github_config(config) + + # Merge via GitHub API + gh: Optional[GitHubIntegration] = None + try: + gh = GitHubIntegration(token=github_token, repo=github_repo) + result = await gh.merge_pull_request( + pr_number=pr_number, + method=request.method, + ) + + # Update database only if merge succeeded + if result.merged: + db.pull_requests.update_pr_status( + pr_id=pr["id"], + status="merged", + merge_commit_sha=result.sha, + ) + + # Broadcast PR merged event + if result.sha: + await broadcast_pr_merged( + manager=manager, + project_id=project_id, + pr_number=pr_number, + merge_commit_sha=result.sha, + ) + + logger.info(f"Merged PR #{pr_number} for project {project_id}") + else: + logger.warning(f"PR #{pr_number} merge returned merged=False") + + return MergePRResponse( + merged=result.merged, + merge_commit_sha=result.sha, + ) + + except GitHubAPIError as e: + logger.error(f"GitHub API error merging PR: {e}") + raise HTTPException(status_code=422, detail=str(e)) + + finally: + if gh is not None: + await gh.close() + + +@router.post("/{pr_number}/close", response_model=ClosePRResponse) +async def close_pull_request( + project_id: int, + pr_number: int, + db: Database = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """Close a pull request without merging. + + Args: + project_id: Project ID + pr_number: GitHub PR number to close + + Returns: + Close result + + Raises: + HTTPException: 404 if PR not found + """ + await validate_project_access(project_id, db, current_user) + + # Verify PR exists in our database + pr = db.pull_requests.get_pr_by_number(project_id, pr_number) + if not pr: + raise HTTPException(status_code=404, detail=f"PR #{pr_number} not found") + + # Get GitHub config + config = get_global_config() + github_token, github_repo = validate_github_config(config) + + # Close via GitHub API + gh: Optional[GitHubIntegration] = None + try: + gh = GitHubIntegration(token=github_token, repo=github_repo) + closed = await gh.close_pull_request(pr_number) + + # Update database only if close succeeded + if closed: + db.pull_requests.update_pr_status( + pr_id=pr["id"], + status="closed", + ) + + # Broadcast PR closed event + await broadcast_pr_closed( + manager=manager, + project_id=project_id, + pr_number=pr_number, + ) + + logger.info(f"Closed PR #{pr_number} for project {project_id}") + else: + logger.warning(f"PR #{pr_number} close returned closed=False") + + return ClosePRResponse(closed=closed) + + except GitHubAPIError as e: + logger.error(f"GitHub API error closing PR: {e}") + raise HTTPException(status_code=422, detail=str(e)) + + finally: + if gh is not None: + await gh.close() diff --git a/codeframe/ui/server.py b/codeframe/ui/server.py index c8f60d08..40378c7f 100644 --- a/codeframe/ui/server.py +++ b/codeframe/ui/server.py @@ -28,6 +28,7 @@ lint, metrics, projects, + prs, quality_gates, review, session, @@ -335,6 +336,7 @@ async def test_broadcast(message: dict, project_id: int = None): app.include_router(lint.router) app.include_router(metrics.router) app.include_router(projects.router) +app.include_router(prs.router) app.include_router(quality_gates.router) app.include_router(review.router) app.include_router(session.router) diff --git a/codeframe/ui/websocket_broadcasts.py b/codeframe/ui/websocket_broadcasts.py index d47e76f4..0a346b7c 100644 --- a/codeframe/ui/websocket_broadcasts.py +++ b/codeframe/ui/websocket_broadcasts.py @@ -1011,3 +1011,109 @@ async def broadcast_development_started( ) except Exception as e: logger.error(f"Failed to broadcast development started: {e}") + + +# ============================================================================ +# Sprint 11: Pull Request Management Broadcasts +# ============================================================================ + + +async def broadcast_pr_created( + manager, + project_id: int, + pr_id: int, + pr_number: int, + pr_url: str, + title: str, + branch_name: str, +) -> None: + """ + Broadcast when a new pull request is created. + + Args: + manager: ConnectionManager instance + project_id: Project ID + pr_id: Local PR database ID + pr_number: GitHub PR number + pr_url: GitHub PR URL + title: PR title + branch_name: Source branch name + """ + message = { + "type": "pr_created", + "project_id": project_id, + "pr_id": pr_id, + "pr_number": pr_number, + "pr_url": pr_url, + "title": title, + "branch_name": branch_name, + "timestamp": datetime.now(UTC).isoformat().replace("+00:00", "Z"), + } + + try: + await manager.broadcast(message, project_id=project_id) + logger.debug(f"Broadcast pr_created: PR #{pr_number} for project {project_id}") + except Exception as e: + logger.error(f"Failed to broadcast PR created: {e}") + + +async def broadcast_pr_merged( + manager, + project_id: int, + pr_number: int, + merge_commit_sha: str, +) -> None: + """ + Broadcast when a pull request is merged. + + Args: + manager: ConnectionManager instance + project_id: Project ID + pr_number: GitHub PR number + merge_commit_sha: SHA of the merge commit + """ + message = { + "type": "pr_merged", + "project_id": project_id, + "pr_number": pr_number, + "merge_commit_sha": merge_commit_sha, + "timestamp": datetime.now(UTC).isoformat().replace("+00:00", "Z"), + } + + try: + await manager.broadcast(message, project_id=project_id) + logger.debug(f"Broadcast pr_merged: PR #{pr_number} for project {project_id}") + except Exception as e: + logger.error(f"Failed to broadcast PR merged: {e}") + + +async def broadcast_pr_closed( + manager, + project_id: int, + pr_number: int, + reason: Optional[str] = None, +) -> None: + """ + Broadcast when a pull request is closed without merging. + + Args: + manager: ConnectionManager instance + project_id: Project ID + pr_number: GitHub PR number + reason: Optional reason for closing + """ + message = { + "type": "pr_closed", + "project_id": project_id, + "pr_number": pr_number, + "timestamp": datetime.now(UTC).isoformat().replace("+00:00", "Z"), + } + + if reason: + message["reason"] = reason + + try: + await manager.broadcast(message, project_id=project_id) + logger.debug(f"Broadcast pr_closed: PR #{pr_number} for project {project_id}") + except Exception as e: + logger.error(f"Failed to broadcast PR closed: {e}") diff --git a/tests/unit/test_github_integration.py b/tests/unit/test_github_integration.py new file mode 100644 index 00000000..592ff48d --- /dev/null +++ b/tests/unit/test_github_integration.py @@ -0,0 +1,321 @@ +"""Unit tests for GitHubIntegration (TDD - written before implementation).""" + +import pytest +from datetime import datetime, UTC +from unittest.mock import AsyncMock, patch + +from codeframe.git.github_integration import ( + GitHubIntegration, + PRDetails, + MergeResult, + GitHubAPIError, +) + + +class TestPRDetails: + """Tests for PRDetails data class.""" + + def test_pr_details_creation(self): + """Test creating a PRDetails object.""" + pr = PRDetails( + number=42, + url="https://github.com/owner/repo/pull/42", + state="open", + title="Test PR", + body="Test body", + created_at=datetime.now(UTC), + merged_at=None, + head_branch="feature/test", + base_branch="main", + ) + + assert pr.number == 42 + assert pr.state == "open" + assert pr.title == "Test PR" + assert pr.merged_at is None + + +class TestMergeResult: + """Tests for MergeResult data class.""" + + def test_merge_result_success(self): + """Test creating a successful MergeResult.""" + result = MergeResult( + sha="abc123def456", + merged=True, + message="Pull Request successfully merged", + ) + + assert result.merged is True + assert result.sha == "abc123def456" + + def test_merge_result_failure(self): + """Test creating a failed MergeResult.""" + result = MergeResult( + sha=None, + merged=False, + message="Pull Request is not mergeable", + ) + + assert result.merged is False + assert result.sha is None + + +class TestGitHubIntegration: + """Tests for GitHubIntegration class.""" + + @pytest.fixture + def github(self): + """Create GitHubIntegration instance.""" + return GitHubIntegration( + token="ghp_test_token_12345", + repo="owner/test-repo", + ) + + def test_init_parses_repo_correctly(self, github): + """Test that repo is parsed correctly.""" + assert github.owner == "owner" + assert github.repo_name == "test-repo" + + def test_init_with_invalid_repo_format(self): + """Test that invalid repo format raises error.""" + with pytest.raises(ValueError, match="Invalid repo format"): + GitHubIntegration(token="token", repo="invalid-format") + + @pytest.mark.asyncio + async def test_create_pull_request_success(self, github): + """Test successful PR creation.""" + mock_response = { + "number": 42, + "html_url": "https://github.com/owner/test-repo/pull/42", + "state": "open", + "title": "Test PR", + "body": "Test body", + "created_at": "2024-01-15T10:30:00Z", + "merged_at": None, + "head": {"ref": "feature/test"}, + "base": {"ref": "main"}, + } + + with patch.object(github, "_make_request", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + pr_details = await github.create_pull_request( + branch="feature/test", + title="Test PR", + body="Test body", + base="main", + ) + + assert pr_details.number == 42 + assert pr_details.state == "open" + assert pr_details.title == "Test PR" + + # Verify API was called correctly + mock_request.assert_called_once() + call_kwargs = mock_request.call_args.kwargs + assert call_kwargs["method"] == "POST" + assert "pulls" in call_kwargs["endpoint"] + + @pytest.mark.asyncio + async def test_create_pull_request_api_error(self, github): + """Test PR creation with API error.""" + with patch.object(github, "_make_request", new_callable=AsyncMock) as mock_request: + mock_request.side_effect = GitHubAPIError( + status_code=422, + message="Validation Failed", + ) + + with pytest.raises(GitHubAPIError) as exc_info: + await github.create_pull_request( + branch="feature/test", + title="Test PR", + body="Test body", + ) + + assert exc_info.value.status_code == 422 + + @pytest.mark.asyncio + async def test_get_pull_request_success(self, github): + """Test getting PR details.""" + mock_response = { + "number": 42, + "html_url": "https://github.com/owner/test-repo/pull/42", + "state": "open", + "title": "Test PR", + "body": "Test body", + "created_at": "2024-01-15T10:30:00Z", + "merged_at": None, + "head": {"ref": "feature/test"}, + "base": {"ref": "main"}, + } + + with patch.object(github, "_make_request", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + pr_details = await github.get_pull_request(42) + + assert pr_details.number == 42 + mock_request.assert_called_once() + + @pytest.mark.asyncio + async def test_get_pull_request_not_found(self, github): + """Test getting non-existent PR.""" + with patch.object(github, "_make_request", new_callable=AsyncMock) as mock_request: + mock_request.side_effect = GitHubAPIError( + status_code=404, + message="Not Found", + ) + + with pytest.raises(GitHubAPIError) as exc_info: + await github.get_pull_request(99999) + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + async def test_list_pull_requests_success(self, github): + """Test listing PRs.""" + mock_response = [ + { + "number": 1, + "html_url": "https://github.com/owner/test-repo/pull/1", + "state": "open", + "title": "PR 1", + "body": "Body 1", + "created_at": "2024-01-15T10:30:00Z", + "merged_at": None, + "head": {"ref": "feature/1"}, + "base": {"ref": "main"}, + }, + { + "number": 2, + "html_url": "https://github.com/owner/test-repo/pull/2", + "state": "open", + "title": "PR 2", + "body": "Body 2", + "created_at": "2024-01-16T10:30:00Z", + "merged_at": None, + "head": {"ref": "feature/2"}, + "base": {"ref": "main"}, + }, + ] + + with patch.object(github, "_make_request", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + prs = await github.list_pull_requests(state="open") + + assert len(prs) == 2 + assert prs[0].number == 1 + assert prs[1].number == 2 + + @pytest.mark.asyncio + async def test_merge_pull_request_success(self, github): + """Test successful PR merge.""" + mock_response = { + "sha": "abc123def456", + "merged": True, + "message": "Pull Request successfully merged", + } + + with patch.object(github, "_make_request", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + result = await github.merge_pull_request(42, method="squash") + + assert result.merged is True + assert result.sha == "abc123def456" + + # Verify merge method was passed + mock_request.assert_called_once() + call_kwargs = mock_request.call_args.kwargs + assert "merge" in call_kwargs["endpoint"] + + @pytest.mark.asyncio + async def test_merge_pull_request_not_mergeable(self, github): + """Test merge with non-mergeable PR.""" + with patch.object(github, "_make_request", new_callable=AsyncMock) as mock_request: + mock_request.side_effect = GitHubAPIError( + status_code=405, + message="Pull Request is not mergeable", + ) + + with pytest.raises(GitHubAPIError) as exc_info: + await github.merge_pull_request(42) + + assert exc_info.value.status_code == 405 + + @pytest.mark.asyncio + async def test_close_pull_request_success(self, github): + """Test closing a PR.""" + mock_response = { + "number": 42, + "html_url": "https://github.com/owner/test-repo/pull/42", + "state": "closed", + "title": "Test PR", + "body": "Test body", + "created_at": "2024-01-15T10:30:00Z", + "merged_at": None, + "head": {"ref": "feature/test"}, + "base": {"ref": "main"}, + } + + with patch.object(github, "_make_request", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + result = await github.close_pull_request(42) + + assert result is True + mock_request.assert_called_once() + call_kwargs = mock_request.call_args.kwargs + assert call_kwargs["method"] == "PATCH" + + @pytest.mark.asyncio + async def test_authentication_error(self, github): + """Test handling of authentication errors.""" + with patch.object(github, "_make_request", new_callable=AsyncMock) as mock_request: + mock_request.side_effect = GitHubAPIError( + status_code=401, + message="Bad credentials", + ) + + with pytest.raises(GitHubAPIError) as exc_info: + await github.get_pull_request(42) + + assert exc_info.value.status_code == 401 + + @pytest.mark.asyncio + async def test_rate_limit_error(self, github): + """Test handling of rate limit errors.""" + with patch.object(github, "_make_request", new_callable=AsyncMock) as mock_request: + mock_request.side_effect = GitHubAPIError( + status_code=403, + message="API rate limit exceeded", + ) + + with pytest.raises(GitHubAPIError) as exc_info: + await github.get_pull_request(42) + + assert exc_info.value.status_code == 403 + + +class TestGitHubAPIError: + """Tests for GitHubAPIError exception.""" + + def test_error_message(self): + """Test error message formatting.""" + error = GitHubAPIError(status_code=404, message="Not Found") + + assert "404" in str(error) + assert "Not Found" in str(error) + + def test_error_with_details(self): + """Test error with additional details.""" + error = GitHubAPIError( + status_code=422, + message="Validation Failed", + details={"errors": [{"field": "title", "code": "missing"}]}, + ) + + assert error.details is not None + assert "title" in str(error.details) diff --git a/tests/unit/test_pr_repository.py b/tests/unit/test_pr_repository.py new file mode 100644 index 00000000..8af0b91f --- /dev/null +++ b/tests/unit/test_pr_repository.py @@ -0,0 +1,298 @@ +"""Unit tests for PRRepository (TDD - written before implementation).""" + +import pytest +from datetime import datetime + +from codeframe.persistence.repositories.pr_repository import PRRepository + + +class TestPRRepository: + """Tests for the PRRepository class.""" + + @pytest.fixture + def db(self, tmp_path): + """Create a test database with schema.""" + import sqlite3 + from codeframe.persistence.schema_manager import SchemaManager + + db_path = tmp_path / "test.db" + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + + # Create schema + schema_mgr = SchemaManager(conn) + schema_mgr.create_schema() + + return conn + + @pytest.fixture + def repo(self, db): + """Create PRRepository instance.""" + return PRRepository(sync_conn=db) + + @pytest.fixture + def project_id(self, db): + """Create a test project and return its ID.""" + cursor = db.cursor() + cursor.execute( + """ + INSERT INTO projects (name, description, workspace_path, status, phase) + VALUES ('Test Project', 'A test project', '/tmp/test', 'active', 'active') + """ + ) + db.commit() + return cursor.lastrowid + + @pytest.fixture + def issue_id(self, db, project_id): + """Create a test issue and return its ID.""" + cursor = db.cursor() + cursor.execute( + """ + INSERT INTO issues (project_id, issue_number, title, status, priority) + VALUES (?, 'ISSUE-001', 'Test Issue', 'pending', 1) + """, + (project_id,), + ) + db.commit() + return cursor.lastrowid + + def test_create_pr_returns_id(self, repo, project_id, issue_id): + """Test that create_pr returns the new PR ID.""" + pr_id = repo.create_pr( + project_id=project_id, + issue_id=issue_id, + branch_name="feature/test-branch", + title="Test PR", + body="This is a test PR", + base_branch="main", + head_branch="feature/test-branch", + ) + + assert pr_id is not None + assert isinstance(pr_id, int) + assert pr_id > 0 + + def test_create_pr_without_issue(self, repo, project_id): + """Test creating a PR without an associated issue.""" + pr_id = repo.create_pr( + project_id=project_id, + issue_id=None, + branch_name="feature/no-issue", + title="PR without issue", + body="No associated issue", + base_branch="main", + head_branch="feature/no-issue", + ) + + assert pr_id is not None + assert isinstance(pr_id, int) + + def test_get_pr_by_id(self, repo, project_id, issue_id): + """Test retrieving a PR by its ID.""" + pr_id = repo.create_pr( + project_id=project_id, + issue_id=issue_id, + branch_name="feature/test", + title="Test PR", + body="Test body", + base_branch="main", + head_branch="feature/test", + ) + + pr = repo.get_pr(pr_id) + + assert pr is not None + assert pr["id"] == pr_id + assert pr["project_id"] == project_id + assert pr["issue_id"] == issue_id + assert pr["branch_name"] == "feature/test" + assert pr["title"] == "Test PR" + assert pr["body"] == "Test body" + assert pr["base_branch"] == "main" + assert pr["head_branch"] == "feature/test" + assert pr["status"] == "open" + + def test_get_pr_not_found(self, repo): + """Test that get_pr returns None for non-existent PR.""" + pr = repo.get_pr(99999) + assert pr is None + + def test_update_pr_github_data(self, repo, project_id): + """Test updating PR with GitHub response data.""" + pr_id = repo.create_pr( + project_id=project_id, + issue_id=None, + branch_name="feature/gh-test", + title="GitHub Test PR", + body="Test", + base_branch="main", + head_branch="feature/gh-test", + ) + + github_created_at = datetime.now() + repo.update_pr_github_data( + pr_id=pr_id, + pr_number=42, + pr_url="https://github.com/owner/repo/pull/42", + github_created_at=github_created_at, + ) + + pr = repo.get_pr(pr_id) + assert pr["pr_number"] == 42 + assert pr["pr_url"] == "https://github.com/owner/repo/pull/42" + + def test_get_pr_by_number(self, repo, project_id): + """Test retrieving a PR by its GitHub PR number.""" + pr_id = repo.create_pr( + project_id=project_id, + issue_id=None, + branch_name="feature/numbered", + title="Numbered PR", + body="Test", + base_branch="main", + head_branch="feature/numbered", + ) + repo.update_pr_github_data( + pr_id=pr_id, + pr_number=123, + pr_url="https://github.com/owner/repo/pull/123", + github_created_at=datetime.now(), + ) + + pr = repo.get_pr_by_number(project_id, 123) + + assert pr is not None + assert pr["pr_number"] == 123 + assert pr["id"] == pr_id + + def test_get_pr_by_number_not_found(self, repo, project_id): + """Test that get_pr_by_number returns None for non-existent PR.""" + pr = repo.get_pr_by_number(project_id, 99999) + assert pr is None + + def test_list_prs_all(self, repo, project_id): + """Test listing all PRs for a project.""" + # Create multiple PRs + for i in range(3): + repo.create_pr( + project_id=project_id, + issue_id=None, + branch_name=f"feature/pr-{i}", + title=f"PR {i}", + body=f"Body {i}", + base_branch="main", + head_branch=f"feature/pr-{i}", + ) + + prs = repo.list_prs(project_id) + + assert len(prs) == 3 + + def test_list_prs_by_status(self, repo, project_id): + """Test listing PRs filtered by status.""" + # Create PRs with different statuses + pr_id1 = repo.create_pr( + project_id=project_id, + issue_id=None, + branch_name="feature/open", + title="Open PR", + body="Test", + base_branch="main", + head_branch="feature/open", + ) + + pr_id2 = repo.create_pr( + project_id=project_id, + issue_id=None, + branch_name="feature/merged", + title="Merged PR", + body="Test", + base_branch="main", + head_branch="feature/merged", + ) + repo.update_pr_status(pr_id2, "merged", merge_commit_sha="abc123") + + # List only open PRs + open_prs = repo.list_prs(project_id, status="open") + assert len(open_prs) == 1 + assert open_prs[0]["title"] == "Open PR" + + # List only merged PRs + merged_prs = repo.list_prs(project_id, status="merged") + assert len(merged_prs) == 1 + assert merged_prs[0]["title"] == "Merged PR" + + def test_update_pr_status_to_merged(self, repo, project_id): + """Test updating PR status to merged.""" + pr_id = repo.create_pr( + project_id=project_id, + issue_id=None, + branch_name="feature/merge", + title="To Merge", + body="Test", + base_branch="main", + head_branch="feature/merge", + ) + + repo.update_pr_status(pr_id, "merged", merge_commit_sha="def456") + + pr = repo.get_pr(pr_id) + assert pr["status"] == "merged" + assert pr["merge_commit_sha"] == "def456" + assert pr["merged_at"] is not None + + def test_update_pr_status_to_closed(self, repo, project_id): + """Test updating PR status to closed.""" + pr_id = repo.create_pr( + project_id=project_id, + issue_id=None, + branch_name="feature/close", + title="To Close", + body="Test", + base_branch="main", + head_branch="feature/close", + ) + + repo.update_pr_status(pr_id, "closed") + + pr = repo.get_pr(pr_id) + assert pr["status"] == "closed" + assert pr["closed_at"] is not None + + def test_get_pr_for_branch(self, repo, project_id): + """Test finding a PR by branch name.""" + repo.create_pr( + project_id=project_id, + issue_id=None, + branch_name="feature/unique-branch", + title="Unique Branch PR", + body="Test", + base_branch="main", + head_branch="feature/unique-branch", + ) + + pr = repo.get_pr_for_branch(project_id, "feature/unique-branch") + + assert pr is not None + assert pr["branch_name"] == "feature/unique-branch" + + def test_get_pr_for_branch_not_found(self, repo, project_id): + """Test that get_pr_for_branch returns None for non-existent branch.""" + pr = repo.get_pr_for_branch(project_id, "feature/nonexistent") + assert pr is None + + def test_create_pr_stores_created_at(self, repo, project_id): + """Test that create_pr automatically sets created_at timestamp.""" + pr_id = repo.create_pr( + project_id=project_id, + issue_id=None, + branch_name="feature/timestamp", + title="Timestamp PR", + body="Test", + base_branch="main", + head_branch="feature/timestamp", + ) + + pr = repo.get_pr(pr_id) + assert pr["created_at"] is not None diff --git a/tests/unit/test_pr_router.py b/tests/unit/test_pr_router.py new file mode 100644 index 00000000..adfd5ed6 --- /dev/null +++ b/tests/unit/test_pr_router.py @@ -0,0 +1,339 @@ +"""Unit tests for PR router (TDD - written before implementation).""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from datetime import datetime, UTC + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from codeframe.ui.routers.prs import router +from codeframe.git.github_integration import PRDetails, MergeResult, GitHubAPIError + + +@pytest.fixture +def app(): + """Create test FastAPI app with PR router.""" + app = FastAPI() + app.include_router(router) + return app + + +@pytest.fixture +def mock_db(): + """Create mock database with PR repository.""" + db = MagicMock() + db.get_project.return_value = {"id": 1, "name": "Test Project"} + db.user_has_project_access.return_value = True + + # Mock PR repository + db.pull_requests = MagicMock() + db.pull_requests.create_pr.return_value = 1 + db.pull_requests.get_pr.return_value = { + "id": 1, + "project_id": 1, + "pr_number": 42, + "pr_url": "https://github.com/owner/repo/pull/42", + "title": "Test PR", + "body": "Test body", + "status": "open", + "branch_name": "feature/test", + "base_branch": "main", + "head_branch": "feature/test", + "created_at": datetime.now(UTC).isoformat(), + } + db.pull_requests.list_prs.return_value = [ + { + "id": 1, + "pr_number": 42, + "title": "PR 1", + "status": "open", + }, + { + "id": 2, + "pr_number": 43, + "title": "PR 2", + "status": "open", + }, + ] + db.pull_requests.get_pr_by_number.return_value = { + "id": 1, + "pr_number": 42, + "title": "Test PR", + "status": "open", + } + + return db + + +@pytest.fixture +def mock_user(): + """Create mock authenticated user.""" + user = MagicMock() + user.id = 1 + user.email = "test@example.com" + return user + + +@pytest.fixture +def mock_github_config(): + """Mock GlobalConfig with GitHub credentials.""" + config = MagicMock() + config.github_token = "ghp_test_token" + config.github_repo = "owner/test-repo" + return config + + +@pytest.fixture +def client(app, mock_db, mock_user, mock_github_config): + """Create test client with dependencies overridden.""" + from codeframe.ui.dependencies import get_db + from codeframe.auth import get_current_user + + # Override dependencies + app.dependency_overrides[get_db] = lambda: mock_db + app.dependency_overrides[get_current_user] = lambda: mock_user + + # Mock the config + with patch("codeframe.ui.routers.prs.get_global_config", return_value=mock_github_config): + yield TestClient(app) + + +class TestCreatePR: + """Tests for POST /api/projects/{project_id}/prs.""" + + def test_create_pr_success(self, client, mock_db): + """Test successful PR creation.""" + mock_pr_details = PRDetails( + number=42, + url="https://github.com/owner/repo/pull/42", + state="open", + title="Test PR", + body="Test body", + created_at=datetime.now(UTC), + merged_at=None, + head_branch="feature/test", + base_branch="main", + ) + + with patch("codeframe.ui.routers.prs.GitHubIntegration") as MockGH: + mock_gh_instance = AsyncMock() + mock_gh_instance.create_pull_request.return_value = mock_pr_details + MockGH.return_value = mock_gh_instance + + response = client.post( + "/api/projects/1/prs", + json={ + "branch": "feature/test", + "title": "Test PR", + "body": "Test body", + "base": "main", + }, + ) + + assert response.status_code == 201 + data = response.json() + assert data["pr_number"] == 42 + assert data["pr_url"] == "https://github.com/owner/repo/pull/42" + assert data["status"] == "open" + + def test_create_pr_project_not_found(self, client, mock_db): + """Test PR creation with non-existent project.""" + mock_db.get_project.return_value = None + + response = client.post( + "/api/projects/999/prs", + json={ + "branch": "feature/test", + "title": "Test PR", + "body": "Test body", + }, + ) + + assert response.status_code == 404 + + def test_create_pr_access_denied(self, client, mock_db): + """Test PR creation without project access.""" + mock_db.user_has_project_access.return_value = False + + response = client.post( + "/api/projects/1/prs", + json={ + "branch": "feature/test", + "title": "Test PR", + "body": "Test body", + }, + ) + + assert response.status_code == 403 + + def test_create_pr_github_not_configured(self, app, mock_db, mock_user): + """Test PR creation when GitHub is not configured.""" + from codeframe.ui.dependencies import get_db + from codeframe.auth import get_current_user + + app.dependency_overrides[get_db] = lambda: mock_db + app.dependency_overrides[get_current_user] = lambda: mock_user + + # Mock missing GitHub config + mock_config = MagicMock() + mock_config.github_token = None + mock_config.github_repo = None + + with patch("codeframe.ui.routers.prs.get_global_config", return_value=mock_config): + client = TestClient(app) + response = client.post( + "/api/projects/1/prs", + json={ + "branch": "feature/test", + "title": "Test PR", + "body": "Test body", + }, + ) + + assert response.status_code == 400 + assert "GitHub" in response.json()["detail"] + + def test_create_pr_github_api_error(self, client, mock_db): + """Test PR creation with GitHub API error.""" + with patch("codeframe.ui.routers.prs.GitHubIntegration") as MockGH: + mock_gh_instance = AsyncMock() + mock_gh_instance.create_pull_request.side_effect = GitHubAPIError( + status_code=422, + message="Validation Failed", + ) + MockGH.return_value = mock_gh_instance + + response = client.post( + "/api/projects/1/prs", + json={ + "branch": "feature/test", + "title": "Test PR", + "body": "Test body", + }, + ) + + assert response.status_code == 422 + + +class TestListPRs: + """Tests for GET /api/projects/{project_id}/prs.""" + + def test_list_prs_success(self, client, mock_db): + """Test listing PRs successfully.""" + response = client.get("/api/projects/1/prs") + + assert response.status_code == 200 + data = response.json() + assert "prs" in data + assert len(data["prs"]) == 2 + assert data["total"] == 2 + + def test_list_prs_with_status_filter(self, client, mock_db): + """Test listing PRs with status filter.""" + mock_db.pull_requests.list_prs.return_value = [ + {"id": 1, "pr_number": 42, "title": "Open PR", "status": "open"} + ] + + response = client.get("/api/projects/1/prs?status=open") + + assert response.status_code == 200 + mock_db.pull_requests.list_prs.assert_called_with(1, status="open") + + def test_list_prs_project_not_found(self, client, mock_db): + """Test listing PRs for non-existent project.""" + mock_db.get_project.return_value = None + + response = client.get("/api/projects/999/prs") + + assert response.status_code == 404 + + +class TestGetPR: + """Tests for GET /api/projects/{project_id}/prs/{pr_number}.""" + + def test_get_pr_success(self, client, mock_db): + """Test getting PR details successfully.""" + response = client.get("/api/projects/1/prs/42") + + assert response.status_code == 200 + data = response.json() + assert data["pr_number"] == 42 + assert data["title"] == "Test PR" + + def test_get_pr_not_found(self, client, mock_db): + """Test getting non-existent PR.""" + mock_db.pull_requests.get_pr_by_number.return_value = None + + response = client.get("/api/projects/1/prs/999") + + assert response.status_code == 404 + + +class TestMergePR: + """Tests for POST /api/projects/{project_id}/prs/{pr_number}/merge.""" + + def test_merge_pr_success(self, client, mock_db): + """Test successful PR merge.""" + mock_merge_result = MergeResult( + sha="abc123def456", + merged=True, + message="Pull Request successfully merged", + ) + + with patch("codeframe.ui.routers.prs.GitHubIntegration") as MockGH: + mock_gh_instance = AsyncMock() + mock_gh_instance.merge_pull_request.return_value = mock_merge_result + MockGH.return_value = mock_gh_instance + + response = client.post( + "/api/projects/1/prs/42/merge", + json={"method": "squash"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["merged"] is True + assert data["merge_commit_sha"] == "abc123def456" + + def test_merge_pr_not_mergeable(self, client, mock_db): + """Test merging non-mergeable PR.""" + with patch("codeframe.ui.routers.prs.GitHubIntegration") as MockGH: + mock_gh_instance = AsyncMock() + mock_gh_instance.merge_pull_request.side_effect = GitHubAPIError( + status_code=405, + message="Pull Request is not mergeable", + ) + MockGH.return_value = mock_gh_instance + + response = client.post( + "/api/projects/1/prs/42/merge", + json={"method": "squash"}, + ) + + assert response.status_code == 422 + + +class TestClosePR: + """Tests for POST /api/projects/{project_id}/prs/{pr_number}/close.""" + + def test_close_pr_success(self, client, mock_db): + """Test closing PR successfully.""" + with patch("codeframe.ui.routers.prs.GitHubIntegration") as MockGH: + mock_gh_instance = AsyncMock() + mock_gh_instance.close_pull_request.return_value = True + MockGH.return_value = mock_gh_instance + + response = client.post("/api/projects/1/prs/42/close") + + assert response.status_code == 200 + data = response.json() + assert data["closed"] is True + + def test_close_pr_not_found(self, client, mock_db): + """Test closing non-existent PR.""" + mock_db.pull_requests.get_pr_by_number.return_value = None + + response = client.post("/api/projects/1/prs/999/close") + + assert response.status_code == 404