Skip to content

Commit cc1d2ad

Browse files
committed
unify unwrap_model calls in dreambooth script
1 parent 0729c66 commit cc1d2ad

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,7 +1602,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16021602
)
16031603

16041604
# handle guidance
1605-
if accelerator.unwrap_model(transformer).config.guidance_embeds:
1605+
if unwrap_model(transformer).config.guidance_embeds:
16061606
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
16071607
guidance = guidance.expand(model_input.shape[0])
16081608
else:
@@ -1728,9 +1728,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17281728
pipeline = FluxPipeline.from_pretrained(
17291729
args.pretrained_model_name_or_path,
17301730
vae=vae,
1731-
text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
1732-
text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
1733-
transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False),
1731+
text_encoder=unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
1732+
text_encoder_2=unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
1733+
transformer=unwrap_model(transformer, keep_fp32_wrapper=False),
17341734
revision=args.revision,
17351735
variant=args.variant,
17361736
torch_dtype=weight_dtype,

0 commit comments

Comments
 (0)