Skip to content

Commit b17f9bf

Browse files
committed
fix mixed precision training as proposed in #9565 for full dreambooth as well
1 parent faa95af commit b17f9bf

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def log_validation(
161161
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
162162
f" {args.validation_prompt}."
163163
)
164-
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
164+
pipeline = pipeline.to(accelerator.device)
165165
pipeline.set_progress_bar_config(disable=True)
166166

167167
# run inference
@@ -1580,7 +1580,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15801580
)
15811581

15821582
# handle guidance
1583-
if transformer.config.guidance_embeds:
1583+
if accelerator.unwrap_model(transformer).config.guidance_embeds:
15841584
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
15851585
guidance = guidance.expand(model_input.shape[0])
15861586
else:
@@ -1694,6 +1694,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16941694
# create pipeline
16951695
if not args.train_text_encoder:
16961696
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
1697+
text_encoder_one.to(weight_dtype)
1698+
text_encoder_two.to(weight_dtype)
16971699
else: # even when training the text encoder we're only training text encoder one
16981700
text_encoder_two = text_encoder_cls_two.from_pretrained(
16991701
args.pretrained_model_name_or_path,

0 commit comments

Comments
 (0)