33import subprocess
44import sys
55from typing import Dict , List , Optional , TypedDict
6+ from urllib .parse import urlparse
67
78import requests
89import semver
1314
1415from comfy_cli import constants , ui , utils
1516from comfy_cli .command .custom_nodes .command import update_node_id_cache
17+ from comfy_cli .command .github .pr_info import PRInfo
1618from comfy_cli .constants import GPU_OPTION
17- from comfy_cli .git_utils import git_checkout_tag
19+ from comfy_cli .git_utils import checkout_pr , git_checkout_tag
1820from comfy_cli .uv import DependencyCompiler
1921from comfy_cli .workspace_manager import WorkspaceManager , check_comfy_repo
2022
@@ -175,9 +177,15 @@ def execute(
175177 skip_torch_or_directml : bool = False ,
176178 skip_requirement : bool = False ,
177179 fast_deps : bool = False ,
180+ pr : Optional [str ] = None ,
178181 * args ,
179182 ** kwargs ,
180183):
184+ # Install ComfyUI from a given PR reference.
185+ if pr :
186+ url = handle_pr_checkout (pr , comfy_path )
187+ version = "nightly"
188+
181189 """
182190 Install ComfyUI from a given URL.
183191 """
@@ -272,6 +280,66 @@ def execute(
272280 rprint ("" )
273281
274282
283+ def handle_pr_checkout (pr_ref : str , comfy_path : str ) -> str :
284+ try :
285+ repo_owner , repo_name , pr_number = parse_pr_reference (pr_ref )
286+ except ValueError as e :
287+ rprint (f"[bold red]Error parsing PR reference: { e } [/bold red]" )
288+ raise typer .Exit (code = 1 )
289+
290+ try :
291+ if pr_number :
292+ pr_info = fetch_pr_info (repo_owner , repo_name , pr_number )
293+ else :
294+ username , branch = pr_ref .split (":" , 1 )
295+ pr_info = find_pr_by_branch ("comfyanonymous" , "ComfyUI" , username , branch )
296+
297+ if not pr_info :
298+ rprint (f"[bold red]PR not found: { pr_ref } [/bold red]" )
299+ raise typer .Exit (code = 1 )
300+
301+ except Exception as e :
302+ rprint (f"[bold red]Error fetching PR information: { e } [/bold red]" )
303+ raise typer .Exit (code = 1 )
304+
305+ console .print (
306+ Panel (
307+ f"[bold]PR #{ pr_info .number } [/bold]: { pr_info .title } \n "
308+ f"[yellow]Author[/yellow]: { pr_info .user } \n "
309+ f"[yellow]Branch[/yellow]: { pr_info .head_branch } \n "
310+ f"[yellow]Source[/yellow]: { pr_info .head_repo_url } \n "
311+ f"[yellow]Mergeable[/yellow]: { '✓' if pr_info .mergeable else '✗' } " ,
312+ title = "[bold blue]Pull Request Information[/bold blue]" ,
313+ border_style = "blue" ,
314+ )
315+ )
316+
317+ if not workspace_manager .skip_prompting :
318+ if not ui .prompt_confirm_action (f"Install ComfyUI from PR #{ pr_info .number } ?" , True ):
319+ rprint ("Aborting..." )
320+ raise typer .Exit (code = 1 )
321+
322+ parent_path = os .path .abspath (os .path .join (comfy_path , ".." ))
323+
324+ if not os .path .exists (parent_path ):
325+ os .makedirs (parent_path , exist_ok = True )
326+
327+ if not os .path .exists (comfy_path ):
328+ rprint (f"Cloning base repository to { comfy_path } ..." )
329+ clone_comfyui (url = pr_info .base_repo_url , repo_dir = comfy_path )
330+
331+ rprint (f"Checking out PR #{ pr_info .number } : { pr_info .title } " )
332+ success = checkout_pr (comfy_path , pr_info )
333+ if not success :
334+ rprint ("[bold red]Failed to checkout PR[/bold red]" )
335+ raise typer .Exit (code = 1 )
336+
337+ rprint (f"[bold green]✓ Successfully checked out PR #{ pr_info .number } [/bold green]" )
338+ rprint (f"[bold yellow]Note:[/bold yellow] You are now on branch pr-{ pr_info .number } " )
339+
340+ return pr_info .base_repo_url
341+
342+
275343def validate_version (version : str ) -> Optional [str ]:
276344 """
277345 Validates the version string as 'latest', 'nightly', or a semantically version number.
@@ -306,6 +374,21 @@ class GitHubRateLimitError(Exception):
306374 """Raised when GitHub API rate limit is exceeded"""
307375
308376
377+ def handle_github_rate_limit (response ):
378+ # Check rate limit headers
379+ remaining = int (response .headers .get ("x-ratelimit-remaining" , 0 ))
380+ if remaining == 0 :
381+ reset_time = int (response .headers .get ("x-ratelimit-reset" , 0 ))
382+ message = f"Primary rate limit from Github exceeded! Please retry after: { reset_time } )"
383+ raise GitHubRateLimitError (message )
384+
385+ if "retry-after" in response .headers :
386+ wait_seconds = int (response .headers ["retry-after" ])
387+ message = f"Rate limit from Github exceeded! Please wait { wait_seconds } seconds before retrying."
388+ rprint (f"[yellow]{ message } [/yellow]" )
389+ raise GitHubRateLimitError (message )
390+
391+
309392def fetch_github_releases (repo_owner : str , repo_name : str ) -> List [Dict [str , str ]]:
310393 """
311394 Fetch the list of releases from the GitHub API.
@@ -321,18 +404,7 @@ def fetch_github_releases(repo_owner: str, repo_name: str) -> List[Dict[str, str
321404
322405 # Handle rate limiting
323406 if response .status_code in (403 , 429 ):
324- # Check rate limit headers
325- remaining = int (response .headers .get ("x-ratelimit-remaining" , 0 ))
326- if remaining == 0 :
327- reset_time = int (response .headers .get ("x-ratelimit-reset" , 0 ))
328- message = f"Primary rate limit from Github exceeded! Please retry after: { reset_time } )"
329- raise GitHubRateLimitError (message )
330-
331- if "retry-after" in response .headers :
332- wait_seconds = int (response .headers ["retry-after" ])
333- message = f"Rate limit from Github exceeded! Please wait { wait_seconds } seconds before retrying."
334- rprint (f"[yellow]{ message } [/yellow]" )
335- raise GitHubRateLimitError (message )
407+ handle_github_rate_limit (response )
336408
337409 response .raise_for_status ()
338410 return response .json ()
@@ -459,3 +531,103 @@ def get_latest_release(repo_owner: str, repo_name: str) -> Optional[GithubReleas
459531 except requests .RequestException as e :
460532 rprint (f"Error fetching latest release: { e } " )
461533 return None
534+
535+
536+ def parse_pr_reference (pr_ref : str ) -> tuple [str , str , Optional [int ]]:
537+ """
538+ support formats:
539+ - username:branch-name
540+ - #123
541+ - https://github.com/comfyanonymous/ComfyUI/pull/123
542+
543+ Returns:
544+ (repo_owner, repo_name, pr_number)
545+ """
546+ pr_ref = pr_ref .strip ()
547+
548+ if pr_ref .startswith ("https://github.com/" ):
549+ parsed = urlparse (pr_ref )
550+ if "/pull/" in parsed .path :
551+ path_parts = parsed .path .strip ("/" ).split ("/" )
552+ if len (path_parts ) >= 4 :
553+ repo_owner = path_parts [0 ]
554+ repo_name = path_parts [1 ]
555+ pr_number = int (path_parts [3 ])
556+ return repo_owner , repo_name , pr_number
557+
558+ elif pr_ref .startswith ("#" ):
559+ pr_number = int (pr_ref [1 :])
560+ return "comfyanonymous" , "ComfyUI" , pr_number
561+
562+ elif ":" in pr_ref :
563+ username , branch = pr_ref .split (":" , 1 )
564+ return username , "ComfyUI" , None
565+
566+ else :
567+ raise ValueError (f"Invalid PR reference format: { pr_ref } " )
568+
569+
570+ def fetch_pr_info (repo_owner : str , repo_name : str , pr_number : int ) -> PRInfo :
571+ url = f"https://api.github.com/repos/{ repo_owner } /{ repo_name } /pulls/{ pr_number } "
572+
573+ headers = {}
574+ if github_token := os .getenv ("GITHUB_TOKEN" ):
575+ headers ["Authorization" ] = f"Bearer { github_token } "
576+
577+ try :
578+ response = requests .get (url , headers = headers , timeout = 10 )
579+
580+ if response is None :
581+ raise Exception (f"Failed to fetch PR #{ pr_number } : No response from GitHub API" )
582+
583+ if response .status_code in (403 , 429 ):
584+ handle_github_rate_limit (response )
585+
586+ response .raise_for_status ()
587+ data = response .json ()
588+
589+ return PRInfo (
590+ number = data ["number" ],
591+ head_repo_url = data ["head" ]["repo" ]["clone_url" ],
592+ head_branch = data ["head" ]["ref" ],
593+ base_repo_url = data ["base" ]["repo" ]["clone_url" ],
594+ base_branch = data ["base" ]["ref" ],
595+ title = data ["title" ],
596+ user = data ["head" ]["repo" ]["owner" ]["login" ],
597+ mergeable = data .get ("mergeable" , True ),
598+ )
599+
600+ except requests .RequestException as e :
601+ raise Exception (f"Failed to fetch PR #{ pr_number } : { e } " )
602+
603+
604+ def find_pr_by_branch (repo_owner : str , repo_name : str , username : str , branch : str ) -> Optional [PRInfo ]:
605+ url = f"https://api.github.com/repos/{ repo_owner } /{ repo_name } /pulls"
606+ params = {"head" : f"{ username } :{ branch } " , "state" : "open" }
607+
608+ headers = {}
609+ if github_token := os .getenv ("GITHUB_TOKEN" ):
610+ headers ["Authorization" ] = f"Bearer { github_token } "
611+
612+ try :
613+ response = requests .get (url , headers = headers , params = params , timeout = 10 )
614+ response .raise_for_status ()
615+ data = response .json ()
616+
617+ if data :
618+ pr_data = data [0 ]
619+ return PRInfo (
620+ number = pr_data ["number" ],
621+ head_repo_url = pr_data ["head" ]["repo" ]["clone_url" ],
622+ head_branch = pr_data ["head" ]["ref" ],
623+ base_repo_url = pr_data ["base" ]["repo" ]["clone_url" ],
624+ base_branch = pr_data ["base" ]["ref" ],
625+ title = pr_data ["title" ],
626+ user = pr_data ["head" ]["repo" ]["owner" ]["login" ],
627+ mergeable = pr_data .get ("mergeable" , True ),
628+ )
629+
630+ return None
631+
632+ except requests .RequestException :
633+ return None
0 commit comments