Skip to content
18 changes: 16 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from diffusers.optimization import get_scheduler
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
from diffusers.utils import (
_collate_lora_metadata,
check_min_version,
convert_all_state_dict_to_peft,
convert_state_dict_to_diffusers,
Expand Down Expand Up @@ -659,6 +660,12 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)

parser.add_argument(
"--lora_alpha",
type=int,
default=4,
help="LoRA alpha to be used for additional scaling.",

parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")

Expand Down Expand Up @@ -1202,10 +1209,10 @@ def main(args):
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()

def get_lora_config(rank, dropout, use_dora, target_modules):
def get_lora_config(rank, lora_alpha, dropout, use_dora, target_modules):
base_config = {
"r": rank,
"lora_alpha": rank,
"lora_alpha":lora_alpha,
"lora_dropout": dropout,
"init_lora_weights": "gaussian",
"target_modules": target_modules,
Expand All @@ -1224,6 +1231,7 @@ def get_lora_config(rank, dropout, use_dora, target_modules):
unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
unet_lora_config = get_lora_config(
rank=args.rank,
lora_alpha=args.lora_alpha,
dropout=args.lora_dropout,
use_dora=args.use_dora,
target_modules=unet_target_modules,
Expand All @@ -1236,6 +1244,7 @@ def get_lora_config(rank, dropout, use_dora, target_modules):
text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
text_lora_config = get_lora_config(
rank=args.rank,
lora_alpha=args.lora_alpha,
dropout=args.lora_dropout,
use_dora=args.use_dora,
target_modules=text_target_modules,
Expand All @@ -1256,10 +1265,12 @@ def save_model_hook(models, weights, output_dir):
unet_lora_layers_to_save = None
text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
modules_to_save = {}

for model in models:
if isinstance(model, type(unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
modules_to_save["transformer"] = model
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
Expand All @@ -1279,6 +1290,7 @@ def save_model_hook(models, weights, output_dir):
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
**_collate_lora_metadata(modules_to_save),
)

def load_model_hook(models, input_dir):
Expand Down Expand Up @@ -1945,6 +1957,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
modules_to_save = {}
unet = unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
Expand All @@ -1967,6 +1980,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
**_collate_lora_metadata(modules_to_save),
)
if args.output_kohya_format:
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
Expand Down