Skip to content

Commit 6165f07

Browse files
authored
Merge branch 'master' into patch-1
2 parents c556d34 + e666220 commit 6165f07

File tree

10 files changed

+114
-17
lines changed

10 files changed

+114
-17
lines changed

javascript/generationParams.js

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
2+
3+
let txt2img_gallery, img2img_gallery, modal = undefined;
4+
onUiUpdate(function(){
5+
if (!txt2img_gallery) {
6+
txt2img_gallery = attachGalleryListeners("txt2img")
7+
}
8+
if (!img2img_gallery) {
9+
img2img_gallery = attachGalleryListeners("img2img")
10+
}
11+
if (!modal) {
12+
modal = gradioApp().getElementById('lightboxModal')
13+
modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] });
14+
}
15+
});
16+
17+
let modalObserver = new MutationObserver(function(mutations) {
18+
mutations.forEach(function(mutationRecord) {
19+
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
20+
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
21+
gradioApp().getElementById(selectedTab+"_generation_info_button").click()
22+
});
23+
});
24+
25+
function attachGalleryListeners(tab_name) {
26+
gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
27+
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
28+
gallery?.addEventListener('keydown', (e) => {
29+
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
30+
gradioApp().getElementById(tab_name+"_generation_info_button").click()
31+
});
32+
return gallery;
33+
}

modules/api/api.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from modules.realesrgan_model import get_realesrgan_models
1616
from typing import List
1717

18+
if shared.cmd_opts.deepdanbooru:
19+
from modules.deepbooru import get_deepbooru_tags
20+
1821
def upscaler_to_index(name: str):
1922
try:
2023
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
@@ -220,11 +223,20 @@ def interrogateapi(self, interrogatereq: InterrogateRequest):
220223
if image_b64 is None:
221224
raise HTTPException(status_code=404, detail="Image not found")
222225

223-
img = self.__base64_to_image(image_b64)
226+
img = decode_base64_to_image(image_b64)
227+
img = img.convert('RGB')
224228

225229
# Override object param
226230
with self.queue_lock:
227-
processed = shared.interrogator.interrogate(img)
231+
if interrogatereq.model == "clip":
232+
processed = shared.interrogator.interrogate(img)
233+
elif interrogatereq.model == "deepdanbooru":
234+
if shared.cmd_opts.deepdanbooru:
235+
processed = get_deepbooru_tags(img)
236+
else:
237+
raise HTTPException(status_code=404, detail="Model not found. Add --deepdanbooru when launching for using the model.")
238+
else:
239+
raise HTTPException(status_code=404, detail="Model not found")
228240

229241
return InterrogateResponse(caption=processed)
230242

modules/api/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ class ProgressResponse(BaseModel):
170170

171171
class InterrogateRequest(BaseModel):
172172
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
173+
model: str = Field(default="clip", title="Model", description="The interrogate model used.")
173174

174175
class InterrogateResponse(BaseModel):
175176
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")

modules/ngrok.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
from pyngrok import ngrok, conf, exception
22

3-
43
def connect(token, port, region):
4+
account = None
55
if token == None:
66
token = 'None'
7+
else:
8+
if ':' in token:
9+
# token = authtoken:username:password
10+
account = token.split(':')[1] + ':' + token.split(':')[-1]
11+
token = token.split(':')[0]
12+
713
config = conf.PyngrokConfig(
814
auth_token=token, region=region
915
)
1016
try:
11-
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
17+
if account == None:
18+
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
19+
else:
20+
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
1221
except exception.PyngrokNgrokError:
1322
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
1423
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')

modules/sd_models.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,21 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
163163
checkpoint_file = checkpoint_info.filename
164164
sd_model_hash = checkpoint_info.hash
165165

166-
if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"):
166+
cache_enabled = shared.opts.sd_checkpoint_cache > 0
167+
168+
if cache_enabled:
167169
sd_vae.restore_base_vae(model)
168-
checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy()
169170

170171
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
171172

172-
if checkpoint_info not in checkpoints_loaded:
173+
if cache_enabled and checkpoint_info in checkpoints_loaded:
174+
# use checkpoint cache
175+
vae_name = sd_vae.get_filename(vae_file) if vae_file else None
176+
vae_message = f" with {vae_name} VAE" if vae_name else ""
177+
print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
178+
model.load_state_dict(checkpoints_loaded[checkpoint_info])
179+
else:
180+
# load from file
173181
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
174182

175183
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
@@ -180,6 +188,10 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
180188
del pl_sd
181189
model.load_state_dict(sd, strict=False)
182190
del sd
191+
192+
if cache_enabled:
193+
# cache newly loaded model
194+
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
183195

184196
if shared.cmd_opts.opt_channelslast:
185197
model.to(memory_format=torch.channels_last)
@@ -199,14 +211,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
199211

200212
model.first_stage_model.to(devices.dtype_vae)
201213

202-
else:
203-
vae_name = sd_vae.get_filename(vae_file) if vae_file else None
204-
vae_message = f" with {vae_name} VAE" if vae_name else ""
205-
print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
206-
model.load_state_dict(checkpoints_loaded[checkpoint_info])
207-
208-
if shared.opts.sd_checkpoint_cache > 0:
209-
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
214+
# clean up cache if limit is reached
215+
if cache_enabled:
216+
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
210217
checkpoints_loaded.popitem(last=False) # LRU
211218

212219
model.sd_model_hash = sd_model_hash

modules/shared.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,8 @@ def options_section(section_identifier, options_dict):
319319

320320
options_templates.update(options_section(('training', "Training"), {
321321
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
322+
"shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."),
323+
"tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}),
322324
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
323325
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
324326
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),

modules/textual_inversion/dataset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,12 @@ def shuffle(self):
9898
def create_text(self, filename_text):
9999
text = random.choice(self.lines)
100100
text = text.replace("[name]", self.placeholder_token)
101-
text = text.replace("[filewords]", filename_text)
101+
tags = filename_text.split(',')
102+
if shared.opts.tag_drop_out != 0:
103+
tags = [t for t in tags if random.random() > shared.opts.tag_drop_out]
104+
if shared.opts.shuffle_tags:
105+
random.shuffle(tags)
106+
text = text.replace("[filewords]", ','.join(tags))
102107
return text
103108

104109
def __len__(self):

modules/ui.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,19 @@ def apply_setting(key, value):
566566
return value
567567

568568

569+
def update_generation_info(args):
570+
generation_info, html_info, img_index = args
571+
try:
572+
generation_info = json.loads(generation_info)
573+
if img_index < 0 or img_index >= len(generation_info["infotexts"]):
574+
return html_info
575+
return plaintext_to_html(generation_info["infotexts"][img_index])
576+
except Exception:
577+
pass
578+
# if the json parse or anything else fails, just return the old html_info
579+
return html_info
580+
581+
569582
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
570583
def refresh():
571584
refresh_method()
@@ -638,6 +651,15 @@ def open_folder(f):
638651
with gr.Group():
639652
html_info = gr.HTML()
640653
generation_info = gr.Textbox(visible=False)
654+
if tabname == 'txt2img' or tabname == 'img2img':
655+
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
656+
generation_info_button.click(
657+
fn=update_generation_info,
658+
_js="(x, y) => [x, y, selected_gallery_index()]",
659+
inputs=[generation_info, html_info],
660+
outputs=[html_info],
661+
preprocess=False
662+
)
641663

642664
save.click(
643665
fn=wrap_gradio_call(save_files),

scripts/prompt_matrix.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def run(self, p, put_at_start):
8080
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
8181
grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
8282
processed.images.insert(0, grid)
83+
processed.index_of_first_image = 1
84+
processed.infotexts.insert(0, processed.infotexts[0])
8385

8486
if opts.grid_save:
8587
images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", prompt=original_prompt, seed=processed.seed, grid=True, p=p)

scripts/prompts_from_file.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str):
145145
state.job_count = job_count
146146

147147
images = []
148+
all_prompts = []
149+
infotexts = []
148150
for n, args in enumerate(jobs):
149151
state.job = f"{state.job_no + 1} out of {state.job_count}"
150152

@@ -157,5 +159,7 @@ def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str):
157159

158160
if checkbox_iterate:
159161
p.seed = p.seed + (p.batch_size * p.n_iter)
162+
all_prompts += proc.all_prompts
163+
infotexts += proc.infotexts
160164

161-
return Processed(p, images, p.seed, "")
165+
return Processed(p, images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts)

0 commit comments

Comments
 (0)