Skip to content

Commit 0b5dcb3

Browse files
committed
fix an error that happens when you type into prompt while switching model, put queue stuff into separate file
1 parent 0376da1 commit 0b5dcb3

File tree

3 files changed

+104
-91
lines changed

3 files changed

+104
-91
lines changed

modules/call_queue.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import html
2+
import sys
3+
import threading
4+
import traceback
5+
import time
6+
7+
from modules import shared
8+
9+
queue_lock = threading.Lock()
10+
11+
12+
def wrap_queued_call(func):
13+
def f(*args, **kwargs):
14+
with queue_lock:
15+
res = func(*args, **kwargs)
16+
17+
return res
18+
19+
return f
20+
21+
22+
def wrap_gradio_gpu_call(func, extra_outputs=None):
23+
def f(*args, **kwargs):
24+
25+
shared.state.begin()
26+
27+
with queue_lock:
28+
res = func(*args, **kwargs)
29+
30+
shared.state.end()
31+
32+
return res
33+
34+
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
35+
36+
37+
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
38+
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
39+
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
40+
if run_memmon:
41+
shared.mem_mon.monitor()
42+
t = time.perf_counter()
43+
44+
try:
45+
res = list(func(*args, **kwargs))
46+
except Exception as e:
47+
# When printing out our debug argument list, do not print out more than a MB of text
48+
max_debug_str_len = 131072 # (1024*1024)/8
49+
50+
print("Error completing request", file=sys.stderr)
51+
argStr = f"Arguments: {str(args)} {str(kwargs)}"
52+
print(argStr[:max_debug_str_len], file=sys.stderr)
53+
if len(argStr) > max_debug_str_len:
54+
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
55+
56+
print(traceback.format_exc(), file=sys.stderr)
57+
58+
shared.state.job = ""
59+
shared.state.job_count = 0
60+
61+
if extra_outputs_array is None:
62+
extra_outputs_array = [None, '']
63+
64+
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
65+
66+
shared.state.skipped = False
67+
shared.state.interrupted = False
68+
shared.state.job_count = 0
69+
70+
if not add_stats:
71+
return tuple(res)
72+
73+
elapsed = time.perf_counter() - t
74+
elapsed_m = int(elapsed // 60)
75+
elapsed_s = elapsed % 60
76+
elapsed_text = f"{elapsed_s:.2f}s"
77+
if elapsed_m > 0:
78+
elapsed_text = f"{elapsed_m}m "+elapsed_text
79+
80+
if run_memmon:
81+
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
82+
active_peak = mem_stats['active_peak']
83+
reserved_peak = mem_stats['reserved_peak']
84+
sys_peak = mem_stats['system_peak']
85+
sys_total = mem_stats['total']
86+
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
87+
88+
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
89+
else:
90+
vram_html = ''
91+
92+
# last item is always HTML
93+
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
94+
95+
return tuple(res)
96+
97+
return f
98+

modules/ui.py

Lines changed: 3 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import gradio.utils
1818
import numpy as np
1919
from PIL import Image, PngImagePlugin
20-
20+
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
2121

2222
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
2323
from modules.paths import script_path
@@ -158,67 +158,6 @@ def __init__(self, d=None):
158158
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
159159

160160

161-
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
162-
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
163-
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
164-
if run_memmon:
165-
shared.mem_mon.monitor()
166-
t = time.perf_counter()
167-
168-
try:
169-
res = list(func(*args, **kwargs))
170-
except Exception as e:
171-
# When printing out our debug argument list, do not print out more than a MB of text
172-
max_debug_str_len = 131072 # (1024*1024)/8
173-
174-
print("Error completing request", file=sys.stderr)
175-
argStr = f"Arguments: {str(args)} {str(kwargs)}"
176-
print(argStr[:max_debug_str_len], file=sys.stderr)
177-
if len(argStr) > max_debug_str_len:
178-
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
179-
180-
print(traceback.format_exc(), file=sys.stderr)
181-
182-
shared.state.job = ""
183-
shared.state.job_count = 0
184-
185-
if extra_outputs_array is None:
186-
extra_outputs_array = [None, '']
187-
188-
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
189-
190-
shared.state.skipped = False
191-
shared.state.interrupted = False
192-
shared.state.job_count = 0
193-
194-
if not add_stats:
195-
return tuple(res)
196-
197-
elapsed = time.perf_counter() - t
198-
elapsed_m = int(elapsed // 60)
199-
elapsed_s = elapsed % 60
200-
elapsed_text = f"{elapsed_s:.2f}s"
201-
if elapsed_m > 0:
202-
elapsed_text = f"{elapsed_m}m "+elapsed_text
203-
204-
if run_memmon:
205-
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
206-
active_peak = mem_stats['active_peak']
207-
reserved_peak = mem_stats['reserved_peak']
208-
sys_peak = mem_stats['system_peak']
209-
sys_total = mem_stats['total']
210-
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
211-
212-
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
213-
else:
214-
vram_html = ''
215-
216-
# last item is always HTML
217-
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
218-
219-
return tuple(res)
220-
221-
return f
222161

223162

224163
def calc_time_left(progress, threshold, label, force_display):
@@ -666,7 +605,7 @@ def open_folder(f):
666605
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info
667606

668607

669-
def create_ui(wrap_gradio_gpu_call):
608+
def create_ui():
670609
import modules.img2img
671610
import modules.txt2img
672611

@@ -826,7 +765,7 @@ def create_ui(wrap_gradio_gpu_call):
826765
height,
827766
]
828767

829-
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
768+
token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
830769

831770
modules.scripts.scripts_current = modules.scripts.scripts_img2img
832771
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)

webui.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from fastapi.middleware.cors import CORSMiddleware
99
from fastapi.middleware.gzip import GZipMiddleware
1010

11+
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
1112
from modules.paths import script_path
1213

1314
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
@@ -32,38 +33,12 @@
3233
import modules.hypernetworks.hypernetwork
3334

3435

35-
queue_lock = threading.Lock()
3636
if cmd_opts.server_name:
3737
server_name = cmd_opts.server_name
3838
else:
3939
server_name = "0.0.0.0" if cmd_opts.listen else None
4040

4141

42-
def wrap_queued_call(func):
43-
def f(*args, **kwargs):
44-
with queue_lock:
45-
res = func(*args, **kwargs)
46-
47-
return res
48-
49-
return f
50-
51-
52-
def wrap_gradio_gpu_call(func, extra_outputs=None):
53-
def f(*args, **kwargs):
54-
55-
shared.state.begin()
56-
57-
with queue_lock:
58-
res = func(*args, **kwargs)
59-
60-
shared.state.end()
61-
62-
return res
63-
64-
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
65-
66-
6742
def initialize():
6843
extensions.list_extensions()
6944
localization.list_localizations(cmd_opts.localizations_dir)
@@ -159,7 +134,7 @@ def webui():
159134
if shared.opts.clean_temp_dir_at_start:
160135
ui_tempdir.cleanup_tmpdr()
161136

162-
shared.demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
137+
shared.demo = modules.ui.create_ui()
163138

164139
app, local_url, share_url = shared.demo.launch(
165140
share=cmd_opts.share,
@@ -189,6 +164,7 @@ def webui():
189164
create_api(app)
190165

191166
modules.script_callbacks.app_started_callback(shared.demo, app)
167+
modules.script_callbacks.app_started_callback(shared.demo, app)
192168

193169
wait_on_server(shared.demo)
194170

0 commit comments

Comments
 (0)