|
1 | 1 | from contextlib import contextmanager |
2 | 2 | from dataclasses import dataclass |
3 | 3 | from math import ceil |
4 | | -from typing import Callable, Optional, Union, Any, Dict |
| 4 | +from typing import Callable, Optional, Union, Any |
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 | import torch |
8 | | -from diffusers.models.cross_attention import AttnProcessor |
| 8 | +from diffusers.models.cross_attention import CrossAttnProcessor |
9 | 9 | from typing_extensions import TypeAlias |
10 | 10 |
|
11 | 11 | from ldm.invoke.globals import Globals |
12 | 12 | from ldm.models.diffusion.cross_attention_control import ( |
13 | 13 | Arguments, |
14 | | - restore_default_cross_attention, |
15 | | - override_cross_attention, |
| 14 | + setup_cross_attention_control_attention_processors, |
16 | 15 | Context, |
17 | 16 | get_cross_attention_modules, |
18 | 17 | CrossAttentionType, |
@@ -84,66 +83,42 @@ def __init__( |
84 | 83 | self.cross_attention_control_context = None |
85 | 84 | self.sequential_guidance = Globals.sequential_guidance |
86 | 85 |
|
| 86 | + @classmethod |
87 | 87 | @contextmanager |
88 | 88 | def custom_attention_context( |
89 | | - self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int |
| 89 | + clss, model, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int |
90 | 90 | ): |
91 | | - old_attn_processor = None |
| 91 | + old_attn_processors = None |
92 | 92 | if extra_conditioning_info and ( |
93 | 93 | extra_conditioning_info.wants_cross_attention_control |
94 | 94 | | extra_conditioning_info.has_lora_conditions |
95 | 95 | ): |
96 | | - old_attn_processor = self.override_attention_processors( |
97 | | - extra_conditioning_info, step_count=step_count |
98 | | - ) |
| 96 | + old_attn_processors = model.attn_processors |
| 97 | + # Load lora conditions into the model |
| 98 | + if extra_conditioning_info.has_lora_conditions: |
| 99 | + for condition in extra_conditioning_info.lora_conditions: |
| 100 | + condition(model) |
| 101 | + if extra_conditioning_info.wants_cross_attention_control: |
| 102 | + cross_attention_control_context = Context( |
| 103 | + arguments=extra_conditioning_info.cross_attention_control_args, |
| 104 | + step_count=step_count, |
| 105 | + ) |
| 106 | + setup_cross_attention_control_attention_processors( |
| 107 | + model, |
| 108 | + cross_attention_control_context, |
| 109 | + ) |
99 | 110 |
|
100 | 111 | try: |
101 | 112 | yield None |
102 | 113 | finally: |
103 | | - if old_attn_processor is not None: |
104 | | - self.restore_default_cross_attention(old_attn_processor) |
| 114 | + if old_attn_processors is not None: |
| 115 | + model.set_attn_processor(old_attn_processors) |
105 | 116 | if extra_conditioning_info and extra_conditioning_info.has_lora_conditions: |
106 | 117 | for lora_condition in extra_conditioning_info.lora_conditions: |
107 | 118 | lora_condition.unload() |
108 | 119 | # TODO resuscitate attention map saving |
109 | 120 | # self.remove_attention_map_saving() |
110 | 121 |
|
111 | | - def override_attention_processors( |
112 | | - self, conditioning: ExtraConditioningInfo, step_count: int |
113 | | - ) -> Dict[str, AttnProcessor]: |
114 | | - """ |
115 | | - setup cross attention .swap control. for diffusers this replaces the attention processor, so |
116 | | - the previous attention processor is returned so that the caller can restore it later. |
117 | | - """ |
118 | | - old_attn_processors = self.model.attn_processors |
119 | | - |
120 | | - # Load lora conditions into the model |
121 | | - if conditioning.has_lora_conditions: |
122 | | - for condition in conditioning.lora_conditions: |
123 | | - condition(self.model) |
124 | | - |
125 | | - if conditioning.wants_cross_attention_control: |
126 | | - self.cross_attention_control_context = Context( |
127 | | - arguments=conditioning.cross_attention_control_args, |
128 | | - step_count=step_count, |
129 | | - ) |
130 | | - override_cross_attention( |
131 | | - self.model, |
132 | | - self.cross_attention_control_context, |
133 | | - is_running_diffusers=self.is_running_diffusers, |
134 | | - ) |
135 | | - return old_attn_processors |
136 | | - |
137 | | - def restore_default_cross_attention( |
138 | | - self, processors_to_restore: Optional[dict[str, "AttnProcessor"]] = None |
139 | | - ): |
140 | | - self.cross_attention_control_context = None |
141 | | - restore_default_cross_attention( |
142 | | - self.model, |
143 | | - is_running_diffusers=self.is_running_diffusers, |
144 | | - processors_to_restore=processors_to_restore, |
145 | | - ) |
146 | | - |
147 | 122 | def setup_attention_map_saving(self, saver: AttentionMapSaver): |
148 | 123 | def callback(slice, dim, offset, slice_size, key): |
149 | 124 | if dim is not None: |
|
0 commit comments