Skip to content

Commit df76a39

Browse files
authored
Fix Prodigy optimizer in SDXL Dreambooth script (#6290)
* Fix ProdigyOPT in SDXL Dreambooth script * style * style
1 parent 3369bc8 commit df76a39

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,10 +1144,26 @@ def load_model_hook(models, input_dir):
11441144

11451145
optimizer_class = prodigyopt.Prodigy
11461146

1147+
if args.learning_rate <= 0.1:
1148+
logger.warn(
1149+
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
1150+
)
1151+
if args.train_text_encoder and args.text_encoder_lr:
1152+
logger.warn(
1153+
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
1154+
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
1155+
f"When using prodigy only learning_rate is used as the initial learning rate."
1156+
)
1157+
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
1158+
# --learning_rate
1159+
params_to_optimize[1]["lr"] = args.learning_rate
1160+
params_to_optimize[2]["lr"] = args.learning_rate
1161+
11471162
optimizer = optimizer_class(
11481163
params_to_optimize,
11491164
lr=args.learning_rate,
11501165
betas=(args.adam_beta1, args.adam_beta2),
1166+
beta3=args.prodigy_beta3,
11511167
weight_decay=args.adam_weight_decay,
11521168
eps=args.adam_epsilon,
11531169
decouple=args.prodigy_decouple,

0 commit comments

Comments
 (0)