|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import math |
15 | | -from typing import Dict, Optional, Tuple, Union |
| 15 | +from typing import Optional, Tuple, Union |
16 | 16 |
|
17 | 17 | import torch |
18 | 18 | import torch.nn as nn |
|
22 | 22 | from ...loaders.single_file_model import FromOriginalModelMixin |
23 | 23 | from ...utils import deprecate |
24 | 24 | from ...utils.accelerate_utils import apply_forward_hook |
| 25 | +from ..attention import AttentionMixin |
25 | 26 | from ..attention_processor import ( |
26 | 27 | ADDED_KV_ATTENTION_PROCESSORS, |
27 | 28 | CROSS_ATTENTION_PROCESSORS, |
28 | 29 | Attention, |
29 | | - AttentionProcessor, |
30 | 30 | AttnAddedKVProcessor, |
31 | 31 | AttnProcessor, |
32 | 32 | FusedAttnProcessor2_0, |
|
36 | 36 | from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder |
37 | 37 |
|
38 | 38 |
|
39 | | -class AutoencoderKLFlux2(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): |
| 39 | +class AutoencoderKLFlux2( |
| 40 | + ModelMixin, AutoencoderMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin |
| 41 | +): |
40 | 42 | r""" |
41 | 43 | A VAE model with KL loss for encoding images into latents and decoding latent representations into images. |
42 | 44 |
|
@@ -154,66 +156,6 @@ def __init__( |
154 | 156 | self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) |
155 | 157 | self.tile_overlap_factor = 0.25 |
156 | 158 |
|
157 | | - @property |
158 | | - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors |
159 | | - def attn_processors(self) -> Dict[str, AttentionProcessor]: |
160 | | - r""" |
161 | | - Returns: |
162 | | - `dict` of attention processors: A dictionary containing all attention processors used in the model with |
163 | | - indexed by its weight name. |
164 | | - """ |
165 | | - # set recursively |
166 | | - processors = {} |
167 | | - |
168 | | - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
169 | | - if hasattr(module, "get_processor"): |
170 | | - processors[f"{name}.processor"] = module.get_processor() |
171 | | - |
172 | | - for sub_name, child in module.named_children(): |
173 | | - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
174 | | - |
175 | | - return processors |
176 | | - |
177 | | - for name, module in self.named_children(): |
178 | | - fn_recursive_add_processors(name, module, processors) |
179 | | - |
180 | | - return processors |
181 | | - |
182 | | - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor |
183 | | - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
184 | | - r""" |
185 | | - Sets the attention processor to use to compute attention. |
186 | | -
|
187 | | - Parameters: |
188 | | - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): |
189 | | - The instantiated processor class or a dictionary of processor classes that will be set as the processor |
190 | | - for **all** `Attention` layers. |
191 | | -
|
192 | | - If `processor` is a dict, the key needs to define the path to the corresponding cross attention |
193 | | - processor. This is strongly recommended when setting trainable attention processors. |
194 | | -
|
195 | | - """ |
196 | | - count = len(self.attn_processors.keys()) |
197 | | - |
198 | | - if isinstance(processor, dict) and len(processor) != count: |
199 | | - raise ValueError( |
200 | | - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
201 | | - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
202 | | - ) |
203 | | - |
204 | | - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
205 | | - if hasattr(module, "set_processor"): |
206 | | - if not isinstance(processor, dict): |
207 | | - module.set_processor(processor) |
208 | | - else: |
209 | | - module.set_processor(processor.pop(f"{name}.processor")) |
210 | | - |
211 | | - for sub_name, child in module.named_children(): |
212 | | - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
213 | | - |
214 | | - for name, module in self.named_children(): |
215 | | - fn_recursive_attn_processor(name, module, processor) |
216 | | - |
217 | 159 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor |
218 | 160 | def set_default_attn_processor(self): |
219 | 161 | """ |
|
0 commit comments