From e7ef7b607e77a6ece8cd8d33e8d02017c8f92a09 Mon Sep 17 00:00:00 2001 From: xduzhangjiayu Date: Mon, 7 Oct 2024 21:27:52 +0800 Subject: [PATCH 1/2] fix vae dtype when accelerate config using --mixed_precision="fp16" --- examples/controlnet/train_controlnet_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index 9ea78370f5e0..828102a47cd0 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -1094,7 +1094,7 @@ def load_model_hook(models, input_dir): weight_dtype = torch.bfloat16 # Move vae, transformer and text_encoder to device and cast to weight_dtype - vae.to(accelerator.device, dtype=torch.float32) + vae.to(accelerator.device, dtype=weight_dtype) transformer.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) From be4c8102d773b4b5377f5ba7c687b5bab8664d79 Mon Sep 17 00:00:00 2001 From: xduzhangjiayu Date: Mon, 7 Oct 2024 22:53:13 +0800 Subject: [PATCH 2/2] Add param for upcast vae --- examples/controlnet/train_controlnet_sd3.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index 828102a47cd0..4fae8a072c6f 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -357,6 +357,11 @@ def parse_args(input_args=None): action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) + parser.add_argument( + "--upcast_vae", + action="store_true", + help="Whether or not to upcast vae to fp32", + ) parser.add_argument( "--learning_rate", type=float, @@ -1094,7 +1099,10 @@ def load_model_hook(models, input_dir): weight_dtype = torch.bfloat16 # Move vae, transformer and text_encoder to device and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) + if args.upcast_vae: + vae.to(accelerator.device, dtype=torch.float32) + else: + vae.to(accelerator.device, dtype=weight_dtype) transformer.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype)