|
1 | 1 | from contextlib import contextmanager |
2 | 2 | from dataclasses import dataclass |
3 | 3 | from math import ceil |
4 | | -from typing import Callable, Optional, Union, Any, Dict |
| 4 | +from typing import Callable, Optional, Union, Any |
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 | import torch |
8 | | -from diffusers.models.cross_attention import AttnProcessor |
| 8 | + |
| 9 | +from diffusers import UNet2DConditionModel |
9 | 10 | from typing_extensions import TypeAlias |
10 | 11 |
|
11 | 12 | from ldm.invoke.globals import Globals |
12 | 13 | from ldm.models.diffusion.cross_attention_control import ( |
13 | 14 | Arguments, |
14 | | - restore_default_cross_attention, |
15 | | - override_cross_attention, |
| 15 | + setup_cross_attention_control_attention_processors, |
16 | 16 | Context, |
17 | 17 | get_cross_attention_modules, |
18 | 18 | CrossAttentionType, |
@@ -84,66 +84,45 @@ def __init__( |
84 | 84 | self.cross_attention_control_context = None |
85 | 85 | self.sequential_guidance = Globals.sequential_guidance |
86 | 86 |
|
| 87 | + @classmethod |
87 | 88 | @contextmanager |
88 | 89 | def custom_attention_context( |
89 | | - self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int |
| 90 | + clss, |
| 91 | + unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs |
| 92 | + extra_conditioning_info: Optional[ExtraConditioningInfo], |
| 93 | + step_count: int |
90 | 94 | ): |
91 | | - old_attn_processor = None |
| 95 | + old_attn_processors = None |
92 | 96 | if extra_conditioning_info and ( |
93 | 97 | extra_conditioning_info.wants_cross_attention_control |
94 | 98 | | extra_conditioning_info.has_lora_conditions |
95 | 99 | ): |
96 | | - old_attn_processor = self.override_attention_processors( |
97 | | - extra_conditioning_info, step_count=step_count |
98 | | - ) |
| 100 | + old_attn_processors = unet.attn_processors |
| 101 | + # Load lora conditions into the model |
| 102 | + if extra_conditioning_info.has_lora_conditions: |
| 103 | + for condition in extra_conditioning_info.lora_conditions: |
| 104 | + condition() # target model is stored in condition state for some reason |
| 105 | + if extra_conditioning_info.wants_cross_attention_control: |
| 106 | + cross_attention_control_context = Context( |
| 107 | + arguments=extra_conditioning_info.cross_attention_control_args, |
| 108 | + step_count=step_count, |
| 109 | + ) |
| 110 | + setup_cross_attention_control_attention_processors( |
| 111 | + unet, |
| 112 | + cross_attention_control_context, |
| 113 | + ) |
99 | 114 |
|
100 | 115 | try: |
101 | 116 | yield None |
102 | 117 | finally: |
103 | | - if old_attn_processor is not None: |
104 | | - self.restore_default_cross_attention(old_attn_processor) |
| 118 | + if old_attn_processors is not None: |
| 119 | + unet.set_attn_processor(old_attn_processors) |
105 | 120 | if extra_conditioning_info and extra_conditioning_info.has_lora_conditions: |
106 | 121 | for lora_condition in extra_conditioning_info.lora_conditions: |
107 | 122 | lora_condition.unload() |
108 | 123 | # TODO resuscitate attention map saving |
109 | 124 | # self.remove_attention_map_saving() |
110 | 125 |
|
111 | | - def override_attention_processors( |
112 | | - self, conditioning: ExtraConditioningInfo, step_count: int |
113 | | - ) -> Dict[str, AttnProcessor]: |
114 | | - """ |
115 | | - setup cross attention .swap control. for diffusers this replaces the attention processor, so |
116 | | - the previous attention processor is returned so that the caller can restore it later. |
117 | | - """ |
118 | | - old_attn_processors = self.model.attn_processors |
119 | | - |
120 | | - # Load lora conditions into the model |
121 | | - if conditioning.has_lora_conditions: |
122 | | - for condition in conditioning.lora_conditions: |
123 | | - condition(self.model) |
124 | | - |
125 | | - if conditioning.wants_cross_attention_control: |
126 | | - self.cross_attention_control_context = Context( |
127 | | - arguments=conditioning.cross_attention_control_args, |
128 | | - step_count=step_count, |
129 | | - ) |
130 | | - override_cross_attention( |
131 | | - self.model, |
132 | | - self.cross_attention_control_context, |
133 | | - is_running_diffusers=self.is_running_diffusers, |
134 | | - ) |
135 | | - return old_attn_processors |
136 | | - |
137 | | - def restore_default_cross_attention( |
138 | | - self, processors_to_restore: Optional[dict[str, "AttnProcessor"]] = None |
139 | | - ): |
140 | | - self.cross_attention_control_context = None |
141 | | - restore_default_cross_attention( |
142 | | - self.model, |
143 | | - is_running_diffusers=self.is_running_diffusers, |
144 | | - processors_to_restore=processors_to_restore, |
145 | | - ) |
146 | | - |
147 | 126 | def setup_attention_map_saving(self, saver: AttentionMapSaver): |
148 | 127 | def callback(slice, dim, offset, slice_size, key): |
149 | 128 | if dim is not None: |
|
0 commit comments