diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index a8911ad64e21..eba436a4d51c 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1157,6 +1157,7 @@ def main(args): "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) + transformer.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) if not args.train_text_encoder: text_encoder_one.to(accelerator.device, dtype=weight_dtype)