diff --git a/comfy_cli/cmdline.py b/comfy_cli/cmdline.py index 247ba95a..ec4439f8 100644 --- a/comfy_cli/cmdline.py +++ b/comfy_cli/cmdline.py @@ -247,6 +247,13 @@ def install( Optional[str], typer.Option(help="Specify commit hash for ComfyUI-Manager"), ] = None, + pr: Annotated[ + Optional[str], + typer.Option( + show_default=False, + help="Install from a specific PR. Supports formats: username:branch, #123, or PR URL", + ), + ] = None, ): check_for_updates() checker = EnvChecker() @@ -338,6 +345,10 @@ def install( ) raise typer.Exit(code=1) + if pr and version not in {None, "nightly"} or commit: + rprint("--pr cannot be used with --version or --commit") + raise typer.Exit(code=1) + install_inner.execute( url, manager_url, @@ -353,6 +364,7 @@ def install( skip_requirement=skip_requirement, fast_deps=fast_deps, manager_commit=manager_commit, + pr=pr, ) rprint(f"ComfyUI is installed at: {comfy_path}") diff --git a/comfy_cli/command/github/pr_info.py b/comfy_cli/command/github/pr_info.py new file mode 100644 index 00000000..d7a4b7f9 --- /dev/null +++ b/comfy_cli/command/github/pr_info.py @@ -0,0 +1,16 @@ +from typing import NamedTuple + + +class PRInfo(NamedTuple): + number: int + head_repo_url: str + head_branch: str + base_repo_url: str + base_branch: str + title: str + user: str + mergeable: bool + + @property + def is_fork(self) -> bool: + return self.head_repo_url != self.base_repo_url diff --git a/comfy_cli/command/install.py b/comfy_cli/command/install.py index abb26229..e208499b 100755 --- a/comfy_cli/command/install.py +++ b/comfy_cli/command/install.py @@ -3,6 +3,7 @@ import subprocess import sys from typing import Dict, List, Optional, TypedDict +from urllib.parse import urlparse import requests import semver @@ -13,8 +14,9 @@ from comfy_cli import constants, ui, utils from comfy_cli.command.custom_nodes.command import update_node_id_cache +from comfy_cli.command.github.pr_info import PRInfo from comfy_cli.constants import GPU_OPTION -from comfy_cli.git_utils import git_checkout_tag +from comfy_cli.git_utils import checkout_pr, git_checkout_tag from comfy_cli.uv import DependencyCompiler from comfy_cli.workspace_manager import WorkspaceManager, check_comfy_repo @@ -175,9 +177,15 @@ def execute( skip_torch_or_directml: bool = False, skip_requirement: bool = False, fast_deps: bool = False, + pr: Optional[str] = None, *args, **kwargs, ): + # Install ComfyUI from a given PR reference. + if pr: + url = handle_pr_checkout(pr, comfy_path) + version = "nightly" + """ Install ComfyUI from a given URL. """ @@ -272,6 +280,66 @@ def execute( rprint("") +def handle_pr_checkout(pr_ref: str, comfy_path: str) -> str: + try: + repo_owner, repo_name, pr_number = parse_pr_reference(pr_ref) + except ValueError as e: + rprint(f"[bold red]Error parsing PR reference: {e}[/bold red]") + raise typer.Exit(code=1) + + try: + if pr_number: + pr_info = fetch_pr_info(repo_owner, repo_name, pr_number) + else: + username, branch = pr_ref.split(":", 1) + pr_info = find_pr_by_branch("comfyanonymous", "ComfyUI", username, branch) + + if not pr_info: + rprint(f"[bold red]PR not found: {pr_ref}[/bold red]") + raise typer.Exit(code=1) + + except Exception as e: + rprint(f"[bold red]Error fetching PR information: {e}[/bold red]") + raise typer.Exit(code=1) + + console.print( + Panel( + f"[bold]PR #{pr_info.number}[/bold]: {pr_info.title}\n" + f"[yellow]Author[/yellow]: {pr_info.user}\n" + f"[yellow]Branch[/yellow]: {pr_info.head_branch}\n" + f"[yellow]Source[/yellow]: {pr_info.head_repo_url}\n" + f"[yellow]Mergeable[/yellow]: {'✓' if pr_info.mergeable else '✗'}", + title="[bold blue]Pull Request Information[/bold blue]", + border_style="blue", + ) + ) + + if not workspace_manager.skip_prompting: + if not ui.prompt_confirm_action(f"Install ComfyUI from PR #{pr_info.number}?", True): + rprint("Aborting...") + raise typer.Exit(code=1) + + parent_path = os.path.abspath(os.path.join(comfy_path, "..")) + + if not os.path.exists(parent_path): + os.makedirs(parent_path, exist_ok=True) + + if not os.path.exists(comfy_path): + rprint(f"Cloning base repository to {comfy_path}...") + clone_comfyui(url=pr_info.base_repo_url, repo_dir=comfy_path) + + rprint(f"Checking out PR #{pr_info.number}: {pr_info.title}") + success = checkout_pr(comfy_path, pr_info) + if not success: + rprint("[bold red]Failed to checkout PR[/bold red]") + raise typer.Exit(code=1) + + rprint(f"[bold green]✓ Successfully checked out PR #{pr_info.number}[/bold green]") + rprint(f"[bold yellow]Note:[/bold yellow] You are now on branch pr-{pr_info.number}") + + return pr_info.base_repo_url + + def validate_version(version: str) -> Optional[str]: """ Validates the version string as 'latest', 'nightly', or a semantically version number. @@ -306,6 +374,21 @@ class GitHubRateLimitError(Exception): """Raised when GitHub API rate limit is exceeded""" +def handle_github_rate_limit(response): + # Check rate limit headers + remaining = int(response.headers.get("x-ratelimit-remaining", 0)) + if remaining == 0: + reset_time = int(response.headers.get("x-ratelimit-reset", 0)) + message = f"Primary rate limit from Github exceeded! Please retry after: {reset_time})" + raise GitHubRateLimitError(message) + + if "retry-after" in response.headers: + wait_seconds = int(response.headers["retry-after"]) + message = f"Rate limit from Github exceeded! Please wait {wait_seconds} seconds before retrying." + rprint(f"[yellow]{message}[/yellow]") + raise GitHubRateLimitError(message) + + def fetch_github_releases(repo_owner: str, repo_name: str) -> List[Dict[str, str]]: """ Fetch the list of releases from the GitHub API. @@ -321,18 +404,7 @@ def fetch_github_releases(repo_owner: str, repo_name: str) -> List[Dict[str, str # Handle rate limiting if response.status_code in (403, 429): - # Check rate limit headers - remaining = int(response.headers.get("x-ratelimit-remaining", 0)) - if remaining == 0: - reset_time = int(response.headers.get("x-ratelimit-reset", 0)) - message = f"Primary rate limit from Github exceeded! Please retry after: {reset_time})" - raise GitHubRateLimitError(message) - - if "retry-after" in response.headers: - wait_seconds = int(response.headers["retry-after"]) - message = f"Rate limit from Github exceeded! Please wait {wait_seconds} seconds before retrying." - rprint(f"[yellow]{message}[/yellow]") - raise GitHubRateLimitError(message) + handle_github_rate_limit(response) response.raise_for_status() return response.json() @@ -459,3 +531,103 @@ def get_latest_release(repo_owner: str, repo_name: str) -> Optional[GithubReleas except requests.RequestException as e: rprint(f"Error fetching latest release: {e}") return None + + +def parse_pr_reference(pr_ref: str) -> tuple[str, str, Optional[int]]: + """ + support formats: + - username:branch-name + - #123 + - https://github.com/comfyanonymous/ComfyUI/pull/123 + + Returns: + (repo_owner, repo_name, pr_number) + """ + pr_ref = pr_ref.strip() + + if pr_ref.startswith("https://github.com/"): + parsed = urlparse(pr_ref) + if "/pull/" in parsed.path: + path_parts = parsed.path.strip("/").split("/") + if len(path_parts) >= 4: + repo_owner = path_parts[0] + repo_name = path_parts[1] + pr_number = int(path_parts[3]) + return repo_owner, repo_name, pr_number + + elif pr_ref.startswith("#"): + pr_number = int(pr_ref[1:]) + return "comfyanonymous", "ComfyUI", pr_number + + elif ":" in pr_ref: + username, branch = pr_ref.split(":", 1) + return username, "ComfyUI", None + + else: + raise ValueError(f"Invalid PR reference format: {pr_ref}") + + +def fetch_pr_info(repo_owner: str, repo_name: str, pr_number: int) -> PRInfo: + url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/pulls/{pr_number}" + + headers = {} + if github_token := os.getenv("GITHUB_TOKEN"): + headers["Authorization"] = f"Bearer {github_token}" + + try: + response = requests.get(url, headers=headers, timeout=10) + + if response is None: + raise Exception(f"Failed to fetch PR #{pr_number}: No response from GitHub API") + + if response.status_code in (403, 429): + handle_github_rate_limit(response) + + response.raise_for_status() + data = response.json() + + return PRInfo( + number=data["number"], + head_repo_url=data["head"]["repo"]["clone_url"], + head_branch=data["head"]["ref"], + base_repo_url=data["base"]["repo"]["clone_url"], + base_branch=data["base"]["ref"], + title=data["title"], + user=data["head"]["repo"]["owner"]["login"], + mergeable=data.get("mergeable", True), + ) + + except requests.RequestException as e: + raise Exception(f"Failed to fetch PR #{pr_number}: {e}") + + +def find_pr_by_branch(repo_owner: str, repo_name: str, username: str, branch: str) -> Optional[PRInfo]: + url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/pulls" + params = {"head": f"{username}:{branch}", "state": "open"} + + headers = {} + if github_token := os.getenv("GITHUB_TOKEN"): + headers["Authorization"] = f"Bearer {github_token}" + + try: + response = requests.get(url, headers=headers, params=params, timeout=10) + response.raise_for_status() + data = response.json() + + if data: + pr_data = data[0] + return PRInfo( + number=pr_data["number"], + head_repo_url=pr_data["head"]["repo"]["clone_url"], + head_branch=pr_data["head"]["ref"], + base_repo_url=pr_data["base"]["repo"]["clone_url"], + base_branch=pr_data["base"]["ref"], + title=pr_data["title"], + user=pr_data["head"]["repo"]["owner"]["login"], + mergeable=pr_data.get("mergeable", True), + ) + + return None + + except requests.RequestException: + return None diff --git a/comfy_cli/git_utils.py b/comfy_cli/git_utils.py index 1acdad5a..54d4fe08 100644 --- a/comfy_cli/git_utils.py +++ b/comfy_cli/git_utils.py @@ -5,9 +5,25 @@ from rich.panel import Panel from rich.text import Text +from comfy_cli.command.github.pr_info import PRInfo + console = Console() +def sanitize_for_local_branch(branch_name: str) -> str: + if not branch_name: + return "unknown" + + sanitized = branch_name.replace("/", "-") + + while "--" in sanitized: + sanitized = sanitized.replace("--", "-") + + sanitized = sanitized.strip("-") + + return sanitized or "unknown" + + def git_checkout_tag(repo_path: str, tag: str) -> bool: """ Checkout a specific Git tag in the given repository. @@ -56,3 +72,79 @@ def git_checkout_tag(repo_path: str, tag: str) -> bool: finally: # Ensure we always return to the original directory os.chdir(original_dir) + + +def checkout_pr(repo_path: str, pr_info: PRInfo) -> bool: + original_dir = os.getcwd() + + try: + os.chdir(repo_path) + + if pr_info.is_fork: + remote_name = f"pr-{pr_info.number}-{pr_info.user}" + + result = subprocess.run(["git", "remote", "get-url", remote_name], capture_output=True, text=True) + + if result.returncode != 0: + subprocess.run( + ["git", "remote", "add", remote_name, pr_info.head_repo_url], + check=True, + capture_output=True, + text=True, + ) + + subprocess.run( + ["git", "fetch", remote_name, pr_info.head_branch], check=True, capture_output=True, text=True + ) + + # fix: "feature/add-support" -> "pr-123-feature-add-support" + sanitized_branch = sanitize_for_local_branch(pr_info.head_branch) + local_branch = f"pr-{pr_info.number}-{sanitized_branch}" + + subprocess.run( + ["git", "checkout", "-B", local_branch, f"{remote_name}/{pr_info.head_branch}"], + check=True, + capture_output=True, + text=True, + ) + + else: + subprocess.run(["git", "fetch", "origin", pr_info.head_branch], check=True, capture_output=True, text=True) + + sanitized_branch = sanitize_for_local_branch(pr_info.head_branch) + local_branch = f"pr-{pr_info.number}-{sanitized_branch}" + + subprocess.run( + ["git", "checkout", "-B", local_branch, f"origin/{pr_info.head_branch}"], + check=True, + capture_output=True, + text=True, + ) + + console.print(f"[bold green]Successfully checked out PR #{pr_info.number}: {pr_info.title}[/bold green]") + console.print(f"[bold yellow]Local branch:[/bold yellow] {local_branch}") + return True + + except subprocess.CalledProcessError as e: + error_message = Text() + error_message.append("Git PR Checkout Error", style="bold red on white") + error_message.append(f"\n\nFailed to checkout PR #{pr_info.number}", style="bold yellow") + error_message.append(f"\nTitle: {pr_info.title}", style="italic") + error_message.append(f"\nBranch: {pr_info.head_branch}", style="italic") + + if e.stderr: + error_message.append("\n\nError output:", style="bold red") + error_message.append(f"\n{e.stderr}", style="italic yellow") + + console.print( + Panel( + error_message, + title="[bold white on red]PR Checkout Failed[/bold white on red]", + border_style="red", + expand=False, + ) + ) + return False + + finally: + os.chdir(original_dir) diff --git a/tests/comfy_cli/command/github/test_pr.py b/tests/comfy_cli/command/github/test_pr.py new file mode 100644 index 00000000..2802922e --- /dev/null +++ b/tests/comfy_cli/command/github/test_pr.py @@ -0,0 +1,376 @@ +import subprocess +from unittest.mock import Mock, patch + +import pytest +import requests +from typer.testing import CliRunner + +from comfy_cli.cmdline import app, g_exclusivity, g_gpu_exclusivity +from comfy_cli.command.install import PRInfo, fetch_pr_info, find_pr_by_branch, handle_pr_checkout, parse_pr_reference +from comfy_cli.git_utils import checkout_pr + + +@pytest.fixture(scope="function") +def runner(): + g_exclusivity.reset_for_testing() + g_gpu_exclusivity.reset_for_testing() + return CliRunner() + + +@pytest.fixture +def sample_pr_info(): + return PRInfo( + number=123, + head_repo_url="https://github.com/jtydhr88/ComfyUI.git", + head_branch="load-3d-nodes", + base_repo_url="https://github.com/comfyanonymous/ComfyUI.git", + base_branch="master", + title="Add 3D node loading support", + user="jtydhr88", + mergeable=True, + ) + + +class TestPRReferenceParsing: + def test_parse_pr_number_format(self): + """Test parsing #123 format""" + repo_owner, repo_name, pr_number = parse_pr_reference("#123") + assert repo_owner == "comfyanonymous" + assert repo_name == "ComfyUI" + assert pr_number == 123 + + def test_parse_user_branch_format(self): + """Test parsing username:branch format""" + repo_owner, repo_name, pr_number = parse_pr_reference("jtydhr88:load-3d-nodes") + assert repo_owner == "jtydhr88" + assert repo_name == "ComfyUI" + assert pr_number is None + + def test_parse_github_url_format(self): + """Test parsing full GitHub PR URL""" + url = "https://github.com/comfyanonymous/ComfyUI/pull/456" + repo_owner, repo_name, pr_number = parse_pr_reference(url) + assert repo_owner == "comfyanonymous" + assert repo_name == "ComfyUI" + assert pr_number == 456 + + def test_parse_invalid_format(self): + """Test parsing invalid format raises ValueError""" + with pytest.raises(ValueError, match="Invalid PR reference format"): + parse_pr_reference("invalid-format") + + def test_parse_empty_string(self): + """Test parsing empty string raises ValueError""" + with pytest.raises(ValueError): + parse_pr_reference("") + + +class TestGitHubAPIIntegration: + """Test GitHub API integration""" + + @patch("requests.get") + def test_fetch_pr_info_success(self, mock_get, sample_pr_info): + """Test successful PR info fetching""" + # Mock API response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "number": 123, + "title": "Add 3D node loading support", + "head": { + "repo": {"clone_url": "https://github.com/jtydhr88/ComfyUI.git", "owner": {"login": "jtydhr88"}}, + "ref": "load-3d-nodes", + }, + "base": {"repo": {"clone_url": "https://github.com/comfyanonymous/ComfyUI.git"}, "ref": "master"}, + "mergeable": True, + } + mock_get.return_value = mock_response + + result = fetch_pr_info("comfyanonymous", "ComfyUI", 123) + + assert result.number == 123 + assert result.title == "Add 3D node loading support" + assert result.user == "jtydhr88" + assert result.head_branch == "load-3d-nodes" + assert result.mergeable is True + + @patch("requests.get") + def test_fetch_pr_info_not_found(self, mock_get): + """Test PR not found (404)""" + mock_response = Mock() + mock_response.status_code = 404 + mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found") + mock_get.return_value = mock_response + + with pytest.raises(Exception, match="Failed to fetch PR"): + fetch_pr_info("comfyanonymous", "ComfyUI", 999) + + @patch("requests.get") + def test_fetch_pr_info_rate_limit(self, mock_get): + """Test GitHub API rate limit handling""" + mock_response = Mock() + mock_response.status_code = 403 + mock_response.headers = {"x-ratelimit-remaining": "0"} + mock_get.return_value = mock_response + + with pytest.raises(Exception, match="Primary rate limit from Github exceeded!"): + fetch_pr_info("comfyanonymous", "ComfyUI", 123) + + @patch("requests.get") + def test_find_pr_by_branch_success(self, mock_get): + """Test finding PR by branch name""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = [ + { + "number": 123, + "title": "Add 3D node loading support", + "head": { + "repo": {"clone_url": "https://github.com/jtydhr88/ComfyUI.git", "owner": {"login": "jtydhr88"}}, + "ref": "load-3d-nodes", + }, + "base": {"repo": {"clone_url": "https://github.com/comfyanonymous/ComfyUI.git"}, "ref": "master"}, + "mergeable": True, + } + ] + mock_get.return_value = mock_response + + result = find_pr_by_branch("comfyanonymous", "ComfyUI", "jtydhr88", "load-3d-nodes") + + assert result is not None + assert result.number == 123 + assert result.user == "jtydhr88" + + @patch("requests.get") + def test_find_pr_by_branch_not_found(self, mock_get): + """Test branch not found returns None""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = [] # Empty list + mock_get.return_value = mock_response + + result = find_pr_by_branch("comfyanonymous", "ComfyUI", "user", "nonexistent-branch") + assert result is None + + +class TestGitOperations: + """Test Git operations for PR checkout""" + + @patch("subprocess.run") + @patch("os.chdir") + @patch("os.getcwd") + def test_checkout_pr_fork_success(self, mock_getcwd, mock_chdir, mock_subprocess, sample_pr_info): + """Test successful checkout of PR from fork""" + mock_getcwd.return_value = "/original/dir" + + mock_subprocess.side_effect = [ + subprocess.CompletedProcess([], 1), + subprocess.CompletedProcess([], 0), + subprocess.CompletedProcess([], 0), + subprocess.CompletedProcess([], 0), + ] + + result = checkout_pr("/repo/path", sample_pr_info) + + assert result is True + assert mock_subprocess.call_count == 4 + + calls = mock_subprocess.call_args_list + assert "git" in calls[0][0][0] + assert "remote" in calls[1][0][0] + assert "fetch" in calls[2][0][0] + assert "checkout" in calls[3][0][0] + + @patch("subprocess.run") + @patch("os.chdir") + @patch("os.getcwd") + def test_checkout_pr_non_fork_success(self, mock_getcwd, mock_chdir, mock_subprocess): + """Test successful checkout of PR from same repo""" + mock_getcwd.return_value = "/original/dir" + + pr_info = PRInfo( + number=123, + head_repo_url="https://github.com/comfyanonymous/ComfyUI.git", + head_branch="feature-branch", + base_repo_url="https://github.com/comfyanonymous/ComfyUI.git", + base_branch="master", + title="Feature branch", + user="comfyanonymous", + mergeable=True, + ) + + mock_subprocess.side_effect = [ + subprocess.CompletedProcess([], 0), # fetch succeeds + subprocess.CompletedProcess([], 0), # checkout succeeds + ] + + result = checkout_pr("/repo/path", pr_info) + + assert result is True + assert mock_subprocess.call_count == 2 + + @patch("subprocess.run") + @patch("os.chdir") + @patch("os.getcwd") + def test_checkout_pr_git_failure(self, mock_getcwd, mock_chdir, mock_subprocess, sample_pr_info): + """Test Git operation failure""" + mock_getcwd.return_value = "/original/dir" + + error = subprocess.CalledProcessError(1, "git", stderr="Permission denied") + mock_subprocess.side_effect = error + + result = checkout_pr("/repo/path", sample_pr_info) + + assert result is False + + +class TestHandlePRCheckout: + """Test the main PR checkout handler""" + + @patch("comfy_cli.command.install.parse_pr_reference") + @patch("comfy_cli.command.install.fetch_pr_info") + @patch("comfy_cli.command.install.checkout_pr") + @patch("comfy_cli.command.install.clone_comfyui") + @patch("comfy_cli.ui.prompt_confirm_action") + @patch("os.path.exists") + @patch("os.makedirs") + def test_handle_pr_checkout_success( + self, + mock_makedirs, + mock_exists, + mock_confirm, + mock_clone, + mock_checkout, + mock_fetch, + mock_parse, + sample_pr_info, + ): + """Test successful PR checkout handling""" + mock_parse.return_value = ("jtydhr88", "ComfyUI", 123) + mock_fetch.return_value = sample_pr_info + mock_exists.side_effect = [True, False] # Parent exists, repo doesn't + mock_confirm.return_value = True + mock_checkout.return_value = True + + with patch("comfy_cli.command.install.workspace_manager") as mock_ws: + mock_ws.skip_prompting = False + + result = handle_pr_checkout("jtydhr88:load-3d-nodes", "/path/to/comfy") + + assert result == "https://github.com/comfyanonymous/ComfyUI.git" + mock_clone.assert_called_once() + mock_checkout.assert_called_once() + + +class TestCommandLineIntegration: + """Test command line integration""" + + @patch("comfy_cli.command.install.execute") + def test_install_with_pr_parameter(self, mock_execute, runner): + """Test install command with --pr parameter""" + result = runner.invoke(app, ["install", "--pr", "jtydhr88:load-3d-nodes", "--nvidia", "--skip-prompt"]) + + assert "Invalid PR reference format" not in result.stdout + + if mock_execute.called: + call_args = mock_execute.call_args + assert "pr" in call_args.kwargs or len(call_args.args) > 8 + + def test_pr_and_version_conflict(self, runner): + """Test that --pr conflicts with --version""" + result = runner.invoke(app, ["install", "--pr", "#123", "--version", "1.0.0"]) + + assert result.exit_code != 0 + + def test_pr_and_commit_conflict(self, runner): + """Test that --pr conflicts with --commit""" + result = runner.invoke(app, ["install", "--pr", "#123", "--version", "nightly", "--commit", "abc123"]) + + assert result.exit_code != 0 + + +class TestPRInfoDataClass: + """Test PRInfo data class""" + + def test_pr_info_is_fork_true(self): + """Test is_fork property returns True for fork""" + pr_info = PRInfo( + number=123, + head_repo_url="https://github.com/user/ComfyUI.git", + head_branch="branch", + base_repo_url="https://github.com/comfyanonymous/ComfyUI.git", + base_branch="master", + title="Title", + user="user", + mergeable=True, + ) + assert pr_info.is_fork is True + + def test_pr_info_is_fork_false(self): + """Test is_fork property returns False for same repo""" + pr_info = PRInfo( + number=123, + head_repo_url="https://github.com/comfyanonymous/ComfyUI.git", + head_branch="feature", + base_repo_url="https://github.com/comfyanonymous/ComfyUI.git", + base_branch="master", + title="Title", + user="comfyanonymous", + mergeable=True, + ) + assert pr_info.is_fork is False + + +class TestEdgeCases: + """Test edge cases and error conditions""" + + def test_parse_pr_reference_whitespace(self): + """Test parsing with whitespace""" + repo_owner, repo_name, pr_number = parse_pr_reference(" #123 ") + assert repo_owner == "comfyanonymous" + assert repo_name == "ComfyUI" + assert pr_number == 123 + + @patch("requests.get") + def test_fetch_pr_info_with_github_token(self, mock_get): + """Test PR fetching with GitHub token""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "number": 123, + "title": "Test", + "head": {"repo": {"clone_url": "url", "owner": {"login": "user"}}, "ref": "branch"}, + "base": {"repo": {"clone_url": "base_url"}, "ref": "master"}, + "mergeable": True, + } + mock_get.return_value = mock_response + + with patch.dict("os.environ", {"GITHUB_TOKEN": "test-token"}): + fetch_pr_info("owner", "repo", 123) + + call_args = mock_get.call_args + headers = call_args.kwargs.get("headers", {}) + assert "Authorization" in headers + assert headers["Authorization"] == "Bearer test-token" + + @patch("subprocess.run") + @patch("os.chdir") + @patch("os.getcwd") + def test_checkout_pr_remote_already_exists(self, mock_getcwd, mock_chdir, mock_subprocess, sample_pr_info): + """Test checkout when remote already exists""" + mock_getcwd.return_value = "/dir" + + mock_subprocess.side_effect = [ + subprocess.CompletedProcess([], 0), + subprocess.CompletedProcess([], 0), + subprocess.CompletedProcess([], 0), + ] + + result = checkout_pr("/repo", sample_pr_info) + + assert result is True + assert mock_subprocess.call_count == 3 + + +if __name__ == "__main__": + pytest.main([__file__])