Skip to content

Commit 95293ea

Browse files
committed
first steps ws api
1 parent 54f592b commit 95293ea

File tree

15 files changed

+291
-0
lines changed

15 files changed

+291
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
output*
2+
__pycache__/

api/__init__.py

Whitespace-only changes.

api/websocket_api.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#This is an example that uses the websockets api to know when a prompt execution is done
2+
#Once the prompt execution is done it downloads the images using the /history endpoint
3+
4+
import uuid
5+
import json
6+
import urllib.request
7+
import urllib.parse
8+
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
9+
from PIL import Image
10+
import io
11+
12+
13+
server_address='127.0.0.1:8188'
14+
client_id=str(uuid.uuid4())
15+
16+
ws = websocket.WebSocket()
17+
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
18+
19+
def generate_image_by_prompt(prompt, output_path, save_previews=False):
20+
prompt_id = queue_prompt(prompt)['prompt_id']
21+
images = get_images(ws, prompt_id, save_previews)
22+
save_image(images, output_path, save_previews)
23+
24+
def save_image(images, output_path, save_previews):
25+
for itm in images:
26+
if itm['type'] == 'temp' and save_previews:
27+
image = Image.open(io.BytesIO(itm['image_data']))
28+
image.save(output_path + 'temp/' + itm['file_name'])
29+
else:
30+
image = Image.open(io.BytesIO(itm['image_data']))
31+
image.save(output_path + itm['file_name'])
32+
33+
def queue_prompt(prompt):
34+
p = {"prompt": prompt, "client_id": client_id}
35+
data = json.dumps(p).encode('utf-8')
36+
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
37+
return json.loads(urllib.request.urlopen(req).read())
38+
39+
def interupt_prompt():
40+
req = urllib.request.Request("http://{}/interrupt".format(server_address), data={})
41+
return json.loads(urllib.request.urlopen(req).read())
42+
43+
def get_image(filename, subfolder, folder_type):
44+
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
45+
url_values = urllib.parse.urlencode(data)
46+
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
47+
return response.read()
48+
49+
def get_history(prompt_id):
50+
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
51+
return json.loads(response.read())
52+
53+
def get_images(ws, prompt_id, allow_preview = False):
54+
output_images = []
55+
while True:
56+
out = ws.recv()
57+
if isinstance(out, str):
58+
message = json.loads(out)
59+
if message['type'] == 'executing':
60+
data = message['data']
61+
if data['node'] is None and data['prompt_id'] == prompt_id:
62+
break #Execution is done
63+
else:
64+
continue #previews are binary data
65+
66+
67+
history = get_history(prompt_id)[prompt_id]
68+
for node_id in history['outputs']:
69+
node_output = history['outputs'][node_id]
70+
output_data = {}
71+
if 'images' in node_output:
72+
for image in node_output['images']:
73+
if allow_preview and image['type'] == 'temp':
74+
preview_data = get_image(image['filename'], image['subfolder'], image['type'])
75+
output_data['image_data'] = preview_data
76+
if image['type'] == 'output':
77+
image_data = get_image(image['filename'], image['subfolder'], image['type'])
78+
output_data['image_data'] = image_data
79+
output_data['file_name'] = image['filename']
80+
output_data['type'] = image['type']
81+
output_images.append(output_data)
82+
83+
return output_images
84+
85+
86+

main.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from utils.actions.prompt_to_image import prompt_to_image
2+
from utils.actions.load_workflow import load_workflow
3+
from utils.actions.interrupt_prompt import interrupt
4+
import time
5+
import sys
6+
7+
def main():
8+
try:
9+
print("Welcome to the program!")
10+
workflow = load_workflow('./workflows/test_workflow.json')
11+
12+
prompt_to_image(workflow, 'a tiny little green and blue monster wearing a coat and a little hat, very cute with sharp teeths', True)
13+
except Exception as e:
14+
print(f"An error occurred: {e}")
15+
exit_program()
16+
17+
def exit_program():
18+
print("Exiting the program...")
19+
interrupt()
20+
sys.exit(0)
21+
22+
main()

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Pillow==10.0.0
2+
Pillow==10.2.0
3+
websocket_client==1.7.0

utils/__init__.py

Whitespace-only changes.

utils/actions/__init__.py

Whitespace-only changes.

utils/actions/interrupt_prompt.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from api.websocket_api import interupt_prompt
2+
def interrupt():
3+
interupt_prompt()

utils/actions/load_workflow.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import json
2+
3+
def load_workflow(workflow_path):
4+
file = open(workflow_path)
5+
workflow = json.load(file)
6+
workflow = json.dumps(workflow)
7+
8+
return workflow

utils/actions/prompt_to_image.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from api.websocket_api import generate_image_by_prompt
2+
from utils.helpers.find_node import find_node
3+
from utils.helpers.randomize_seed import generate_random_15_digit_number
4+
from utils.helpers.replace_key import replace_key
5+
import json
6+
def prompt_to_image(workflow, positve_prompt, negative_prompt='', save_previews=False):
7+
prompt = json.loads(workflow)
8+
replace_key(prompt, 'seed', generate_random_15_digit_number())
9+
postive_prompt_id = find_node(prompt, 'positive')[0]
10+
positive_prompt_node = find_node(prompt, postive_prompt_id)
11+
12+
if negative_prompt != '':
13+
negative_prompt_id = find_node(prompt, 'negative')[0]
14+
negative_prompt_node = find_node(prompt, negative_prompt_id)
15+
negative_prompt_node['inputs']['text'] = negative_prompt
16+
17+
positive_prompt_node['inputs']['text'] = positve_prompt
18+
19+
generate_image_by_prompt(prompt, './output/', save_previews)
20+
21+
22+
23+

0 commit comments

Comments
 (0)