Skip to content

Commit 96c39b6

Browse files
authored
Enable LoRAs to patch the text_encoder as well as the unet (#3214)
Load LoRAs during compel's text embedding encode pass in case there are requested LoRAs which also want to patch the text encoder. Also generally cleanup the attention processor patching stuff. It's still a mess, but at least now it's a *stateless* mess.
2 parents a9e8005 + 40744ed commit 96c39b6

File tree

9 files changed

+77
-136
lines changed

9 files changed

+77
-136
lines changed

invokeai/backend/invoke_ai_web_server.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
get_tokens_for_prompt_object,
3131
get_prompt_structure,
3232
split_weighted_subprompts,
33-
get_tokenizer,
3433
)
3534
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
3635
from ldm.invoke.generator.inpaint import infill_methods
@@ -1314,7 +1313,7 @@ def image_done(image, seed, first_seed, attention_maps_image=None):
13141313
None
13151314
if type(parsed_prompt) is Blend
13161315
else get_tokens_for_prompt_object(
1317-
get_tokenizer(self.generate.model), parsed_prompt
1316+
self.generate.model.tokenizer, parsed_prompt
13181317
)
13191318
)
13201319
attention_maps_image_base64_url = (

ldm/invoke/conditioning.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,10 @@
1515
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser, \
1616
Conjunction
1717
from .devices import torch_dtype
18+
from .generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
1819
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
1920
from ldm.invoke.globals import Globals
2021

21-
def get_tokenizer(model) -> CLIPTokenizer:
22-
# TODO remove legacy ckpt fallback handling
23-
return (getattr(model, 'tokenizer', None) # diffusers
24-
or model.cond_stage_model.tokenizer) # ldm
25-
26-
def get_text_encoder(model) -> Any:
27-
# TODO remove legacy ckpt fallback handling
28-
return (getattr(model, 'text_encoder', None) # diffusers
29-
or UnsqueezingLDMTransformer(model.cond_stage_model.transformer)) # ldm
30-
3122
class UnsqueezingLDMTransformer:
3223
def __init__(self, ldm_transformer):
3324
self.ldm_transformer = ldm_transformer
@@ -41,15 +32,15 @@ def __call__(self, *args, **kwargs):
4132
return insufficiently_unsqueezed_tensor.unsqueeze(0)
4233

4334

44-
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
35+
def get_uc_and_c_and_ec(prompt_string,
36+
model: StableDiffusionGeneratorPipeline,
37+
log_tokens=False, skip_normalize_legacy_blend=False):
4538
# lazy-load any deferred textual inversions.
4639
# this might take a couple of seconds the first time a textual inversion is used.
4740
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
4841

49-
tokenizer = get_tokenizer(model)
50-
text_encoder = get_text_encoder(model)
51-
compel = Compel(tokenizer=tokenizer,
52-
text_encoder=text_encoder,
42+
compel = Compel(tokenizer=model.tokenizer,
43+
text_encoder=model.text_encoder,
5344
textual_inversion_manager=model.textual_inversion_manager,
5445
dtype_for_device_getter=torch_dtype)
5546

@@ -78,14 +69,20 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
7869
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
7970
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
8071

72+
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
8173
if log_tokens or getattr(Globals, "log_tokenization", False):
82-
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
83-
84-
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
85-
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
86-
87-
tokens_count = get_max_token_count(tokenizer, positive_prompt)
88-
74+
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
75+
76+
# some LoRA models also mess with the text encoder, so they must be active while compel builds conditioning tensors
77+
lora_conditioning_ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
78+
lora_conditions=lora_conditions)
79+
with InvokeAIDiffuserComponent.custom_attention_context(model.unet,
80+
extra_conditioning_info=lora_conditioning_ec,
81+
step_count=-1):
82+
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
83+
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
84+
85+
# now build the "real" ec
8986
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
9087
cross_attention_control_args=options.get(
9188
'cross_attention_control', None),

ldm/invoke/generator/diffusers_pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,8 +467,9 @@ def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps,
467467
if additional_guidance is None:
468468
additional_guidance = []
469469
extra_conditioning_info = conditioning_data.extra
470-
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
471-
step_count=len(self.scheduler.timesteps)
470+
with InvokeAIDiffuserComponent.custom_attention_context(self.invokeai_diffuser.model,
471+
extra_conditioning_info=extra_conditioning_info,
472+
step_count=len(self.scheduler.timesteps)
472473
):
473474

474475
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,

ldm/models/diffusion/cross_attention_control.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -288,16 +288,7 @@ def get_invokeai_attention_mem_efficient(self, q, k, v):
288288
return self.einsum_op_tensor_mem(q, k, v, 32)
289289

290290

291-
292-
def restore_default_cross_attention(model, is_running_diffusers: bool, processors_to_restore: Optional[AttnProcessor]=None):
293-
if is_running_diffusers:
294-
unet = model
295-
unet.set_attn_processor(processors_to_restore or CrossAttnProcessor())
296-
else:
297-
remove_attention_function(model)
298-
299-
300-
def override_cross_attention(model, context: Context, is_running_diffusers = False):
291+
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
301292
"""
302293
Inject attention parameters and functions into the passed in model to enable cross attention editing.
303294
@@ -323,22 +314,15 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal
323314

324315
context.cross_attention_mask = mask.to(device)
325316
context.cross_attention_index_map = indices.to(device)
326-
if is_running_diffusers:
327-
unet = model
328-
old_attn_processors = unet.attn_processors
329-
if torch.backends.mps.is_available():
330-
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
331-
unet.set_attn_processor(SwapCrossAttnProcessor())
332-
else:
333-
# try to re-use an existing slice size
334-
default_slice_size = 4
335-
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
336-
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
317+
old_attn_processors = unet.attn_processors
318+
if torch.backends.mps.is_available():
319+
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
320+
unet.set_attn_processor(SwapCrossAttnProcessor())
337321
else:
338-
context.register_cross_attention_modules(model)
339-
inject_attention_function(model, context)
340-
341-
322+
# try to re-use an existing slice size
323+
default_slice_size = 4
324+
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
325+
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
342326

343327

344328
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:

ldm/models/diffusion/ddim.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,6 @@ def __init__(self, model, schedule='linear', device=None, **kwargs):
1212
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
1313
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
1414

15-
def prepare_to_sample(self, t_enc, **kwargs):
16-
super().prepare_to_sample(t_enc, **kwargs)
17-
18-
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
19-
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
20-
21-
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
22-
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
23-
else:
24-
self.invokeai_diffuser.restore_default_cross_attention()
25-
2615

2716
# This is the central routine
2817
@torch.no_grad()

ldm/models/diffusion/ksampler.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,6 @@ def __init__(self, model, threshold = 0, warmup = 0):
3838
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
3939

4040

41-
def prepare_to_sample(self, t_enc, **kwargs):
42-
43-
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
44-
45-
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
46-
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = t_enc)
47-
else:
48-
self.invokeai_diffuser.restore_default_cross_attention()
49-
5041

5142
def forward(self, x, sigma, uncond, cond, cond_scale):
5243
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)

ldm/models/diffusion/plms.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,6 @@ class PLMSSampler(Sampler):
1414
def __init__(self, model, schedule='linear', device=None, **kwargs):
1515
super().__init__(model,schedule,model.num_timesteps, device)
1616

17-
def prepare_to_sample(self, t_enc, **kwargs):
18-
super().prepare_to_sample(t_enc, **kwargs)
19-
20-
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
21-
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
22-
23-
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
24-
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
25-
else:
26-
self.invokeai_diffuser.restore_default_cross_attention()
27-
2817

2918
# this is the essential routine
3019
@torch.no_grad()

ldm/models/diffusion/shared_invokeai_diffusion.py

Lines changed: 26 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
from contextlib import contextmanager
22
from dataclasses import dataclass
33
from math import ceil
4-
from typing import Callable, Optional, Union, Any, Dict
4+
from typing import Callable, Optional, Union, Any
55

66
import numpy as np
77
import torch
8-
from diffusers.models.cross_attention import AttnProcessor
8+
9+
from diffusers import UNet2DConditionModel
910
from typing_extensions import TypeAlias
1011

1112
from ldm.invoke.globals import Globals
1213
from ldm.models.diffusion.cross_attention_control import (
1314
Arguments,
14-
restore_default_cross_attention,
15-
override_cross_attention,
15+
setup_cross_attention_control_attention_processors,
1616
Context,
1717
get_cross_attention_modules,
1818
CrossAttentionType,
@@ -84,66 +84,45 @@ def __init__(
8484
self.cross_attention_control_context = None
8585
self.sequential_guidance = Globals.sequential_guidance
8686

87+
@classmethod
8788
@contextmanager
8889
def custom_attention_context(
89-
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
90+
clss,
91+
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
92+
extra_conditioning_info: Optional[ExtraConditioningInfo],
93+
step_count: int
9094
):
91-
old_attn_processor = None
95+
old_attn_processors = None
9296
if extra_conditioning_info and (
9397
extra_conditioning_info.wants_cross_attention_control
9498
| extra_conditioning_info.has_lora_conditions
9599
):
96-
old_attn_processor = self.override_attention_processors(
97-
extra_conditioning_info, step_count=step_count
98-
)
100+
old_attn_processors = unet.attn_processors
101+
# Load lora conditions into the model
102+
if extra_conditioning_info.has_lora_conditions:
103+
for condition in extra_conditioning_info.lora_conditions:
104+
condition() # target model is stored in condition state for some reason
105+
if extra_conditioning_info.wants_cross_attention_control:
106+
cross_attention_control_context = Context(
107+
arguments=extra_conditioning_info.cross_attention_control_args,
108+
step_count=step_count,
109+
)
110+
setup_cross_attention_control_attention_processors(
111+
unet,
112+
cross_attention_control_context,
113+
)
99114

100115
try:
101116
yield None
102117
finally:
103-
if old_attn_processor is not None:
104-
self.restore_default_cross_attention(old_attn_processor)
118+
if old_attn_processors is not None:
119+
unet.set_attn_processor(old_attn_processors)
105120
if extra_conditioning_info and extra_conditioning_info.has_lora_conditions:
106121
for lora_condition in extra_conditioning_info.lora_conditions:
107122
lora_condition.unload()
108123
# TODO resuscitate attention map saving
109124
# self.remove_attention_map_saving()
110125

111-
def override_attention_processors(
112-
self, conditioning: ExtraConditioningInfo, step_count: int
113-
) -> Dict[str, AttnProcessor]:
114-
"""
115-
setup cross attention .swap control. for diffusers this replaces the attention processor, so
116-
the previous attention processor is returned so that the caller can restore it later.
117-
"""
118-
old_attn_processors = self.model.attn_processors
119-
120-
# Load lora conditions into the model
121-
if conditioning.has_lora_conditions:
122-
for condition in conditioning.lora_conditions:
123-
condition(self.model)
124-
125-
if conditioning.wants_cross_attention_control:
126-
self.cross_attention_control_context = Context(
127-
arguments=conditioning.cross_attention_control_args,
128-
step_count=step_count,
129-
)
130-
override_cross_attention(
131-
self.model,
132-
self.cross_attention_control_context,
133-
is_running_diffusers=self.is_running_diffusers,
134-
)
135-
return old_attn_processors
136-
137-
def restore_default_cross_attention(
138-
self, processors_to_restore: Optional[dict[str, "AttnProcessor"]] = None
139-
):
140-
self.cross_attention_control_context = None
141-
restore_default_cross_attention(
142-
self.model,
143-
is_running_diffusers=self.is_running_diffusers,
144-
processors_to_restore=processors_to_restore,
145-
)
146-
147126
def setup_attention_map_saving(self, saver: AttentionMapSaver):
148127
def callback(slice, dim, offset, slice_size, key):
149128
if dim is not None:

ldm/modules/lora_manager.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22
from pathlib import Path
3+
4+
from diffusers import UNet2DConditionModel, StableDiffusionPipeline
35
from ldm.invoke.globals import global_lora_models_dir
46
from .kohya_lora_manager import KohyaLoraManager
57
from typing import Optional, Dict
@@ -8,20 +10,29 @@ class LoraCondition:
810
name: str
911
weight: float
1012

11-
def __init__(self, name, weight: float = 1.0, kohya_manager: Optional[KohyaLoraManager]=None):
13+
def __init__(self,
14+
name,
15+
weight: float = 1.0,
16+
unet: UNet2DConditionModel=None, # for diffusers format LoRAs
17+
kohya_manager: Optional[KohyaLoraManager]=None, # for KohyaLoraManager-compatible LoRAs
18+
):
1219
self.name = name
1320
self.weight = weight
1421
self.kohya_manager = kohya_manager
22+
self.unet = unet
1523

16-
def __call__(self, model):
24+
def __call__(self):
1725
# TODO: make model able to load from huggingface, rather then just local files
1826
path = Path(global_lora_models_dir(), self.name)
1927
if path.is_dir():
20-
if model.load_attn_procs:
28+
if not self.unet:
29+
print(f" ** Unable to load diffusers-format LoRA {self.name}: unet is None")
30+
return
31+
if self.unet.load_attn_procs:
2132
file = Path(path, "pytorch_lora_weights.bin")
2233
if file.is_file():
2334
print(f">> Loading LoRA: {path}")
24-
model.load_attn_procs(path.absolute().as_posix())
35+
self.unet.load_attn_procs(path.absolute().as_posix())
2536
else:
2637
print(f" ** Unable to find valid LoRA at: {path}")
2738
else:
@@ -37,15 +48,16 @@ def unload(self):
3748
self.kohya_manager.unload_applied_lora(self.name)
3849

3950
class LoraManager:
40-
def __init__(self, pipe):
51+
def __init__(self, pipe: StableDiffusionPipeline):
4152
# Kohya class handles lora not generated through diffusers
4253
self.kohya = KohyaLoraManager(pipe, global_lora_models_dir())
54+
self.unet = pipe.unet
4355

4456
def set_loras_conditions(self, lora_weights: list):
4557
conditions = []
4658
if len(lora_weights) > 0:
4759
for lora in lora_weights:
48-
conditions.append(LoraCondition(lora.model, lora.weight, self.kohya))
60+
conditions.append(LoraCondition(lora.model, lora.weight, self.unet, self.kohya))
4961

5062
if len(conditions) > 0:
5163
return conditions
@@ -63,4 +75,4 @@ def list_loras(self)->Dict[str, Path]:
6375
if suffix in [".ckpt", ".pt", ".safetensors"]:
6476
models_found[name]=Path(root,x)
6577
return models_found
66-
78+

0 commit comments

Comments
 (0)