|
17 | 17 | import gradio.utils
|
18 | 18 | import numpy as np
|
19 | 19 | from PIL import Image, PngImagePlugin
|
20 |
| - |
| 20 | +from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call |
21 | 21 |
|
22 | 22 | from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
|
23 | 23 | from modules.paths import script_path
|
@@ -158,67 +158,6 @@ def __init__(self, d=None):
|
158 | 158 | return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
159 | 159 |
|
160 | 160 |
|
161 |
| -def wrap_gradio_call(func, extra_outputs=None, add_stats=False): |
162 |
| - def f(*args, extra_outputs_array=extra_outputs, **kwargs): |
163 |
| - run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats |
164 |
| - if run_memmon: |
165 |
| - shared.mem_mon.monitor() |
166 |
| - t = time.perf_counter() |
167 |
| - |
168 |
| - try: |
169 |
| - res = list(func(*args, **kwargs)) |
170 |
| - except Exception as e: |
171 |
| - # When printing out our debug argument list, do not print out more than a MB of text |
172 |
| - max_debug_str_len = 131072 # (1024*1024)/8 |
173 |
| - |
174 |
| - print("Error completing request", file=sys.stderr) |
175 |
| - argStr = f"Arguments: {str(args)} {str(kwargs)}" |
176 |
| - print(argStr[:max_debug_str_len], file=sys.stderr) |
177 |
| - if len(argStr) > max_debug_str_len: |
178 |
| - print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) |
179 |
| - |
180 |
| - print(traceback.format_exc(), file=sys.stderr) |
181 |
| - |
182 |
| - shared.state.job = "" |
183 |
| - shared.state.job_count = 0 |
184 |
| - |
185 |
| - if extra_outputs_array is None: |
186 |
| - extra_outputs_array = [None, ''] |
187 |
| - |
188 |
| - res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"] |
189 |
| - |
190 |
| - shared.state.skipped = False |
191 |
| - shared.state.interrupted = False |
192 |
| - shared.state.job_count = 0 |
193 |
| - |
194 |
| - if not add_stats: |
195 |
| - return tuple(res) |
196 |
| - |
197 |
| - elapsed = time.perf_counter() - t |
198 |
| - elapsed_m = int(elapsed // 60) |
199 |
| - elapsed_s = elapsed % 60 |
200 |
| - elapsed_text = f"{elapsed_s:.2f}s" |
201 |
| - if elapsed_m > 0: |
202 |
| - elapsed_text = f"{elapsed_m}m "+elapsed_text |
203 |
| - |
204 |
| - if run_memmon: |
205 |
| - mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} |
206 |
| - active_peak = mem_stats['active_peak'] |
207 |
| - reserved_peak = mem_stats['reserved_peak'] |
208 |
| - sys_peak = mem_stats['system_peak'] |
209 |
| - sys_total = mem_stats['total'] |
210 |
| - sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2) |
211 |
| - |
212 |
| - vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>" |
213 |
| - else: |
214 |
| - vram_html = '' |
215 |
| - |
216 |
| - # last item is always HTML |
217 |
| - res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>" |
218 |
| - |
219 |
| - return tuple(res) |
220 |
| - |
221 |
| - return f |
222 | 161 |
|
223 | 162 |
|
224 | 163 | def calc_time_left(progress, threshold, label, force_display):
|
@@ -666,7 +605,7 @@ def open_folder(f):
|
666 | 605 | return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info
|
667 | 606 |
|
668 | 607 |
|
669 |
| -def create_ui(wrap_gradio_gpu_call): |
| 608 | +def create_ui(): |
670 | 609 | import modules.img2img
|
671 | 610 | import modules.txt2img
|
672 | 611 |
|
@@ -826,7 +765,7 @@ def create_ui(wrap_gradio_gpu_call):
|
826 | 765 | height,
|
827 | 766 | ]
|
828 | 767 |
|
829 |
| - token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) |
| 768 | + token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) |
830 | 769 |
|
831 | 770 | modules.scripts.scripts_current = modules.scripts.scripts_img2img
|
832 | 771 | modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
|
0 commit comments