|
| 1 | +#This is an example that uses the websockets api to know when a prompt execution is done |
| 2 | +#Once the prompt execution is done it downloads the images using the /history endpoint |
| 3 | + |
| 4 | +import uuid |
| 5 | +import json |
| 6 | +import urllib.request |
| 7 | +import urllib.parse |
| 8 | +import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) |
| 9 | +from PIL import Image |
| 10 | +import io |
| 11 | + |
| 12 | + |
| 13 | +server_address='127.0.0.1:8188' |
| 14 | +client_id=str(uuid.uuid4()) |
| 15 | + |
| 16 | +ws = websocket.WebSocket() |
| 17 | +ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) |
| 18 | + |
| 19 | +def generate_image_by_prompt(prompt, output_path, save_previews=False): |
| 20 | + prompt_id = queue_prompt(prompt)['prompt_id'] |
| 21 | + images = get_images(ws, prompt_id, save_previews) |
| 22 | + save_image(images, output_path, save_previews) |
| 23 | + |
| 24 | +def save_image(images, output_path, save_previews): |
| 25 | + for itm in images: |
| 26 | + if itm['type'] == 'temp' and save_previews: |
| 27 | + image = Image.open(io.BytesIO(itm['image_data'])) |
| 28 | + image.save(output_path + 'temp/' + itm['file_name']) |
| 29 | + else: |
| 30 | + image = Image.open(io.BytesIO(itm['image_data'])) |
| 31 | + image.save(output_path + itm['file_name']) |
| 32 | + |
| 33 | +def queue_prompt(prompt): |
| 34 | + p = {"prompt": prompt, "client_id": client_id} |
| 35 | + data = json.dumps(p).encode('utf-8') |
| 36 | + req = urllib.request.Request("http://{}/prompt".format(server_address), data=data) |
| 37 | + return json.loads(urllib.request.urlopen(req).read()) |
| 38 | + |
| 39 | +def interupt_prompt(): |
| 40 | + req = urllib.request.Request("http://{}/interrupt".format(server_address), data={}) |
| 41 | + return json.loads(urllib.request.urlopen(req).read()) |
| 42 | + |
| 43 | +def get_image(filename, subfolder, folder_type): |
| 44 | + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} |
| 45 | + url_values = urllib.parse.urlencode(data) |
| 46 | + with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response: |
| 47 | + return response.read() |
| 48 | + |
| 49 | +def get_history(prompt_id): |
| 50 | + with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response: |
| 51 | + return json.loads(response.read()) |
| 52 | + |
| 53 | +def get_images(ws, prompt_id, allow_preview = False): |
| 54 | + output_images = [] |
| 55 | + while True: |
| 56 | + out = ws.recv() |
| 57 | + if isinstance(out, str): |
| 58 | + message = json.loads(out) |
| 59 | + if message['type'] == 'executing': |
| 60 | + data = message['data'] |
| 61 | + if data['node'] is None and data['prompt_id'] == prompt_id: |
| 62 | + break #Execution is done |
| 63 | + else: |
| 64 | + continue #previews are binary data |
| 65 | + |
| 66 | + |
| 67 | + history = get_history(prompt_id)[prompt_id] |
| 68 | + for node_id in history['outputs']: |
| 69 | + node_output = history['outputs'][node_id] |
| 70 | + output_data = {} |
| 71 | + if 'images' in node_output: |
| 72 | + for image in node_output['images']: |
| 73 | + if allow_preview and image['type'] == 'temp': |
| 74 | + preview_data = get_image(image['filename'], image['subfolder'], image['type']) |
| 75 | + output_data['image_data'] = preview_data |
| 76 | + if image['type'] == 'output': |
| 77 | + image_data = get_image(image['filename'], image['subfolder'], image['type']) |
| 78 | + output_data['image_data'] = image_data |
| 79 | + output_data['file_name'] = image['filename'] |
| 80 | + output_data['type'] = image['type'] |
| 81 | + output_images.append(output_data) |
| 82 | + |
| 83 | + return output_images |
| 84 | + |
| 85 | + |
| 86 | + |
0 commit comments