Skip to content

Commit c96aafb

Browse files
Add comfy run command (#80)
* Add `comfy run` command * Swap match for ifs * reformat * fix typo * improve: workflow path handling fix: typo * Support non background ComfyUI instance --------- Co-authored-by: Dr.Lt.Data <dr.lt.data@gmail.com>
1 parent e427b56 commit c96aafb

File tree

5 files changed

+350
-10
lines changed

5 files changed

+350
-10
lines changed

comfy_cli/cmdline.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from comfy_cli import constants, env_checker, logging, tracking, ui, utils
1919
from comfy_cli.command import custom_nodes
20+
from comfy_cli.command import run as run_inner
2021
from comfy_cli.command import install as install_inner
2122
from comfy_cli.command.models import models as models_command
2223
from comfy_cli.config_manager import ConfigManager
@@ -355,12 +356,55 @@ def update(
355356
custom_nodes.command.update_node_id_cache()
356357

357358

358-
# @app.command(help="Run workflow file")
359-
# @tracking.track_command()
360-
# def run(
361-
# workflow_file: Annotated[str, typer.Option(help="Path to the workflow file.")],
362-
# ):
363-
# run_inner.execute(workflow_file)
359+
@app.command(
360+
help="Run API workflow file using the ComfyUI launched by `comfy launch --background`"
361+
)
362+
@tracking.track_command()
363+
def run(
364+
workflow: Annotated[str, typer.Option(help="Path to the workflow API json file.")],
365+
wait: Annotated[
366+
Optional[bool],
367+
typer.Option(help="If the command should wait until execution completes."),
368+
] = True,
369+
verbose: Annotated[
370+
Optional[bool],
371+
typer.Option(help="Enables verbose output of the execution process."),
372+
] = False,
373+
host: Annotated[
374+
Optional[str],
375+
typer.Option(
376+
help="The IP/hostname where the ComfyUI instance is running, e.g. 127.0.0.1 or localhost."
377+
),
378+
] = None,
379+
port: Annotated[
380+
Optional[int],
381+
typer.Option(help="The port where the ComfyUI instance is running, e.g. 8188."),
382+
] = None,
383+
):
384+
config = ConfigManager()
385+
386+
if host:
387+
s = host.split(":")
388+
host = s[0]
389+
if not port and len(s) == 2:
390+
port = int(s[1])
391+
392+
local_paths = False
393+
if config.background:
394+
if not host:
395+
host = config.background[0]
396+
local_paths = True
397+
if port:
398+
local_paths = False
399+
else:
400+
port = config.background[1]
401+
402+
if not host:
403+
host = "127.0.0.1"
404+
if not port:
405+
port = 8188
406+
407+
run_inner.execute(workflow, host, port, wait, verbose, local_paths)
364408

365409

366410
def validate_comfyui(_env_checker):

comfy_cli/command/run.py

Lines changed: 296 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,296 @@
1-
def execute(workflow_name: str):
2-
print(f"Executing workflow: {workflow_name}")
1+
import json
2+
import os
3+
import sys
4+
import time
5+
import typer
6+
import uuid
7+
import urllib.error
8+
import urllib.parse
9+
from rich.progress import BarColumn, Progress, TimeElapsedColumn, Column, Table
10+
from urllib import request
11+
from websocket import WebSocket
12+
from rich import print as pprint
13+
from comfy_cli.env_checker import check_comfy_server_running
14+
from comfy_cli.workspace_manager import WorkspaceManager
15+
from datetime import timedelta
16+
17+
workspace_manager = WorkspaceManager()
18+
19+
20+
def load_api_workflow(file: str):
21+
with open(file, encoding="utf-8") as f:
22+
workflow = json.load(f)
23+
# Check for litegraph properties to ensure this isnt a UI workflow file
24+
if "nodes" in workflow and "links" in workflow:
25+
return None
26+
27+
# Try validating the first entry to ensure it has a node class property
28+
node_id = next(iter(workflow))
29+
node = workflow[node_id]
30+
if "class_type" not in node:
31+
return None
32+
33+
return workflow
34+
35+
36+
def execute(workflow: str, host, port, wait=True, verbose=False, local_paths=False):
37+
workflow_name = os.path.abspath(os.path.expanduser(workflow))
38+
if not os.path.isfile(workflow):
39+
pprint(
40+
f"[bold red]Specified workflow file not found: {workflow}[/bold red]",
41+
file=sys.stderr,
42+
)
43+
raise typer.Exit(code=1)
44+
45+
workflow = load_api_workflow(workflow)
46+
47+
if not workflow:
48+
pprint(
49+
"[bold red]Specified workflow does not appear to be an API workflow json file[/bold red]"
50+
)
51+
raise typer.Exit(code=1)
52+
53+
if not check_comfy_server_running(port, host):
54+
pprint(
55+
f"[bold red]ComfyUI not running on specified address ({host}:{port})[/bold red]"
56+
)
57+
raise typer.Exit(code=1)
58+
59+
progress = None
60+
start = time.time()
61+
if wait:
62+
pprint(f"Executing workflow: {workflow_name}")
63+
progress = ExecutionProgress()
64+
progress.start()
65+
else:
66+
print(f"Queuing workflow: {workflow_name}")
67+
68+
execution = WorkflowExecution(workflow, host, port, verbose, progress, local_paths)
69+
70+
try:
71+
if wait:
72+
execution.connect()
73+
execution.queue()
74+
if wait:
75+
execution.watch_execution()
76+
end = time.time()
77+
progress.stop()
78+
progress = None
79+
80+
if len(execution.outputs):
81+
pprint("[bold green]\nOutputs:[/bold green]")
82+
83+
for f in execution.outputs:
84+
pprint(f)
85+
86+
elapsed = timedelta(seconds=end - start)
87+
pprint(
88+
f"[bold green]\nWorkflow execution completed ({elapsed})[/bold green]"
89+
)
90+
else:
91+
pprint("[bold green]Workflow queued[/bold green]")
92+
finally:
93+
if progress:
94+
progress.stop()
95+
96+
97+
class ExecutionProgress(Progress):
98+
def get_renderables(self):
99+
table_columns = (
100+
(
101+
Column(no_wrap=True)
102+
if isinstance(_column, str)
103+
else _column.get_table_column().copy()
104+
)
105+
for _column in self.columns
106+
)
107+
108+
for task in self.tasks:
109+
percent = "[progress.percentage]{task.percentage:>3.0f}%".format(task=task)
110+
if task.fields.get("progress_type") == "overall":
111+
overall_table = Table.grid(
112+
*table_columns, padding=(0, 1), expand=self.expand
113+
)
114+
overall_table.add_row(
115+
BarColumn().render(task), percent, TimeElapsedColumn().render(task)
116+
)
117+
yield overall_table
118+
else:
119+
yield self.make_tasks_table([task])
120+
121+
122+
class WorkflowExecution:
123+
def __init__(self, workflow, host, port, verbose, progress, local_paths):
124+
self.workflow = workflow
125+
self.host = host
126+
self.port = port
127+
self.verbose = verbose
128+
self.local_paths = local_paths
129+
self.client_id = str(uuid.uuid4())
130+
self.outputs = []
131+
self.progress = progress
132+
self.remaining_nodes = set(self.workflow.keys())
133+
self.total_nodes = len(self.remaining_nodes)
134+
if progress:
135+
self.overall_task = self.progress.add_task(
136+
"", total=self.total_nodes, progress_type="overall"
137+
)
138+
self.current_node = None
139+
self.progress_task = None
140+
self.progress_node = None
141+
self.prompt_id = None
142+
143+
def connect(self):
144+
self.ws = WebSocket()
145+
self.ws.connect(f"ws://{self.host}:{self.port}/ws?clientId={self.client_id}")
146+
147+
def queue(self):
148+
data = {"prompt": self.workflow, "client_id": self.client_id}
149+
req = request.Request(
150+
f"http://{self.host}:{self.port}/prompt", json.dumps(data).encode("utf-8")
151+
)
152+
try:
153+
resp = request.urlopen(req)
154+
body = json.loads(resp.read())
155+
156+
self.prompt_id = body["prompt_id"]
157+
except urllib.error.HTTPError as e:
158+
message = "An unknown error occurred"
159+
if e.status == 500:
160+
# This is normally just the generic internal server error
161+
message = e.read().decode()
162+
elif e.status == 400:
163+
# Bad Request - workflow failed validation on the server
164+
body = json.loads(e.read())
165+
if body["node_errors"].keys():
166+
message = json.dumps(body["node_errors"], indent=2)
167+
168+
self.progress.stop()
169+
170+
pprint(f"[bold red]Error running workflow\n{message}[/bold red]")
171+
raise typer.Exit(code=1)
172+
173+
def watch_execution(self):
174+
self.ws.settimeout(30)
175+
while True:
176+
message = self.ws.recv()
177+
if isinstance(message, str):
178+
message = json.loads(message)
179+
if not self.on_message(message):
180+
break
181+
182+
def update_overall_progress(self):
183+
self.progress.update(
184+
self.overall_task, completed=self.total_nodes - len(self.remaining_nodes)
185+
)
186+
187+
def get_node_title(self, node_id):
188+
node = self.workflow[node_id]
189+
if "_meta" in node and "title" in node["_meta"]:
190+
return node["_meta"]["title"]
191+
return node["class_type"]
192+
193+
def log_node(self, type, node_id):
194+
if not self.verbose:
195+
return
196+
197+
node = self.workflow[node_id]
198+
class_type = node["class_type"]
199+
title = self.get_node_title(node_id)
200+
201+
if title != class_type:
202+
title += f"[bright_black] - {class_type}[/]"
203+
title += f"[bright_black] ({node_id})[/]"
204+
205+
pprint(f"{type} : {title}")
206+
207+
def format_image_path(self, img):
208+
filename = img["filename"]
209+
subfolder = img["subfolder"]
210+
output_type = img["type"] or "output"
211+
212+
if self.local_paths:
213+
if subfolder:
214+
filename = os.path.join(subfolder, filename)
215+
216+
filename = os.path.join(
217+
workspace_manager.get_workspace_path()[0], output_type, filename
218+
)
219+
return filename
220+
221+
query = urllib.parse.urlencode(img)
222+
return f"http://{self.host}:{self.port}/view?{query}"
223+
224+
def on_message(self, message):
225+
data = message["data"] if "data" in message else {}
226+
# Skip any messages that aren't about our prompt
227+
if "prompt_id" not in data or data["prompt_id"] != self.prompt_id:
228+
return True
229+
230+
if message["type"] == "executing":
231+
return self.on_executing(data)
232+
elif message["type"] == "execution_cached":
233+
self.on_cached(data)
234+
elif message["type"] == "progress":
235+
self.on_progress(data)
236+
elif message["type"] == "executed":
237+
self.on_executed(data)
238+
elif message["type"] == "execution_error":
239+
self.on_error(data)
240+
241+
return True
242+
243+
def on_executing(self, data):
244+
if self.progress_task:
245+
self.progress.remove_task(self.progress_task)
246+
self.progress_task = None
247+
248+
if data["node"] is None:
249+
return False
250+
else:
251+
if self.current_node:
252+
self.remaining_nodes.discard(self.current_node)
253+
self.update_overall_progress()
254+
self.current_node = data["node"]
255+
self.log_node("Executing", data["node"])
256+
return True
257+
258+
def on_cached(self, data):
259+
nodes = data["nodes"]
260+
for n in nodes:
261+
self.remaining_nodes.discard(n)
262+
self.log_node("Cached", n)
263+
self.update_overall_progress()
264+
265+
def on_progress(self, data):
266+
node = data["node"]
267+
if self.progress_node != node:
268+
self.progress_node = node
269+
if self.progress_task:
270+
self.progress.remove_task(self.progress_task)
271+
272+
self.progress_task = self.progress.add_task(
273+
self.get_node_title(node), total=data["max"], progress_type="node"
274+
)
275+
self.progress.update(self.progress_task, completed=data["value"])
276+
277+
def on_executed(self, data):
278+
self.remaining_nodes.discard(data["node"])
279+
self.update_overall_progress()
280+
281+
if "output" not in data:
282+
return
283+
284+
output = data["output"]
285+
286+
if "images" not in output:
287+
return
288+
289+
for img in output["images"]:
290+
self.outputs.append(self.format_image_path(img))
291+
292+
def on_error(self, data):
293+
pprint(
294+
f"[bold red]Error running workflow\n{json.dumps(data, indent=2)}[/bold red]"
295+
)
296+
raise typer.Exit(code=1)

comfy_cli/env_checker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ def format_python_version(version_info):
3636
)
3737

3838

39-
def check_comfy_server_running(port=8188):
39+
def check_comfy_server_running(port=8188, host="localhost"):
4040
"""
4141
Checks if the Comfy server is running by making a GET request to the /history endpoint.
4242
4343
Returns:
4444
bool: True if the Comfy server is running, False otherwise.
4545
"""
4646
try:
47-
response = requests.get(f"http://localhost:{port}/history")
47+
response = requests.get(f"http://{host}:{port}/history")
4848
return response.status_code == 200
4949
except requests.exceptions.RequestException:
5050
return False

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies = [
3333
"pathspec",
3434
"httpx",
3535
"packaging",
36+
"websocket-client"
3637
]
3738

3839
classifiers = [

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ pathspec
1111
httpx
1212
packaging
1313
charset-normalizer>=3.0.0
14+
websocket-client

0 commit comments

Comments
 (0)