Skip to content

Commit 6f6aa1b

Browse files
committed
fix use_dora
1 parent 76b7d86 commit 6f6aa1b

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
convert_state_dict_to_kohya,
6969
convert_unet_state_dict_to_peft,
7070
is_wandb_available,
71+
is_peft_version
7172
)
7273
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
7374
from diffusers.utils.import_utils import is_xformers_available
@@ -1183,25 +1184,39 @@ def main(args):
11831184
text_encoder_one.gradient_checkpointing_enable()
11841185
text_encoder_two.gradient_checkpointing_enable()
11851186

1187+
def get_lora_config(rank, use_dora, target_modules):
1188+
base_config = {
1189+
"r": rank,
1190+
"lora_alpha": rank,
1191+
"init_lora_weights": "gaussian",
1192+
"target_modules": target_modules,
1193+
}
1194+
if use_dora and is_peft_version("<", "0.9.0"):
1195+
raise ValueError(
1196+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1197+
)
1198+
else:
1199+
base_config["use_dora"] = True
1200+
1201+
return LoraConfig(**base_config)
1202+
11861203
# now we will add new LoRA weights to the attention layers
1187-
unet_lora_config = LoraConfig(
1188-
r=args.rank,
1204+
unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
1205+
unet_lora_config = get_lora_config(
1206+
rank=args.rank,
11891207
use_dora=args.use_dora,
1190-
lora_alpha=args.rank,
1191-
init_lora_weights="gaussian",
1192-
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
1208+
target_modules=unet_target_modules
11931209
)
11941210
unet.add_adapter(unet_lora_config)
11951211

11961212
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
11971213
# So, instead, we monkey-patch the forward calls of its attention-blocks.
11981214
if args.train_text_encoder:
1199-
text_lora_config = LoraConfig(
1200-
r=args.rank,
1215+
text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
1216+
text_lora_config = get_lora_config(
1217+
rank=args.rank,
12011218
use_dora=args.use_dora,
1202-
lora_alpha=args.rank,
1203-
init_lora_weights="gaussian",
1204-
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
1219+
target_modules=text_target_modules
12051220
)
12061221
text_encoder_one.add_adapter(text_lora_config)
12071222
text_encoder_two.add_adapter(text_lora_config)

0 commit comments

Comments
 (0)