Skip to content

Commit dd69d54

Browse files
authored
Update transformer to accelerator.device with it's weight_dtype
By default, the transformer always gets transferred to the device with float32 whether we specify the dtype as bfloat16 or float16. This fixes it.
1 parent b572635 commit dd69d54

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,6 +1157,7 @@ def main(args):
11571157
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
11581158
)
11591159

1160+
transformer.to(accelerator.device, dtype=weight_dtype)
11601161
vae.to(accelerator.device, dtype=weight_dtype)
11611162
if not args.train_text_encoder:
11621163
text_encoder_one.to(accelerator.device, dtype=weight_dtype)

0 commit comments

Comments
 (0)