Skip to content

Commit e72a6c4

Browse files
committed
fix missing infotext cased by conda cache
some generation params such as TI hashes or Emphasis is added in sd_hijack / sd_hijack_clip if conda are fetche from cache sd_hijack_clip will not be executed and it won't have a chance to to add generation params the generation params will also be missing if in non low-vram mode because the hijack.extra_generation_params was never read after calculate_hr_conds
1 parent 023454b commit e72a6c4

File tree

4 files changed

+76
-7
lines changed

4 files changed

+76
-7
lines changed

modules/processing.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ class StableDiffusionProcessing:
187187

188188
cached_uc = [None, None]
189189
cached_c = [None, None]
190+
hijack_generation_params_state_list = []
190191

191192
comments: dict = None
192193
sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
@@ -480,16 +481,36 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr
480481

481482
for cache in caches:
482483
if cache[0] is not None and cached_params == cache[0]:
484+
if len(cache) == 3:
485+
generation_params_state, cached_params_2 = cache[2]
486+
if cached_params == cached_params_2:
487+
self.hijack_generation_params_state_list.extend(generation_params_state)
483488
return cache[1]
484489

485490
cache = caches[0]
486491

487492
with devices.autocast():
488493
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
489494

495+
generation_params_state = model_hijack.capture_generation_params_state()
496+
self.hijack_generation_params_state_list.extend(generation_params_state)
497+
if len(cache) == 2:
498+
cache.append((generation_params_state, cached_params))
499+
else:
500+
cache[2] = (generation_params_state, cached_params)
501+
490502
cache[0] = cached_params
491503
return cache[1]
492504

505+
def apply_hijack_generation_params(self):
506+
self.extra_generation_params.update(model_hijack.extra_generation_params)
507+
for func in self.hijack_generation_params_state_list:
508+
try:
509+
func(self.extra_generation_params)
510+
except Exception:
511+
errors.report(f"Failed to apply hijack generation params state", exc_info=True)
512+
self.hijack_generation_params_state_list.clear()
513+
493514
def setup_conds(self):
494515
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
495516
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
@@ -502,6 +523,8 @@ def setup_conds(self):
502523
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
503524
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
504525

526+
self.apply_hijack_generation_params()
527+
505528
def get_conds(self):
506529
return self.c, self.uc
507530

@@ -965,8 +988,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
965988

966989
p.setup_conds()
967990

968-
p.extra_generation_params.update(model_hijack.extra_generation_params)
969-
970991
# params.txt should be saved after scripts.process_batch, since the
971992
# infotext could be modified by that callback
972993
# Example: a wildcard processed by process_batch sets an extra model
@@ -1513,6 +1534,8 @@ def calculate_hr_conds(self):
15131534
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
15141535
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
15151536

1537+
self.apply_hijack_generation_params()
1538+
15161539
def setup_conds(self):
15171540
if self.is_hr_pass:
15181541
# if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model

modules/sd_hijack.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from modules.hypernetworks import hypernetwork
77
from modules.shared import cmd_opts
88
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
9+
from modules.util import GenerationParamsState
910

1011
import ldm.modules.attention
1112
import ldm.modules.diffusionmodules.model
@@ -321,6 +322,13 @@ def clear_comments(self):
321322
self.comments = []
322323
self.extra_generation_params = {}
323324

325+
def capture_generation_params_state(self):
326+
state = []
327+
for key in list(self.extra_generation_params):
328+
if isinstance(self.extra_generation_params[key], GenerationParamsState):
329+
state.append(self.extra_generation_params.pop(key))
330+
return state
331+
324332
def get_prompt_lengths(self, text):
325333
if self.clip is None:
326334
return "-", "-"

modules/sd_hijack_clip.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from modules import prompt_parser, devices, sd_hijack, sd_emphasis
77
from modules.shared import opts
8+
from modules.util import GenerationParamsState
89

910

1011
class PromptChunk:
@@ -27,6 +28,31 @@ def __init__(self):
2728
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
2829

2930

31+
class EmbeddingHashes(GenerationParamsState):
32+
def __init__(self, hashes: list):
33+
super().__init__()
34+
self.hashes = hashes
35+
36+
def __call__(self, extra_generation_params):
37+
unique_hashes = dict.fromkeys(self.hashes)
38+
if existing_ti_hashes := extra_generation_params.get('TI hashes'):
39+
unique_hashes.update(dict.fromkeys(existing_ti_hashes.split(', ')))
40+
extra_generation_params['TI hashes'] = ', '.join(unique_hashes)
41+
42+
43+
class EmphasisMode(GenerationParamsState):
44+
def __init__(self, texts):
45+
super().__init__()
46+
if opts.emphasis != 'Original' and any(x for x in texts if '(' in x or '[' in x):
47+
self.emphasis = opts.emphasis
48+
else:
49+
self.emphasis = None
50+
51+
def __call__(self, extra_generation_params):
52+
if self.emphasis:
53+
extra_generation_params['Emphasis'] = self.emphasis
54+
55+
3056
class TextConditionalModel(torch.nn.Module):
3157
def __init__(self):
3258
super().__init__()
@@ -238,12 +264,9 @@ def forward(self, texts):
238264
hashes.append(f"{name}: {shorthash}")
239265

240266
if hashes:
241-
if self.hijack.extra_generation_params.get("TI hashes"):
242-
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
243-
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
267+
self.hijack.extra_generation_params["TI hashes"] = EmbeddingHashes(hashes)
244268

245-
if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
246-
self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
269+
self.hijack.extra_generation_params["Emphasis"] = EmphasisMode(texts)
247270

248271
if self.return_pooled:
249272
return torch.hstack(zs), zs[0].pooled

modules/util.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,18 @@ def compare_sha256(file_path: str, hash_prefix: str) -> bool:
288288
for chunk in iter(lambda: f.read(blksize), b""):
289289
hash_sha256.update(chunk)
290290
return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())
291+
292+
293+
class GenerationParamsState:
294+
"""A custom class used in StableDiffusionModelHijack for assigning extra_generation_params
295+
generation_params assigned using this class will work properly with StableDiffusionProcessing.get_conds_with_caching()
296+
if assigned directly the generation_params will not be populated if conda cache is used
297+
298+
Generation_params of this class will be captured (see StableDiffusionModelHijack.capture_generation_params_state) and stored with conda cache, and will be extracted in StableDiffusionProcessing.apply_hijack_generation_params()
299+
300+
To use this class, create a subclass with a __call__ method that takes extra_generation_params: dict as input
301+
302+
Example usage: sd_hijack_clip.EmbeddingHashes, sd_hijack_clip.EmphasisMode
303+
"""
304+
def __call__(self, extra_generation_params: dict):
305+
raise NotImplementedError

0 commit comments

Comments
 (0)