88import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
99from PIL import Image
1010import io
11+ from utils .helpers .find_node import find_node
1112
1213
1314server_address = '127.0.0.1:8188'
1819
1920def generate_image_by_prompt (prompt , output_path , save_previews = False ):
2021 prompt_id = queue_prompt (prompt )['prompt_id' ]
22+ track_progress (prompt , ws , prompt_id )
2123 images = get_images (ws , prompt_id , save_previews )
2224 save_image (images , output_path , save_previews )
2325
@@ -30,6 +32,38 @@ def save_image(images, output_path, save_previews):
3032 image = Image .open (io .BytesIO (itm ['image_data' ]))
3133 image .save (output_path + itm ['file_name' ])
3234
35+
36+ def track_progress (prompt , ws , prompt_id ):
37+ node_ids = list (prompt .keys ())
38+ finished_nodes = []
39+
40+ while True :
41+ out = ws .recv ()
42+ if isinstance (out , str ):
43+ message = json .loads (out )
44+ if message ['type' ] == 'progress' :
45+ data = message ['data' ]
46+ current_step = data ['value' ]
47+ print ('In K-Sampler -> Step: ' , current_step , ' of: ' , data ['max' ])
48+ if message ['type' ] == 'execution_cached' :
49+ data = message ['data' ]
50+ for itm in data ['nodes' ]:
51+ if itm not in finished_nodes :
52+ finished_nodes .append (itm )
53+ print ('Progess: ' , len (finished_nodes ), '/' , len (node_ids ), ' Tasks done' )
54+ if message ['type' ] == 'executing' :
55+ data = message ['data' ]
56+ if data ['node' ] not in finished_nodes :
57+ finished_nodes .append (data ['node' ])
58+ print ('Progess: ' , len (finished_nodes ), '/' , len (node_ids ), ' Tasks done' )
59+
60+
61+ if data ['node' ] is None and data ['prompt_id' ] == prompt_id :
62+ break #Execution is done
63+ else :
64+ continue #previews are binary data
65+ return
66+
3367def queue_prompt (prompt ):
3468 p = {"prompt" : prompt , "client_id" : client_id }
3569 data = json .dumps (p ).encode ('utf-8' )
@@ -52,17 +86,6 @@ def get_history(prompt_id):
5286
5387def get_images (ws , prompt_id , allow_preview = False ):
5488 output_images = []
55- while True :
56- out = ws .recv ()
57- if isinstance (out , str ):
58- message = json .loads (out )
59- if message ['type' ] == 'executing' :
60- data = message ['data' ]
61- if data ['node' ] is None and data ['prompt_id' ] == prompt_id :
62- break #Execution is done
63- else :
64- continue #previews are binary data
65-
6689
6790 history = get_history (prompt_id )[prompt_id ]
6891 for node_id in history ['outputs' ]:
@@ -82,5 +105,17 @@ def get_images(ws, prompt_id, allow_preview = False):
82105
83106 return output_images
84107
108+ def get_node_info_by_class (node_class ):
109+ with urllib .request .urlopen ("http://{}/object_info/{}" .format (server_address , node_class )) as response :
110+ return json .loads (response .read ())
111+
112+ def clear_comfy_cache (unload_models = False , free_memory = False ):
113+ clear_data = {
114+ "unload_models" : unload_models ,
115+ "free_memory" : free_memory
116+ }
117+ data = json .dumps (clear_data ).encode ('utf-8' )
85118
119+ with urllib .request .urlopen ("http://{}/free" .format (server_address ), data = data ) as response :
120+ return response .read ()
86121
0 commit comments