Skip to content

Commit 4afe8de

Browse files
committed
style.
1 parent f75f695 commit 4afe8de

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1294,8 +1294,12 @@ def save_model_hook(models, weights, output_dir):
12941294
for model in models:
12951295
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
12961296
model = unwrap_model(model)
1297+
if args.upcast_before_saving:
1298+
model = model.to(torch.float32)
12971299
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
1298-
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): # or text_encoder_two
1300+
elif args.train_text_encoder and isinstance(
1301+
unwrap_model(model), type(unwrap_model(text_encoder_one))
1302+
): # or text_encoder_two
12991303
# both text encoders are of the same class, so we check hidden size to distinguish between the two
13001304
model = unwrap_model(model)
13011305
hidden_size = model.config.hidden_size

0 commit comments

Comments
 (0)