diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index d91d263ec9c4..9fcdc5ee2cb0 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1716,9 +1716,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, - text_encoder=accelerator.unwrap_model(text_encoder_one), - text_encoder_2=accelerator.unwrap_model(text_encoder_two), - transformer=accelerator.unwrap_model(transformer), + text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False), + text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False), + transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype,