Skip to content

Commit 23a2cd3

Browse files
authored
[LoRA] training fix the position of param casting when loading them (#8460)
fix the position of param casting when loading them
1 parent 4edde13 commit 23a2cd3

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,8 +1289,8 @@ def load_model_hook(models, input_dir):
12891289
models = [unet_]
12901290
if args.train_text_encoder:
12911291
models.extend([text_encoder_one_, text_encoder_two_])
1292-
# only upcast trainable parameters (LoRA) into fp32
1293-
cast_training_params(models)
1292+
# only upcast trainable parameters (LoRA) into fp32
1293+
cast_training_params(models)
12941294

12951295
accelerator.register_save_state_pre_hook(save_model_hook)
12961296
accelerator.register_load_state_pre_hook(load_model_hook)

examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,8 +1363,8 @@ def load_model_hook(models, input_dir):
13631363
models = [unet_]
13641364
if args.train_text_encoder:
13651365
models.extend([text_encoder_one_, text_encoder_two_])
1366-
# only upcast trainable parameters (LoRA) into fp32
1367-
cast_training_params(models)
1366+
# only upcast trainable parameters (LoRA) into fp32
1367+
cast_training_params(models)
13681368

13691369
accelerator.register_save_state_pre_hook(save_model_hook)
13701370
accelerator.register_load_state_pre_hook(load_model_hook)

0 commit comments

Comments
 (0)