diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index 9ea78370f5e0..4fae8a072c6f 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -357,6 +357,11 @@ def parse_args(input_args=None): action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) + parser.add_argument( + "--upcast_vae", + action="store_true", + help="Whether or not to upcast vae to fp32", + ) parser.add_argument( "--learning_rate", type=float, @@ -1094,7 +1099,10 @@ def load_model_hook(models, input_dir): weight_dtype = torch.bfloat16 # Move vae, transformer and text_encoder to device and cast to weight_dtype - vae.to(accelerator.device, dtype=torch.float32) + if args.upcast_vae: + vae.to(accelerator.device, dtype=torch.float32) + else: + vae.to(accelerator.device, dtype=weight_dtype) transformer.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype)