Skip to content

Commit 9b2917f

Browse files
committed
mixed precision
1 parent b211eea commit 9b2917f

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,7 +1661,8 @@ def load_model_hook(models, input_dir):
16611661
for name, param in text_encoder_one.named_parameters():
16621662
if "token_embedding" in name:
16631663
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1664-
param.data = param.to(dtype=torch.float32)
1664+
if args.mixed_precision == "fp16":
1665+
param.data = param.to(dtype=torch.float32)
16651666
param.requires_grad = True
16661667
text_lora_parameters_one.append(param)
16671668
else:
@@ -1671,7 +1672,8 @@ def load_model_hook(models, input_dir):
16711672
for name, param in text_encoder_two.named_parameters():
16721673
if "shared" in name:
16731674
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1674-
param.data = param.to(dtype=torch.float32)
1675+
if args.mixed_precision == "fp16":
1676+
param.data = param.to(dtype=torch.float32)
16751677
param.requires_grad = True
16761678
text_lora_parameters_two.append(param)
16771679
else:
@@ -1946,6 +1948,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19461948
lr_scheduler,
19471949
)
19481950
else:
1951+
print("I SHOULD BE HERE")
19491952
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
19501953
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
19511954
)

0 commit comments

Comments
 (0)