Skip to content

Commit 3596af0

Browse files
committed
Add API for scripts to add elements anywhere in UI.
1 parent ccd73fc commit 3596af0

File tree

3 files changed

+111
-5
lines changed

3 files changed

+111
-5
lines changed

modules/script_callbacks.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def __init__(self, txt2img_preview_params):
6161
callbacks_before_image_saved=[],
6262
callbacks_image_saved=[],
6363
callbacks_cfg_denoiser=[],
64+
callbacks_before_component=[],
65+
callbacks_after_component=[],
6466
)
6567

6668

@@ -137,6 +139,22 @@ def cfg_denoiser_callback(params: CFGDenoiserParams):
137139
report_exception(c, 'cfg_denoiser_callback')
138140

139141

142+
def before_component_callback(component, **kwargs):
143+
for c in callback_map['callbacks_before_component']:
144+
try:
145+
c.callback(component, **kwargs)
146+
except Exception:
147+
report_exception(c, 'before_component_callback')
148+
149+
150+
def after_component_callback(component, **kwargs):
151+
for c in callback_map['callbacks_after_component']:
152+
try:
153+
c.callback(component, **kwargs)
154+
except Exception:
155+
report_exception(c, 'after_component_callback')
156+
157+
140158
def add_callback(callbacks, fun):
141159
stack = [x for x in inspect.stack() if x.filename != __file__]
142160
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -220,3 +238,20 @@ def on_cfg_denoiser(callback):
220238
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
221239
"""
222240
add_callback(callback_map['callbacks_cfg_denoiser'], callback)
241+
242+
243+
def on_before_component(callback):
244+
"""register a function to be called before a component is created.
245+
The callback is called with arguments:
246+
- component - gradio component that is about to be created.
247+
- **kwargs - args to gradio.components.IOComponent.__init__ function
248+
249+
Use elem_id/label fields of kwargs to figure out which component it is.
250+
This can be useful to inject your own components somewhere in the middle of vanilla UI.
251+
"""
252+
add_callback(callback_map['callbacks_before_component'], callback)
253+
254+
255+
def on_after_component(callback):
256+
"""register a function to be called after a component is created. See on_before_component for more."""
257+
add_callback(callback_map['callbacks_after_component'], callback)

modules/scripts.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ class Script:
1717
args_to = None
1818
alwayson = False
1919

20+
is_txt2img = False
21+
is_img2img = False
22+
2023
"""A gr.Group component that has all script's UI inside it"""
2124
group = None
2225

@@ -93,6 +96,23 @@ def postprocess(self, p, processed, *args):
9396

9497
pass
9598

99+
def before_component(self, component, **kwargs):
100+
"""
101+
Called before a component is created.
102+
Use elem_id/label fields of kwargs to figure out which component it is.
103+
This can be useful to inject your own components somewhere in the middle of vanilla UI.
104+
You can return created components in the ui() function to add them to the list of arguments for your processing functions
105+
"""
106+
107+
pass
108+
109+
def after_component(self, component, **kwargs):
110+
"""
111+
Called after a component is created. Same as above.
112+
"""
113+
114+
pass
115+
96116
def describe(self):
97117
"""unused"""
98118
return ""
@@ -195,12 +215,18 @@ def __init__(self):
195215
self.titles = []
196216
self.infotext_fields = []
197217

198-
def setup_ui(self, is_img2img):
218+
def initialize_scripts(self, is_img2img):
219+
self.scripts.clear()
220+
self.alwayson_scripts.clear()
221+
self.selectable_scripts.clear()
222+
199223
for script_class, path, basedir in scripts_data:
200224
script = script_class()
201225
script.filename = path
226+
script.is_txt2img = not is_img2img
227+
script.is_img2img = is_img2img
202228

203-
visibility = script.show(is_img2img)
229+
visibility = script.show(script.is_img2img)
204230

205231
if visibility == AlwaysVisible:
206232
self.scripts.append(script)
@@ -211,6 +237,7 @@ def setup_ui(self, is_img2img):
211237
self.scripts.append(script)
212238
self.selectable_scripts.append(script)
213239

240+
def setup_ui(self):
214241
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
215242

216243
inputs = [None]
@@ -220,7 +247,7 @@ def create_script_ui(script, inputs, inputs_alwayson):
220247
script.args_from = len(inputs)
221248
script.args_to = len(inputs)
222249

223-
controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
250+
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
224251

225252
if controls is None:
226253
return
@@ -320,6 +347,22 @@ def postprocess(self, p, processed):
320347
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
321348
print(traceback.format_exc(), file=sys.stderr)
322349

350+
def before_component(self, component, **kwargs):
351+
for script in self.scripts:
352+
try:
353+
script.before_component(component, **kwargs)
354+
except Exception:
355+
print(f"Error running before_component: {script.filename}", file=sys.stderr)
356+
print(traceback.format_exc(), file=sys.stderr)
357+
358+
def after_component(self, component, **kwargs):
359+
for script in self.scripts:
360+
try:
361+
script.after_component(component, **kwargs)
362+
except Exception:
363+
print(f"Error running after_component: {script.filename}", file=sys.stderr)
364+
print(traceback.format_exc(), file=sys.stderr)
365+
323366
def reload_sources(self, cache):
324367
for si, script in list(enumerate(self.scripts)):
325368
args_from = script.args_from
@@ -341,6 +384,7 @@ def reload_sources(self, cache):
341384

342385
scripts_txt2img = ScriptRunner()
343386
scripts_img2img = ScriptRunner()
387+
scripts_current: ScriptRunner = None
344388

345389

346390
def reload_script_body_only():
@@ -357,3 +401,22 @@ def reload_scripts():
357401
scripts_txt2img = ScriptRunner()
358402
scripts_img2img = ScriptRunner()
359403

404+
405+
def IOComponent_init(self, *args, **kwargs):
406+
if scripts_current is not None:
407+
scripts_current.before_component(self, **kwargs)
408+
409+
script_callbacks.before_component_callback(self, **kwargs)
410+
411+
res = original_IOComponent_init(self, *args, **kwargs)
412+
413+
script_callbacks.after_component_callback(self, **kwargs)
414+
415+
if scripts_current is not None:
416+
scripts_current.after_component(self, **kwargs)
417+
418+
return res
419+
420+
421+
original_IOComponent_init = gr.components.IOComponent.__init__
422+
gr.components.IOComponent.__init__ = IOComponent_init

modules/ui.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,9 @@ def create_ui(wrap_gradio_gpu_call):
695695

696696
parameters_copypaste.reset()
697697

698+
modules.scripts.scripts_current = modules.scripts.scripts_txt2img
699+
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
700+
698701
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
699702
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
700703
dummy_component = gr.Label(visible=False)
@@ -737,7 +740,7 @@ def create_ui(wrap_gradio_gpu_call):
737740
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
738741

739742
with gr.Group():
740-
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
743+
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
741744

742745
txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
743746
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
@@ -846,6 +849,9 @@ def create_ui(wrap_gradio_gpu_call):
846849

847850
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
848851

852+
modules.scripts.scripts_current = modules.scripts.scripts_img2img
853+
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
854+
849855
with gr.Blocks(analytics_enabled=False) as img2img_interface:
850856
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True)
851857

@@ -916,7 +922,7 @@ def create_ui(wrap_gradio_gpu_call):
916922
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
917923

918924
with gr.Group():
919-
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
925+
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
920926

921927
img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
922928
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
@@ -1065,6 +1071,8 @@ def create_ui(wrap_gradio_gpu_call):
10651071
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
10661072
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
10671073

1074+
modules.scripts.scripts_current = None
1075+
10681076
with gr.Blocks(analytics_enabled=False) as extras_interface:
10691077
with gr.Row().style(equal_height=False):
10701078
with gr.Column(variant='panel'):

0 commit comments

Comments
 (0)