Skip to content

Commit 8ee6f42

Browse files
authored
add PR support (#287)
1 parent be0c4c5 commit 8ee6f42

File tree

5 files changed

+681
-13
lines changed

5 files changed

+681
-13
lines changed

comfy_cli/cmdline.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,13 @@ def install(
247247
Optional[str],
248248
typer.Option(help="Specify commit hash for ComfyUI-Manager"),
249249
] = None,
250+
pr: Annotated[
251+
Optional[str],
252+
typer.Option(
253+
show_default=False,
254+
help="Install from a specific PR. Supports formats: username:branch, #123, or PR URL",
255+
),
256+
] = None,
250257
):
251258
check_for_updates()
252259
checker = EnvChecker()
@@ -338,6 +345,10 @@ def install(
338345
)
339346
raise typer.Exit(code=1)
340347

348+
if pr and version not in {None, "nightly"} or commit:
349+
rprint("--pr cannot be used with --version or --commit")
350+
raise typer.Exit(code=1)
351+
341352
install_inner.execute(
342353
url,
343354
manager_url,
@@ -353,6 +364,7 @@ def install(
353364
skip_requirement=skip_requirement,
354365
fast_deps=fast_deps,
355366
manager_commit=manager_commit,
367+
pr=pr,
356368
)
357369

358370
rprint(f"ComfyUI is installed at: {comfy_path}")
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import NamedTuple
2+
3+
4+
class PRInfo(NamedTuple):
5+
number: int
6+
head_repo_url: str
7+
head_branch: str
8+
base_repo_url: str
9+
base_branch: str
10+
title: str
11+
user: str
12+
mergeable: bool
13+
14+
@property
15+
def is_fork(self) -> bool:
16+
return self.head_repo_url != self.base_repo_url

comfy_cli/command/install.py

Lines changed: 185 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import subprocess
44
import sys
55
from typing import Dict, List, Optional, TypedDict
6+
from urllib.parse import urlparse
67

78
import requests
89
import semver
@@ -13,8 +14,9 @@
1314

1415
from comfy_cli import constants, ui, utils
1516
from comfy_cli.command.custom_nodes.command import update_node_id_cache
17+
from comfy_cli.command.github.pr_info import PRInfo
1618
from comfy_cli.constants import GPU_OPTION
17-
from comfy_cli.git_utils import git_checkout_tag
19+
from comfy_cli.git_utils import checkout_pr, git_checkout_tag
1820
from comfy_cli.uv import DependencyCompiler
1921
from comfy_cli.workspace_manager import WorkspaceManager, check_comfy_repo
2022

@@ -175,9 +177,15 @@ def execute(
175177
skip_torch_or_directml: bool = False,
176178
skip_requirement: bool = False,
177179
fast_deps: bool = False,
180+
pr: Optional[str] = None,
178181
*args,
179182
**kwargs,
180183
):
184+
# Install ComfyUI from a given PR reference.
185+
if pr:
186+
url = handle_pr_checkout(pr, comfy_path)
187+
version = "nightly"
188+
181189
"""
182190
Install ComfyUI from a given URL.
183191
"""
@@ -272,6 +280,66 @@ def execute(
272280
rprint("")
273281

274282

283+
def handle_pr_checkout(pr_ref: str, comfy_path: str) -> str:
284+
try:
285+
repo_owner, repo_name, pr_number = parse_pr_reference(pr_ref)
286+
except ValueError as e:
287+
rprint(f"[bold red]Error parsing PR reference: {e}[/bold red]")
288+
raise typer.Exit(code=1)
289+
290+
try:
291+
if pr_number:
292+
pr_info = fetch_pr_info(repo_owner, repo_name, pr_number)
293+
else:
294+
username, branch = pr_ref.split(":", 1)
295+
pr_info = find_pr_by_branch("comfyanonymous", "ComfyUI", username, branch)
296+
297+
if not pr_info:
298+
rprint(f"[bold red]PR not found: {pr_ref}[/bold red]")
299+
raise typer.Exit(code=1)
300+
301+
except Exception as e:
302+
rprint(f"[bold red]Error fetching PR information: {e}[/bold red]")
303+
raise typer.Exit(code=1)
304+
305+
console.print(
306+
Panel(
307+
f"[bold]PR #{pr_info.number}[/bold]: {pr_info.title}\n"
308+
f"[yellow]Author[/yellow]: {pr_info.user}\n"
309+
f"[yellow]Branch[/yellow]: {pr_info.head_branch}\n"
310+
f"[yellow]Source[/yellow]: {pr_info.head_repo_url}\n"
311+
f"[yellow]Mergeable[/yellow]: {'✓' if pr_info.mergeable else '✗'}",
312+
title="[bold blue]Pull Request Information[/bold blue]",
313+
border_style="blue",
314+
)
315+
)
316+
317+
if not workspace_manager.skip_prompting:
318+
if not ui.prompt_confirm_action(f"Install ComfyUI from PR #{pr_info.number}?", True):
319+
rprint("Aborting...")
320+
raise typer.Exit(code=1)
321+
322+
parent_path = os.path.abspath(os.path.join(comfy_path, ".."))
323+
324+
if not os.path.exists(parent_path):
325+
os.makedirs(parent_path, exist_ok=True)
326+
327+
if not os.path.exists(comfy_path):
328+
rprint(f"Cloning base repository to {comfy_path}...")
329+
clone_comfyui(url=pr_info.base_repo_url, repo_dir=comfy_path)
330+
331+
rprint(f"Checking out PR #{pr_info.number}: {pr_info.title}")
332+
success = checkout_pr(comfy_path, pr_info)
333+
if not success:
334+
rprint("[bold red]Failed to checkout PR[/bold red]")
335+
raise typer.Exit(code=1)
336+
337+
rprint(f"[bold green]✓ Successfully checked out PR #{pr_info.number}[/bold green]")
338+
rprint(f"[bold yellow]Note:[/bold yellow] You are now on branch pr-{pr_info.number}")
339+
340+
return pr_info.base_repo_url
341+
342+
275343
def validate_version(version: str) -> Optional[str]:
276344
"""
277345
Validates the version string as 'latest', 'nightly', or a semantically version number.
@@ -306,6 +374,21 @@ class GitHubRateLimitError(Exception):
306374
"""Raised when GitHub API rate limit is exceeded"""
307375

308376

377+
def handle_github_rate_limit(response):
378+
# Check rate limit headers
379+
remaining = int(response.headers.get("x-ratelimit-remaining", 0))
380+
if remaining == 0:
381+
reset_time = int(response.headers.get("x-ratelimit-reset", 0))
382+
message = f"Primary rate limit from Github exceeded! Please retry after: {reset_time})"
383+
raise GitHubRateLimitError(message)
384+
385+
if "retry-after" in response.headers:
386+
wait_seconds = int(response.headers["retry-after"])
387+
message = f"Rate limit from Github exceeded! Please wait {wait_seconds} seconds before retrying."
388+
rprint(f"[yellow]{message}[/yellow]")
389+
raise GitHubRateLimitError(message)
390+
391+
309392
def fetch_github_releases(repo_owner: str, repo_name: str) -> List[Dict[str, str]]:
310393
"""
311394
Fetch the list of releases from the GitHub API.
@@ -321,18 +404,7 @@ def fetch_github_releases(repo_owner: str, repo_name: str) -> List[Dict[str, str
321404

322405
# Handle rate limiting
323406
if response.status_code in (403, 429):
324-
# Check rate limit headers
325-
remaining = int(response.headers.get("x-ratelimit-remaining", 0))
326-
if remaining == 0:
327-
reset_time = int(response.headers.get("x-ratelimit-reset", 0))
328-
message = f"Primary rate limit from Github exceeded! Please retry after: {reset_time})"
329-
raise GitHubRateLimitError(message)
330-
331-
if "retry-after" in response.headers:
332-
wait_seconds = int(response.headers["retry-after"])
333-
message = f"Rate limit from Github exceeded! Please wait {wait_seconds} seconds before retrying."
334-
rprint(f"[yellow]{message}[/yellow]")
335-
raise GitHubRateLimitError(message)
407+
handle_github_rate_limit(response)
336408

337409
response.raise_for_status()
338410
return response.json()
@@ -459,3 +531,103 @@ def get_latest_release(repo_owner: str, repo_name: str) -> Optional[GithubReleas
459531
except requests.RequestException as e:
460532
rprint(f"Error fetching latest release: {e}")
461533
return None
534+
535+
536+
def parse_pr_reference(pr_ref: str) -> tuple[str, str, Optional[int]]:
537+
"""
538+
support formats:
539+
- username:branch-name
540+
- #123
541+
- https://github.com/comfyanonymous/ComfyUI/pull/123
542+
543+
Returns:
544+
(repo_owner, repo_name, pr_number)
545+
"""
546+
pr_ref = pr_ref.strip()
547+
548+
if pr_ref.startswith("https://github.com/"):
549+
parsed = urlparse(pr_ref)
550+
if "/pull/" in parsed.path:
551+
path_parts = parsed.path.strip("/").split("/")
552+
if len(path_parts) >= 4:
553+
repo_owner = path_parts[0]
554+
repo_name = path_parts[1]
555+
pr_number = int(path_parts[3])
556+
return repo_owner, repo_name, pr_number
557+
558+
elif pr_ref.startswith("#"):
559+
pr_number = int(pr_ref[1:])
560+
return "comfyanonymous", "ComfyUI", pr_number
561+
562+
elif ":" in pr_ref:
563+
username, branch = pr_ref.split(":", 1)
564+
return username, "ComfyUI", None
565+
566+
else:
567+
raise ValueError(f"Invalid PR reference format: {pr_ref}")
568+
569+
570+
def fetch_pr_info(repo_owner: str, repo_name: str, pr_number: int) -> PRInfo:
571+
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/pulls/{pr_number}"
572+
573+
headers = {}
574+
if github_token := os.getenv("GITHUB_TOKEN"):
575+
headers["Authorization"] = f"Bearer {github_token}"
576+
577+
try:
578+
response = requests.get(url, headers=headers, timeout=10)
579+
580+
if response is None:
581+
raise Exception(f"Failed to fetch PR #{pr_number}: No response from GitHub API")
582+
583+
if response.status_code in (403, 429):
584+
handle_github_rate_limit(response)
585+
586+
response.raise_for_status()
587+
data = response.json()
588+
589+
return PRInfo(
590+
number=data["number"],
591+
head_repo_url=data["head"]["repo"]["clone_url"],
592+
head_branch=data["head"]["ref"],
593+
base_repo_url=data["base"]["repo"]["clone_url"],
594+
base_branch=data["base"]["ref"],
595+
title=data["title"],
596+
user=data["head"]["repo"]["owner"]["login"],
597+
mergeable=data.get("mergeable", True),
598+
)
599+
600+
except requests.RequestException as e:
601+
raise Exception(f"Failed to fetch PR #{pr_number}: {e}")
602+
603+
604+
def find_pr_by_branch(repo_owner: str, repo_name: str, username: str, branch: str) -> Optional[PRInfo]:
605+
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/pulls"
606+
params = {"head": f"{username}:{branch}", "state": "open"}
607+
608+
headers = {}
609+
if github_token := os.getenv("GITHUB_TOKEN"):
610+
headers["Authorization"] = f"Bearer {github_token}"
611+
612+
try:
613+
response = requests.get(url, headers=headers, params=params, timeout=10)
614+
response.raise_for_status()
615+
data = response.json()
616+
617+
if data:
618+
pr_data = data[0]
619+
return PRInfo(
620+
number=pr_data["number"],
621+
head_repo_url=pr_data["head"]["repo"]["clone_url"],
622+
head_branch=pr_data["head"]["ref"],
623+
base_repo_url=pr_data["base"]["repo"]["clone_url"],
624+
base_branch=pr_data["base"]["ref"],
625+
title=pr_data["title"],
626+
user=pr_data["head"]["repo"]["owner"]["login"],
627+
mergeable=pr_data.get("mergeable", True),
628+
)
629+
630+
return None
631+
632+
except requests.RequestException:
633+
return None

0 commit comments

Comments
 (0)