Skip to content

Commit 4d58444

Browse files
committed
fix issues and further cleanup
1 parent 3667eb4 commit 4d58444

File tree

5 files changed

+34
-18
lines changed

5 files changed

+34
-18
lines changed

ldm/invoke/conditioning.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser, \
1616
Conjunction
1717
from .devices import torch_dtype
18+
from .generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
1819
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
1920
from ldm.invoke.globals import Globals
2021

@@ -32,7 +33,7 @@ def __call__(self, *args, **kwargs):
3233

3334

3435
def get_uc_and_c_and_ec(prompt_string,
35-
model: StableDiffusionPipeline,
36+
model: StableDiffusionGeneratorPipeline,
3637
log_tokens=False, skip_normalize_legacy_blend=False):
3738
# lazy-load any deferred textual inversions.
3839
# this might take a couple of seconds the first time a textual inversion is used.
@@ -75,7 +76,7 @@ def get_uc_and_c_and_ec(prompt_string,
7576
# some LoRA models also mess with the text encoder, so they must be active while compel builds conditioning tensors
7677
lora_conditioning_ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
7778
lora_conditions=lora_conditions)
78-
with InvokeAIDiffuserComponent.custom_attention_context(model,
79+
with InvokeAIDiffuserComponent.custom_attention_context(model.unet,
7980
extra_conditioning_info=lora_conditioning_ec,
8081
step_count=-1):
8182
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)

ldm/models/diffusion/cross_attention_control.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def get_invokeai_attention_mem_efficient(self, q, k, v):
288288
return self.einsum_op_tensor_mem(q, k, v, 32)
289289

290290

291-
def setup_cross_attention_control_attention_processors(model, context: Context):
291+
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
292292
"""
293293
Inject attention parameters and functions into the passed in model to enable cross attention editing.
294294
@@ -314,7 +314,6 @@ def setup_cross_attention_control_attention_processors(model, context: Context):
314314

315315
context.cross_attention_mask = mask.to(device)
316316
context.cross_attention_index_map = indices.to(device)
317-
unet = model
318317
old_attn_processors = unet.attn_processors
319318
if torch.backends.mps.is_available():
320319
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS

ldm/models/diffusion/shared_invokeai_diffusion.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
import numpy as np
77
import torch
8-
from diffusers.models.cross_attention import CrossAttnProcessor
8+
9+
from diffusers import UNet2DConditionModel
910
from typing_extensions import TypeAlias
1011

1112
from ldm.invoke.globals import Globals
@@ -86,33 +87,36 @@ def __init__(
8687
@classmethod
8788
@contextmanager
8889
def custom_attention_context(
89-
clss, model, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
90+
clss,
91+
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
92+
extra_conditioning_info: Optional[ExtraConditioningInfo],
93+
step_count: int
9094
):
9195
old_attn_processors = None
9296
if extra_conditioning_info and (
9397
extra_conditioning_info.wants_cross_attention_control
9498
| extra_conditioning_info.has_lora_conditions
9599
):
96-
old_attn_processors = model.attn_processors
100+
old_attn_processors = unet.attn_processors
97101
# Load lora conditions into the model
98102
if extra_conditioning_info.has_lora_conditions:
99103
for condition in extra_conditioning_info.lora_conditions:
100-
condition(model)
104+
condition() # target model is stored in condition state for some reason
101105
if extra_conditioning_info.wants_cross_attention_control:
102106
cross_attention_control_context = Context(
103107
arguments=extra_conditioning_info.cross_attention_control_args,
104108
step_count=step_count,
105109
)
106110
setup_cross_attention_control_attention_processors(
107-
model,
111+
unet,
108112
cross_attention_control_context,
109113
)
110114

111115
try:
112116
yield None
113117
finally:
114118
if old_attn_processors is not None:
115-
model.set_attn_processor(old_attn_processors)
119+
unet.set_attn_processor(old_attn_processors)
116120
if extra_conditioning_info and extra_conditioning_info.has_lora_conditions:
117121
for lora_condition in extra_conditioning_info.lora_conditions:
118122
lora_condition.unload()

ldm/modules/kohya_lora_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def forward(self, lora, input_h, output):
8686
rebuild1 = torch.einsum('i j k l, j r, i p -> p r k l', self.t1, self.w1_b, self.w1_a)
8787
rebuild2 = torch.einsum('i j k l, j r, i p -> p r k l', self.t2, self.w2_b, self.w2_a)
8888
weight = rebuild1 * rebuild2
89-
89+
9090
bias = self.bias if self.bias is not None else 0
9191
return output + op(
9292
*input_h,

ldm/modules/lora_manager.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22
from pathlib import Path
3+
4+
from diffusers import UNet2DConditionModel, StableDiffusionPipeline
35
from ldm.invoke.globals import global_lora_models_dir
46
from .kohya_lora_manager import KohyaLoraManager
57
from typing import Optional, Dict
@@ -8,20 +10,29 @@ class LoraCondition:
810
name: str
911
weight: float
1012

11-
def __init__(self, name, weight: float = 1.0, kohya_manager: Optional[KohyaLoraManager]=None):
13+
def __init__(self,
14+
name,
15+
weight: float = 1.0,
16+
unet: UNet2DConditionModel=None, # for diffusers format LoRAs
17+
kohya_manager: Optional[KohyaLoraManager]=None, # for KohyaLoraManager-compatible LoRAs
18+
):
1219
self.name = name
1320
self.weight = weight
1421
self.kohya_manager = kohya_manager
22+
self.unet = unet
1523

16-
def __call__(self, model):
24+
def __call__(self):
1725
# TODO: make model able to load from huggingface, rather then just local files
1826
path = Path(global_lora_models_dir(), self.name)
1927
if path.is_dir():
20-
if model.load_attn_procs:
28+
if not self.unet:
29+
print(f" ** Unable to load diffusers-format LoRA {self.name}: unet is None")
30+
return
31+
if self.unet.load_attn_procs:
2132
file = Path(path, "pytorch_lora_weights.bin")
2233
if file.is_file():
2334
print(f">> Loading LoRA: {path}")
24-
model.load_attn_procs(path.absolute().as_posix())
35+
self.unet.load_attn_procs(path.absolute().as_posix())
2536
else:
2637
print(f" ** Unable to find valid LoRA at: {path}")
2738
else:
@@ -37,15 +48,16 @@ def unload(self):
3748
self.kohya_manager.unload_applied_lora(self.name)
3849

3950
class LoraManager:
40-
def __init__(self, pipe):
51+
def __init__(self, pipe: StableDiffusionPipeline):
4152
# Kohya class handles lora not generated through diffusers
4253
self.kohya = KohyaLoraManager(pipe, global_lora_models_dir())
54+
self.unet = pipe.unet
4355

4456
def set_loras_conditions(self, lora_weights: list):
4557
conditions = []
4658
if len(lora_weights) > 0:
4759
for lora in lora_weights:
48-
conditions.append(LoraCondition(lora.model, lora.weight, self.kohya))
60+
conditions.append(LoraCondition(lora.model, lora.weight, self.unet, self.kohya))
4961

5062
if len(conditions) > 0:
5163
return conditions
@@ -63,4 +75,4 @@ def list_loras(self)->Dict[str, Path]:
6375
if suffix in [".ckpt", ".pt", ".safetensors"]:
6476
models_found[name]=Path(root,x)
6577
return models_found
66-
78+

0 commit comments

Comments
 (0)