|
25 | 25 | from huggingface_hub import hf_hub_download |
26 | 26 | from torch import nn |
27 | 27 |
|
28 | | -from .models.attention_processor import ( |
29 | | - LORA_ATTENTION_PROCESSORS, |
30 | | - AttnAddedKVProcessor, |
31 | | - AttnAddedKVProcessor2_0, |
32 | | - AttnProcessor, |
33 | | - AttnProcessor2_0, |
34 | | - CustomDiffusionAttnProcessor, |
35 | | - CustomDiffusionXFormersAttnProcessor, |
36 | | - LoRAAttnAddedKVProcessor, |
37 | | - LoRAAttnProcessor, |
38 | | - LoRAAttnProcessor2_0, |
39 | | - LoRALinearLayer, |
40 | | - LoRAXFormersAttnProcessor, |
41 | | - SlicedAttnAddedKVProcessor, |
42 | | - XFormersAttnProcessor, |
43 | | -) |
44 | 28 | from .utils import ( |
45 | 29 | DIFFUSERS_CACHE, |
46 | 30 | HF_HUB_OFFLINE, |
|
83 | 67 | class PatchedLoraProjection(nn.Module): |
84 | 68 | def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None): |
85 | 69 | super().__init__() |
| 70 | + from .models.attention_processor import LoRALinearLayer |
| 71 | + |
86 | 72 | self.regular_linear_layer = regular_linear_layer |
87 | 73 |
|
88 | 74 | device = self.regular_linear_layer.weight.device |
@@ -231,6 +217,17 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict |
231 | 217 | information. |
232 | 218 |
|
233 | 219 | """ |
| 220 | + from .models.attention_processor import ( |
| 221 | + AttnAddedKVProcessor, |
| 222 | + AttnAddedKVProcessor2_0, |
| 223 | + CustomDiffusionAttnProcessor, |
| 224 | + LoRAAttnAddedKVProcessor, |
| 225 | + LoRAAttnProcessor, |
| 226 | + LoRAAttnProcessor2_0, |
| 227 | + LoRAXFormersAttnProcessor, |
| 228 | + SlicedAttnAddedKVProcessor, |
| 229 | + XFormersAttnProcessor, |
| 230 | + ) |
234 | 231 |
|
235 | 232 | cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) |
236 | 233 | force_download = kwargs.pop("force_download", False) |
@@ -423,6 +420,11 @@ def save_attn_procs( |
423 | 420 | `DIFFUSERS_SAVE_MODE`. |
424 | 421 |
|
425 | 422 | """ |
| 423 | + from .models.attention_processor import ( |
| 424 | + CustomDiffusionAttnProcessor, |
| 425 | + CustomDiffusionXFormersAttnProcessor, |
| 426 | + ) |
| 427 | + |
426 | 428 | weight_name = weight_name or deprecate( |
427 | 429 | "weights_name", |
428 | 430 | "0.20.0", |
@@ -1317,6 +1319,17 @@ def unload_lora_weights(self): |
1317 | 1319 | >>> ... |
1318 | 1320 | ``` |
1319 | 1321 | """ |
| 1322 | + from .models.attention_processor import ( |
| 1323 | + LORA_ATTENTION_PROCESSORS, |
| 1324 | + AttnProcessor, |
| 1325 | + AttnProcessor2_0, |
| 1326 | + LoRAAttnAddedKVProcessor, |
| 1327 | + LoRAAttnProcessor, |
| 1328 | + LoRAAttnProcessor2_0, |
| 1329 | + LoRAXFormersAttnProcessor, |
| 1330 | + XFormersAttnProcessor, |
| 1331 | + ) |
| 1332 | + |
1320 | 1333 | unet_attention_classes = {type(processor) for _, processor in self.unet.attn_processors.items()} |
1321 | 1334 |
|
1322 | 1335 | if unet_attention_classes.issubset(LORA_ATTENTION_PROCESSORS): |
|
0 commit comments