| 
68 | 68 |     convert_state_dict_to_kohya,  | 
69 | 69 |     convert_unet_state_dict_to_peft,  | 
70 | 70 |     is_wandb_available,  | 
 | 71 | +    is_peft_version  | 
71 | 72 | )  | 
72 | 73 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card  | 
73 | 74 | from diffusers.utils.import_utils import is_xformers_available  | 
@@ -1183,25 +1184,39 @@ def main(args):  | 
1183 | 1184 |             text_encoder_one.gradient_checkpointing_enable()  | 
1184 | 1185 |             text_encoder_two.gradient_checkpointing_enable()  | 
1185 | 1186 | 
 
  | 
 | 1187 | +    def get_lora_config(rank, use_dora, target_modules):  | 
 | 1188 | +        base_config = {  | 
 | 1189 | +            "r": rank,  | 
 | 1190 | +            "lora_alpha": rank,  | 
 | 1191 | +            "init_lora_weights": "gaussian",  | 
 | 1192 | +            "target_modules": target_modules,  | 
 | 1193 | +        }  | 
 | 1194 | +        if use_dora and is_peft_version("<", "0.9.0"):  | 
 | 1195 | +            raise ValueError(  | 
 | 1196 | +                        "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."  | 
 | 1197 | +                    )  | 
 | 1198 | +        else:  | 
 | 1199 | +            base_config["use_dora"] = True  | 
 | 1200 | +          | 
 | 1201 | +        return LoraConfig(**base_config)  | 
 | 1202 | +          | 
1186 | 1203 |     # now we will add new LoRA weights to the attention layers  | 
1187 |  | -    unet_lora_config = LoraConfig(  | 
1188 |  | -        r=args.rank,  | 
 | 1204 | +    unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]  | 
 | 1205 | +    unet_lora_config = get_lora_config(  | 
 | 1206 | +        rank=args.rank,  | 
1189 | 1207 |         use_dora=args.use_dora,  | 
1190 |  | -        lora_alpha=args.rank,  | 
1191 |  | -        init_lora_weights="gaussian",  | 
1192 |  | -        target_modules=["to_k", "to_q", "to_v", "to_out.0"],  | 
 | 1208 | +        target_modules=unet_target_modules  | 
1193 | 1209 |     )  | 
1194 | 1210 |     unet.add_adapter(unet_lora_config)  | 
1195 | 1211 | 
 
  | 
1196 | 1212 |     # The text encoder comes from 🤗 transformers, so we cannot directly modify it.  | 
1197 | 1213 |     # So, instead, we monkey-patch the forward calls of its attention-blocks.  | 
1198 | 1214 |     if args.train_text_encoder:  | 
1199 |  | -        text_lora_config = LoraConfig(  | 
1200 |  | -            r=args.rank,  | 
 | 1215 | +        text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]  | 
 | 1216 | +        text_lora_config = get_lora_config(  | 
 | 1217 | +            rank=args.rank,  | 
1201 | 1218 |             use_dora=args.use_dora,  | 
1202 |  | -            lora_alpha=args.rank,  | 
1203 |  | -            init_lora_weights="gaussian",  | 
1204 |  | -            target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],  | 
 | 1219 | +            target_modules=text_target_modules  | 
1205 | 1220 |         )  | 
1206 | 1221 |         text_encoder_one.add_adapter(text_lora_config)  | 
1207 | 1222 |         text_encoder_two.add_adapter(text_lora_config)  | 
 | 
0 commit comments