Skip to content

Commit f936dcc

Browse files
committed
setup api with multiple classes
1 parent e7ed992 commit f936dcc

File tree

11 files changed

+187
-212
lines changed

11 files changed

+187
-212
lines changed

api/api_helpers.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import json
2+
from PIL import Image
3+
import io
4+
import os
5+
6+
# Assuming the import paths are correct and the methods are defined elsewhere:
7+
from api.websocket_api import queue_prompt, get_history, get_image, upload_image, clear_comfy_cache
8+
from api.open_websocket import open_websocket_connection
9+
10+
def generate_image_by_prompt(prompt, output_path, save_previews=False):
11+
try:
12+
ws, server_address, client_id = open_websocket_connection()
13+
prompt_id = queue_prompt(prompt, client_id, server_address)['prompt_id']
14+
track_progress(prompt, ws, prompt_id)
15+
images = get_images(prompt_id, server_address, save_previews)
16+
save_image(images, output_path, save_previews)
17+
finally:
18+
ws.close()
19+
20+
def generate_image_by_prompt_and_image(prompt, output_path, input_path, filename, save_previews=False):
21+
try:
22+
ws, server_address, client_id = open_websocket_connection()
23+
upload_image(input_path, filename, server_address, client_id)
24+
prompt_id = queue_prompt(prompt, client_id, server_address)['prompt_id']
25+
track_progress(prompt, ws, prompt_id)
26+
images = get_images(prompt_id, server_address, save_previews)
27+
save_image(images, output_path, save_previews)
28+
finally:
29+
ws.close()
30+
31+
def save_image(images, output_path, save_previews):
32+
for itm in images:
33+
directory = os.path.join(output_path, 'temp/') if itm['type'] == 'temp' and save_previews else output_path
34+
os.makedirs(directory, exist_ok=True)
35+
try:
36+
image = Image.open(io.BytesIO(itm['image_data']))
37+
image.save(os.path.join(directory, itm['file_name']))
38+
except Exception as e:
39+
print(f"Failed to save image {itm['file_name']}: {e}")
40+
41+
def track_progress(prompt, ws, prompt_id):
42+
node_ids = list(prompt.keys())
43+
finished_nodes = []
44+
45+
while True:
46+
out = ws.recv()
47+
if isinstance(out, str):
48+
message = json.loads(out)
49+
if message['type'] == 'progress':
50+
data = message['data']
51+
current_step = data['value']
52+
print('In K-Sampler -> Step: ', current_step, ' of: ', data['max'])
53+
if message['type'] == 'execution_cached':
54+
data = message['data']
55+
for itm in data['nodes']:
56+
if itm not in finished_nodes:
57+
finished_nodes.append(itm)
58+
print('Progess: ', len(finished_nodes), '/', len(node_ids), ' Tasks done')
59+
if message['type'] == 'executing':
60+
data = message['data']
61+
if data['node'] not in finished_nodes:
62+
finished_nodes.append(data['node'])
63+
print('Progess: ', len(finished_nodes), '/', len(node_ids), ' Tasks done')
64+
65+
66+
if data['node'] is None and data['prompt_id'] == prompt_id:
67+
break #Execution is done
68+
else:
69+
continue #previews are binary data
70+
return
71+
72+
def get_images(prompt_id, server_address, allow_preview = False):
73+
output_images = []
74+
75+
history = get_history(prompt_id, server_address)[prompt_id]
76+
for node_id in history['outputs']:
77+
node_output = history['outputs'][node_id]
78+
output_data = {}
79+
if 'images' in node_output:
80+
for image in node_output['images']:
81+
if allow_preview and image['type'] == 'temp':
82+
preview_data = get_image(image['filename'], image['subfolder'], image['type'], server_address)
83+
output_data['image_data'] = preview_data
84+
if image['type'] == 'output':
85+
image_data = get_image(image['filename'], image['subfolder'], image['type'], server_address)
86+
output_data['image_data'] = image_data
87+
output_data['file_name'] = image['filename']
88+
output_data['type'] = image['type']
89+
output_images.append(output_data)
90+
91+
return output_images
92+
93+
94+
def clear():
95+
clear_comfy_cache()

api/open_websocket.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
2+
import uuid
3+
4+
def open_websocket_connection():
5+
server_address='127.0.0.1:8188'
6+
client_id=str(uuid.uuid4())
7+
8+
ws = websocket.WebSocket()
9+
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
10+
return ws, server_address, client_id

api/websocket_api.py

Lines changed: 50 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1,137 +1,56 @@
1-
#This is an example that uses the websockets api to know when a prompt execution is done
2-
#Once the prompt execution is done it downloads the images using the /history endpoint
3-
4-
import uuid
51
import json
62
import urllib.request
73
import urllib.parse
8-
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
9-
from PIL import Image
10-
import io
11-
from utils.helpers.find_node import find_node
124
from requests_toolbelt import MultipartEncoder
135

14-
server_address='127.0.0.1:8188'
15-
client_id=str(uuid.uuid4())
16-
17-
ws = websocket.WebSocket()
18-
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
19-
20-
def setup_image(input_path, name, type="input", overwrite=False):
21-
with open(input_path, 'rb') as file:
22-
multipart_data = MultipartEncoder(
23-
fields={
24-
'image': (name, file, 'image/jpeg'), # Adjust the content-type accordingly
25-
'type': type,
26-
'overwrite': str(overwrite).lower()
27-
}
28-
)
29-
30-
data = multipart_data
31-
headers = {'Content-Type': multipart_data.content_type}
32-
request = urllib.request.Request("http://{}/upload/image".format(server_address), data=data, headers=headers)
33-
with urllib.request.urlopen(request) as response:
34-
return response.read()
35-
36-
def generate_image_by_prompt(prompt, output_path, save_previews=False):
37-
prompt_id = queue_prompt(prompt)['prompt_id']
38-
track_progress(prompt, ws, prompt_id)
39-
images = get_images(ws, prompt_id, save_previews)
40-
save_image(images, output_path, save_previews)
41-
42-
def save_image(images, output_path, save_previews):
43-
for itm in images:
44-
if itm['type'] == 'temp' and save_previews:
45-
image = Image.open(io.BytesIO(itm['image_data']))
46-
image.save(output_path + 'temp/' + itm['file_name'])
47-
else:
48-
image = Image.open(io.BytesIO(itm['image_data']))
49-
image.save(output_path + itm['file_name'])
50-
51-
52-
def track_progress(prompt, ws, prompt_id):
53-
node_ids = list(prompt.keys())
54-
finished_nodes = []
55-
56-
while True:
57-
out = ws.recv()
58-
if isinstance(out, str):
59-
message = json.loads(out)
60-
if message['type'] == 'progress':
61-
data = message['data']
62-
current_step = data['value']
63-
print('In K-Sampler -> Step: ', current_step, ' of: ', data['max'])
64-
if message['type'] == 'execution_cached':
65-
data = message['data']
66-
for itm in data['nodes']:
67-
if itm not in finished_nodes:
68-
finished_nodes.append(itm)
69-
print('Progess: ', len(finished_nodes), '/', len(node_ids), ' Tasks done')
70-
if message['type'] == 'executing':
71-
data = message['data']
72-
if data['node'] not in finished_nodes:
73-
finished_nodes.append(data['node'])
74-
print('Progess: ', len(finished_nodes), '/', len(node_ids), ' Tasks done')
75-
76-
77-
if data['node'] is None and data['prompt_id'] == prompt_id:
78-
break #Execution is done
79-
else:
80-
continue #previews are binary data
81-
return
82-
83-
def queue_prompt(prompt):
84-
p = {"prompt": prompt, "client_id": client_id}
85-
data = json.dumps(p).encode('utf-8')
86-
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
87-
return json.loads(urllib.request.urlopen(req).read())
88-
89-
def interupt_prompt():
90-
req = urllib.request.Request("http://{}/interrupt".format(server_address), data={})
91-
return json.loads(urllib.request.urlopen(req).read())
92-
93-
def get_image(filename, subfolder, folder_type):
94-
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
95-
url_values = urllib.parse.urlencode(data)
96-
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
97-
return response.read()
98-
99-
def get_history(prompt_id):
100-
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
101-
return json.loads(response.read())
102-
103-
def get_images(ws, prompt_id, allow_preview = False):
104-
output_images = []
105-
106-
history = get_history(prompt_id)[prompt_id]
107-
for node_id in history['outputs']:
108-
node_output = history['outputs'][node_id]
109-
output_data = {}
110-
if 'images' in node_output:
111-
for image in node_output['images']:
112-
if allow_preview and image['type'] == 'temp':
113-
preview_data = get_image(image['filename'], image['subfolder'], image['type'])
114-
output_data['image_data'] = preview_data
115-
if image['type'] == 'output':
116-
image_data = get_image(image['filename'], image['subfolder'], image['type'])
117-
output_data['image_data'] = image_data
118-
output_data['file_name'] = image['filename']
119-
output_data['type'] = image['type']
120-
output_images.append(output_data)
121-
122-
return output_images
123-
124-
def get_node_info_by_class(node_class):
125-
with urllib.request.urlopen("http://{}/object_info/{}".format(server_address, node_class)) as response:
126-
return json.loads(response.read())
127-
128-
def clear_comfy_cache(unload_models=False, free_memory=False):
129-
clear_data = {
130-
"unload_models": unload_models,
131-
"free_memory": free_memory
132-
}
133-
data = json.dumps(clear_data).encode('utf-8')
134-
135-
with urllib.request.urlopen("http://{}/free".format(server_address), data=data) as response:
136-
return response.read()
6+
def upload_image(input_path, name, server_address, image_type="input", overwrite=False):
7+
with open(input_path, 'rb') as file:
8+
multipart_data = MultipartEncoder(
9+
fields= {
10+
'image': (name, file, 'image/png'),
11+
'type': image_type,
12+
'overwrite': str(overwrite).lower()
13+
}
14+
)
15+
16+
data = multipart_data
17+
headers = { 'Content-Type': multipart_data.content_type }
18+
request = urllib.request.Request("http://{}/upload/image".format(server_address), data=data, headers=headers)
19+
with urllib.request.urlopen(request) as response:
20+
return response.read()
21+
22+
def queue_prompt(prompt, client_id, server_address):
23+
p = {"prompt": prompt, "client_id": client_id}
24+
headers = {'Content-Type': 'application/json'}
25+
data = json.dumps(p).encode('utf-8')
26+
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data, headers=headers)
27+
return json.loads(urllib.request.urlopen(req).read())
28+
29+
def interupt_prompt(server_address):
30+
req = urllib.request.Request("http://{}/interrupt".format(server_address), data={})
31+
return json.loads(urllib.request.urlopen(req).read())
32+
33+
def get_image(filename, subfolder, folder_type, server_address):
34+
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
35+
url_values = urllib.parse.urlencode(data)
36+
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
37+
return response.read()
38+
39+
def get_history(prompt_id, server_address):
40+
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
41+
return json.loads(response.read())
42+
43+
def get_node_info_by_class(node_class, server_address):
44+
with urllib.request.urlopen("http://{}/object_info/{}".format(server_address, node_class)) as response:
45+
return json.loads(response.read())
46+
47+
def clear_comfy_cache(server_address, unload_models=False, free_memory=False):
48+
clear_data = {
49+
"unload_models": unload_models,
50+
"free_memory": free_memory
51+
}
52+
data = json.dumps(clear_data).encode('utf-8')
53+
54+
with urllib.request.urlopen("http://{}/free".format(server_address), data=data) as response:
55+
return response.read()
13756

main.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
from utils.actions.prompt_to_image import prompt_to_image
22
from utils.actions.prompt_image_to_image import prompt_image_to_image
33
from utils.actions.load_workflow import load_workflow
4-
from api.websocket_api import clear_comfy_cache
5-
from api.websocket_api import get_image
6-
import time
4+
from api.api_helpers import clear
75
import sys
86

97
def main():
108
try:
119
print("Welcome to the program!")
12-
workflow = load_workflow('./workflows/image_to_image.json')
13-
14-
# prompt_to_image(workflow, 'beautiful woman sitting on a desk in a nice restaurant, candlelight dinner atmosphere, wearing a red dress', save_previews=True)
15-
input_path = './input/ComfyUI_00103_.png'
16-
prompt_image_to_image(workflow, input_path, 'beautiful [white woman], (dark lighting), curly blond hair', save_previews=True)
10+
workflow = load_workflow('./workflows/base_workflow.json')
11+
for iter in range(1, 11):
12+
prompt_to_image(workflow, '(realistic:1.25), beautiful:1.1) mountain landscape with a deep blue lake, photolike, high detail, monoton colors', 'lowres, text, branding, watermark, humans, frames, painting', save_previews=True)
13+
# prompt_to_image(workflow, '(beautiful woman:1.3) sitting on a desk in a nice restaurant with a (glass of wine and plate with salat:0.9), (candlelight dinner atmosphere:1.1), (wearing a red evening dress:1.2), dimmed lighting, cinema, high detail', save_previews=True)
14+
# input_path = './input/ComfyUI_00241_.png'
15+
# prompt_image_to_image(workflow, input_path, '(white woman wearing a black evening dress:1.5), dimmed lighting, cinema, high detail', save_previews=True)
1716
except Exception as e:
1817
print(f"An error occurred: {e}")
1918
exit_program()
@@ -22,13 +21,7 @@ def exit_program():
2221
print("Exiting the program...")
2322
sys.exit(0)
2423

25-
def clear():
26-
clear_comfy_cache(True, True)
27-
28-
def image_by_file():
29-
print(get_image('ComfyUI_00042_.png', '', 'output'))
24+
def clear_comfy():
25+
clear(True, True)
3026

3127
main()
32-
# clear()
33-
34-
# image_by_file()

utils/actions/load_workflow.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import json
22

33
def load_workflow(workflow_path):
4-
file = open(workflow_path)
5-
workflow = json.load(file)
6-
workflow = json.dumps(workflow)
7-
8-
return workflow
4+
try:
5+
with open(workflow_path, 'r') as file:
6+
workflow = json.load(file)
7+
return json.dumps(workflow)
8+
except FileNotFoundError:
9+
print(f"The file {workflow_path} was not found.")
10+
return None
11+
except json.JSONDecodeError:
12+
print(f"The file {workflow_path} contains invalid JSON.")
13+
return None
Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
1-
from api.websocket_api import generate_image_by_prompt
2-
from api.websocket_api import setup_image
3-
from utils.helpers.find_node import find_node
1+
from api.api_helpers import generate_image_by_prompt_and_image
42
from utils.helpers.randomize_seed import generate_random_15_digit_number
5-
from utils.helpers.replace_key import replace_key
6-
from utils.helpers.find_parent import find_parent_of_key
73
import json
84
def prompt_image_to_image(workflow, input_path, positve_prompt, negative_prompt='', save_previews=False):
95
prompt = json.loads(workflow)
10-
replace_key(prompt, 'seed', generate_random_15_digit_number())
116
id_to_class_type = {id: details['class_type'] for id, details in prompt.items()}
127
k_sampler = [key for key, value in id_to_class_type.items() if value == 'KSampler'][0]
13-
8+
prompt.get(k_sampler)['inputs']['seed'] = generate_random_15_digit_number()
149
postive_input_id = prompt.get(k_sampler)['inputs']['positive'][0]
1510
prompt.get(postive_input_id)['inputs']['text_g'] = positve_prompt
1611
prompt.get(postive_input_id)['inputs']['text_l'] = positve_prompt
@@ -20,14 +15,9 @@ def prompt_image_to_image(workflow, input_path, positve_prompt, negative_prompt=
2015
id_to_class_type.get(negative_input_id)['inputs']['text_g'] = negative_prompt
2116
id_to_class_type.get(negative_input_id)['inputs']['text_l'] = negative_prompt
2217

23-
2418
image_loader = [key for key, value in id_to_class_type.items() if value == 'LoadImage'][0]
2519
filename = input_path.split('/')[-1]
2620
prompt.get(image_loader)['inputs']['image'] = filename
2721

28-
setup_image(input_path, filename)
29-
generate_image_by_prompt(prompt, './output/', save_previews)
30-
31-
32-
22+
generate_image_by_prompt_and_image(prompt, './output/', input_path, filename, save_previews)
3323

0 commit comments

Comments
 (0)