Skip to content

Commit 1610b32

Browse files
committed
add callback for creating a tab in train UI
1 parent 8011be3 commit 1610b32

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

modules/script_callbacks.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from fastapi import FastAPI
88
from gradio import Blocks
99

10+
1011
def report_exception(c, job):
1112
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
1213
print(traceback.format_exc(), file=sys.stderr)
@@ -45,22 +46,29 @@ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
4546
"""Total number of sampling steps planned"""
4647

4748

49+
class UiTrainTabParams:
50+
def __init__(self, txt2img_preview_params):
51+
self.txt2img_preview_params = txt2img_preview_params
52+
53+
4854
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
4955
callback_map = dict(
5056
callbacks_app_started=[],
5157
callbacks_model_loaded=[],
5258
callbacks_ui_tabs=[],
59+
callbacks_ui_train_tabs=[],
5360
callbacks_ui_settings=[],
5461
callbacks_before_image_saved=[],
5562
callbacks_image_saved=[],
56-
callbacks_cfg_denoiser=[]
63+
callbacks_cfg_denoiser=[],
5764
)
5865

5966

6067
def clear_callbacks():
6168
for callback_list in callback_map.values():
6269
callback_list.clear()
6370

71+
6472
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
6573
for c in callback_map['callbacks_app_started']:
6674
try:
@@ -79,7 +87,7 @@ def model_loaded_callback(sd_model):
7987

8088
def ui_tabs_callback():
8189
res = []
82-
90+
8391
for c in callback_map['callbacks_ui_tabs']:
8492
try:
8593
res += c.callback() or []
@@ -89,6 +97,14 @@ def ui_tabs_callback():
8997
return res
9098

9199

100+
def ui_train_tabs_callback(params: UiTrainTabParams):
101+
for c in callback_map['callbacks_ui_train_tabs']:
102+
try:
103+
c.callback(params)
104+
except Exception:
105+
report_exception(c, 'callbacks_ui_train_tabs')
106+
107+
92108
def ui_settings_callback():
93109
for c in callback_map['callbacks_ui_settings']:
94110
try:
@@ -169,6 +185,13 @@ def on_ui_tabs(callback):
169185
add_callback(callback_map['callbacks_ui_tabs'], callback)
170186

171187

188+
def on_ui_train_tabs(callback):
189+
"""register a function to be called when the UI is creating new tabs for the train tab.
190+
Create your new tabs with gr.Tab.
191+
"""
192+
add_callback(callback_map['callbacks_ui_train_tabs'], callback)
193+
194+
172195
def on_ui_settings(callback):
173196
"""register a function to be called before UI settings are populated; add your settings
174197
by using shared.opts.add_option(shared.OptionInfo(...)) """

modules/ui.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,10 @@ def create_ui(wrap_gradio_gpu_call):
12701270
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
12711271
train_embedding = gr.Button(value="Train Embedding", variant='primary')
12721272

1273+
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
1274+
1275+
script_callbacks.ui_train_tabs_callback(params)
1276+
12731277
with gr.Column():
12741278
progressbar = gr.HTML(elem_id="ti_progressbar")
12751279
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)

0 commit comments

Comments
 (0)