|
67 | 67 | convert_state_dict_to_diffusers, |
68 | 68 | convert_state_dict_to_kohya, |
69 | 69 | convert_unet_state_dict_to_peft, |
| 70 | + is_peft_version, |
70 | 71 | is_wandb_available, |
71 | | - is_peft_version |
72 | 72 | ) |
73 | 73 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
74 | 74 | from diffusers.utils.import_utils import is_xformers_available |
@@ -1193,31 +1193,23 @@ def get_lora_config(rank, use_dora, target_modules): |
1193 | 1193 | } |
1194 | 1194 | if use_dora and is_peft_version("<", "0.9.0"): |
1195 | 1195 | raise ValueError( |
1196 | | - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." |
1197 | | - ) |
| 1196 | + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." |
| 1197 | + ) |
1198 | 1198 | else: |
1199 | 1199 | base_config["use_dora"] = True |
1200 | | - |
| 1200 | + |
1201 | 1201 | return LoraConfig(**base_config) |
1202 | | - |
| 1202 | + |
1203 | 1203 | # now we will add new LoRA weights to the attention layers |
1204 | 1204 | unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] |
1205 | | - unet_lora_config = get_lora_config( |
1206 | | - rank=args.rank, |
1207 | | - use_dora=args.use_dora, |
1208 | | - target_modules=unet_target_modules |
1209 | | - ) |
| 1205 | + unet_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=unet_target_modules) |
1210 | 1206 | unet.add_adapter(unet_lora_config) |
1211 | 1207 |
|
1212 | 1208 | # The text encoder comes from 🤗 transformers, so we cannot directly modify it. |
1213 | 1209 | # So, instead, we monkey-patch the forward calls of its attention-blocks. |
1214 | 1210 | if args.train_text_encoder: |
1215 | 1211 | text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] |
1216 | | - text_lora_config = get_lora_config( |
1217 | | - rank=args.rank, |
1218 | | - use_dora=args.use_dora, |
1219 | | - target_modules=text_target_modules |
1220 | | - ) |
| 1212 | + text_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=text_target_modules) |
1221 | 1213 | text_encoder_one.add_adapter(text_lora_config) |
1222 | 1214 | text_encoder_two.add_adapter(text_lora_config) |
1223 | 1215 |
|
|
0 commit comments