Skip to content

Commit 3667eb4

Browse files
committed
activate LoRAs when generating prompt embeddings; also cleanup attention stuff
1 parent d81584c commit 3667eb4

File tree

8 files changed

+54
-129
lines changed

8 files changed

+54
-129
lines changed

invokeai/backend/invoke_ai_web_server.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
get_tokens_for_prompt_object,
3131
get_prompt_structure,
3232
split_weighted_subprompts,
33-
get_tokenizer,
3433
)
3534
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
3635
from ldm.invoke.generator.inpaint import infill_methods
@@ -1314,7 +1313,7 @@ def image_done(image, seed, first_seed, attention_maps_image=None):
13141313
None
13151314
if type(parsed_prompt) is Blend
13161315
else get_tokens_for_prompt_object(
1317-
get_tokenizer(self.generate.model), parsed_prompt
1316+
self.generate.model.tokenizer, parsed_prompt
13181317
)
13191318
)
13201319
attention_maps_image_base64_url = (

ldm/invoke/conditioning.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,6 @@
1818
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
1919
from ldm.invoke.globals import Globals
2020

21-
def get_tokenizer(model) -> CLIPTokenizer:
22-
# TODO remove legacy ckpt fallback handling
23-
return (getattr(model, 'tokenizer', None) # diffusers
24-
or model.cond_stage_model.tokenizer) # ldm
25-
26-
def get_text_encoder(model) -> Any:
27-
# TODO remove legacy ckpt fallback handling
28-
return (getattr(model, 'text_encoder', None) # diffusers
29-
or UnsqueezingLDMTransformer(model.cond_stage_model.transformer)) # ldm
30-
3121
class UnsqueezingLDMTransformer:
3222
def __init__(self, ldm_transformer):
3323
self.ldm_transformer = ldm_transformer
@@ -41,15 +31,15 @@ def __call__(self, *args, **kwargs):
4131
return insufficiently_unsqueezed_tensor.unsqueeze(0)
4232

4333

44-
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
34+
def get_uc_and_c_and_ec(prompt_string,
35+
model: StableDiffusionPipeline,
36+
log_tokens=False, skip_normalize_legacy_blend=False):
4537
# lazy-load any deferred textual inversions.
4638
# this might take a couple of seconds the first time a textual inversion is used.
4739
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
4840

49-
tokenizer = get_tokenizer(model)
50-
text_encoder = get_text_encoder(model)
51-
compel = Compel(tokenizer=tokenizer,
52-
text_encoder=text_encoder,
41+
compel = Compel(tokenizer=model.tokenizer,
42+
text_encoder=model.text_encoder,
5343
textual_inversion_manager=model.textual_inversion_manager,
5444
dtype_for_device_getter=torch_dtype)
5545

@@ -78,14 +68,20 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
7868
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
7969
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
8070

71+
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
8172
if log_tokens or getattr(Globals, "log_tokenization", False):
82-
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
83-
84-
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
85-
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
86-
87-
tokens_count = get_max_token_count(tokenizer, positive_prompt)
88-
73+
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
74+
75+
# some LoRA models also mess with the text encoder, so they must be active while compel builds conditioning tensors
76+
lora_conditioning_ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
77+
lora_conditions=lora_conditions)
78+
with InvokeAIDiffuserComponent.custom_attention_context(model,
79+
extra_conditioning_info=lora_conditioning_ec,
80+
step_count=-1):
81+
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
82+
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
83+
84+
# now build the "real" ec
8985
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
9086
cross_attention_control_args=options.get(
9187
'cross_attention_control', None),

ldm/invoke/generator/diffusers_pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,8 +467,9 @@ def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps,
467467
if additional_guidance is None:
468468
additional_guidance = []
469469
extra_conditioning_info = conditioning_data.extra
470-
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
471-
step_count=len(self.scheduler.timesteps)
470+
with InvokeAIDiffuserComponent.custom_attention_context(self.invokeai_diffuser.model,
471+
extra_conditioning_info=extra_conditioning_info,
472+
step_count=len(self.scheduler.timesteps)
472473
):
473474

474475
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,

ldm/models/diffusion/cross_attention_control.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -288,16 +288,7 @@ def get_invokeai_attention_mem_efficient(self, q, k, v):
288288
return self.einsum_op_tensor_mem(q, k, v, 32)
289289

290290

291-
292-
def restore_default_cross_attention(model, is_running_diffusers: bool, processors_to_restore: Optional[AttnProcessor]=None):
293-
if is_running_diffusers:
294-
unet = model
295-
unet.set_attn_processor(processors_to_restore or CrossAttnProcessor())
296-
else:
297-
remove_attention_function(model)
298-
299-
300-
def override_cross_attention(model, context: Context, is_running_diffusers = False):
291+
def setup_cross_attention_control_attention_processors(model, context: Context):
301292
"""
302293
Inject attention parameters and functions into the passed in model to enable cross attention editing.
303294
@@ -323,22 +314,16 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal
323314

324315
context.cross_attention_mask = mask.to(device)
325316
context.cross_attention_index_map = indices.to(device)
326-
if is_running_diffusers:
327-
unet = model
328-
old_attn_processors = unet.attn_processors
329-
if torch.backends.mps.is_available():
330-
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
331-
unet.set_attn_processor(SwapCrossAttnProcessor())
332-
else:
333-
# try to re-use an existing slice size
334-
default_slice_size = 4
335-
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
336-
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
317+
unet = model
318+
old_attn_processors = unet.attn_processors
319+
if torch.backends.mps.is_available():
320+
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
321+
unet.set_attn_processor(SwapCrossAttnProcessor())
337322
else:
338-
context.register_cross_attention_modules(model)
339-
inject_attention_function(model, context)
340-
341-
323+
# try to re-use an existing slice size
324+
default_slice_size = 4
325+
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
326+
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
342327

343328

344329
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:

ldm/models/diffusion/ddim.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,6 @@ def __init__(self, model, schedule='linear', device=None, **kwargs):
1212
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
1313
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
1414

15-
def prepare_to_sample(self, t_enc, **kwargs):
16-
super().prepare_to_sample(t_enc, **kwargs)
17-
18-
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
19-
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
20-
21-
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
22-
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
23-
else:
24-
self.invokeai_diffuser.restore_default_cross_attention()
25-
2615

2716
# This is the central routine
2817
@torch.no_grad()

ldm/models/diffusion/ksampler.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,6 @@ def __init__(self, model, threshold = 0, warmup = 0):
3838
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
3939

4040

41-
def prepare_to_sample(self, t_enc, **kwargs):
42-
43-
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
44-
45-
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
46-
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = t_enc)
47-
else:
48-
self.invokeai_diffuser.restore_default_cross_attention()
49-
5041

5142
def forward(self, x, sigma, uncond, cond, cond_scale):
5243
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)

ldm/models/diffusion/plms.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,6 @@ class PLMSSampler(Sampler):
1414
def __init__(self, model, schedule='linear', device=None, **kwargs):
1515
super().__init__(model,schedule,model.num_timesteps, device)
1616

17-
def prepare_to_sample(self, t_enc, **kwargs):
18-
super().prepare_to_sample(t_enc, **kwargs)
19-
20-
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
21-
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
22-
23-
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
24-
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
25-
else:
26-
self.invokeai_diffuser.restore_default_cross_attention()
27-
2817

2918
# this is the essential routine
3019
@torch.no_grad()

ldm/models/diffusion/shared_invokeai_diffusion.py

Lines changed: 22 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
from contextlib import contextmanager
22
from dataclasses import dataclass
33
from math import ceil
4-
from typing import Callable, Optional, Union, Any, Dict
4+
from typing import Callable, Optional, Union, Any
55

66
import numpy as np
77
import torch
8-
from diffusers.models.cross_attention import AttnProcessor
8+
from diffusers.models.cross_attention import CrossAttnProcessor
99
from typing_extensions import TypeAlias
1010

1111
from ldm.invoke.globals import Globals
1212
from ldm.models.diffusion.cross_attention_control import (
1313
Arguments,
14-
restore_default_cross_attention,
15-
override_cross_attention,
14+
setup_cross_attention_control_attention_processors,
1615
Context,
1716
get_cross_attention_modules,
1817
CrossAttentionType,
@@ -84,66 +83,42 @@ def __init__(
8483
self.cross_attention_control_context = None
8584
self.sequential_guidance = Globals.sequential_guidance
8685

86+
@classmethod
8787
@contextmanager
8888
def custom_attention_context(
89-
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
89+
clss, model, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
9090
):
91-
old_attn_processor = None
91+
old_attn_processors = None
9292
if extra_conditioning_info and (
9393
extra_conditioning_info.wants_cross_attention_control
9494
| extra_conditioning_info.has_lora_conditions
9595
):
96-
old_attn_processor = self.override_attention_processors(
97-
extra_conditioning_info, step_count=step_count
98-
)
96+
old_attn_processors = model.attn_processors
97+
# Load lora conditions into the model
98+
if extra_conditioning_info.has_lora_conditions:
99+
for condition in extra_conditioning_info.lora_conditions:
100+
condition(model)
101+
if extra_conditioning_info.wants_cross_attention_control:
102+
cross_attention_control_context = Context(
103+
arguments=extra_conditioning_info.cross_attention_control_args,
104+
step_count=step_count,
105+
)
106+
setup_cross_attention_control_attention_processors(
107+
model,
108+
cross_attention_control_context,
109+
)
99110

100111
try:
101112
yield None
102113
finally:
103-
if old_attn_processor is not None:
104-
self.restore_default_cross_attention(old_attn_processor)
114+
if old_attn_processors is not None:
115+
model.set_attn_processor(old_attn_processors)
105116
if extra_conditioning_info and extra_conditioning_info.has_lora_conditions:
106117
for lora_condition in extra_conditioning_info.lora_conditions:
107118
lora_condition.unload()
108119
# TODO resuscitate attention map saving
109120
# self.remove_attention_map_saving()
110121

111-
def override_attention_processors(
112-
self, conditioning: ExtraConditioningInfo, step_count: int
113-
) -> Dict[str, AttnProcessor]:
114-
"""
115-
setup cross attention .swap control. for diffusers this replaces the attention processor, so
116-
the previous attention processor is returned so that the caller can restore it later.
117-
"""
118-
old_attn_processors = self.model.attn_processors
119-
120-
# Load lora conditions into the model
121-
if conditioning.has_lora_conditions:
122-
for condition in conditioning.lora_conditions:
123-
condition(self.model)
124-
125-
if conditioning.wants_cross_attention_control:
126-
self.cross_attention_control_context = Context(
127-
arguments=conditioning.cross_attention_control_args,
128-
step_count=step_count,
129-
)
130-
override_cross_attention(
131-
self.model,
132-
self.cross_attention_control_context,
133-
is_running_diffusers=self.is_running_diffusers,
134-
)
135-
return old_attn_processors
136-
137-
def restore_default_cross_attention(
138-
self, processors_to_restore: Optional[dict[str, "AttnProcessor"]] = None
139-
):
140-
self.cross_attention_control_context = None
141-
restore_default_cross_attention(
142-
self.model,
143-
is_running_diffusers=self.is_running_diffusers,
144-
processors_to_restore=processors_to_restore,
145-
)
146-
147122
def setup_attention_map_saving(self, saver: AttentionMapSaver):
148123
def callback(slice, dim, offset, slice_size, key):
149124
if dim is not None:

0 commit comments

Comments
 (0)