|
1 | | -#This is an example that uses the websockets api to know when a prompt execution is done |
2 | | -#Once the prompt execution is done it downloads the images using the /history endpoint |
3 | | - |
4 | | -import uuid |
5 | 1 | import json |
6 | 2 | import urllib.request |
7 | 3 | import urllib.parse |
8 | | -import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) |
9 | | -from PIL import Image |
10 | | -import io |
11 | | -from utils.helpers.find_node import find_node |
12 | 4 | from requests_toolbelt import MultipartEncoder |
13 | 5 |
|
14 | | -server_address='127.0.0.1:8188' |
15 | | -client_id=str(uuid.uuid4()) |
16 | | - |
17 | | -ws = websocket.WebSocket() |
18 | | -ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) |
19 | | - |
20 | | -def setup_image(input_path, name, type="input", overwrite=False): |
21 | | - with open(input_path, 'rb') as file: |
22 | | - multipart_data = MultipartEncoder( |
23 | | - fields={ |
24 | | - 'image': (name, file, 'image/jpeg'), # Adjust the content-type accordingly |
25 | | - 'type': type, |
26 | | - 'overwrite': str(overwrite).lower() |
27 | | - } |
28 | | - ) |
29 | | - |
30 | | - data = multipart_data |
31 | | - headers = {'Content-Type': multipart_data.content_type} |
32 | | - request = urllib.request.Request("http://{}/upload/image".format(server_address), data=data, headers=headers) |
33 | | - with urllib.request.urlopen(request) as response: |
34 | | - return response.read() |
35 | | - |
36 | | -def generate_image_by_prompt(prompt, output_path, save_previews=False): |
37 | | - prompt_id = queue_prompt(prompt)['prompt_id'] |
38 | | - track_progress(prompt, ws, prompt_id) |
39 | | - images = get_images(ws, prompt_id, save_previews) |
40 | | - save_image(images, output_path, save_previews) |
41 | | - |
42 | | -def save_image(images, output_path, save_previews): |
43 | | - for itm in images: |
44 | | - if itm['type'] == 'temp' and save_previews: |
45 | | - image = Image.open(io.BytesIO(itm['image_data'])) |
46 | | - image.save(output_path + 'temp/' + itm['file_name']) |
47 | | - else: |
48 | | - image = Image.open(io.BytesIO(itm['image_data'])) |
49 | | - image.save(output_path + itm['file_name']) |
50 | | - |
51 | | - |
52 | | -def track_progress(prompt, ws, prompt_id): |
53 | | - node_ids = list(prompt.keys()) |
54 | | - finished_nodes = [] |
55 | | - |
56 | | - while True: |
57 | | - out = ws.recv() |
58 | | - if isinstance(out, str): |
59 | | - message = json.loads(out) |
60 | | - if message['type'] == 'progress': |
61 | | - data = message['data'] |
62 | | - current_step = data['value'] |
63 | | - print('In K-Sampler -> Step: ', current_step, ' of: ', data['max']) |
64 | | - if message['type'] == 'execution_cached': |
65 | | - data = message['data'] |
66 | | - for itm in data['nodes']: |
67 | | - if itm not in finished_nodes: |
68 | | - finished_nodes.append(itm) |
69 | | - print('Progess: ', len(finished_nodes), '/', len(node_ids), ' Tasks done') |
70 | | - if message['type'] == 'executing': |
71 | | - data = message['data'] |
72 | | - if data['node'] not in finished_nodes: |
73 | | - finished_nodes.append(data['node']) |
74 | | - print('Progess: ', len(finished_nodes), '/', len(node_ids), ' Tasks done') |
75 | | - |
76 | | - |
77 | | - if data['node'] is None and data['prompt_id'] == prompt_id: |
78 | | - break #Execution is done |
79 | | - else: |
80 | | - continue #previews are binary data |
81 | | - return |
82 | | - |
83 | | -def queue_prompt(prompt): |
84 | | - p = {"prompt": prompt, "client_id": client_id} |
85 | | - data = json.dumps(p).encode('utf-8') |
86 | | - req = urllib.request.Request("http://{}/prompt".format(server_address), data=data) |
87 | | - return json.loads(urllib.request.urlopen(req).read()) |
88 | | - |
89 | | -def interupt_prompt(): |
90 | | - req = urllib.request.Request("http://{}/interrupt".format(server_address), data={}) |
91 | | - return json.loads(urllib.request.urlopen(req).read()) |
92 | | - |
93 | | -def get_image(filename, subfolder, folder_type): |
94 | | - data = {"filename": filename, "subfolder": subfolder, "type": folder_type} |
95 | | - url_values = urllib.parse.urlencode(data) |
96 | | - with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response: |
97 | | - return response.read() |
98 | | - |
99 | | -def get_history(prompt_id): |
100 | | - with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response: |
101 | | - return json.loads(response.read()) |
102 | | - |
103 | | -def get_images(ws, prompt_id, allow_preview = False): |
104 | | - output_images = [] |
105 | | - |
106 | | - history = get_history(prompt_id)[prompt_id] |
107 | | - for node_id in history['outputs']: |
108 | | - node_output = history['outputs'][node_id] |
109 | | - output_data = {} |
110 | | - if 'images' in node_output: |
111 | | - for image in node_output['images']: |
112 | | - if allow_preview and image['type'] == 'temp': |
113 | | - preview_data = get_image(image['filename'], image['subfolder'], image['type']) |
114 | | - output_data['image_data'] = preview_data |
115 | | - if image['type'] == 'output': |
116 | | - image_data = get_image(image['filename'], image['subfolder'], image['type']) |
117 | | - output_data['image_data'] = image_data |
118 | | - output_data['file_name'] = image['filename'] |
119 | | - output_data['type'] = image['type'] |
120 | | - output_images.append(output_data) |
121 | | - |
122 | | - return output_images |
123 | | - |
124 | | -def get_node_info_by_class(node_class): |
125 | | - with urllib.request.urlopen("http://{}/object_info/{}".format(server_address, node_class)) as response: |
126 | | - return json.loads(response.read()) |
127 | | - |
128 | | -def clear_comfy_cache(unload_models=False, free_memory=False): |
129 | | - clear_data = { |
130 | | - "unload_models": unload_models, |
131 | | - "free_memory": free_memory |
132 | | - } |
133 | | - data = json.dumps(clear_data).encode('utf-8') |
134 | | - |
135 | | - with urllib.request.urlopen("http://{}/free".format(server_address), data=data) as response: |
136 | | - return response.read() |
| 6 | +def upload_image(input_path, name, server_address, image_type="input", overwrite=False): |
| 7 | + with open(input_path, 'rb') as file: |
| 8 | + multipart_data = MultipartEncoder( |
| 9 | + fields= { |
| 10 | + 'image': (name, file, 'image/png'), |
| 11 | + 'type': image_type, |
| 12 | + 'overwrite': str(overwrite).lower() |
| 13 | + } |
| 14 | + ) |
| 15 | + |
| 16 | + data = multipart_data |
| 17 | + headers = { 'Content-Type': multipart_data.content_type } |
| 18 | + request = urllib.request.Request("http://{}/upload/image".format(server_address), data=data, headers=headers) |
| 19 | + with urllib.request.urlopen(request) as response: |
| 20 | + return response.read() |
| 21 | + |
| 22 | +def queue_prompt(prompt, client_id, server_address): |
| 23 | + p = {"prompt": prompt, "client_id": client_id} |
| 24 | + headers = {'Content-Type': 'application/json'} |
| 25 | + data = json.dumps(p).encode('utf-8') |
| 26 | + req = urllib.request.Request("http://{}/prompt".format(server_address), data=data, headers=headers) |
| 27 | + return json.loads(urllib.request.urlopen(req).read()) |
| 28 | + |
| 29 | +def interupt_prompt(server_address): |
| 30 | + req = urllib.request.Request("http://{}/interrupt".format(server_address), data={}) |
| 31 | + return json.loads(urllib.request.urlopen(req).read()) |
| 32 | + |
| 33 | +def get_image(filename, subfolder, folder_type, server_address): |
| 34 | + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} |
| 35 | + url_values = urllib.parse.urlencode(data) |
| 36 | + with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response: |
| 37 | + return response.read() |
| 38 | + |
| 39 | +def get_history(prompt_id, server_address): |
| 40 | + with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response: |
| 41 | + return json.loads(response.read()) |
| 42 | + |
| 43 | +def get_node_info_by_class(node_class, server_address): |
| 44 | + with urllib.request.urlopen("http://{}/object_info/{}".format(server_address, node_class)) as response: |
| 45 | + return json.loads(response.read()) |
| 46 | + |
| 47 | +def clear_comfy_cache(server_address, unload_models=False, free_memory=False): |
| 48 | + clear_data = { |
| 49 | + "unload_models": unload_models, |
| 50 | + "free_memory": free_memory |
| 51 | + } |
| 52 | + data = json.dumps(clear_data).encode('utf-8') |
| 53 | + |
| 54 | + with urllib.request.urlopen("http://{}/free".format(server_address), data=data) as response: |
| 55 | + return response.read() |
137 | 56 |
|
0 commit comments